/ src / ARMeilleure / Signal / NativeSignalHandlerGenerator.cs
NativeSignalHandlerGenerator.cs
  1  using ARMeilleure.IntermediateRepresentation;
  2  using ARMeilleure.Translation;
  3  using System;
  4  using System.Runtime.InteropServices;
  5  using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
  6  
  7  namespace ARMeilleure.Signal
  8  {
  9      public static class NativeSignalHandlerGenerator
 10      {
 11          public const int MaxTrackedRanges = 8;
 12  
 13          private const int StructAddressOffset = 0;
 14          private const int StructWriteOffset = 4;
 15          private const int UnixOldSigaction = 8;
 16          private const int UnixOldSigaction3Arg = 16;
 17          private const int RangeOffset = 20;
 18  
 19          private const int EXCEPTION_CONTINUE_SEARCH = 0;
 20          private const int EXCEPTION_CONTINUE_EXECUTION = -1;
 21  
 22          private const uint EXCEPTION_ACCESS_VIOLATION = 0xc0000005;
 23  
 24          private static Operand EmitGenericRegionCheck(EmitterContext context, IntPtr signalStructPtr, Operand faultAddress, Operand isWrite, int rangeStructSize)
 25          {
 26              Operand inRegionLocal = context.AllocateLocal(OperandType.I32);
 27              context.Copy(inRegionLocal, Const(0));
 28  
 29              Operand endLabel = Label();
 30  
 31              for (int i = 0; i < MaxTrackedRanges; i++)
 32              {
 33                  ulong rangeBaseOffset = (ulong)(RangeOffset + i * rangeStructSize);
 34  
 35                  Operand nextLabel = Label();
 36  
 37                  Operand isActive = context.Load(OperandType.I32, Const((ulong)signalStructPtr + rangeBaseOffset));
 38  
 39                  context.BranchIfFalse(nextLabel, isActive);
 40  
 41                  Operand rangeAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 4));
 42                  Operand rangeEndAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 12));
 43  
 44                  // Is the fault address within this tracked region?
 45                  Operand inRange = context.BitwiseAnd(
 46                      context.ICompare(faultAddress, rangeAddress, Comparison.GreaterOrEqualUI),
 47                      context.ICompare(faultAddress, rangeEndAddress, Comparison.LessUI));
 48  
 49                  // Only call tracking if in range.
 50                  context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold);
 51  
 52                  Operand offset = context.Subtract(faultAddress, rangeAddress);
 53  
 54                  // Call the tracking action, with the pointer's relative offset to the base address.
 55                  Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20));
 56  
 57                  context.Copy(inRegionLocal, Const(0));
 58  
 59                  Operand skipActionLabel = Label();
 60  
 61                  // Tracking action should be non-null to call it, otherwise assume false return.
 62                  context.BranchIfFalse(skipActionLabel, trackingActionPtr);
 63                  Operand result = context.Call(trackingActionPtr, OperandType.I64, offset, Const(1UL), isWrite);
 64                  context.Copy(inRegionLocal, context.ICompareNotEqual(result, Const(0UL)));
 65  
 66                  GenerateFaultAddressPatchCode(context, faultAddress, result);
 67  
 68                  context.MarkLabel(skipActionLabel);
 69  
 70                  // If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows.
 71                  if (OperatingSystem.IsWindows())
 72                  {
 73                      context.BranchIfTrue(endLabel, inRegionLocal);
 74  
 75                      context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context));
 76                  }
 77  
 78                  context.Branch(endLabel);
 79  
 80                  context.MarkLabel(nextLabel);
 81              }
 82  
 83              context.MarkLabel(endLabel);
 84  
 85              return context.Copy(inRegionLocal);
 86          }
 87  
 88          private static Operand GenerateUnixFaultAddress(EmitterContext context, Operand sigInfoPtr)
 89          {
 90              ulong structAddressOffset = OperatingSystem.IsMacOS() ? 24ul : 16ul; // si_addr
 91              return context.Load(OperandType.I64, context.Add(sigInfoPtr, Const(structAddressOffset)));
 92          }
 93  
 94          private static Operand GenerateUnixWriteFlag(EmitterContext context, Operand ucontextPtr)
 95          {
 96              if (OperatingSystem.IsMacOS())
 97              {
 98                  const ulong McontextOffset = 48; // uc_mcontext
 99                  Operand ctxPtr = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(McontextOffset)));
