/ src / Ryujinx.Graphics.Shader / StructuredIr / StructuredProgram.cs
StructuredProgram.cs
  1  using Ryujinx.Graphics.Shader.IntermediateRepresentation;
  2  using Ryujinx.Graphics.Shader.Translation;
  3  using System;
  4  using System.Collections.Generic;
  5  using System.Diagnostics;
  6  using System.Numerics;
  7  
  8  namespace Ryujinx.Graphics.Shader.StructuredIr
  9  {
 10      static class StructuredProgram
 11      {
 12          // TODO: Eventually it should be possible to specify the parameter types for the function instead of using S32 for everything.
 13          private const AggregateType FuncParameterType = AggregateType.S32;
 14  
 15          public static StructuredProgramInfo MakeStructuredProgram(
 16              IReadOnlyList<Function> functions,
 17              AttributeUsage attributeUsage,
 18              ShaderDefinitions definitions,
 19              ResourceManager resourceManager,
 20              TargetLanguage targetLanguage,
 21              bool debugMode)
 22          {
 23              StructuredProgramContext context = new(attributeUsage, definitions, resourceManager, debugMode);
 24  
 25              for (int funcIndex = 0; funcIndex < functions.Count; funcIndex++)
 26              {
 27                  Function function = functions[funcIndex];
 28  
 29                  BasicBlock[] blocks = function.Blocks;
 30  
 31                  AggregateType returnType = function.ReturnsValue ? FuncParameterType : AggregateType.Void;
 32  
 33                  AggregateType[] inArguments = new AggregateType[function.InArgumentsCount];
 34                  AggregateType[] outArguments = new AggregateType[function.OutArgumentsCount];
 35  
 36                  for (int i = 0; i < inArguments.Length; i++)
 37                  {
 38                      inArguments[i] = FuncParameterType;
 39                  }
 40  
 41                  for (int i = 0; i < outArguments.Length; i++)
 42                  {
 43                      outArguments[i] = FuncParameterType;
 44                  }
 45  
 46                  context.EnterFunction(blocks.Length, function.Name, returnType, inArguments, outArguments);
 47  
 48                  PhiFunctions.Remove(blocks);
 49  
 50                  for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++)
 51                  {
 52                      BasicBlock block = blocks[blkIndex];
 53  
 54                      context.EnterBlock(block);
 55  
 56                      for (LinkedListNode<INode> opNode = block.Operations.First; opNode != null; opNode = opNode.Next)
 57                      {
 58                          Operation operation = (Operation)opNode.Value;
 59  
 60                          if (IsBranchInst(operation.Inst))
 61                          {
 62                              context.LeaveBlock(block, operation);
 63                          }
 64                          else
 65                          {
 66                              AddOperation(context, operation, targetLanguage, functions);
 67                          }
 68                      }
 69                  }
 70  
 71                  GotoElimination.Eliminate(context.GetGotos());
 72  
 73                  AstOptimizer.Optimize(context);
 74  
 75                  context.LeaveFunction();
 76              }
 77  
 78              return context.Info;
 79          }
 80  
 81          private static void AddOperation(StructuredProgramContext context, Operation operation, TargetLanguage targetLanguage, IReadOnlyList<Function> functions)
 82          {
 83              Instruction inst = operation.Inst;
 84              StorageKind storageKind = operation.StorageKind;
 85  
 86              if (inst == Instruction.Load || inst == Instruction.Store)
 87              {
 88                  if (storageKind.IsInputOrOutput())
 89                  {
 90                      IoVariable ioVariable = (IoVariable)operation.GetSource(0).Value;
 91                      bool isOutput = storageKind.IsOutput();
 92                      int location = 0;
 93                      int component = 0;
 94  
 95                      if (context.Definitions.HasPerLocationInputOrOutput(ioVariable, isOutput))
 96                      {
 97                          location = operation.GetSource(1).Value;
 98  
 99                          if (operation.SourcesCount > 2 &&
100                              operation.GetSource(2).Type == OperandType.Constant &&
101                              context.Definitions.HasPerLocationInputOrOutputComponent(ioVariable, location, operation.GetSource(2).Value, isOutput))
102                          {
103                              component = operation.GetSource(2).Value;
104                          }
105                      }
106  
107                      context.Info.IoDefinitions.Add(new IoDefinition(storageKind, ioVariable, location, component));
108                  }
109                  else if (storageKind == StorageKind.ConstantBuffer && operation.GetSource(0).Type == OperandType.Constant)
110                  {
111                      context.ResourceManager.SetUsedConstantBufferBinding(operation.GetSource(0).Value);
112                  }
113              }
114  
115              bool vectorDest = IsVectorDestInst(inst);
116  
117              int sourcesCount = operation.SourcesCount;
118              int outDestsCount = operation.DestsCount != 0 && !vectorDest ? operation.DestsCount - 1 : 0;
119  
120              IAstNode[] sources = new IAstNode[sourcesCount + outDestsCount];
121  
122              if (inst == Instruction.Call && targetLanguage == TargetLanguage.Spirv)
123              {
124                  // SPIR-V requires that all function parameters are copied to a local variable before the call
125                  // (or at least that's what the Khronos compiler does).
126  
127                  // First one is the function index.
128                  Operand funcIndexOperand = operation.GetSource(0);
129                  Debug.Assert(funcIndexOperand.Type == OperandType.Constant);
130                  int funcIndex = funcIndexOperand.Value;
131  
132                  sources[0] = new AstOperand(OperandType.Constant, funcIndex);
133  
134                  int inArgsCount = functions[funcIndex].InArgumentsCount;
135  
136                  // Remaining ones are parameters, copy them to a temp local variable.
137                  for (int index = 1; index < operation.SourcesCount; index++)
138                  {
139                      IAstNode source = context.GetOperandOrCbLoad(operation.GetSource(index));
140  
141                      if (index - 1 < inArgsCount)
142                      {
143                          AstOperand argTemp = context.NewTemp(FuncParameterType);
144                          context.AddNode(new AstAssignment(argTemp, source));
145                          sources[index] = argTemp;
146                      }
147                      else
148                      {
149                          sources[index] = source;
150                      }
151                  }
152              }
153              else
154              {
155                  for (int index = 0; index < operation.SourcesCount; index++)
156                  {
157                      sources[index] = context.GetOperandOrCbLoad(operation.GetSource(index));
158                  }
159              }
160  
161              for (int index = 0; index < outDestsCount; index++)
162              {
163                  AstOperand oper = context.GetOperand(operation.GetDest(1 + index));
164  
165                  oper.VarType = InstructionInfo.GetSrcVarType(inst, sourcesCount + index);
166  
167                  sources[sourcesCount + index] = oper;
168              }
169  
170              AstTextureOperation GetAstTextureOperation(TextureOperation texOp)
171              {
172                  return new AstTextureOperation(
173                      inst,
174                      texOp.Type,
175                      texOp.Format,
176                      texOp.Flags,
177                      texOp.Set,
178                      texOp.Binding,
179                      texOp.SamplerSet,
180                      texOp.SamplerBinding,
181                      texOp.Index,
182                      sources);
183              }
184  
185              int componentsCount = BitOperations.PopCount((uint)operation.Index);
186  
187              if (vectorDest && componentsCount > 1)
188              {
189                  AggregateType destType = InstructionInfo.GetDestVarType(inst);
190  
191                  IAstNode source;
192  
193                  if (operation is TextureOperation texOp)
194                  {
195                      if (texOp.Inst == Instruction.ImageLoad)
196                      {
197                          destType = texOp.Format.GetComponentType();
198                      }
199  
200                      source = GetAstTextureOperation(texOp);
201                  }
202                  else
203                  {
204                      source = new AstOperation(
205                          inst,
206                          operation.StorageKind,
207                          operation.ForcePrecise,
208                          operation.Index,
209                          sources,
210                          operation.SourcesCount);
211                  }
212  
213                  AggregateType destElemType = destType;
214  
215                  switch (componentsCount)
216                  {
217                      case 2:
218                          destType |= AggregateType.Vector2;
219                          break;
220                      case 3:
221                          destType |= AggregateType.Vector3;
222                          break;
223                      case 4:
224                          destType |= AggregateType.Vector4;
225                          break;
226                  }
227  
228                  AstOperand destVec = context.NewTemp(destType);
229  
230                  context.AddNode(new AstAssignment(destVec, source));
231  
232                  for (int i = 0; i < operation.DestsCount; i++)
233                  {
234                      AstOperand dest = context.GetOperand(operation.GetDest(i));
235                      AstOperand index = new(OperandType.Constant, i);
236  
237                      dest.VarType = destElemType;
238  
239                      context.AddNode(new AstAssignment(dest, new AstOperation(Instruction.VectorExtract, StorageKind.None, false, new[] { destVec, index }, 2)));
240                  }
241              }
242              else if (operation.Dest != null)
243              {
244                  AstOperand dest = context.GetOperand(operation.Dest);
245  
246                  // If all the sources are bool, it's better to use short-circuiting
247                  // logical operations, rather than forcing a cast to int and doing
248                  // a bitwise operation with the value, as it is likely to be used as
249                  // a bool in the end.
250                  if (IsBitwiseInst(inst) && AreAllSourceTypesEqual(sources, AggregateType.Bool))
251                  {
252                      inst = GetLogicalFromBitwiseInst(inst);
253                  }
254  
255                  bool isCondSel = inst == Instruction.ConditionalSelect;
256                  bool isCopy = inst == Instruction.Copy;
257  
258                  if (isCondSel || isCopy)
259                  {
260                      AggregateType type = GetVarTypeFromUses(operation.Dest);
261  
262                      if (isCondSel && type == AggregateType.FP32)
263                      {
264                          inst |= Instruction.FP32;
265                      }
266  
267                      dest.VarType = type;
268                  }
269                  else
270                  {
271                      dest.VarType = InstructionInfo.GetDestVarType(inst);
272                  }
273  
274                  IAstNode source;
275  
276                  if (operation is TextureOperation texOp)
277                  {
278                      if (texOp.Inst == Instruction.ImageLoad)
279                      {
280                          dest.VarType = texOp.Format.GetComponentType();
281                      }
282  
283                      source = GetAstTextureOperation(texOp);
284                  }
285                  else if (!isCopy)
286                  {
287                      source = new AstOperation(
288                          inst,
289                          operation.StorageKind,
290                          operation.ForcePrecise,
291                          operation.Index,
292                          sources,
293                          operation.SourcesCount);
294                  }
295                  else
296                  {
297                      source = sources[0];
298                  }
299  
300                  context.AddNode(new AstAssignment(dest, source));
301              }
302              else if (operation.Inst == Instruction.Comment)
303              {
304                  context.AddNode(new AstComment(((CommentNode)operation).Comment));
305              }
306              else if (operation is TextureOperation texOp)
307              {
308                  AstTextureOperation astTexOp = GetAstTextureOperation(texOp);
309  
310                  context.AddNode(astTexOp);
311              }
312              else
313              {
314                  context.AddNode(new AstOperation(
315                      inst,
316                      operation.StorageKind,
317                      operation.ForcePrecise,
318                      operation.Index,
319                      sources,
320                      operation.SourcesCount));
321              }
322  
323              // Those instructions needs to be emulated by using helper functions,
324              // because they are NVIDIA specific. Those flags helps the backend to
325              // decide which helper functions are needed on the final generated code.
326              switch (operation.Inst)
327              {
328                  case Instruction.MultiplyHighS32:
329                      context.Info.HelperFunctionsMask |= HelperFunctionsMask.MultiplyHighS32;
330                      break;
331                  case Instruction.MultiplyHighU32:
332                      context.Info.HelperFunctionsMask |= HelperFunctionsMask.MultiplyHighU32;
333                      break;
334                  case Instruction.SwizzleAdd:
335                      context.Info.HelperFunctionsMask |= HelperFunctionsMask.SwizzleAdd;
336                      break;
337                  case Instruction.FSIBegin:
338                  case Instruction.FSIEnd:
339                      context.Info.HelperFunctionsMask |= HelperFunctionsMask.FSI;
340                      break;
341              }
342          }
343  
344          private static AggregateType GetVarTypeFromUses(Operand dest)
345          {
346              HashSet<Operand> visited = new();
347  
348              Queue<Operand> pending = new();
349  
350              bool Enqueue(Operand operand)
351              {
352                  if (visited.Add(operand))
353                  {
354                      pending.Enqueue(operand);
355  
356                      return true;
357                  }
358  
359                  return false;
360              }
361  
362              Enqueue(dest);
363  
364              while (pending.TryDequeue(out Operand operand))
365              {
366                  foreach (INode useNode in operand.UseOps)
367                  {
368                      if (useNode is not Operation operation)
369                      {
370                          continue;
371                      }
372  
373                      if (operation.Inst == Instruction.Copy)
374                      {
375                          if (operation.Dest.Type == OperandType.LocalVariable)
376                          {
377                              if (Enqueue(operation.Dest))
378                              {
379                                  break;
380                              }
381                          }
382                          else
383                          {
384                              return OperandInfo.GetVarType(operation.Dest.Type);
385                          }
386                      }
387                      else
388                      {
389                          for (int index = 0; index < operation.SourcesCount; index++)
390                          {
391                              if (operation.GetSource(index) == operand)
392                              {
393                                  return InstructionInfo.GetSrcVarType(operation.Inst, index);
394                              }
395                          }
396                      }
397                  }
398              }
399  
400              return AggregateType.S32;
401          }
402  
403          private static bool AreAllSourceTypesEqual(IAstNode[] sources, AggregateType type)
404          {
405              foreach (IAstNode node in sources)
406              {
407                  if (node is not AstOperand operand)
408                  {
409                      return false;
410                  }
411  
412                  if (operand.VarType != type)
413                  {
414                      return false;
415                  }
416              }
417  
418              return true;
419          }
420  
421          private static bool IsVectorDestInst(Instruction inst)
422          {
423              return inst switch
424              {
425                  Instruction.ImageLoad or
426                  Instruction.TextureSample => true,
427                  _ => false,
428              };
429          }
430  
431          private static bool IsBranchInst(Instruction inst)
432          {
433              return inst switch
434              {
435                  Instruction.Branch or
436                  Instruction.BranchIfFalse or
437                  Instruction.BranchIfTrue => true,
438                  _ => false,
439              };
440          }
441  
442          private static bool IsBitwiseInst(Instruction inst)
443          {
444              return inst switch
445              {
446                  Instruction.BitwiseAnd or
447                  Instruction.BitwiseExclusiveOr or
448                  Instruction.BitwiseNot or
449                  Instruction.BitwiseOr => true,
450                  _ => false,
451              };
452          }
453  
454          private static Instruction GetLogicalFromBitwiseInst(Instruction inst)
455          {
456              return inst switch
457              {
458                  Instruction.BitwiseAnd => Instruction.LogicalAnd,
459                  Instruction.BitwiseExclusiveOr => Instruction.LogicalExclusiveOr,
460                  Instruction.BitwiseNot => Instruction.LogicalNot,
461                  Instruction.BitwiseOr => Instruction.LogicalOr,
462                  _ => throw new ArgumentException($"Unexpected instruction \"{inst}\"."),
463              };
464          }
465      }
466  }