/ src / Ryujinx.Graphics.Shader / Instructions / InstEmitFlowControl.cs
InstEmitFlowControl.cs
  1  using Ryujinx.Graphics.Shader.Decoders;
  2  using Ryujinx.Graphics.Shader.IntermediateRepresentation;
  3  using Ryujinx.Graphics.Shader.Translation;
  4  using System.Collections.Generic;
  5  using System.Linq;
  6  using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
  7  
  8  namespace Ryujinx.Graphics.Shader.Instructions
  9  {
 10      static partial class InstEmit
 11      {
 12          public static void Bra(EmitterContext context)
 13          {
 14              context.GetOp<InstBra>();
 15  
 16              EmitBranch(context, context.CurrBlock.Successors[^1].Address);
 17          }
 18  
 19          public static void Brk(EmitterContext context)
 20          {
 21              context.GetOp<InstBrk>();
 22  
 23              EmitBrkContSync(context);
 24          }
 25  
 26          public static void Brx(EmitterContext context)
 27          {
 28              InstBrx op = context.GetOp<InstBrx>();
 29              InstOp currOp = context.CurrOp;
 30              int startIndex = context.CurrBlock.HasNext() ? 1 : 0;
 31  
 32              if (context.CurrBlock.Successors.Count <= startIndex)
 33              {
 34                  context.TranslatorContext.GpuAccessor.Log($"Failed to find targets for BRX instruction at 0x{currOp.Address:X}.");
 35                  return;
 36              }
 37  
 38              int offset = (int)currOp.GetAbsoluteAddress();
 39  
 40              Operand address = context.IAdd(Register(op.SrcA, RegisterType.Gpr), Const(offset));
 41  
 42              var targets = context.CurrBlock.Successors.Skip(startIndex);
 43  
 44              bool allTargetsSinglePred = true;
 45              int total = context.CurrBlock.Successors.Count - startIndex;
 46              int count = 0;
 47  
 48              foreach (var target in targets.OrderBy(x => x.Address))
 49              {
 50                  if (++count < total && (target.Predecessors.Count > 1 || target.Address <= context.CurrBlock.Address))
 51                  {
 52                      allTargetsSinglePred = false;
 53                      break;
 54                  }
 55              }
 56  
 57              if (allTargetsSinglePred)
 58              {
 59                  // Chain blocks, each target block will check if the BRX target address
 60                  // matches its own address, if not, it jumps to the next target which will do the same check,
 61                  // until it reaches the last possible target, which executed unconditionally.
 62                  // We can only do this if the BRX block is the only predecessor of all target blocks.
 63                  // Additionally, this is not supported for blocks located before the current block,
 64                  // since it will be too late to insert a label, but this is something that can be improved
 65                  // in the future if necessary.
 66  
 67                  var sortedTargets = targets.OrderBy(x => x.Address);
 68  
 69                  Block currentTarget = null;
 70                  ulong firstTargetAddress = 0;
 71  
 72                  foreach (Block nextTarget in sortedTargets)
 73                  {
 74                      if (currentTarget != null)
 75                      {
 76                          if (currentTarget.Address != nextTarget.Address)
 77                          {
 78                              context.SetBrxTarget(currentTarget.Address, address, (int)currentTarget.Address, nextTarget.Address);
 79                          }
 80                      }
 81                      else
 82                      {
 83                          firstTargetAddress = nextTarget.Address;
 84                      }
 85  
 86                      currentTarget = nextTarget;
 87                  }
 88  
 89                  context.Branch(context.GetLabel(firstTargetAddress));
 90              }
 91              else
 92              {
 93                  // Emit the branches sequentially.
 94                  // This generates slightly worse code, but should work for all cases.
 95  
 96                  var sortedTargets = targets.OrderByDescending(x => x.Address);
 97                  ulong lastTargetAddress = ulong.MaxValue;
 98  
 99                  count = 0;
100  
101                  foreach (Block target in sortedTargets)
102                  {
103                      Operand label = context.GetLabel(target.Address);
104  
105                      if (++count < total)
106                      {
107                          if (target.Address != lastTargetAddress)
108                          {
109                              context.BranchIfTrue(label, context.ICompareEqual(address, Const((int)target.Address)));
110                          }
111  
112                          lastTargetAddress = target.Address;
113                      }
114                      else
115                      {
116                          context.Branch(label);
117                      }
118                  }
119              }
120          }
121  
122          public static void Cal(EmitterContext context)
123          {
124              context.GetOp<InstCal>();
125  
126              DecodedFunction function = context.Program.GetFunctionByAddress(context.CurrOp.GetAbsoluteAddress());
127  
128              if (function.IsCompilerGenerated)
129              {
130                  switch (function.Type)
131                  {
132                      case FunctionType.BuiltInFSIBegin:
133                          context.FSIBegin();
134                          break;
135                      case FunctionType.BuiltInFSIEnd:
136                          context.FSIEnd();
137                          break;
138                  }
139              }
140              else
141              {
142                  context.Call(function.Id, false);
143              }
144          }
145  
146          public static void Cont(EmitterContext context)
147          {
148              context.GetOp<InstCont>();
149  
150              EmitBrkContSync(context);
151          }
152  
153          public static void Exit(EmitterContext context)
154          {
155              InstExit op = context.GetOp<InstExit>();
156  
157              if (context.IsNonMain)
158              {
159                  context.TranslatorContext.GpuAccessor.Log("Invalid exit on non-main function.");
160                  return;
161              }
162  
163              if (op.Ccc == Ccc.T)
164              {
165                  if (context.PrepareForReturn())
166                  {
167                      context.Return();
168                  }
169              }
170              else
171              {
172                  Operand cond = GetCondition(context, op.Ccc, IrConsts.False);
173  
174                  // If the condition is always false, we don't need to do anything.
175                  if (cond.Type != OperandType.Constant || cond.Value != IrConsts.False)
176                  {
177                      Operand lblSkip = Label();
178                      context.BranchIfFalse(lblSkip, cond);
179  
180                      if (context.PrepareForReturn())
181                      {
182                          context.Return();
183                      }
184  
185                      context.MarkLabel(lblSkip);
186                  }
187              }
188          }
189  
190          public static void Kil(EmitterContext context)
191          {
192              context.GetOp<InstKil>();
193  
194              context.Discard();
195          }
196  
197          public static void Pbk(EmitterContext context)
198          {
199              context.GetOp<InstPbk>();
200  
201              EmitPbkPcntSsy(context);
202          }
203  
204          public static void Pcnt(EmitterContext context)
205          {
206              context.GetOp<InstPcnt>();
207  
208              EmitPbkPcntSsy(context);
209          }
210  
211          public static void Ret(EmitterContext context)
212          {
213              context.GetOp<InstRet>();
214  
215              if (context.IsNonMain)
216              {
217                  context.Return();
218              }
219              else
220              {
221                  context.TranslatorContext.GpuAccessor.Log("Invalid return on main function.");
222              }
223          }
224  
225          public static void Ssy(EmitterContext context)
226          {
227              context.GetOp<InstSsy>();
228  
229              EmitPbkPcntSsy(context);
230          }
231  
232          public static void Sync(EmitterContext context)
233          {
234              context.GetOp<InstSync>();
235  
236              EmitBrkContSync(context);
237          }
238  
239          private static void EmitPbkPcntSsy(EmitterContext context)
240          {
241              var consumers = context.CurrBlock.PushOpCodes.First(x => x.Op.Address == context.CurrOp.Address).Consumers;
242  
243              foreach (KeyValuePair<Block, Operand> kv in consumers)
244              {
245                  Block consumerBlock = kv.Key;
246                  Operand local = kv.Value;
247  
248                  int id = consumerBlock.SyncTargets[context.CurrOp.Address].PushOpId;
249  
250                  context.Copy(local, Const(id));
251              }
252          }
253  
254          private static void EmitBrkContSync(EmitterContext context)
255          {
256              var targets = context.CurrBlock.SyncTargets;
257  
258              if (targets.Count == 1)
259              {
260                  // If we have only one target, then the SSY/PBK is basically
261                  // a branch, we can produce better codegen for this case.
262                  EmitBranch(context, targets.Values.First().PushOpInfo.Op.GetAbsoluteAddress());
263              }
264              else
265              {
266                  // TODO: Support CC here as well (condition).
267                  foreach (SyncTarget target in targets.Values)
268                  {
269                      PushOpInfo pushOpInfo = target.PushOpInfo;
270  
271                      Operand label = context.GetLabel(pushOpInfo.Op.GetAbsoluteAddress());
272                      Operand local = pushOpInfo.Consumers[context.CurrBlock];
273  
274                      context.BranchIfTrue(label, context.ICompareEqual(local, Const(target.PushOpId)));
275                  }
276              }
277          }
278  
279          private static void EmitBranch(EmitterContext context, ulong address)
280          {
281              InstOp op = context.CurrOp;
282              InstConditional opCond = new(op.RawOpCode);
283  
284              // If we're branching to the next instruction, then the branch
285              // is useless and we can ignore it.
286              if (address == op.Address + 8)
287              {
288                  return;
289              }
290  
291              Operand label = context.GetLabel(address);
292  
293              Operand pred = Register(opCond.Pred, RegisterType.Predicate);
294  
295              if (opCond.Ccc != Ccc.T)
296              {
297                  Operand cond = GetCondition(context, opCond.Ccc);
298  
299                  if (opCond.Pred == RegisterConsts.PredicateTrueIndex)
300                  {
301                      pred = cond;
302                  }
303                  else if (opCond.PredInv)
304                  {
305                      pred = context.BitwiseAnd(context.BitwiseNot(pred), cond);
306                  }
307                  else
308                  {
309                      pred = context.BitwiseAnd(pred, cond);
310                  }
311  
312                  context.BranchIfTrue(label, pred);
313              }
314              else if (opCond.Pred == RegisterConsts.PredicateTrueIndex)
315              {
316                  context.Branch(label);
317              }
318              else if (opCond.PredInv)
319              {
320                  context.BranchIfFalse(label, pred);
321              }
322              else
323              {
324                  context.BranchIfTrue(label, pred);
325              }
326          }
327      }
328  }