100  
101                  if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64)
102                  {
103                      const ulong EsrOffset = 8; // __es.__esr
104                      Operand esr = context.Load(OperandType.I64, context.Add(ctxPtr, Const(EsrOffset)));
105                      return context.BitwiseAnd(esr, Const(0x40ul));
106                  }
107                  else if (RuntimeInformation.ProcessArchitecture == Architecture.X64)
108                  {
109                      const ulong ErrOffset = 4; // __es.__err
110                      Operand err = context.Load(OperandType.I64, context.Add(ctxPtr, Const(ErrOffset)));
111                      return context.BitwiseAnd(err, Const(2ul));
112                  }
113              }
114              else if (OperatingSystem.IsLinux())
115              {
116                  if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64)
117                  {
118                      Operand auxPtr = context.AllocateLocal(OperandType.I64);
119  
120                      Operand loopLabel = Label();
121                      Operand successLabel = Label();
122  
123                      const ulong AuxOffset = 464; // uc_mcontext.__reserved
124                      const uint EsrMagic = 0x45535201;
125  
126                      context.Copy(auxPtr, context.Add(ucontextPtr, Const(AuxOffset)));
127  
128                      context.MarkLabel(loopLabel);
129  
130                      // _aarch64_ctx::magic
131                      Operand magic = context.Load(OperandType.I32, auxPtr);
132                      // _aarch64_ctx::size
133                      Operand size = context.Load(OperandType.I32, context.Add(auxPtr, Const(4ul)));
134  
135                      context.BranchIf(successLabel, magic, Const(EsrMagic), Comparison.Equal);
136  
137                      context.Copy(auxPtr, context.Add(auxPtr, context.ZeroExtend32(OperandType.I64, size)));
138  
139                      context.Branch(loopLabel);
140  
141                      context.MarkLabel(successLabel);
142  
143                      // esr_context::esr
144                      Operand esr = context.Load(OperandType.I64, context.Add(auxPtr, Const(8ul)));
145                      return context.BitwiseAnd(esr, Const(0x40ul));
146                  }
147                  else if (RuntimeInformation.ProcessArchitecture == Architecture.X64)
148                  {
149                      const int ErrOffset = 192; // uc_mcontext.gregs[REG_ERR]
150                      Operand err = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(ErrOffset)));
151                      return context.BitwiseAnd(err, Const(2ul));
152                  }
153              }
154  
155              throw new PlatformNotSupportedException();
156          }
157  
158          public static byte[] GenerateUnixSignalHandler(IntPtr signalStructPtr, int rangeStructSize)
159          {
160              EmitterContext context = new();
161  
162              // (int sig, SigInfo* sigInfo, void* ucontext)
163              Operand sigInfoPtr = context.LoadArgument(OperandType.I64, 1);
164              Operand ucontextPtr = context.LoadArgument(OperandType.I64, 2);
165  
166              Operand faultAddress = GenerateUnixFaultAddress(context, sigInfoPtr);
167              Operand writeFlag = GenerateUnixWriteFlag(context, ucontextPtr);
168  
169              Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
170  
171              Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite, rangeStructSize);
172  
173              Operand endLabel = Label();
174  
175              context.BranchIfTrue(endLabel, isInRegion);
176  
177              Operand unixOldSigaction = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction));
178              Operand unixOldSigaction3Arg = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction3Arg));
179              Operand threeArgLabel = Label();
180  
181              context.BranchIfTrue(threeArgLabel, unixOldSigaction3Arg);
182  
183              context.Call(unixOldSigaction, OperandType.None, context.LoadArgument(OperandType.I32, 0));
184              context.Branch(endLabel);
185  
186              context.MarkLabel(threeArgLabel);
187  
188              context.Call(unixOldSigaction,
189                  OperandType.None,
190                  context.LoadArgument(OperandType.I32, 0),
191                  sigInfoPtr,
192                  context.LoadArgument(OperandType.I64, 2)
193                  );
194  
195              context.MarkLabel(endLabel);
196  
197              context.Return();
198  
199              ControlFlowGraph cfg = context.GetControlFlowGraph();
200  
201              OperandType[] argTypes = new OperandType[] { OperandType.I32, OperandType.I64, OperandType.I64 };
202  
203              return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Code;
204          }
205  
206          public static byte[] GenerateWindowsSignalHandler(IntPtr signalStructPtr, int rangeStructSize)
207          {
208              EmitterContext context = new();
209  
210              // (ExceptionPointers* exceptionInfo)
211              Operand exceptionInfoPtr = context.LoadArgument(OperandType.I64, 0);
212              Operand exceptionRecordPtr = context.Load(OperandType.I64, exceptionInfoPtr);
213  
214              // First thing's first - this catches a number of exceptions, but we only want access violations.
215              Operand validExceptionLabel = Label();
216  
217              Operand exceptionCode = context.Load(OperandType.I32, exceptionRecordPtr);
218  
219              context.BranchIf(validExceptionLabel, exceptionCode, Const(EXCEPTION_ACCESS_VIOLATION), Comparison.Equal);
220  
221              context.Return(Const(EXCEPTION_CONTINUE_SEARCH)); // Don't handle this one.
222  
223              context.MarkLabel(validExceptionLabel);
224  
225              // Next, read the address of the invalid access, and whether it is a write or not.
226  
227              Operand structAddressOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructAddressOffset));
228              Operand structWriteOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructWriteOffset));
229  
230              Operand faultAddress = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset)));
231              Operand writeFlag = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset)));
232  
233              Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
234  
235              Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite, rangeStructSize);
236  
237              Operand endLabel = Label();
238  
239              // If the region check result is false, then run the next vectored exception handler.
240  
241              context.BranchIfTrue(endLabel, isInRegion);
242  
243              context.Return(Const(EXCEPTION_CONTINUE_SEARCH));
244  
245              context.MarkLabel(endLabel);
246  
247              // Otherwise, return to execution.
248  
249              context.Return(Const(EXCEPTION_CONTINUE_EXECUTION));
250  
251              // Compile and return the function.
252  
253              ControlFlowGraph cfg = context.GetControlFlowGraph();
254  
255              OperandType[] argTypes = new OperandType[] { OperandType.I64 };
256  
257              return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Code;
258          }
259  
260          private static void GenerateFaultAddressPatchCode(EmitterContext context, Operand faultAddress, Operand newAddress)
261          {
262              if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64)
263              {
264                  if (SupportsFaultAddressPatchingForHostOs())
265                  {
266                      Operand lblSkip = Label();
267  
268                      context.BranchIf(lblSkip, faultAddress, newAddress, Comparison.Equal);
269  
270                      Operand ucontextPtr = context.LoadArgument(OperandType.I64, 2);
271                      Operand pcCtxAddress = default;
272                      ulong baseRegsOffset = 0;
273  
274                      if (OperatingSystem.IsLinux())
275                      {
276                          pcCtxAddress = context.Add(ucontextPtr, Const(440UL));
277                          baseRegsOffset = 184UL;
278                      }
279                      else if (OperatingSystem.IsMacOS() || OperatingSystem.IsIOS())
280                      {
281                          ucontextPtr = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(48UL)));
282  
283                          pcCtxAddress = context.Add(ucontextPtr, Const(272UL));
284                          baseRegsOffset = 16UL;
285                      }
286  
287                      Operand pc = context.Load(OperandType.I64, pcCtxAddress);
288  
289                      Operand reg = GetAddressRegisterFromArm64Instruction(context, pc);
290                      Operand reg64 = context.ZeroExtend32(OperandType.I64, reg);
291                      Operand regCtxAddress = context.Add(ucontextPtr, context.Add(context.ShiftLeft(reg64, Const(3)), Const(baseRegsOffset)));
292                      Operand regAddress = context.Load(OperandType.I64, regCtxAddress);
293  
294                      Operand addressDelta = context.Subtract(regAddress, faultAddress);
295  
296                      context.Store(regCtxAddress, context.Add(newAddress, addressDelta));
297  
298                      context.MarkLabel(lblSkip);
299                  }
300              }
301          }
302  
303          private static Operand GetAddressRegisterFromArm64Instruction(EmitterContext context, Operand pc)
304          {
305              Operand inst = context.Load(OperandType.I32, pc);
306              Operand reg = context.AllocateLocal(OperandType.I32);
307  
308              Operand isSysInst = context.ICompareEqual(context.BitwiseAnd(inst, Const(0xFFF80000)), Const(0xD5080000));
309  
310              Operand lblSys = Label();
311              Operand lblEnd = Label();
312  
313              context.BranchIfTrue(lblSys, isSysInst, BasicBlockFrequency.Cold);
314  
315              context.Copy(reg, context.BitwiseAnd(context.ShiftRightUI(inst, Const(5)), Const(0x1F)));
316              context.Branch(lblEnd);
317  
318              context.MarkLabel(lblSys);
319              context.Copy(reg, context.BitwiseAnd(inst, Const(0x1F)));
320  
321              context.MarkLabel(lblEnd);
322  
323              return reg;
324          }
325  
326          public static bool SupportsFaultAddressPatchingForHost()
327          {
328              return SupportsFaultAddressPatchingForHostArch() && SupportsFaultAddressPatchingForHostOs();
329          }
330  
331          private static bool SupportsFaultAddressPatchingForHostArch()
332          {
333              return RuntimeInformation.ProcessArchitecture == Architecture.Arm64;
334          }
335  
336          private static bool SupportsFaultAddressPatchingForHostOs()
337          {
338              return OperatingSystem.IsLinux() || OperatingSystem.IsMacOS() || OperatingSystem.IsIOS();
339          }
340      }
341  }