/ src / Ryujinx.Graphics.Shader / Translation / HelperFunctionManager.cs
HelperFunctionManager.cs
  1  using Ryujinx.Graphics.Shader.IntermediateRepresentation;
  2  using System;
  3  using System.Collections.Generic;
  4  using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
  5  
  6  namespace Ryujinx.Graphics.Shader.Translation
  7  {
  8      class HelperFunctionManager
  9      {
 10          private readonly List<Function> _functionList;
 11          private readonly Dictionary<int, int> _functionIds;
 12          private readonly ShaderStage _stage;
 13  
 14          public HelperFunctionManager(List<Function> functionList, ShaderStage stage)
 15          {
 16              _functionList = functionList;
 17              _functionIds = new Dictionary<int, int>();
 18              _stage = stage;
 19          }
 20  
 21          public int AddFunction(Function function)
 22          {
 23              int functionId = _functionList.Count;
 24              _functionList.Add(function);
 25  
 26              return functionId;
 27          }
 28  
 29          public int GetOrCreateFunctionId(HelperFunctionName functionName)
 30          {
 31              if (_functionIds.TryGetValue((int)functionName, out int functionId))
 32              {
 33                  return functionId;
 34              }
 35  
 36              Function function = GenerateFunction(functionName);
 37              functionId = AddFunction(function);
 38              _functionIds.Add((int)functionName, functionId);
 39  
 40              return functionId;
 41          }
 42  
 43          public int GetOrCreateFunctionId(HelperFunctionName functionName, int id)
 44          {
 45              int key = (int)functionName | (id << 16);
 46  
 47              if (_functionIds.TryGetValue(key, out int functionId))
 48              {
 49                  return functionId;
 50              }
 51  
 52              Function function = GenerateFunction(functionName, id);
 53              functionId = AddFunction(function);
 54              _functionIds.Add(key, functionId);
 55  
 56              return functionId;
 57          }
 58  
 59          public int GetOrCreateShuffleFunctionId(HelperFunctionName functionName, int subgroupSize)
 60          {
 61              if (_functionIds.TryGetValue((int)functionName, out int functionId))
 62              {
 63                  return functionId;
 64              }
 65  
 66              Function function = GenerateShuffleFunction(functionName, subgroupSize);
 67              functionId = AddFunction(function);
 68              _functionIds.Add((int)functionName, functionId);
 69  
 70              return functionId;
 71          }
 72  
 73          private Function GenerateFunction(HelperFunctionName functionName)
 74          {
 75              return functionName switch
 76              {
 77                  HelperFunctionName.ConvertDoubleToFloat => GenerateConvertDoubleToFloatFunction(),
 78                  HelperFunctionName.ConvertFloatToDouble => GenerateConvertFloatToDoubleFunction(),
 79                  HelperFunctionName.TexelFetchScale => GenerateTexelFetchScaleFunction(),
 80                  HelperFunctionName.TextureSizeUnscale => GenerateTextureSizeUnscaleFunction(),
 81                  _ => throw new ArgumentException($"Invalid function name {functionName}"),
 82              };
 83          }
 84  
 85          private static Function GenerateConvertDoubleToFloatFunction()
 86          {
 87              EmitterContext context = new();
 88  
 89              Operand valueLow = Argument(0);
 90              Operand valueHigh = Argument(1);
 91  
 92              Operand mantissaLow = context.BitwiseAnd(valueLow, Const(((1 << 22) - 1)));
 93              Operand mantissa = context.ShiftRightU32(valueLow, Const(22));
 94  
 95              mantissa = context.BitwiseOr(mantissa, context.ShiftLeft(context.BitwiseAnd(valueHigh, Const(0xfffff)), Const(10)));
 96              mantissa = context.BitwiseOr(mantissa, context.ConditionalSelect(mantissaLow, Const(1), Const(0)));
 97  
 98              Operand exp = context.BitwiseAnd(context.ShiftRightU32(valueHigh, Const(20)), Const(0x7ff));
 99              Operand sign = context.ShiftRightS32(valueHigh, Const(31));
100  
101              Operand resultSign = context.ShiftLeft(sign, Const(31));
102  
103              Operand notZero = context.BitwiseOr(mantissa, exp);
104  
105              Operand lblNotZero = Label();
106  
107              context.BranchIfTrue(lblNotZero, notZero);
108  
109              context.Return(resultSign);
110  
111              context.MarkLabel(lblNotZero);
112  
113              Operand notNaNOrInf = context.ICompareNotEqual(exp, Const(0x7ff));
114  
115              mantissa = context.BitwiseOr(mantissa, Const(0x40000000));
116              exp = context.ISubtract(exp, Const(0x381));
117  
118              // Note: Overflow cases are not handled here and might produce incorrect results.
119  
120              Operand roundBits = context.BitwiseAnd(mantissa, Const(0x7f));
121              Operand roundBitsXor64 = context.BitwiseExclusiveOr(roundBits, Const(0x40));
122              mantissa = context.ShiftRightU32(context.IAdd(mantissa, Const(0x40)), Const(7));
123              mantissa = context.BitwiseAnd(mantissa, context.ConditionalSelect(roundBitsXor64, Const(~0), Const(~1)));
124  
125              exp = context.ConditionalSelect(mantissa, exp, Const(0));
126              exp = context.ConditionalSelect(notNaNOrInf, exp, Const(0xff));
127  
128              Operand result = context.IAdd(context.IAdd(mantissa, context.ShiftLeft(exp, Const(23))), resultSign);
129  
130              context.Return(result);
131  
132              return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ConvertDoubleToFloat", true, 2, 0);
133          }
134  
135          private static Function GenerateConvertFloatToDoubleFunction()
136          {
137              EmitterContext context = new();
138  
139              Operand value = Argument(0);
140  
141              Operand mantissa = context.BitwiseAnd(value, Const(0x7fffff));
142              Operand exp = context.BitwiseAnd(context.ShiftRightU32(value, Const(23)), Const(0xff));
143              Operand sign = context.ShiftRightS32(value, Const(31));
144  
145              Operand notNaNOrInf = context.ICompareNotEqual(exp, Const(0xff));
146              Operand expNotZero = context.ICompareNotEqual(exp, Const(0));
147              Operand notDenorm = context.BitwiseOr(expNotZero, context.ICompareEqual(mantissa, Const(0)));
148  
149              exp = context.IAdd(exp, Const(0x380));
150  
151              Operand shiftDist = context.ISubtract(Const(32), context.FindMSBU32(mantissa));
152              Operand normExp = context.ISubtract(context.ISubtract(Const(1), shiftDist), Const(1));
153              Operand normMant = context.ShiftLeft(mantissa, shiftDist);
154  
155              exp = context.ConditionalSelect(notNaNOrInf, exp, Const(0x7ff));
156              exp = context.ConditionalSelect(notDenorm, exp, normExp);
157              mantissa = context.ConditionalSelect(expNotZero, mantissa, normMant);
158  
159              Operand resultLow = context.ShiftLeft(mantissa, Const(29));
160              Operand resultHigh = context.ShiftRightU32(mantissa, Const(3));
161  
162              resultHigh = context.IAdd(resultHigh, context.ShiftLeft(exp, Const(20)));
163              resultHigh = context.IAdd(resultHigh, context.ShiftLeft(sign, Const(31)));
164  
165              context.Copy(Argument(1), resultLow);
166              context.Copy(Argument(2), resultHigh);
167              context.Return();
168  
169              return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ConvertFloatToDouble", false, 1, 2);
170          }
171  
172          private static Function GenerateFunction(HelperFunctionName functionName, int id)
173          {
174              return functionName switch
175              {
176                  HelperFunctionName.SharedAtomicMaxS32 => GenerateSharedAtomicSigned(id, isMin: false),
177                  HelperFunctionName.SharedAtomicMinS32 => GenerateSharedAtomicSigned(id, isMin: true),
178                  HelperFunctionName.SharedStore8 => GenerateSharedStore8(id),
179                  HelperFunctionName.SharedStore16 => GenerateSharedStore16(id),
180                  _ => throw new ArgumentException($"Invalid function name {functionName}"),
181              };
182          }
183  
184          private static Function GenerateSharedAtomicSigned(int id, bool isMin)
185          {
186              EmitterContext context = new();
187  
188              Operand wordOffset = Argument(0);
189              Operand value = Argument(1);
190  
191              Operand result = GenerateSharedAtomicCasLoop(context, wordOffset, id, (memValue) =>
192              {
193                  return isMin
194                      ? context.IMinimumS32(memValue, value)
195                      : context.IMaximumS32(memValue, value);
196              });
197  
198              context.Return(result);
199  
200              return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, $"SharedAtomic{(isMin ? "Min" : "Max")}_{id}", true, 2, 0);
201          }
202  
203          private static Function GenerateSharedStore8(int id)
204          {
205              return GenerateSharedStore(id, 8);
206          }
207  
208          private static Function GenerateSharedStore16(int id)
209          {
210              return GenerateSharedStore(id, 16);
211          }
212  
213          private static Function GenerateSharedStore(int id, int bitSize)
214          {
215              EmitterContext context = new();
216  
217              Operand offset = Argument(0);
218              Operand value = Argument(1);
219  
220              Operand wordOffset = context.ShiftRightU32(offset, Const(2));
221              Operand bitOffset = GetBitOffset(context, offset);
222  
223              GenerateSharedAtomicCasLoop(context, wordOffset, id, (memValue) =>
224              {
225                  return context.BitfieldInsert(memValue, value, bitOffset, Const(bitSize));
226              });
227  
228              context.Return();
229  
230              return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, $"SharedStore{bitSize}_{id}", false, 2, 0);
231          }
232  
233          private static Function GenerateShuffleFunction(HelperFunctionName functionName, int subgroupSize)
234          {
235              return functionName switch
236              {
237                  HelperFunctionName.Shuffle => GenerateShuffle(subgroupSize),
238                  HelperFunctionName.ShuffleDown => GenerateShuffleDown(subgroupSize),
239                  HelperFunctionName.ShuffleUp => GenerateShuffleUp(subgroupSize),
240                  HelperFunctionName.ShuffleXor => GenerateShuffleXor(subgroupSize),
241                  _ => throw new ArgumentException($"Invalid function name {functionName}"),
242              };
243          }
244  
245          private static Function GenerateShuffle(int subgroupSize)
246          {
247              EmitterContext context = new();
248  
249              Operand value = Argument(0);
250              Operand index = Argument(1);
251              Operand mask = Argument(2);
252  
253              Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
254              Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
255              Operand minThreadId = context.BitwiseAnd(GenerateLoadSubgroupLaneId(context, subgroupSize), segMask);
256              Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
257              Operand srcThreadId = context.BitwiseOr(context.BitwiseAnd(index, context.BitwiseNot(segMask)), minThreadId);
258              Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
259  
260              context.Copy(Argument(3), valid);
261  
262              Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
263  
264              context.Return(context.ConditionalSelect(valid, result, value));
265  
266              return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "Shuffle", true, 3, 1);
267          }
268  
269          private static Function GenerateShuffleDown(int subgroupSize)
270          {
271              EmitterContext context = new();
272  
273              Operand value = Argument(0);
274              Operand index = Argument(1);
275              Operand mask = Argument(2);
276  
277              Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
278              Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
279              Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
280              Operand minThreadId = context.BitwiseAnd(laneId, segMask);
281              Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
282              Operand srcThreadId = context.IAdd(laneId, index);
283              Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
284  
285              context.Copy(Argument(3), valid);
286  
287              Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
288  
289              context.Return(context.ConditionalSelect(valid, result, value));
290  
291              return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleDown", true, 3, 1);
292          }
293  
294          private static Function GenerateShuffleUp(int subgroupSize)
295          {
296              EmitterContext context = new();
297  
298              Operand value = Argument(0);
299              Operand index = Argument(1);
300              Operand mask = Argument(2);
301  
302              Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
303              Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
304              Operand minThreadId = context.BitwiseAnd(laneId, segMask);
305              Operand srcThreadId = context.ISubtract(laneId, index);
306              Operand valid = context.ICompareGreaterOrEqual(srcThreadId, minThreadId);
307  
308              context.Copy(Argument(3), valid);
309  
310              Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
311  
312              context.Return(context.ConditionalSelect(valid, result, value));
313  
314              return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleUp", true, 3, 1);
315          }
316  
317          private static Function GenerateShuffleXor(int subgroupSize)
318          {
319              EmitterContext context = new();
320  
321              Operand value = Argument(0);
322              Operand index = Argument(1);
323              Operand mask = Argument(2);
324  
325              Operand clamp = context.BitwiseAnd(mask, Const(0x1f));
326              Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f));
327              Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize);
328              Operand minThreadId = context.BitwiseAnd(laneId, segMask);
329              Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId);
330              Operand srcThreadId = context.BitwiseExclusiveOr(laneId, index);
331              Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId);
332  
333              context.Copy(Argument(3), valid);
334  
335              Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize));
336  
337              context.Return(context.ConditionalSelect(valid, result, value));
338  
339              return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleXor", true, 3, 1);
340          }
341  
342          private static Operand GenerateLoadSubgroupLaneId(EmitterContext context, int subgroupSize)
343          {
344              if (subgroupSize <= 32)
345              {
346                  return context.Load(StorageKind.Input, IoVariable.SubgroupLaneId);
347              }
348  
349              return context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x1f));
350          }
351  
352          private static Operand GenerateSubgroupShuffleIndex(EmitterContext context, Operand srcThreadId, int subgroupSize)
353          {
354              if (subgroupSize <= 32)
355              {
356                  return srcThreadId;
357              }
358  
359              return context.BitwiseOr(
360                  context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x60)),
361                  srcThreadId);
362          }
363  
364          private Function GenerateTexelFetchScaleFunction()
365          {
366              EmitterContext context = new();
367  
368              Operand input = Argument(0);
369              Operand samplerIndex = Argument(1);
370              Operand index = GetScaleIndex(context, samplerIndex);
371  
372              Operand scale = context.Load(StorageKind.ConstantBuffer, 0, Const((int)SupportBufferField.RenderScale), index);
373  
374              Operand scaleIsOne = context.FPCompareEqual(scale, ConstF(1f));
375              Operand lblScaleNotOne = Label();
376  
377              context.BranchIfFalse(lblScaleNotOne, scaleIsOne);
378              context.Return(input);
379              context.MarkLabel(lblScaleNotOne);
380  
381              int inArgumentsCount;
382  
383              if (_stage == ShaderStage.Fragment)
384              {
385                  Operand scaleIsLessThanZero = context.FPCompareLess(scale, ConstF(0f));
386                  Operand lblScaleGreaterOrEqualZero = Label();
387  
388                  context.BranchIfFalse(lblScaleGreaterOrEqualZero, scaleIsLessThanZero);
389  
390                  Operand negScale = context.FPNegate(scale);
391                  Operand inputScaled = context.FPMultiply(context.IConvertS32ToFP32(input), negScale);
392                  Operand fragCoordX = context.Load(StorageKind.Input, IoVariable.FragmentCoord, null, Const(0));
393                  Operand fragCoordY = context.Load(StorageKind.Input, IoVariable.FragmentCoord, null, Const(1));
394                  Operand fragCoord = context.ConditionalSelect(Argument(2), fragCoordY, fragCoordX);
395                  Operand inputBias = context.FPModulo(fragCoord, negScale);
396                  Operand inputWithBias = context.FPAdd(inputScaled, inputBias);
397  
398                  context.Return(context.FP32ConvertToS32(inputWithBias));
399                  context.MarkLabel(lblScaleGreaterOrEqualZero);
400  
401                  inArgumentsCount = 3;
402              }
403              else
404              {
405                  inArgumentsCount = 2;
406              }
407  
408              Operand inputScaled2 = context.FPMultiply(context.IConvertS32ToFP32(input), scale);
409  
410              context.Return(context.FP32ConvertToS32(inputScaled2));
411  
412              return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "TexelFetchScale", true, inArgumentsCount, 0);
413          }
414  
415          private Function GenerateTextureSizeUnscaleFunction()
416          {
417              EmitterContext context = new();
418  
419              Operand input = Argument(0);
420              Operand samplerIndex = Argument(1);
421              Operand index = GetScaleIndex(context, samplerIndex);
422  
423              Operand scale = context.FPAbsolute(context.Load(StorageKind.ConstantBuffer, 0, Const((int)SupportBufferField.RenderScale), index));
424  
425              Operand scaleIsOne = context.FPCompareEqual(scale, ConstF(1f));
426              Operand lblScaleNotOne = Label();
427  
428              context.BranchIfFalse(lblScaleNotOne, scaleIsOne);
429              context.Return(input);
430              context.MarkLabel(lblScaleNotOne);
431  
432              Operand inputUnscaled = context.FPDivide(context.IConvertS32ToFP32(input), scale);
433  
434              context.Return(context.FP32ConvertToS32(inputUnscaled));
435  
436              return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "TextureSizeUnscale", true, 2, 0);
437          }
438  
439          private Operand GetScaleIndex(EmitterContext context, Operand index)
440          {
441              switch (_stage)
442              {
443                  case ShaderStage.Vertex:
444                      Operand fragScaleCount = context.Load(StorageKind.ConstantBuffer, 0, Const((int)SupportBufferField.FragmentRenderScaleCount));
445                      return context.IAdd(Const(1), context.IAdd(index, fragScaleCount));
446                  default:
447                      return context.IAdd(Const(1), index);
448              }
449          }
450  
451          public static Operand GetBitOffset(EmitterContext context, Operand offset)
452          {
453              return context.ShiftLeft(context.BitwiseAnd(offset, Const(3)), Const(3));
454          }
455  
456          private static Operand GenerateSharedAtomicCasLoop(EmitterContext context, Operand wordOffset, int id, Func<Operand, Operand> opCallback)
457          {
458              Operand lblLoopHead = Label();
459  
460              context.MarkLabel(lblLoopHead);
461  
462              Operand oldValue = context.Load(StorageKind.SharedMemory, id, wordOffset);
463              Operand newValue = opCallback(oldValue);
464  
465              Operand casResult = context.AtomicCompareAndSwap(StorageKind.SharedMemory, id, wordOffset, oldValue, newValue);
466  
467              Operand casFail = context.ICompareNotEqual(casResult, oldValue);
468  
469              context.BranchIfTrue(lblLoopHead, casFail);
470  
471              return oldValue;
472          }
473  
474      }
475  }