Ssa.cs
  1  using Ryujinx.Graphics.Shader.Decoders;
  2  using Ryujinx.Graphics.Shader.IntermediateRepresentation;
  3  using System.Collections.Generic;
  4  using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
  5  
  6  namespace Ryujinx.Graphics.Shader.Translation
  7  {
  8      static class Ssa
  9      {
 10          private const int GprsAndPredsCount = RegisterConsts.GprsCount + RegisterConsts.PredsCount;
 11  
 12          private class DefMap
 13          {
 14              private readonly Dictionary<Register, Operand> _map;
 15  
 16              private readonly long[] _phiMasks;
 17  
 18              public DefMap()
 19              {
 20                  _map = new Dictionary<Register, Operand>();
 21  
 22                  _phiMasks = new long[(RegisterConsts.TotalCount + 63) / 64];
 23              }
 24  
 25              public bool TryAddOperand(Register reg, Operand operand)
 26              {
 27                  return _map.TryAdd(reg, operand);
 28              }
 29  
 30              public bool TryGetOperand(Register reg, out Operand operand)
 31              {
 32                  return _map.TryGetValue(reg, out operand);
 33              }
 34  
 35              public bool AddPhi(Register reg)
 36              {
 37                  int key = GetKeyFromRegister(reg);
 38  
 39                  int index = key / 64;
 40                  int bit = key & 63;
 41  
 42                  long mask = 1L << bit;
 43  
 44                  if ((_phiMasks[index] & mask) != 0)
 45                  {
 46                      return false;
 47                  }
 48  
 49                  _phiMasks[index] |= mask;
 50  
 51                  return true;
 52              }
 53  
 54              public bool HasPhi(Register reg)
 55              {
 56                  int key = GetKeyFromRegister(reg);
 57  
 58                  int index = key / 64;
 59                  int bit = key & 63;
 60  
 61                  return (_phiMasks[index] & (1L << bit)) != 0;
 62              }
 63          }
 64  
 65          private class LocalDefMap
 66          {
 67              private readonly Operand[] _map;
 68              private readonly int[] _uses;
 69              public int UseCount { get; private set; }
 70  
 71              public LocalDefMap()
 72              {
 73                  _map = new Operand[RegisterConsts.TotalCount];
 74                  _uses = new int[RegisterConsts.TotalCount];
 75              }
 76  
 77              public Operand Get(int key)
 78              {
 79                  return _map[key];
 80              }
 81  
 82              public void Add(int key, Operand operand)
 83              {
 84                  if (_map[key] == null)
 85                  {
 86                      _uses[UseCount++] = key;
 87                  }
 88  
 89                  _map[key] = operand;
 90              }
 91  
 92              public Operand GetUse(int index, out int key)
 93              {
 94                  key = _uses[index];
 95  
 96                  return _map[key];
 97              }
 98  
 99              public void Clear()
100              {
101                  for (int i = 0; i < UseCount; i++)
102                  {
103                      _map[_uses[i]] = null;
104                  }
105  
106                  UseCount = 0;
107              }
108          }
109  
110          private readonly struct Definition
111          {
112              public BasicBlock Block { get; }
113              public Operand Local { get; }
114  
115              public Definition(BasicBlock block, Operand local)
116              {
117                  Block = block;
118                  Local = local;
119              }
120          }
121  
122          public static void Rename(BasicBlock[] blocks)
123          {
124              DefMap[] globalDefs = new DefMap[blocks.Length];
125              LocalDefMap localDefs = new();
126  
127              for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++)
128              {
129                  globalDefs[blkIndex] = new DefMap();
130              }
131  
132              Queue<BasicBlock> dfPhiBlocks = new();
133  
134              // First pass, get all defs and locals uses.
135              for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++)
136              {
137                  Operand RenameLocal(Operand operand)
138                  {
139                      if (operand != null && operand.Type == OperandType.Register)
140                      {
141                          Operand local = localDefs.Get(GetKeyFromRegister(operand.GetRegister()));
142  
143                          operand = local ?? operand;
144                      }
145  
146                      return operand;
147                  }
148  
149                  BasicBlock block = blocks[blkIndex];
150  
151                  LinkedListNode<INode> node = block.Operations.First;
152  
153                  while (node != null)
154                  {
155                      if (node.Value is Operation operation)
156                      {
157                          for (int index = 0; index < operation.SourcesCount; index++)
158                          {
159                              operation.SetSource(index, RenameLocal(operation.GetSource(index)));
160                          }
161  
162                          for (int index = 0; index < operation.DestsCount; index++)
163                          {
164                              Operand dest = operation.GetDest(index);
165  
166                              if (dest != null && dest.Type == OperandType.Register)
167                              {
168                                  Operand local = Local();
169  
170                                  localDefs.Add(GetKeyFromRegister(dest.GetRegister()), local);
171  
172                                  operation.SetDest(index, local);
173                              }
174                          }
175                      }
176  
177                      node = node.Next;
178                  }
179  
180                  int localUses = localDefs.UseCount;
181                  for (int index = 0; index < localUses; index++)
182                  {
183                      Operand local = localDefs.GetUse(index, out int key);
184  
185                      Register reg = GetRegisterFromKey(key);
186  
187                      globalDefs[block.Index].TryAddOperand(reg, local);
188  
189                      dfPhiBlocks.Enqueue(block);
190  
191                      while (dfPhiBlocks.TryDequeue(out BasicBlock dfPhiBlock))
192                      {
193                          foreach (BasicBlock domFrontier in dfPhiBlock.DominanceFrontiers)
194                          {
195                              if (globalDefs[domFrontier.Index].AddPhi(reg))
196                              {
197                                  dfPhiBlocks.Enqueue(domFrontier);
198                              }
199                          }
200                      }
201                  }
202  
203                  localDefs.Clear();
204              }
205  
206              // Second pass, rename variables with definitions on different blocks.
207              for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++)
208              {
209                  BasicBlock block = blocks[blkIndex];
210  
211                  Operand RenameGlobal(Operand operand)
212                  {
213                      if (operand != null && operand.Type == OperandType.Register)
214                      {
215                          int key = GetKeyFromRegister(operand.GetRegister());
216  
217                          Operand local = localDefs.Get(key);
218  
219                          if (local != null)
220                          {
221                              return local;
222                          }
223  
224                          operand = FindDefinitionForCurr(globalDefs, block, operand.GetRegister());
225  
226                          localDefs.Add(key, operand);
227                      }
228  
229                      return operand;
230                  }
231  
232                  for (LinkedListNode<INode> node = block.Operations.First; node != null; node = node.Next)
233                  {
234                      if (node.Value is Operation operation)
235                      {
236                          for (int index = 0; index < operation.SourcesCount; index++)
237                          {
238                              operation.SetSource(index, RenameGlobal(operation.GetSource(index)));
239                          }
240                      }
241                  }
242  
243                  if (blkIndex < blocks.Length - 1)
244                  {
245                      localDefs.Clear();
246                  }
247              }
248          }
249  
250          private static Operand FindDefinitionForCurr(DefMap[] globalDefs, BasicBlock current, Register reg)
251          {
252              if (globalDefs[current.Index].HasPhi(reg))
253              {
254                  return InsertPhi(globalDefs, current, reg);
255              }
256  
257              if (current != current.ImmediateDominator)
258              {
259                  return FindDefinition(globalDefs, current.ImmediateDominator, reg).Local;
260              }
261  
262              return Undef();
263          }
264  
265          private static Definition FindDefinition(DefMap[] globalDefs, BasicBlock current, Register reg)
266          {
267              foreach (BasicBlock block in SelfAndImmediateDominators(current))
268              {
269                  DefMap defMap = globalDefs[block.Index];
270  
271                  if (defMap.TryGetOperand(reg, out Operand lastDef))
272                  {
273                      return new Definition(block, lastDef);
274                  }
275  
276                  if (defMap.HasPhi(reg))
277                  {
278                      return new Definition(block, InsertPhi(globalDefs, block, reg));
279                  }
280              }
281  
282              return new Definition(current, Undef());
283          }
284  
285          private static IEnumerable<BasicBlock> SelfAndImmediateDominators(BasicBlock block)
286          {
287              while (block != block.ImmediateDominator)
288              {
289                  yield return block;
290  
291                  block = block.ImmediateDominator;
292              }
293  
294              yield return block;
295          }
296  
297          private static Operand InsertPhi(DefMap[] globalDefs, BasicBlock block, Register reg)
298          {
299              // This block has a Phi that has not been materialized yet, but that
300              // would define a new version of the variable we're looking for. We need
301              // to materialize the Phi, add all the block/operand pairs into the Phi, and
302              // then use the definition from that Phi.
303              Operand local = Local();
304  
305              PhiNode phi = new(local);
306  
307              AddPhi(block, phi);
308  
309              globalDefs[block.Index].TryAddOperand(reg, local);
310  
311              foreach (BasicBlock predecessor in block.Predecessors)
312              {
313                  Definition def = FindDefinition(globalDefs, predecessor, reg);
314  
315                  phi.AddSource(def.Block, def.Local);
316              }
317  
318              return local;
319          }
320  
321          private static void AddPhi(BasicBlock block, PhiNode phi)
322          {
323              LinkedListNode<INode> node = block.Operations.First;
324  
325              if (node != null)
326              {
327                  while (node.Next?.Value is PhiNode)
328                  {
329                      node = node.Next;
330                  }
331              }
332  
333              if (node?.Value is PhiNode)
334              {
335                  block.Operations.AddAfter(node, phi);
336              }
337              else
338              {
339                  block.Operations.AddFirst(phi);
340              }
341          }
342  
343          private static int GetKeyFromRegister(Register reg)
344          {
345              if (reg.Type == RegisterType.Gpr)
346              {
347                  return reg.Index;
348              }
349              else if (reg.Type == RegisterType.Predicate)
350              {
351                  return RegisterConsts.GprsCount + reg.Index;
352              }
353              else /* if (reg.Type == RegisterType.Flag) */
354              {
355                  return GprsAndPredsCount + reg.Index;
356              }
357          }
358  
359          private static Register GetRegisterFromKey(int key)
360          {
361              if (key < RegisterConsts.GprsCount)
362              {
363                  return new Register(key, RegisterType.Gpr);
364              }
365              else if (key < GprsAndPredsCount)
366              {
367                  return new Register(key - RegisterConsts.GprsCount, RegisterType.Predicate);
368              }
369              else /* if (key < RegisterConsts.TotalCount) */
370              {
371                  return new Register(key - GprsAndPredsCount, RegisterType.Flag);
372              }
373          }
374      }
375  }