/ src / ARMeilleure / CodeGen / RegisterAllocators / HybridAllocator.cs
HybridAllocator.cs
  1  using ARMeilleure.IntermediateRepresentation;
  2  using ARMeilleure.Translation;
  3  using System;
  4  using System.Diagnostics;
  5  using System.Numerics;
  6  using System.Runtime.CompilerServices;
  7  using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
  8  using static ARMeilleure.IntermediateRepresentation.Operation.Factory;
  9  
 10  namespace ARMeilleure.CodeGen.RegisterAllocators
 11  {
 12      class HybridAllocator : IRegisterAllocator
 13      {
 14          private readonly struct BlockInfo
 15          {
 16              public bool HasCall { get; }
 17  
 18              public int IntFixedRegisters { get; }
 19              public int VecFixedRegisters { get; }
 20  
 21              public BlockInfo(bool hasCall, int intFixedRegisters, int vecFixedRegisters)
 22              {
 23                  HasCall = hasCall;
 24                  IntFixedRegisters = intFixedRegisters;
 25                  VecFixedRegisters = vecFixedRegisters;
 26              }
 27          }
 28  
 29          private struct LocalInfo
 30          {
 31              public int Uses { get; set; }
 32              public int UsesAllocated { get; set; }
 33              public int Sequence { get; set; }
 34              public Operand Temp { get; set; }
 35              public Operand Register { get; set; }
 36              public Operand SpillOffset { get; set; }
 37              public OperandType Type { get; }
 38  
 39              private int _first;
 40              private int _last;
 41  
 42              public readonly bool IsBlockLocal => _first == _last;
 43  
 44              public LocalInfo(OperandType type, int uses, int blkIndex)
 45              {
 46                  Uses = uses;
 47                  Type = type;
 48  
 49                  UsesAllocated = 0;
 50                  Sequence = 0;
 51                  Temp = default;
 52                  Register = default;
 53                  SpillOffset = default;
 54  
 55                  _first = -1;
 56                  _last = -1;
 57  
 58                  SetBlockIndex(blkIndex);
 59              }
 60  
 61              public void SetBlockIndex(int blkIndex)
 62              {
 63                  if (_first == -1 || blkIndex < _first)
 64                  {
 65                      _first = blkIndex;
 66                  }
 67  
 68                  if (_last == -1 || blkIndex > _last)
 69                  {
 70                      _last = blkIndex;
 71                  }
 72              }
 73          }
 74  
 75          private const int MaxIROperands = 4;
 76          // The "visited" state is stored in the MSB of the local's value.
 77          private const ulong VisitedMask = 1ul << 63;
 78  
 79          private BlockInfo[] _blockInfo;
 80          private LocalInfo[] _localInfo;
 81  
 82          [MethodImpl(MethodImplOptions.AggressiveInlining)]
 83          private static bool IsVisited(Operand local)
 84          {
 85              Debug.Assert(local.Kind == OperandKind.LocalVariable);
 86  
 87              return (local.GetValueUnsafe() & VisitedMask) != 0;
 88          }
 89  
 90          [MethodImpl(MethodImplOptions.AggressiveInlining)]
 91          private static void SetVisited(Operand local)
 92          {
 93              Debug.Assert(local.Kind == OperandKind.LocalVariable);
 94  
 95              local.GetValueUnsafe() |= VisitedMask;
 96          }
 97  
 98          [MethodImpl(MethodImplOptions.AggressiveInlining)]
 99          private ref LocalInfo GetLocalInfo(Operand local)
100          {
101              Debug.Assert(local.Kind == OperandKind.LocalVariable);
102              Debug.Assert(IsVisited(local), "Local variable not visited. Used before defined?");
103  
104              return ref _localInfo[(uint)local.GetValueUnsafe() - 1];
105          }
106  
107          public AllocationResult RunPass(ControlFlowGraph cfg, StackAllocator stackAlloc, RegisterMasks regMasks)
108          {
109              int intUsedRegisters = 0;
110              int vecUsedRegisters = 0;
111  
112              int intFreeRegisters = regMasks.IntAvailableRegisters;
113              int vecFreeRegisters = regMasks.VecAvailableRegisters;
114  
115              _blockInfo = new BlockInfo[cfg.Blocks.Count];
116              _localInfo = new LocalInfo[cfg.Blocks.Count * 3];
117  
118              int localInfoCount = 0;
119  
120              for (int index = cfg.PostOrderBlocks.Length - 1; index >= 0; index--)
121              {
122                  BasicBlock block = cfg.PostOrderBlocks[index];
123  
124                  int intFixedRegisters = 0;
125                  int vecFixedRegisters = 0;
126  
127                  bool hasCall = false;
128  
129                  for (Operation node = block.Operations.First; node != default; node = node.ListNext)
130                  {
131                      if (node.Instruction == Instruction.Call)
132                      {
133                          hasCall = true;
134                      }
135  
136                      foreach (Operand source in node.SourcesUnsafe)
137                      {
138                          if (source.Kind == OperandKind.LocalVariable)
139                          {
140                              GetLocalInfo(source).SetBlockIndex(block.Index);
141                          }
142                          else if (source.Kind == OperandKind.Memory)
143                          {
144                              MemoryOperand memOp = source.GetMemory();
145  
146                              if (memOp.BaseAddress != default)
147                              {
148                                  GetLocalInfo(memOp.BaseAddress).SetBlockIndex(block.Index);
149                              }
150  
151                              if (memOp.Index != default)
152                              {
153                                  GetLocalInfo(memOp.Index).SetBlockIndex(block.Index);
154                              }
155                          }
156                      }
157  
158                      foreach (Operand dest in node.DestinationsUnsafe)
159                      {
160                          if (dest.Kind == OperandKind.LocalVariable)
161                          {
162                              if (IsVisited(dest))
163                              {
164                                  GetLocalInfo(dest).SetBlockIndex(block.Index);
165                              }
166                              else
167                              {
168                                  dest.NumberLocal(++localInfoCount);
169  
170                                  if (localInfoCount > _localInfo.Length)
171                                  {
172                                      Array.Resize(ref _localInfo, localInfoCount * 2);
173                                  }
174  
175                                  SetVisited(dest);
176                                  GetLocalInfo(dest) = new LocalInfo(dest.Type, UsesCount(dest), block.Index);
177                              }
178                          }
179                          else if (dest.Kind == OperandKind.Register)
180                          {
181                              if (dest.Type.IsInteger())
182                              {
183                                  intFixedRegisters |= 1 << dest.GetRegister().Index;
184                              }
185                              else
186                              {
187                                  vecFixedRegisters |= 1 << dest.GetRegister().Index;
188                              }
189                          }
190                      }
191                  }
192  
193                  _blockInfo[block.Index] = new BlockInfo(hasCall, intFixedRegisters, vecFixedRegisters);
194              }
195  
196              int sequence = 0;
197  
198              for (int index = cfg.PostOrderBlocks.Length - 1; index >= 0; index--)
199              {
200                  BasicBlock block = cfg.PostOrderBlocks[index];
201  
202                  ref BlockInfo blkInfo = ref _blockInfo[block.Index];
203  
204                  int intLocalFreeRegisters = intFreeRegisters & ~blkInfo.IntFixedRegisters;
205                  int vecLocalFreeRegisters = vecFreeRegisters & ~blkInfo.VecFixedRegisters;
206  
207                  int intCallerSavedRegisters = blkInfo.HasCall ? regMasks.IntCallerSavedRegisters : 0;
208                  int vecCallerSavedRegisters = blkInfo.HasCall ? regMasks.VecCallerSavedRegisters : 0;
209  
210                  int intSpillTempRegisters = SelectSpillTemps(
211                      intCallerSavedRegisters & ~blkInfo.IntFixedRegisters,
212                      intLocalFreeRegisters);
213                  int vecSpillTempRegisters = SelectSpillTemps(
214                      vecCallerSavedRegisters & ~blkInfo.VecFixedRegisters,
215                      vecLocalFreeRegisters);
216  
217                  intLocalFreeRegisters &= ~(intSpillTempRegisters | intCallerSavedRegisters);
218                  vecLocalFreeRegisters &= ~(vecSpillTempRegisters | vecCallerSavedRegisters);
219  
220                  for (Operation node = block.Operations.First; node != default; node = node.ListNext)
221                  {
222                      int intLocalUse = 0;
223                      int vecLocalUse = 0;
224  
225                      Operand AllocateRegister(Operand local)
226                      {
227                          ref LocalInfo info = ref GetLocalInfo(local);
228  
229                          info.UsesAllocated++;
230  
231                          Debug.Assert(info.UsesAllocated <= info.Uses);
232  
233                          if (info.Register != default)
234                          {
235                              if (info.UsesAllocated == info.Uses)
236                              {
237                                  Register reg = info.Register.GetRegister();
238  
239                                  if (local.Type.IsInteger())
240                                  {
241                                      intLocalFreeRegisters |= 1 << reg.Index;
242                                  }
243                                  else
244                                  {
245                                      vecLocalFreeRegisters |= 1 << reg.Index;
246                                  }
247                              }
248  
249                              return info.Register;
250                          }
251                          else
252                          {
253                              Operand temp = info.Temp;
254  
255                              if (temp == default || info.Sequence != sequence)
256                              {
257                                  temp = local.Type.IsInteger()
258                                      ? GetSpillTemp(local, intSpillTempRegisters, ref intLocalUse)
259                                      : GetSpillTemp(local, vecSpillTempRegisters, ref vecLocalUse);
260  
261                                  info.Sequence = sequence;
262                                  info.Temp = temp;
263                              }
264  
265                              Operation fillOp = Operation(Instruction.Fill, temp, info.SpillOffset);
266  
267                              block.Operations.AddBefore(node, fillOp);
268  
269                              return temp;
270                          }
271                      }
272  
273                      bool folded = false;
274  
275                      // If operation is a copy of a local and that local is living on the stack, we turn the copy into
276                      // a fill, instead of inserting a fill before it.
277                      if (node.Instruction == Instruction.Copy)
278                      {
279                          Operand source = node.GetSource(0);
280  
281                          if (source.Kind == OperandKind.LocalVariable)
282                          {
283                              ref LocalInfo info = ref GetLocalInfo(source);
284  
285                              if (info.Register == default)
286                              {
287                                  Operation fillOp = Operation(Instruction.Fill, node.Destination, info.SpillOffset);
288  
289                                  block.Operations.AddBefore(node, fillOp);
290                                  block.Operations.Remove(node);
291  
292                                  node = fillOp;
293  
294                                  folded = true;
295                              }
296                          }
297                      }
298  
299                      if (!folded)
300                      {
301                          foreach (ref Operand source in node.SourcesUnsafe)
302                          {
303                              if (source.Kind == OperandKind.LocalVariable)
304                              {
305                                  source = AllocateRegister(source);
306                              }
307                              else if (source.Kind == OperandKind.Memory)
308                              {
309                                  MemoryOperand memOp = source.GetMemory();
310  
311                                  if (memOp.BaseAddress != default)
312                                  {
313                                      memOp.BaseAddress = AllocateRegister(memOp.BaseAddress);
314                                  }
315  
316                                  if (memOp.Index != default)
317                                  {
318                                      memOp.Index = AllocateRegister(memOp.Index);
319                                  }
320                              }
321                          }
322                      }
323  
324                      int intLocalAsg = 0;
325                      int vecLocalAsg = 0;
326  
327                      foreach (ref Operand dest in node.DestinationsUnsafe)
328                      {
329                          if (dest.Kind != OperandKind.LocalVariable)
330                          {
331                              continue;
332                          }
333  
334                          ref LocalInfo info = ref GetLocalInfo(dest);
335  
336                          if (info.UsesAllocated == 0)
337                          {
338                              int mask = dest.Type.IsInteger()
339                                  ? intLocalFreeRegisters
340                                  : vecLocalFreeRegisters;
341  
342                              if (info.IsBlockLocal && mask != 0)
343                              {
344                                  int selectedReg = BitOperations.TrailingZeroCount(mask);
345  
346                                  info.Register = Register(selectedReg, info.Type.ToRegisterType(), info.Type);
347  
348                                  if (dest.Type.IsInteger())
349                                  {
350                                      intLocalFreeRegisters &= ~(1 << selectedReg);
351                                      intUsedRegisters |= 1 << selectedReg;
352                                  }
353                                  else
354                                  {
355                                      vecLocalFreeRegisters &= ~(1 << selectedReg);
356                                      vecUsedRegisters |= 1 << selectedReg;
357                                  }
358                              }
359                              else
360                              {
361                                  info.Register = default;
362                                  info.SpillOffset = Const(stackAlloc.Allocate(dest.Type.GetSizeInBytes()));
363                              }
364                          }
365  
366                          info.UsesAllocated++;
367  
368                          Debug.Assert(info.UsesAllocated <= info.Uses);
369  
370                          if (info.Register != default)
371                          {
372                              dest = info.Register;
373                          }
374                          else
375                          {
376                              Operand temp = info.Temp;
377  
378                              if (temp == default || info.Sequence != sequence)
379                              {
380                                  temp = dest.Type.IsInteger()
381                                      ? GetSpillTemp(dest, intSpillTempRegisters, ref intLocalAsg)
382                                      : GetSpillTemp(dest, vecSpillTempRegisters, ref vecLocalAsg);
383  
384                                  info.Sequence = sequence;
385                                  info.Temp = temp;
386                              }
387  
388                              dest = temp;
389  
390                              Operation spillOp = Operation(Instruction.Spill, default, info.SpillOffset, temp);
391  
392                              block.Operations.AddAfter(node, spillOp);
393  
394                              node = spillOp;
395                          }
396                      }
397  
398                      sequence++;
399  
400                      intUsedRegisters |= intLocalAsg | intLocalUse;
401                      vecUsedRegisters |= vecLocalAsg | vecLocalUse;
402                  }
403              }
404  
405              return new AllocationResult(intUsedRegisters, vecUsedRegisters, stackAlloc.TotalSize);
406          }
407  
408          private static int SelectSpillTemps(int mask0, int mask1)
409          {
410              int selection = 0;
411              int count = 0;
412  
413              while (count < MaxIROperands && mask0 != 0)
414              {
415                  int mask = mask0 & -mask0;
416  
417                  selection |= mask;
418  
419                  mask0 &= ~mask;
420  
421                  count++;
422              }
423  
424              while (count < MaxIROperands && mask1 != 0)
425              {
426                  int mask = mask1 & -mask1;
427  
428                  selection |= mask;
429  
430                  mask1 &= ~mask;
431  
432                  count++;
433              }
434  
435              Debug.Assert(count == MaxIROperands, "No enough registers for spill temps.");
436  
437              return selection;
438          }
439  
440          private static Operand GetSpillTemp(Operand local, int freeMask, ref int useMask)
441          {
442              int selectedReg = BitOperations.TrailingZeroCount(freeMask & ~useMask);
443  
444              useMask |= 1 << selectedReg;
445  
446              return Register(selectedReg, local.Type.ToRegisterType(), local.Type);
447          }
448  
449          private static int UsesCount(Operand local)
450          {
451              return local.AssignmentsCount + local.UsesCount;
452          }
453      }
454  }