HybridAllocator.cs
1 using ARMeilleure.IntermediateRepresentation; 2 using ARMeilleure.Translation; 3 using System; 4 using System.Diagnostics; 5 using System.Numerics; 6 using System.Runtime.CompilerServices; 7 using static ARMeilleure.IntermediateRepresentation.Operand.Factory; 8 using static ARMeilleure.IntermediateRepresentation.Operation.Factory; 9 10 namespace ARMeilleure.CodeGen.RegisterAllocators 11 { 12 class HybridAllocator : IRegisterAllocator 13 { 14 private readonly struct BlockInfo 15 { 16 public bool HasCall { get; } 17 18 public int IntFixedRegisters { get; } 19 public int VecFixedRegisters { get; } 20 21 public BlockInfo(bool hasCall, int intFixedRegisters, int vecFixedRegisters) 22 { 23 HasCall = hasCall; 24 IntFixedRegisters = intFixedRegisters; 25 VecFixedRegisters = vecFixedRegisters; 26 } 27 } 28 29 private struct LocalInfo 30 { 31 public int Uses { get; set; } 32 public int UsesAllocated { get; set; } 33 public int Sequence { get; set; } 34 public Operand Temp { get; set; } 35 public Operand Register { get; set; } 36 public Operand SpillOffset { get; set; } 37 public OperandType Type { get; } 38 39 private int _first; 40 private int _last; 41 42 public readonly bool IsBlockLocal => _first == _last; 43 44 public LocalInfo(OperandType type, int uses, int blkIndex) 45 { 46 Uses = uses; 47 Type = type; 48 49 UsesAllocated = 0; 50 Sequence = 0; 51 Temp = default; 52 Register = default; 53 SpillOffset = default; 54 55 _first = -1; 56 _last = -1; 57 58 SetBlockIndex(blkIndex); 59 } 60 61 public void SetBlockIndex(int blkIndex) 62 { 63 if (_first == -1 || blkIndex < _first) 64 { 65 _first = blkIndex; 66 } 67 68 if (_last == -1 || blkIndex > _last) 69 { 70 _last = blkIndex; 71 } 72 } 73 } 74 75 private const int MaxIROperands = 4; 76 // The "visited" state is stored in the MSB of the local's value. 77 private const ulong VisitedMask = 1ul << 63; 78 79 private BlockInfo[] _blockInfo; 80 private LocalInfo[] _localInfo; 81 82 [MethodImpl(MethodImplOptions.AggressiveInlining)] 83 private static bool IsVisited(Operand local) 84 { 85 Debug.Assert(local.Kind == OperandKind.LocalVariable); 86 87 return (local.GetValueUnsafe() & VisitedMask) != 0; 88 } 89 90 [MethodImpl(MethodImplOptions.AggressiveInlining)] 91 private static void SetVisited(Operand local) 92 { 93 Debug.Assert(local.Kind == OperandKind.LocalVariable); 94 95 local.GetValueUnsafe() |= VisitedMask; 96 } 97 98 [MethodImpl(MethodImplOptions.AggressiveInlining)] 99 private ref LocalInfo GetLocalInfo(Operand local) 100 { 101 Debug.Assert(local.Kind == OperandKind.LocalVariable); 102 Debug.Assert(IsVisited(local), "Local variable not visited. Used before defined?"); 103 104 return ref _localInfo[(uint)local.GetValueUnsafe() - 1]; 105 } 106 107 public AllocationResult RunPass(ControlFlowGraph cfg, StackAllocator stackAlloc, RegisterMasks regMasks) 108 { 109 int intUsedRegisters = 0; 110 int vecUsedRegisters = 0; 111 112 int intFreeRegisters = regMasks.IntAvailableRegisters; 113 int vecFreeRegisters = regMasks.VecAvailableRegisters; 114 115 _blockInfo = new BlockInfo[cfg.Blocks.Count]; 116 _localInfo = new LocalInfo[cfg.Blocks.Count * 3]; 117 118 int localInfoCount = 0; 119 120 for (int index = cfg.PostOrderBlocks.Length - 1; index >= 0; index--) 121 { 122 BasicBlock block = cfg.PostOrderBlocks[index]; 123 124 int intFixedRegisters = 0; 125 int vecFixedRegisters = 0; 126 127 bool hasCall = false; 128 129 for (Operation node = block.Operations.First; node != default; node = node.ListNext) 130 { 131 if (node.Instruction == Instruction.Call) 132 { 133 hasCall = true; 134 } 135 136 foreach (Operand source in node.SourcesUnsafe) 137 { 138 if (source.Kind == OperandKind.LocalVariable) 139 { 140 GetLocalInfo(source).SetBlockIndex(block.Index); 141 } 142 else if (source.Kind == OperandKind.Memory) 143 { 144 MemoryOperand memOp = source.GetMemory(); 145 146 if (memOp.BaseAddress != default) 147 { 148 GetLocalInfo(memOp.BaseAddress).SetBlockIndex(block.Index); 149 } 150 151 if (memOp.Index != default) 152 { 153 GetLocalInfo(memOp.Index).SetBlockIndex(block.Index); 154 } 155 } 156 } 157 158 foreach (Operand dest in node.DestinationsUnsafe) 159 { 160 if (dest.Kind == OperandKind.LocalVariable) 161 { 162 if (IsVisited(dest)) 163 { 164 GetLocalInfo(dest).SetBlockIndex(block.Index); 165 } 166 else 167 { 168 dest.NumberLocal(++localInfoCount); 169 170 if (localInfoCount > _localInfo.Length) 171 { 172 Array.Resize(ref _localInfo, localInfoCount * 2); 173 } 174 175 SetVisited(dest); 176 GetLocalInfo(dest) = new LocalInfo(dest.Type, UsesCount(dest), block.Index); 177 } 178 } 179 else if (dest.Kind == OperandKind.Register) 180 { 181 if (dest.Type.IsInteger()) 182 { 183 intFixedRegisters |= 1 << dest.GetRegister().Index; 184 } 185 else 186 { 187 vecFixedRegisters |= 1 << dest.GetRegister().Index; 188 } 189 } 190 } 191 } 192 193 _blockInfo[block.Index] = new BlockInfo(hasCall, intFixedRegisters, vecFixedRegisters); 194 } 195 196 int sequence = 0; 197 198 for (int index = cfg.PostOrderBlocks.Length - 1; index >= 0; index--) 199 { 200 BasicBlock block = cfg.PostOrderBlocks[index]; 201 202 ref BlockInfo blkInfo = ref _blockInfo[block.Index]; 203 204 int intLocalFreeRegisters = intFreeRegisters & ~blkInfo.IntFixedRegisters; 205 int vecLocalFreeRegisters = vecFreeRegisters & ~blkInfo.VecFixedRegisters; 206 207 int intCallerSavedRegisters = blkInfo.HasCall ? regMasks.IntCallerSavedRegisters : 0; 208 int vecCallerSavedRegisters = blkInfo.HasCall ? regMasks.VecCallerSavedRegisters : 0; 209 210 int intSpillTempRegisters = SelectSpillTemps( 211 intCallerSavedRegisters & ~blkInfo.IntFixedRegisters, 212 intLocalFreeRegisters); 213 int vecSpillTempRegisters = SelectSpillTemps( 214 vecCallerSavedRegisters & ~blkInfo.VecFixedRegisters, 215 vecLocalFreeRegisters); 216 217 intLocalFreeRegisters &= ~(intSpillTempRegisters | intCallerSavedRegisters); 218 vecLocalFreeRegisters &= ~(vecSpillTempRegisters | vecCallerSavedRegisters); 219 220 for (Operation node = block.Operations.First; node != default; node = node.ListNext) 221 { 222 int intLocalUse = 0; 223 int vecLocalUse = 0; 224 225 Operand AllocateRegister(Operand local) 226 { 227 ref LocalInfo info = ref GetLocalInfo(local); 228 229 info.UsesAllocated++; 230 231 Debug.Assert(info.UsesAllocated <= info.Uses); 232 233 if (info.Register != default) 234 { 235 if (info.UsesAllocated == info.Uses) 236 { 237 Register reg = info.Register.GetRegister(); 238 239 if (local.Type.IsInteger()) 240 { 241 intLocalFreeRegisters |= 1 << reg.Index; 242 } 243 else 244 { 245 vecLocalFreeRegisters |= 1 << reg.Index; 246 } 247 } 248 249 return info.Register; 250 } 251 else 252 { 253 Operand temp = info.Temp; 254 255 if (temp == default || info.Sequence != sequence) 256 { 257 temp = local.Type.IsInteger() 258 ? GetSpillTemp(local, intSpillTempRegisters, ref intLocalUse) 259 : GetSpillTemp(local, vecSpillTempRegisters, ref vecLocalUse); 260 261 info.Sequence = sequence; 262 info.Temp = temp; 263 } 264 265 Operation fillOp = Operation(Instruction.Fill, temp, info.SpillOffset); 266 267 block.Operations.AddBefore(node, fillOp); 268 269 return temp; 270 } 271 } 272 273 bool folded = false; 274 275 // If operation is a copy of a local and that local is living on the stack, we turn the copy into 276 // a fill, instead of inserting a fill before it. 277 if (node.Instruction == Instruction.Copy) 278 { 279 Operand source = node.GetSource(0); 280 281 if (source.Kind == OperandKind.LocalVariable) 282 { 283 ref LocalInfo info = ref GetLocalInfo(source); 284 285 if (info.Register == default) 286 { 287 Operation fillOp = Operation(Instruction.Fill, node.Destination, info.SpillOffset); 288 289 block.Operations.AddBefore(node, fillOp); 290 block.Operations.Remove(node); 291 292 node = fillOp; 293 294 folded = true; 295 } 296 } 297 } 298 299 if (!folded) 300 { 301 foreach (ref Operand source in node.SourcesUnsafe) 302 { 303 if (source.Kind == OperandKind.LocalVariable) 304 { 305 source = AllocateRegister(source); 306 } 307 else if (source.Kind == OperandKind.Memory) 308 { 309 MemoryOperand memOp = source.GetMemory(); 310 311 if (memOp.BaseAddress != default) 312 { 313 memOp.BaseAddress = AllocateRegister(memOp.BaseAddress); 314 } 315 316 if (memOp.Index != default) 317 { 318 memOp.Index = AllocateRegister(memOp.Index); 319 } 320 } 321 } 322 } 323 324 int intLocalAsg = 0; 325 int vecLocalAsg = 0; 326 327 foreach (ref Operand dest in node.DestinationsUnsafe) 328 { 329 if (dest.Kind != OperandKind.LocalVariable) 330 { 331 continue; 332 } 333 334 ref LocalInfo info = ref GetLocalInfo(dest); 335 336 if (info.UsesAllocated == 0) 337 { 338 int mask = dest.Type.IsInteger() 339 ? intLocalFreeRegisters 340 : vecLocalFreeRegisters; 341 342 if (info.IsBlockLocal && mask != 0) 343 { 344 int selectedReg = BitOperations.TrailingZeroCount(mask); 345 346 info.Register = Register(selectedReg, info.Type.ToRegisterType(), info.Type); 347 348 if (dest.Type.IsInteger()) 349 { 350 intLocalFreeRegisters &= ~(1 << selectedReg); 351 intUsedRegisters |= 1 << selectedReg; 352 } 353 else 354 { 355 vecLocalFreeRegisters &= ~(1 << selectedReg); 356 vecUsedRegisters |= 1 << selectedReg; 357 } 358 } 359 else 360 { 361 info.Register = default; 362 info.SpillOffset = Const(stackAlloc.Allocate(dest.Type.GetSizeInBytes())); 363 } 364 } 365 366 info.UsesAllocated++; 367 368 Debug.Assert(info.UsesAllocated <= info.Uses); 369 370 if (info.Register != default) 371 { 372 dest = info.Register; 373 } 374 else 375 { 376 Operand temp = info.Temp; 377 378 if (temp == default || info.Sequence != sequence) 379 { 380 temp = dest.Type.IsInteger() 381 ? GetSpillTemp(dest, intSpillTempRegisters, ref intLocalAsg) 382 : GetSpillTemp(dest, vecSpillTempRegisters, ref vecLocalAsg); 383 384 info.Sequence = sequence; 385 info.Temp = temp; 386 } 387 388 dest = temp; 389 390 Operation spillOp = Operation(Instruction.Spill, default, info.SpillOffset, temp); 391 392 block.Operations.AddAfter(node, spillOp); 393 394 node = spillOp; 395 } 396 } 397 398 sequence++; 399 400 intUsedRegisters |= intLocalAsg | intLocalUse; 401 vecUsedRegisters |= vecLocalAsg | vecLocalUse; 402 } 403 } 404 405 return new AllocationResult(intUsedRegisters, vecUsedRegisters, stackAlloc.TotalSize); 406 } 407 408 private static int SelectSpillTemps(int mask0, int mask1) 409 { 410 int selection = 0; 411 int count = 0; 412 413 while (count < MaxIROperands && mask0 != 0) 414 { 415 int mask = mask0 & -mask0; 416 417 selection |= mask; 418 419 mask0 &= ~mask; 420 421 count++; 422 } 423 424 while (count < MaxIROperands && mask1 != 0) 425 { 426 int mask = mask1 & -mask1; 427 428 selection |= mask; 429 430 mask1 &= ~mask; 431 432 count++; 433 } 434 435 Debug.Assert(count == MaxIROperands, "No enough registers for spill temps."); 436 437 return selection; 438 } 439 440 private static Operand GetSpillTemp(Operand local, int freeMask, ref int useMask) 441 { 442 int selectedReg = BitOperations.TrailingZeroCount(freeMask & ~useMask); 443 444 useMask |= 1 << selectedReg; 445 446 return Register(selectedReg, local.Type.ToRegisterType(), local.Type); 447 } 448 449 private static int UsesCount(Operand local) 450 { 451 return local.AssignmentsCount + local.UsesCount; 452 } 453 } 454 }