/ src / Ryujinx.Graphics.Shader / StructuredIr / StructuredProgramContext.cs
StructuredProgramContext.cs
  1  using Ryujinx.Graphics.Shader.IntermediateRepresentation;
  2  using Ryujinx.Graphics.Shader.Translation;
  3  using System.Collections.Generic;
  4  using System.Linq;
  5  using System.Numerics;
  6  using static Ryujinx.Graphics.Shader.StructuredIr.AstHelper;
  7  
  8  namespace Ryujinx.Graphics.Shader.StructuredIr
  9  {
 10      class StructuredProgramContext
 11      {
 12          private HashSet<BasicBlock> _loopTails;
 13  
 14          private Stack<(AstBlock Block, int CurrEndIndex, int LoopEndIndex)> _blockStack;
 15  
 16          private Dictionary<Operand, AstOperand> _localsMap;
 17  
 18          private Dictionary<int, AstAssignment> _gotoTempAsgs;
 19  
 20          private List<GotoStatement> _gotos;
 21  
 22          private AstBlock _currBlock;
 23  
 24          private int _currEndIndex;
 25          private int _loopEndIndex;
 26  
 27          public StructuredFunction CurrentFunction { get; private set; }
 28  
 29          public StructuredProgramInfo Info { get; }
 30  
 31          public ShaderDefinitions Definitions { get; }
 32          public ResourceManager ResourceManager { get; }
 33          public bool DebugMode { get; }
 34  
 35          public StructuredProgramContext(
 36              AttributeUsage attributeUsage,
 37              ShaderDefinitions definitions,
 38              ResourceManager resourceManager,
 39              bool debugMode)
 40          {
 41              Info = new StructuredProgramInfo();
 42  
 43              Definitions = definitions;
 44              ResourceManager = resourceManager;
 45              DebugMode = debugMode;
 46  
 47              if (definitions.GpPassthrough)
 48              {
 49                  int passthroughAttributes = attributeUsage.PassthroughAttributes;
 50                  while (passthroughAttributes != 0)
 51                  {
 52                      int index = BitOperations.TrailingZeroCount(passthroughAttributes);
 53  
 54                      Info.IoDefinitions.Add(new IoDefinition(StorageKind.Input, IoVariable.UserDefined, index));
 55  
 56                      passthroughAttributes &= ~(1 << index);
 57                  }
 58  
 59                  Info.IoDefinitions.Add(new IoDefinition(StorageKind.Input, IoVariable.Position));
 60                  Info.IoDefinitions.Add(new IoDefinition(StorageKind.Input, IoVariable.PointSize));
 61                  Info.IoDefinitions.Add(new IoDefinition(StorageKind.Input, IoVariable.ClipDistance));
 62              }
 63          }
 64  
 65          public void EnterFunction(
 66              int blocksCount,
 67              string name,
 68              AggregateType returnType,
 69              AggregateType[] inArguments,
 70              AggregateType[] outArguments)
 71          {
 72              _loopTails = new HashSet<BasicBlock>();
 73  
 74              _blockStack = new Stack<(AstBlock, int, int)>();
 75  
 76              _localsMap = new Dictionary<Operand, AstOperand>();
 77  
 78              _gotoTempAsgs = new Dictionary<int, AstAssignment>();
 79  
 80              _gotos = new List<GotoStatement>();
 81  
 82              _currBlock = new AstBlock(AstBlockType.Main);
 83  
 84              _currEndIndex = blocksCount;
 85              _loopEndIndex = blocksCount;
 86  
 87              CurrentFunction = new StructuredFunction(_currBlock, name, returnType, inArguments, outArguments);
 88          }
 89  
 90          public void LeaveFunction()
 91          {
 92              Info.Functions.Add(CurrentFunction);
 93          }
 94  
 95          public void EnterBlock(BasicBlock block)
 96          {
 97              while (_currEndIndex == block.Index)
 98              {
 99                  (_currBlock, _currEndIndex, _loopEndIndex) = _blockStack.Pop();
100              }
101  
102              if (_gotoTempAsgs.TryGetValue(block.Index, out AstAssignment gotoTempAsg))
103              {
104                  AddGotoTempReset(block, gotoTempAsg);
105              }
106  
107              LookForDoWhileStatements(block);
108          }
109  
110          public void LeaveBlock(BasicBlock block, Operation branchOp)
111          {
112              LookForIfStatements(block, branchOp);
113          }
114  
115          private void LookForDoWhileStatements(BasicBlock block)
116          {
117              // Check if we have any predecessor whose index is greater than the
118              // current block, this indicates a loop.
119              bool done = false;
120  
121              foreach (BasicBlock predecessor in block.Predecessors.OrderByDescending(x => x.Index))
122              {
123                  // If not a loop, break.
124                  if (predecessor.Index < block.Index)
125                  {
126                      break;
127                  }
128  
129                  // Check if we can create a do-while loop here (only possible if the loop end
130                  // falls inside the current scope), if not add a goto instead.
131                  if (predecessor.Index < _currEndIndex && !done)
132                  {
133                      // Create do-while loop block. We must avoid inserting a goto at the end
134                      // of the loop later, when the tail block is processed. So we add the predecessor
135                      // to a list of loop tails to prevent it from being processed later.
136                      Operation branchOp = (Operation)predecessor.GetLastOp();
137  
138                      NewBlock(AstBlockType.DoWhile, branchOp, predecessor.Index + 1);
139  
140                      _loopTails.Add(predecessor);
141  
142                      done = true;
143                  }
144                  else
145                  {
146                      // Failed to create loop. Since this block is the loop head, we reset the
147                      // goto condition variable here. The variable is always reset on the jump
148                      // target, and this block is the jump target for some loop.
149                      AddGotoTempReset(block, GetGotoTempAsg(block.Index));
150  
151                      break;
152                  }
153              }
154          }
155  
156          private void LookForIfStatements(BasicBlock block, Operation branchOp)
157          {
158              if (block.Branch == null)
159              {
160                  return;
161              }
162  
163              // We can only enclose the "if" when the branch lands before
164              // the end of the current block. If the current enclosing block
165              // is not a loop, then we can also do so if the branch lands
166              // right at the end of the current block. When it is a loop,
167              // this is not valid as the loop condition would be evaluated,
168              // and it could erroneously jump back to the start of the loop.
169              bool inRange =
170                  block.Branch.Index < _currEndIndex ||
171                 (block.Branch.Index == _currEndIndex && block.Branch.Index < _loopEndIndex);
172  
173              bool isLoop = block.Branch.Index <= block.Index;
174  
175              if (inRange && !isLoop)
176              {
177                  NewBlock(AstBlockType.If, branchOp, block.Branch.Index);
178              }
179              else if (!_loopTails.Contains(block))
180              {
181                  AstAssignment gotoTempAsg = GetGotoTempAsg(block.Branch.Index);
182  
183                  // We use DoWhile type here, as the condition should be true for
184                  // unconditional branches, or it should jump if the condition is true otherwise.
185                  IAstNode cond = GetBranchCond(AstBlockType.DoWhile, branchOp);
186  
187                  AddNode(Assign(gotoTempAsg.Destination, cond));
188  
189                  AstOperation branch = new(branchOp.Inst);
190  
191                  AddNode(branch);
192  
193                  GotoStatement gotoStmt = new(branch, gotoTempAsg, isLoop);
194  
195                  _gotos.Add(gotoStmt);
196              }
197          }
198  
199          private AstAssignment GetGotoTempAsg(int index)
200          {
201              if (_gotoTempAsgs.TryGetValue(index, out AstAssignment gotoTempAsg))
202              {
203                  return gotoTempAsg;
204              }
205  
206              AstOperand gotoTemp = NewTemp(AggregateType.Bool);
207  
208              gotoTempAsg = Assign(gotoTemp, Const(IrConsts.False));
209  
210              _gotoTempAsgs.Add(index, gotoTempAsg);
211  
212              return gotoTempAsg;
213          }
214  
215          private void AddGotoTempReset(BasicBlock block, AstAssignment gotoTempAsg)
216          {
217              // If it was already added, we don't need to add it again.
218              if (gotoTempAsg.Parent != null)
219              {
220                  return;
221              }
222  
223              AddNode(gotoTempAsg);
224  
225              // For block 0, we don't need to add the extra "reset" at the beginning,
226              // because it is already the first node to be executed on the shader,
227              // so it is reset to false by the "local" assignment anyway.
228              if (block.Index != 0)
229              {
230                  CurrentFunction.MainBlock.AddFirst(Assign(gotoTempAsg.Destination, Const(IrConsts.False)));
231              }
232          }
233  
234          private void NewBlock(AstBlockType type, Operation branchOp, int endIndex)
235          {
236              NewBlock(type, GetBranchCond(type, branchOp), endIndex);
237          }
238  
239          private void NewBlock(AstBlockType type, IAstNode cond, int endIndex)
240          {
241              AstBlock childBlock = new(type, cond);
242  
243              AddNode(childBlock);
244  
245              _blockStack.Push((_currBlock, _currEndIndex, _loopEndIndex));
246  
247              _currBlock = childBlock;
248              _currEndIndex = endIndex;
249  
250              if (type == AstBlockType.DoWhile)
251              {
252                  _loopEndIndex = endIndex;
253              }
254          }
255  
256          private IAstNode GetBranchCond(AstBlockType type, Operation branchOp)
257          {
258              IAstNode cond;
259  
260              if (branchOp.Inst == Instruction.Branch)
261              {
262                  // If the branch is not conditional, the condition is a constant.
263                  // For if it's false (always jump over, if block never executed).
264                  // For loops it's always true (always loop).
265                  cond = Const(type == AstBlockType.If ? IrConsts.False : IrConsts.True);
266              }
267              else
268              {
269                  cond = GetOperand(branchOp.GetSource(0));
270  
271                  Instruction invInst = type == AstBlockType.If
272                      ? Instruction.BranchIfTrue
273                      : Instruction.BranchIfFalse;
274  
275                  if (branchOp.Inst == invInst)
276                  {
277                      cond = new AstOperation(Instruction.LogicalNot, cond);
278                  }
279              }
280  
281              return cond;
282          }
283  
284          public void AddNode(IAstNode node)
285          {
286              _currBlock.Add(node);
287          }
288  
289          public GotoStatement[] GetGotos()
290          {
291              return _gotos.ToArray();
292          }
293  
294          public AstOperand NewTemp(AggregateType type)
295          {
296              AstOperand newTemp = Local(type);
297  
298              CurrentFunction.Locals.Add(newTemp);
299  
300              return newTemp;
301          }
302  
303          public IAstNode GetOperandOrCbLoad(Operand operand)
304          {
305              if (operand.Type == OperandType.ConstantBuffer)
306              {
307                  int cbufSlot = operand.GetCbufSlot();
308                  int cbufOffset = operand.GetCbufOffset();
309  
310                  int binding = ResourceManager.GetConstantBufferBinding(cbufSlot);
311                  int vecIndex = cbufOffset >> 2;
312                  int elemIndex = cbufOffset & 3;
313  
314                  ResourceManager.SetUsedConstantBufferBinding(binding);
315  
316                  IAstNode[] sources = new IAstNode[]
317                  {
318                      new AstOperand(OperandType.Constant, binding),
319                      new AstOperand(OperandType.Constant, 0),
320                      new AstOperand(OperandType.Constant, vecIndex),
321                      new AstOperand(OperandType.Constant, elemIndex),
322                  };
323  
324                  return new AstOperation(Instruction.Load, StorageKind.ConstantBuffer, false, sources, sources.Length);
325              }
326  
327              return GetOperand(operand);
328          }
329  
330          public AstOperand GetOperand(Operand operand)
331          {
332              if (operand == null)
333              {
334                  return null;
335              }
336  
337              if (operand.Type != OperandType.LocalVariable)
338              {
339                  return new AstOperand(operand);
340              }
341  
342              if (!_localsMap.TryGetValue(operand, out AstOperand astOperand))
343              {
344                  astOperand = new AstOperand(operand);
345  
346                  _localsMap.Add(operand, astOperand);
347  
348                  CurrentFunction.Locals.Add(astOperand);
349              }
350  
351              return astOperand;
352          }
353      }
354  }