/ src / Ryujinx.Graphics.Shader / Translation / Transforms / GeometryToCompute.cs
GeometryToCompute.cs
  1  using Ryujinx.Graphics.Shader.IntermediateRepresentation;
  2  using Ryujinx.Graphics.Shader.Translation.Optimizations;
  3  using System.Collections.Generic;
  4  
  5  using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
  6  
  7  namespace Ryujinx.Graphics.Shader.Translation.Transforms
  8  {
  9      class GeometryToCompute : ITransformPass
 10      {
 11          public static bool IsEnabled(IGpuAccessor gpuAccessor, ShaderStage stage, TargetLanguage targetLanguage, FeatureFlags usedFeatures)
 12          {
 13              return usedFeatures.HasFlag(FeatureFlags.VtgAsCompute);
 14          }
 15  
 16          public static LinkedListNode<INode> RunPass(TransformContext context, LinkedListNode<INode> node)
 17          {
 18              if (context.Definitions.Stage != ShaderStage.Geometry)
 19              {
 20                  return node;
 21              }
 22  
 23              Operation operation = (Operation)node.Value;
 24  
 25              LinkedListNode<INode> newNode = node;
 26  
 27              switch (operation.Inst)
 28              {
 29                  case Instruction.EmitVertex:
 30                      newNode = GenerateEmitVertex(context.Definitions, context.ResourceManager, node);
 31                      break;
 32                  case Instruction.EndPrimitive:
 33                      newNode = GenerateEndPrimitive(context.Definitions, context.ResourceManager, node);
 34                      break;
 35                  case Instruction.Load:
 36                      if (operation.StorageKind == StorageKind.Input)
 37                      {
 38                          IoVariable ioVariable = (IoVariable)operation.GetSource(0).Value;
 39  
 40                          if (TryGetOffset(context.ResourceManager, operation, StorageKind.Input, out int inputOffset))
 41                          {
 42                              Operand primVertex = ioVariable == IoVariable.UserDefined
 43                                  ? operation.GetSource(2)
 44                                  : operation.GetSource(1);
 45  
 46                              Operand vertexElemOffset = GenerateVertexOffset(context.ResourceManager, node, inputOffset, primVertex);
 47  
 48                              newNode = node.List.AddBefore(node, new Operation(
 49                                  Instruction.Load,
 50                                  StorageKind.StorageBuffer,
 51                                  operation.Dest,
 52                                  new[] { Const(context.ResourceManager.Reservations.VertexOutputStorageBufferBinding), Const(0), vertexElemOffset }));
 53                          }
 54                          else
 55                          {
 56                              switch (ioVariable)
 57                              {
 58                                  case IoVariable.InvocationId:
 59                                      newNode = GenerateInvocationId(node, operation.Dest);
 60                                      break;
 61                                  case IoVariable.PrimitiveId:
 62                                      newNode = GeneratePrimitiveId(context.ResourceManager, node, operation.Dest);
 63                                      break;
 64                                  case IoVariable.GlobalId:
 65                                  case IoVariable.SubgroupEqMask:
 66                                  case IoVariable.SubgroupGeMask:
 67                                  case IoVariable.SubgroupGtMask:
 68                                  case IoVariable.SubgroupLaneId:
 69                                  case IoVariable.SubgroupLeMask:
 70                                  case IoVariable.SubgroupLtMask:
 71                                      // Those are valid or expected for geometry shaders.
 72                                      break;
 73                                  default:
 74                                      context.GpuAccessor.Log($"Invalid input \"{ioVariable}\".");
 75                                      break;
 76                              }
 77                          }
 78                      }
 79                      else if (operation.StorageKind == StorageKind.Output)
 80                      {
 81                          if (TryGetOffset(context.ResourceManager, operation, StorageKind.Output, out int outputOffset))
 82                          {
 83                              newNode = node.List.AddBefore(node, new Operation(
 84                                  Instruction.Load,
 85                                  StorageKind.LocalMemory,
 86                                  operation.Dest,
 87                                  new[] { Const(context.ResourceManager.LocalVertexDataMemoryId), Const(outputOffset) }));
 88                          }
 89                          else
 90                          {
 91                              context.GpuAccessor.Log($"Invalid output \"{(IoVariable)operation.GetSource(0).Value}\".");
 92                          }
 93                      }
 94                      break;
 95                  case Instruction.Store:
 96                      if (operation.StorageKind == StorageKind.Output)
 97                      {
 98                          if (TryGetOffset(context.ResourceManager, operation, StorageKind.Output, out int outputOffset))
 99                          {
100                              Operand value = operation.GetSource(operation.SourcesCount - 1);
101  
102                              newNode = node.List.AddBefore(node, new Operation(
103                                  Instruction.Store,
104                                  StorageKind.LocalMemory,
105                                  (Operand)null,
106                                  new[] { Const(context.ResourceManager.LocalVertexDataMemoryId), Const(outputOffset), value }));
107                          }
108                          else
109                          {
110                              context.GpuAccessor.Log($"Invalid output \"{(IoVariable)operation.GetSource(0).Value}\".");
111                          }
112                      }
113                      break;
114              }
115  
116              if (newNode != node)
117              {
118                  Utils.DeleteNode(node, operation);
119              }
120  
121              return newNode;
122          }
123  
124          private static LinkedListNode<INode> GenerateEmitVertex(ShaderDefinitions definitions, ResourceManager resourceManager, LinkedListNode<INode> node)
125          {
126              int vbOutputBinding = resourceManager.Reservations.GeometryVertexOutputStorageBufferBinding;
127              int ibOutputBinding = resourceManager.Reservations.GeometryIndexOutputStorageBufferBinding;
128              int stride = resourceManager.Reservations.OutputSizePerInvocation;
129  
130              Operand outputPrimVertex = IncrementLocalMemory(node, resourceManager.LocalGeometryOutputVertexCountMemoryId);
131              Operand baseVertexOffset = GenerateBaseOffset(
132                  resourceManager,
133                  node,
134                  definitions.MaxOutputVertices * definitions.ThreadsPerInputPrimitive,
135                  definitions.ThreadsPerInputPrimitive);
136              Operand outputBaseVertex = Local();
137              node.List.AddBefore(node, new Operation(Instruction.Add, outputBaseVertex, new[] { baseVertexOffset, outputPrimVertex }));
138  
139              Operand outputPrimIndex = IncrementLocalMemory(node, resourceManager.LocalGeometryOutputIndexCountMemoryId);
140              Operand baseIndexOffset = GenerateBaseOffset(
141                  resourceManager,
142                  node,
143                  definitions.GetGeometryOutputIndexBufferStride(),
144                  definitions.ThreadsPerInputPrimitive);
145              Operand outputBaseIndex = Local();
146              node.List.AddBefore(node, new Operation(Instruction.Add, outputBaseIndex, new[] { baseIndexOffset, outputPrimIndex }));
147  
148              node.List.AddBefore(node, new Operation(
149                  Instruction.Store,
150                  StorageKind.StorageBuffer,
151                  null,
152                  new[] { Const(ibOutputBinding), Const(0), outputBaseIndex, outputBaseVertex }));
153  
154              Operand baseOffset = Local();
155              node.List.AddBefore(node, new Operation(Instruction.Multiply, baseOffset, new[] { outputBaseVertex, Const(stride) }));
156  
157              LinkedListNode<INode> newNode = node;
158  
159              for (int offset = 0; offset < stride; offset++)
160              {
161                  Operand vertexOffset;
162  
163                  if (offset > 0)
164                  {
165                      vertexOffset = Local();
166                      node.List.AddBefore(node, new Operation(Instruction.Add, vertexOffset, new[] { baseOffset, Const(offset) }));
167                  }
168                  else
169                  {
170                      vertexOffset = baseOffset;
171                  }
172  
173                  Operand value = Local();
174                  node.List.AddBefore(node, new Operation(
175                      Instruction.Load,
176                      StorageKind.LocalMemory,
177                      value,
178                      new[] { Const(resourceManager.LocalVertexDataMemoryId), Const(offset) }));
179  
180                  newNode = node.List.AddBefore(node, new Operation(
181                      Instruction.Store,
182                      StorageKind.StorageBuffer,
183                      null,
184                      new[] { Const(vbOutputBinding), Const(0), vertexOffset, value }));
185              }
186  
187              return newNode;
188          }
189  
190          private static LinkedListNode<INode> GenerateEndPrimitive(ShaderDefinitions definitions, ResourceManager resourceManager, LinkedListNode<INode> node)
191          {
192              int ibOutputBinding = resourceManager.Reservations.GeometryIndexOutputStorageBufferBinding;
193  
194              Operand outputPrimIndex = IncrementLocalMemory(node, resourceManager.LocalGeometryOutputIndexCountMemoryId);
195              Operand baseIndexOffset = GenerateBaseOffset(
196                  resourceManager,
197                  node,
198                  definitions.GetGeometryOutputIndexBufferStride(),
199                  definitions.ThreadsPerInputPrimitive);
200              Operand outputBaseIndex = Local();
201              node.List.AddBefore(node, new Operation(Instruction.Add, outputBaseIndex, new[] { baseIndexOffset, outputPrimIndex }));
202  
203              return node.List.AddBefore(node, new Operation(
204                  Instruction.Store,
205                  StorageKind.StorageBuffer,
206                  null,
207                  new[] { Const(ibOutputBinding), Const(0), outputBaseIndex, Const(-1) }));
208          }
209  
210          private static Operand GenerateBaseOffset(ResourceManager resourceManager, LinkedListNode<INode> node, int stride, int threadsPerInputPrimitive)
211          {
212              Operand primitiveId = Local();
213              GeneratePrimitiveId(resourceManager, node, primitiveId);
214  
215              Operand baseOffset = Local();
216              node.List.AddBefore(node, new Operation(Instruction.Multiply, baseOffset, new[] { primitiveId, Const(stride) }));
217  
218              Operand invocationId = Local();
219              GenerateInvocationId(node, invocationId);
220  
221              Operand invocationOffset = Local();
222              node.List.AddBefore(node, new Operation(Instruction.Multiply, invocationOffset, new[] { invocationId, Const(stride / threadsPerInputPrimitive) }));
223  
224              Operand combinedOffset = Local();
225              node.List.AddBefore(node, new Operation(Instruction.Add, combinedOffset, new[] { baseOffset, invocationOffset }));
226  
227              return combinedOffset;
228          }
229  
230          private static Operand IncrementLocalMemory(LinkedListNode<INode> node, int memoryId)
231          {
232              Operand oldValue = Local();
233              node.List.AddBefore(node, new Operation(
234                  Instruction.Load,
235                  StorageKind.LocalMemory,
236                  oldValue,
237                  new[] { Const(memoryId) }));
238  
239              Operand newValue = Local();
240              node.List.AddBefore(node, new Operation(Instruction.Add, newValue, new[] { oldValue, Const(1) }));
241  
242              node.List.AddBefore(node, new Operation(Instruction.Store, StorageKind.LocalMemory, null, new[] { Const(memoryId), newValue }));
243  
244              return oldValue;
245          }
246  
247          private static Operand GenerateVertexOffset(
248              ResourceManager resourceManager,
249              LinkedListNode<INode> node,
250              int elementOffset,
251              Operand primVertex)
252          {
253              int vertexInfoCbBinding = resourceManager.Reservations.VertexInfoConstantBufferBinding;
254  
255              Operand vertexCount = Local();
256              node.List.AddBefore(node, new Operation(
257                  Instruction.Load,
258                  StorageKind.ConstantBuffer,
259                  vertexCount,
260                  new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexCounts), Const(0) }));
261  
262              Operand primInputVertex = Local();
263              node.List.AddBefore(node, new Operation(
264                  Instruction.Load,
265                  StorageKind.LocalMemory,
266                  primInputVertex,
267                  new[] { Const(resourceManager.LocalTopologyRemapMemoryId), primVertex }));
268  
269              Operand instanceIndex = Local();
270              node.List.AddBefore(node, new Operation(
271                  Instruction.Load,
272                  StorageKind.Input,
273                  instanceIndex,
274                  new[] { Const((int)IoVariable.GlobalId), Const(1) }));
275  
276              Operand baseVertex = Local();
277              node.List.AddBefore(node, new Operation(Instruction.Multiply, baseVertex, new[] { instanceIndex, vertexCount }));
278  
279              Operand vertexIndex = Local();
280              node.List.AddBefore(node, new Operation(Instruction.Add, vertexIndex, new[] { baseVertex, primInputVertex }));
281  
282              Operand vertexBaseOffset = Local();
283              node.List.AddBefore(node, new Operation(
284                  Instruction.Multiply,
285                  vertexBaseOffset,
286                  new[] { vertexIndex, Const(resourceManager.Reservations.InputSizePerInvocation) }));
287  
288              Operand vertexElemOffset;
289  
290              if (elementOffset != 0)
291              {
292                  vertexElemOffset = Local();
293  
294                  node.List.AddBefore(node, new Operation(Instruction.Add, vertexElemOffset, new[] { vertexBaseOffset, Const(elementOffset) }));
295              }
296              else
297              {
298                  vertexElemOffset = vertexBaseOffset;
299              }
300  
301              return vertexElemOffset;
302          }
303  
304          private static LinkedListNode<INode> GeneratePrimitiveId(ResourceManager resourceManager, LinkedListNode<INode> node, Operand dest)
305          {
306              int vertexInfoCbBinding = resourceManager.Reservations.VertexInfoConstantBufferBinding;
307  
308              Operand vertexCount = Local();
309              node.List.AddBefore(node, new Operation(
310                  Instruction.Load,
311                  StorageKind.ConstantBuffer,
312                  vertexCount,
313                  new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexCounts), Const(0) }));
314  
315              Operand vertexIndex = Local();
316              node.List.AddBefore(node, new Operation(
317                  Instruction.Load,
318                  StorageKind.Input,
319                  vertexIndex,
320                  new[] { Const((int)IoVariable.GlobalId), Const(0) }));
321  
322              Operand instanceIndex = Local();
323              node.List.AddBefore(node, new Operation(
324                  Instruction.Load,
325                  StorageKind.Input,
326                  instanceIndex,
327                  new[] { Const((int)IoVariable.GlobalId), Const(1) }));
328  
329              Operand baseVertex = Local();
330              node.List.AddBefore(node, new Operation(Instruction.Multiply, baseVertex, new[] { instanceIndex, vertexCount }));
331  
332              return node.List.AddBefore(node, new Operation(Instruction.Add, dest, new[] { baseVertex, vertexIndex }));
333          }
334  
335          private static LinkedListNode<INode> GenerateInvocationId(LinkedListNode<INode> node, Operand dest)
336          {
337              return node.List.AddBefore(node, new Operation(
338                  Instruction.Load,
339                  StorageKind.Input,
340                  dest,
341                  new[] { Const((int)IoVariable.GlobalId), Const(2) }));
342          }
343  
344          private static bool TryGetOffset(ResourceManager resourceManager, Operation operation, StorageKind storageKind, out int outputOffset)
345          {
346              bool isStore = operation.Inst == Instruction.Store;
347  
348              IoVariable ioVariable = (IoVariable)operation.GetSource(0).Value;
349  
350              bool isValidOutput;
351  
352              if (ioVariable == IoVariable.UserDefined)
353              {
354                  int lastIndex = operation.SourcesCount - (isStore ? 2 : 1);
355  
356                  int location = operation.GetSource(1).Value;
357                  int component = operation.GetSource(lastIndex).Value;
358  
359                  isValidOutput = resourceManager.Reservations.TryGetOffset(storageKind, location, component, out outputOffset);
360              }
361              else
362              {
363                  if (ResourceReservations.IsVectorOrArrayVariable(ioVariable))
364                  {
365                      int component = operation.GetSource(operation.SourcesCount - (isStore ? 2 : 1)).Value;
366  
367                      isValidOutput = resourceManager.Reservations.TryGetOffset(storageKind, ioVariable, component, out outputOffset);
368                  }
369                  else
370                  {
371                      isValidOutput = resourceManager.Reservations.TryGetOffset(storageKind, ioVariable, out outputOffset);
372                  }
373              }
374  
375              return isValidOutput;
376          }
377      }
378  }