Ssa.cs
1 using Ryujinx.Graphics.Shader.Decoders; 2 using Ryujinx.Graphics.Shader.IntermediateRepresentation; 3 using System.Collections.Generic; 4 using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper; 5 6 namespace Ryujinx.Graphics.Shader.Translation 7 { 8 static class Ssa 9 { 10 private const int GprsAndPredsCount = RegisterConsts.GprsCount + RegisterConsts.PredsCount; 11 12 private class DefMap 13 { 14 private readonly Dictionary<Register, Operand> _map; 15 16 private readonly long[] _phiMasks; 17 18 public DefMap() 19 { 20 _map = new Dictionary<Register, Operand>(); 21 22 _phiMasks = new long[(RegisterConsts.TotalCount + 63) / 64]; 23 } 24 25 public bool TryAddOperand(Register reg, Operand operand) 26 { 27 return _map.TryAdd(reg, operand); 28 } 29 30 public bool TryGetOperand(Register reg, out Operand operand) 31 { 32 return _map.TryGetValue(reg, out operand); 33 } 34 35 public bool AddPhi(Register reg) 36 { 37 int key = GetKeyFromRegister(reg); 38 39 int index = key / 64; 40 int bit = key & 63; 41 42 long mask = 1L << bit; 43 44 if ((_phiMasks[index] & mask) != 0) 45 { 46 return false; 47 } 48 49 _phiMasks[index] |= mask; 50 51 return true; 52 } 53 54 public bool HasPhi(Register reg) 55 { 56 int key = GetKeyFromRegister(reg); 57 58 int index = key / 64; 59 int bit = key & 63; 60 61 return (_phiMasks[index] & (1L << bit)) != 0; 62 } 63 } 64 65 private class LocalDefMap 66 { 67 private readonly Operand[] _map; 68 private readonly int[] _uses; 69 public int UseCount { get; private set; } 70 71 public LocalDefMap() 72 { 73 _map = new Operand[RegisterConsts.TotalCount]; 74 _uses = new int[RegisterConsts.TotalCount]; 75 } 76 77 public Operand Get(int key) 78 { 79 return _map[key]; 80 } 81 82 public void Add(int key, Operand operand) 83 { 84 if (_map[key] == null) 85 { 86 _uses[UseCount++] = key; 87 } 88 89 _map[key] = operand; 90 } 91 92 public Operand GetUse(int index, out int key) 93 { 94 key = _uses[index]; 95 96 return _map[key]; 97 } 98 99 public void Clear() 100 { 101 for (int i = 0; i < UseCount; i++) 102 { 103 _map[_uses[i]] = null; 104 } 105 106 UseCount = 0; 107 } 108 } 109 110 private readonly struct Definition 111 { 112 public BasicBlock Block { get; } 113 public Operand Local { get; } 114 115 public Definition(BasicBlock block, Operand local) 116 { 117 Block = block; 118 Local = local; 119 } 120 } 121 122 public static void Rename(BasicBlock[] blocks) 123 { 124 DefMap[] globalDefs = new DefMap[blocks.Length]; 125 LocalDefMap localDefs = new(); 126 127 for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++) 128 { 129 globalDefs[blkIndex] = new DefMap(); 130 } 131 132 Queue<BasicBlock> dfPhiBlocks = new(); 133 134 // First pass, get all defs and locals uses. 135 for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++) 136 { 137 Operand RenameLocal(Operand operand) 138 { 139 if (operand != null && operand.Type == OperandType.Register) 140 { 141 Operand local = localDefs.Get(GetKeyFromRegister(operand.GetRegister())); 142 143 operand = local ?? operand; 144 } 145 146 return operand; 147 } 148 149 BasicBlock block = blocks[blkIndex]; 150 151 LinkedListNode<INode> node = block.Operations.First; 152 153 while (node != null) 154 { 155 if (node.Value is Operation operation) 156 { 157 for (int index = 0; index < operation.SourcesCount; index++) 158 { 159 operation.SetSource(index, RenameLocal(operation.GetSource(index))); 160 } 161 162 for (int index = 0; index < operation.DestsCount; index++) 163 { 164 Operand dest = operation.GetDest(index); 165 166 if (dest != null && dest.Type == OperandType.Register) 167 { 168 Operand local = Local(); 169 170 localDefs.Add(GetKeyFromRegister(dest.GetRegister()), local); 171 172 operation.SetDest(index, local); 173 } 174 } 175 } 176 177 node = node.Next; 178 } 179 180 int localUses = localDefs.UseCount; 181 for (int index = 0; index < localUses; index++) 182 { 183 Operand local = localDefs.GetUse(index, out int key); 184 185 Register reg = GetRegisterFromKey(key); 186 187 globalDefs[block.Index].TryAddOperand(reg, local); 188 189 dfPhiBlocks.Enqueue(block); 190 191 while (dfPhiBlocks.TryDequeue(out BasicBlock dfPhiBlock)) 192 { 193 foreach (BasicBlock domFrontier in dfPhiBlock.DominanceFrontiers) 194 { 195 if (globalDefs[domFrontier.Index].AddPhi(reg)) 196 { 197 dfPhiBlocks.Enqueue(domFrontier); 198 } 199 } 200 } 201 } 202 203 localDefs.Clear(); 204 } 205 206 // Second pass, rename variables with definitions on different blocks. 207 for (int blkIndex = 0; blkIndex < blocks.Length; blkIndex++) 208 { 209 BasicBlock block = blocks[blkIndex]; 210 211 Operand RenameGlobal(Operand operand) 212 { 213 if (operand != null && operand.Type == OperandType.Register) 214 { 215 int key = GetKeyFromRegister(operand.GetRegister()); 216 217 Operand local = localDefs.Get(key); 218 219 if (local != null) 220 { 221 return local; 222 } 223 224 operand = FindDefinitionForCurr(globalDefs, block, operand.GetRegister()); 225 226 localDefs.Add(key, operand); 227 } 228 229 return operand; 230 } 231 232 for (LinkedListNode<INode> node = block.Operations.First; node != null; node = node.Next) 233 { 234 if (node.Value is Operation operation) 235 { 236 for (int index = 0; index < operation.SourcesCount; index++) 237 { 238 operation.SetSource(index, RenameGlobal(operation.GetSource(index))); 239 } 240 } 241 } 242 243 if (blkIndex < blocks.Length - 1) 244 { 245 localDefs.Clear(); 246 } 247 } 248 } 249 250 private static Operand FindDefinitionForCurr(DefMap[] globalDefs, BasicBlock current, Register reg) 251 { 252 if (globalDefs[current.Index].HasPhi(reg)) 253 { 254 return InsertPhi(globalDefs, current, reg); 255 } 256 257 if (current != current.ImmediateDominator) 258 { 259 return FindDefinition(globalDefs, current.ImmediateDominator, reg).Local; 260 } 261 262 return Undef(); 263 } 264 265 private static Definition FindDefinition(DefMap[] globalDefs, BasicBlock current, Register reg) 266 { 267 foreach (BasicBlock block in SelfAndImmediateDominators(current)) 268 { 269 DefMap defMap = globalDefs[block.Index]; 270 271 if (defMap.TryGetOperand(reg, out Operand lastDef)) 272 { 273 return new Definition(block, lastDef); 274 } 275 276 if (defMap.HasPhi(reg)) 277 { 278 return new Definition(block, InsertPhi(globalDefs, block, reg)); 279 } 280 } 281 282 return new Definition(current, Undef()); 283 } 284 285 private static IEnumerable<BasicBlock> SelfAndImmediateDominators(BasicBlock block) 286 { 287 while (block != block.ImmediateDominator) 288 { 289 yield return block; 290 291 block = block.ImmediateDominator; 292 } 293 294 yield return block; 295 } 296 297 private static Operand InsertPhi(DefMap[] globalDefs, BasicBlock block, Register reg) 298 { 299 // This block has a Phi that has not been materialized yet, but that 300 // would define a new version of the variable we're looking for. We need 301 // to materialize the Phi, add all the block/operand pairs into the Phi, and 302 // then use the definition from that Phi. 303 Operand local = Local(); 304 305 PhiNode phi = new(local); 306 307 AddPhi(block, phi); 308 309 globalDefs[block.Index].TryAddOperand(reg, local); 310 311 foreach (BasicBlock predecessor in block.Predecessors) 312 { 313 Definition def = FindDefinition(globalDefs, predecessor, reg); 314 315 phi.AddSource(def.Block, def.Local); 316 } 317 318 return local; 319 } 320 321 private static void AddPhi(BasicBlock block, PhiNode phi) 322 { 323 LinkedListNode<INode> node = block.Operations.First; 324 325 if (node != null) 326 { 327 while (node.Next?.Value is PhiNode) 328 { 329 node = node.Next; 330 } 331 } 332 333 if (node?.Value is PhiNode) 334 { 335 block.Operations.AddAfter(node, phi); 336 } 337 else 338 { 339 block.Operations.AddFirst(phi); 340 } 341 } 342 343 private static int GetKeyFromRegister(Register reg) 344 { 345 if (reg.Type == RegisterType.Gpr) 346 { 347 return reg.Index; 348 } 349 else if (reg.Type == RegisterType.Predicate) 350 { 351 return RegisterConsts.GprsCount + reg.Index; 352 } 353 else /* if (reg.Type == RegisterType.Flag) */ 354 { 355 return GprsAndPredsCount + reg.Index; 356 } 357 } 358 359 private static Register GetRegisterFromKey(int key) 360 { 361 if (key < RegisterConsts.GprsCount) 362 { 363 return new Register(key, RegisterType.Gpr); 364 } 365 else if (key < GprsAndPredsCount) 366 { 367 return new Register(key - RegisterConsts.GprsCount, RegisterType.Predicate); 368 } 369 else /* if (key < RegisterConsts.TotalCount) */ 370 { 371 return new Register(key - GprsAndPredsCount, RegisterType.Flag); 372 } 373 } 374 } 375 }