NativeSignalHandlerGenerator.cs
1 using ARMeilleure.IntermediateRepresentation; 2 using ARMeilleure.Translation; 3 using System; 4 using System.Runtime.InteropServices; 5 using static ARMeilleure.IntermediateRepresentation.Operand.Factory; 6 7 namespace ARMeilleure.Signal 8 { 9 public static class NativeSignalHandlerGenerator 10 { 11 public const int MaxTrackedRanges = 8; 12 13 private const int StructAddressOffset = 0; 14 private const int StructWriteOffset = 4; 15 private const int UnixOldSigaction = 8; 16 private const int UnixOldSigaction3Arg = 16; 17 private const int RangeOffset = 20; 18 19 private const int EXCEPTION_CONTINUE_SEARCH = 0; 20 private const int EXCEPTION_CONTINUE_EXECUTION = -1; 21 22 private const uint EXCEPTION_ACCESS_VIOLATION = 0xc0000005; 23 24 private static Operand EmitGenericRegionCheck(EmitterContext context, IntPtr signalStructPtr, Operand faultAddress, Operand isWrite, int rangeStructSize) 25 { 26 Operand inRegionLocal = context.AllocateLocal(OperandType.I32); 27 context.Copy(inRegionLocal, Const(0)); 28 29 Operand endLabel = Label(); 30 31 for (int i = 0; i < MaxTrackedRanges; i++) 32 { 33 ulong rangeBaseOffset = (ulong)(RangeOffset + i * rangeStructSize); 34 35 Operand nextLabel = Label(); 36 37 Operand isActive = context.Load(OperandType.I32, Const((ulong)signalStructPtr + rangeBaseOffset)); 38 39 context.BranchIfFalse(nextLabel, isActive); 40 41 Operand rangeAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 4)); 42 Operand rangeEndAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 12)); 43 44 // Is the fault address within this tracked region? 45 Operand inRange = context.BitwiseAnd( 46 context.ICompare(faultAddress, rangeAddress, Comparison.GreaterOrEqualUI), 47 context.ICompare(faultAddress, rangeEndAddress, Comparison.LessUI)); 48 49 // Only call tracking if in range. 50 context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold); 51 52 Operand offset = context.Subtract(faultAddress, rangeAddress); 53 54 // Call the tracking action, with the pointer's relative offset to the base address. 55 Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20)); 56 57 context.Copy(inRegionLocal, Const(0)); 58 59 Operand skipActionLabel = Label(); 60 61 // Tracking action should be non-null to call it, otherwise assume false return. 62 context.BranchIfFalse(skipActionLabel, trackingActionPtr); 63 Operand result = context.Call(trackingActionPtr, OperandType.I64, offset, Const(1UL), isWrite); 64 context.Copy(inRegionLocal, context.ICompareNotEqual(result, Const(0UL))); 65 66 GenerateFaultAddressPatchCode(context, faultAddress, result); 67 68 context.MarkLabel(skipActionLabel); 69 70 // If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows. 71 if (OperatingSystem.IsWindows()) 72 { 73 context.BranchIfTrue(endLabel, inRegionLocal); 74 75 context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context)); 76 } 77 78 context.Branch(endLabel); 79 80 context.MarkLabel(nextLabel); 81 } 82 83 context.MarkLabel(endLabel); 84 85 return context.Copy(inRegionLocal); 86 } 87 88 private static Operand GenerateUnixFaultAddress(EmitterContext context, Operand sigInfoPtr) 89 { 90 ulong structAddressOffset = OperatingSystem.IsMacOS() ? 24ul : 16ul; // si_addr 91 return context.Load(OperandType.I64, context.Add(sigInfoPtr, Const(structAddressOffset))); 92 } 93 94 private static Operand GenerateUnixWriteFlag(EmitterContext context, Operand ucontextPtr) 95 { 96 if (OperatingSystem.IsMacOS()) 97 { 98 const ulong McontextOffset = 48; // uc_mcontext 99 Operand ctxPtr = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(McontextOffset))); 100 101 if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64) 102 { 103 const ulong EsrOffset = 8; // __es.__esr 104 Operand esr = context.Load(OperandType.I64, context.Add(ctxPtr, Const(EsrOffset))); 105 return context.BitwiseAnd(esr, Const(0x40ul)); 106 } 107 else if (RuntimeInformation.ProcessArchitecture == Architecture.X64) 108 { 109 const ulong ErrOffset = 4; // __es.__err 110 Operand err = context.Load(OperandType.I64, context.Add(ctxPtr, Const(ErrOffset))); 111 return context.BitwiseAnd(err, Const(2ul)); 112 } 113 } 114 else if (OperatingSystem.IsLinux()) 115 { 116 if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64) 117 { 118 Operand auxPtr = context.AllocateLocal(OperandType.I64); 119 120 Operand loopLabel = Label(); 121 Operand successLabel = Label(); 122 123 const ulong AuxOffset = 464; // uc_mcontext.__reserved 124 const uint EsrMagic = 0x45535201; 125 126 context.Copy(auxPtr, context.Add(ucontextPtr, Const(AuxOffset))); 127 128 context.MarkLabel(loopLabel); 129 130 // _aarch64_ctx::magic 131 Operand magic = context.Load(OperandType.I32, auxPtr); 132 // _aarch64_ctx::size 133 Operand size = context.Load(OperandType.I32, context.Add(auxPtr, Const(4ul))); 134 135 context.BranchIf(successLabel, magic, Const(EsrMagic), Comparison.Equal); 136 137 context.Copy(auxPtr, context.Add(auxPtr, context.ZeroExtend32(OperandType.I64, size))); 138 139 context.Branch(loopLabel); 140 141 context.MarkLabel(successLabel); 142 143 // esr_context::esr 144 Operand esr = context.Load(OperandType.I64, context.Add(auxPtr, Const(8ul))); 145 return context.BitwiseAnd(esr, Const(0x40ul)); 146 } 147 else if (RuntimeInformation.ProcessArchitecture == Architecture.X64) 148 { 149 const int ErrOffset = 192; // uc_mcontext.gregs[REG_ERR] 150 Operand err = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(ErrOffset))); 151 return context.BitwiseAnd(err, Const(2ul)); 152 } 153 } 154 155 throw new PlatformNotSupportedException(); 156 } 157 158 public static byte[] GenerateUnixSignalHandler(IntPtr signalStructPtr, int rangeStructSize) 159 { 160 EmitterContext context = new(); 161 162 // (int sig, SigInfo* sigInfo, void* ucontext) 163 Operand sigInfoPtr = context.LoadArgument(OperandType.I64, 1); 164 Operand ucontextPtr = context.LoadArgument(OperandType.I64, 2); 165 166 Operand faultAddress = GenerateUnixFaultAddress(context, sigInfoPtr); 167 Operand writeFlag = GenerateUnixWriteFlag(context, ucontextPtr); 168 169 Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1. 170 171 Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite, rangeStructSize); 172 173 Operand endLabel = Label(); 174 175 context.BranchIfTrue(endLabel, isInRegion); 176 177 Operand unixOldSigaction = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction)); 178 Operand unixOldSigaction3Arg = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction3Arg)); 179 Operand threeArgLabel = Label(); 180 181 context.BranchIfTrue(threeArgLabel, unixOldSigaction3Arg); 182 183 context.Call(unixOldSigaction, OperandType.None, context.LoadArgument(OperandType.I32, 0)); 184 context.Branch(endLabel); 185 186 context.MarkLabel(threeArgLabel); 187 188 context.Call(unixOldSigaction, 189 OperandType.None, 190 context.LoadArgument(OperandType.I32, 0), 191 sigInfoPtr, 192 context.LoadArgument(OperandType.I64, 2) 193 ); 194 195 context.MarkLabel(endLabel); 196 197 context.Return(); 198 199 ControlFlowGraph cfg = context.GetControlFlowGraph(); 200 201 OperandType[] argTypes = new OperandType[] { OperandType.I32, OperandType.I64, OperandType.I64 }; 202 203 return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Code; 204 } 205 206 public static byte[] GenerateWindowsSignalHandler(IntPtr signalStructPtr, int rangeStructSize) 207 { 208 EmitterContext context = new(); 209 210 // (ExceptionPointers* exceptionInfo) 211 Operand exceptionInfoPtr = context.LoadArgument(OperandType.I64, 0); 212 Operand exceptionRecordPtr = context.Load(OperandType.I64, exceptionInfoPtr); 213 214 // First thing's first - this catches a number of exceptions, but we only want access violations. 215 Operand validExceptionLabel = Label(); 216 217 Operand exceptionCode = context.Load(OperandType.I32, exceptionRecordPtr); 218 219 context.BranchIf(validExceptionLabel, exceptionCode, Const(EXCEPTION_ACCESS_VIOLATION), Comparison.Equal); 220 221 context.Return(Const(EXCEPTION_CONTINUE_SEARCH)); // Don't handle this one. 222 223 context.MarkLabel(validExceptionLabel); 224 225 // Next, read the address of the invalid access, and whether it is a write or not. 226 227 Operand structAddressOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructAddressOffset)); 228 Operand structWriteOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructWriteOffset)); 229 230 Operand faultAddress = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset))); 231 Operand writeFlag = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset))); 232 233 Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1. 234 235 Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite, rangeStructSize); 236 237 Operand endLabel = Label(); 238 239 // If the region check result is false, then run the next vectored exception handler. 240 241 context.BranchIfTrue(endLabel, isInRegion); 242 243 context.Return(Const(EXCEPTION_CONTINUE_SEARCH)); 244 245 context.MarkLabel(endLabel); 246 247 // Otherwise, return to execution. 248 249 context.Return(Const(EXCEPTION_CONTINUE_EXECUTION)); 250 251 // Compile and return the function. 252 253 ControlFlowGraph cfg = context.GetControlFlowGraph(); 254 255 OperandType[] argTypes = new OperandType[] { OperandType.I64 }; 256 257 return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Code; 258 } 259 260 private static void GenerateFaultAddressPatchCode(EmitterContext context, Operand faultAddress, Operand newAddress) 261 { 262 if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64) 263 { 264 if (SupportsFaultAddressPatchingForHostOs()) 265 { 266 Operand lblSkip = Label(); 267 268 context.BranchIf(lblSkip, faultAddress, newAddress, Comparison.Equal); 269 270 Operand ucontextPtr = context.LoadArgument(OperandType.I64, 2); 271 Operand pcCtxAddress = default; 272 ulong baseRegsOffset = 0; 273 274 if (OperatingSystem.IsLinux()) 275 { 276 pcCtxAddress = context.Add(ucontextPtr, Const(440UL)); 277 baseRegsOffset = 184UL; 278 } 279 else if (OperatingSystem.IsMacOS() || OperatingSystem.IsIOS()) 280 { 281 ucontextPtr = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(48UL))); 282 283 pcCtxAddress = context.Add(ucontextPtr, Const(272UL)); 284 baseRegsOffset = 16UL; 285 } 286 287 Operand pc = context.Load(OperandType.I64, pcCtxAddress); 288 289 Operand reg = GetAddressRegisterFromArm64Instruction(context, pc); 290 Operand reg64 = context.ZeroExtend32(OperandType.I64, reg); 291 Operand regCtxAddress = context.Add(ucontextPtr, context.Add(context.ShiftLeft(reg64, Const(3)), Const(baseRegsOffset))); 292 Operand regAddress = context.Load(OperandType.I64, regCtxAddress); 293 294 Operand addressDelta = context.Subtract(regAddress, faultAddress); 295 296 context.Store(regCtxAddress, context.Add(newAddress, addressDelta)); 297 298 context.MarkLabel(lblSkip); 299 } 300 } 301 } 302 303 private static Operand GetAddressRegisterFromArm64Instruction(EmitterContext context, Operand pc) 304 { 305 Operand inst = context.Load(OperandType.I32, pc); 306 Operand reg = context.AllocateLocal(OperandType.I32); 307 308 Operand isSysInst = context.ICompareEqual(context.BitwiseAnd(inst, Const(0xFFF80000)), Const(0xD5080000)); 309 310 Operand lblSys = Label(); 311 Operand lblEnd = Label(); 312 313 context.BranchIfTrue(lblSys, isSysInst, BasicBlockFrequency.Cold); 314 315 context.Copy(reg, context.BitwiseAnd(context.ShiftRightUI(inst, Const(5)), Const(0x1F))); 316 context.Branch(lblEnd); 317 318 context.MarkLabel(lblSys); 319 context.Copy(reg, context.BitwiseAnd(inst, Const(0x1F))); 320 321 context.MarkLabel(lblEnd); 322 323 return reg; 324 } 325 326 public static bool SupportsFaultAddressPatchingForHost() 327 { 328 return SupportsFaultAddressPatchingForHostArch() && SupportsFaultAddressPatchingForHostOs(); 329 } 330 331 private static bool SupportsFaultAddressPatchingForHostArch() 332 { 333 return RuntimeInformation.ProcessArchitecture == Architecture.Arm64; 334 } 335 336 private static bool SupportsFaultAddressPatchingForHostOs() 337 { 338 return OperatingSystem.IsLinux() || OperatingSystem.IsMacOS() || OperatingSystem.IsIOS(); 339 } 340 } 341 }