/ src / Ryujinx.Tests / Memory / PartialUnmaps.cs
PartialUnmaps.cs
  1  using ARMeilleure.Signal;
  2  using ARMeilleure.Translation;
  3  using NUnit.Framework;
  4  using Ryujinx.Common.Memory.PartialUnmaps;
  5  using Ryujinx.Cpu;
  6  using Ryujinx.Cpu.Jit;
  7  using Ryujinx.Memory;
  8  using Ryujinx.Memory.Tracking;
  9  using System;
 10  using System.Collections.Generic;
 11  using System.Diagnostics.CodeAnalysis;
 12  using System.Runtime.CompilerServices;
 13  using System.Runtime.InteropServices;
 14  using System.Threading;
 15  
 16  namespace Ryujinx.Tests.Memory
 17  {
 18      [TestFixture]
 19      internal class PartialUnmaps
 20      {
 21          private static Translator _translator;
 22  
 23          private static (MemoryBlock virt, MemoryBlock mirror, MemoryEhMeilleure exceptionHandler) GetVirtual(ulong asSize)
 24          {
 25              MemoryAllocationFlags asFlags = MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible;
 26  
 27              var addressSpace = new MemoryBlock(asSize, asFlags);
 28              var addressSpaceMirror = new MemoryBlock(asSize, asFlags);
 29  
 30              var tracking = new MemoryTracking(new MockVirtualMemoryManager(asSize, 0x1000), 0x1000);
 31              var exceptionHandler = new MemoryEhMeilleure(addressSpace, addressSpaceMirror, tracking);
 32  
 33              return (addressSpace, addressSpaceMirror, exceptionHandler);
 34          }
 35  
 36          private static int CountThreads(ref PartialUnmapState state)
 37          {
 38              int count = 0;
 39  
 40              ref var ids = ref state.LocalCounts.ThreadIds;
 41  
 42              for (int i = 0; i < ids.Length; i++)
 43              {
 44                  if (ids[i] != 0)
 45                  {
 46                      count++;
 47                  }
 48              }
 49  
 50              return count;
 51          }
 52  
 53          private static void EnsureTranslator()
 54          {
 55              // Create a translator, as one is needed to register the signal handler or emit methods.
 56              _translator ??= new Translator(new JitMemoryAllocator(), new MockMemoryManager(), true);
 57          }
 58  
 59          [Test]
 60          // Memory aliasing tests fail on CI at the moment.
 61          [Platform(Exclude = "MacOsX")]
 62          public void PartialUnmap([Values] bool readOnly)
 63          {
 64              // Set up an address space to test partial unmapping.
 65              // Should register the signal handler to deal with this on Windows.
 66              ulong vaSize = 0x100000;
 67  
 68              // The first 0x100000 is mapped to start. It is replaced from the center with the 0x200000 mapping.
 69              var backing = new MemoryBlock(vaSize * 2, MemoryAllocationFlags.Mirrorable);
 70  
 71              (MemoryBlock unusedMainMemory, MemoryBlock memory, MemoryEhMeilleure exceptionHandler) = GetVirtual(vaSize * 2);
 72  
 73              EnsureTranslator();
 74  
 75              ref var state = ref PartialUnmapState.GetRef();
 76  
 77              Thread testThread = null;
 78              bool shouldAccess = true;
 79  
 80              try
 81              {
 82                  // Globally reset the struct for handling partial unmap races.
 83                  PartialUnmapState.Reset();
 84                  bool error = false;
 85  
 86                  // Create a large mapping.
 87                  memory.MapView(backing, 0, 0, vaSize);
 88  
 89                  if (readOnly)
 90                  {
 91                      memory.Reprotect(0, vaSize, MemoryPermission.Read);
 92                  }
 93  
 94                  if (readOnly)
 95                  {
 96                      // Write a value to the physical memory, then try to read it repeately from virtual.
 97                      // It should not change.
 98                      testThread = new Thread(() =>
 99                      {
100                          int i = 12345;
101                          backing.Write(vaSize - 0x1000, i);
102  
103                          while (shouldAccess)
104                          {
105                              if (memory.Read<int>(vaSize - 0x1000) != i)
106                              {
107                                  error = true;
108                                  shouldAccess = false;
109                              }
110                          }
111                      });
112                  }
113                  else
114                  {
115                      // Repeatedly write and check the value on the last page of the mapping on another thread.
116                      testThread = new Thread(() =>
117                      {
118                          int i = 0;
119                          while (shouldAccess)
120                          {
121                              memory.Write(vaSize - 0x1000, i);
122                              if (memory.Read<int>(vaSize - 0x1000) != i)
123                              {
124                                  error = true;
125                                  shouldAccess = false;
126                              }
127  
128                              i++;
129                          }
130                      });
131                  }
132  
133                  testThread.Start();
134  
135                  // Create a smaller mapping, covering the larger mapping.
136                  // Immediately try to write to the part of the larger mapping that did not change.
137                  // Do this a lot, with the smaller mapping gradually increasing in size. Should not crash, data should not be lost.
138  
139                  ulong pageSize = 0x1000;
140                  int mappingExpandCount = (int)(vaSize / (pageSize * 2)) - 1;
141                  ulong vaCenter = vaSize / 2;
142  
143                  for (int i = 1; i <= mappingExpandCount; i++)
144                  {
145                      ulong start = vaCenter - (pageSize * (ulong)i);
146                      ulong size = pageSize * (ulong)i * 2;
147  
148                      ulong startPa = start + vaSize;
149  
150                      memory.MapView(backing, startPa, start, size);
151                  }
152  
153                  // On Windows, this should put unmap counts on the thread local map.
154                  if (OperatingSystem.IsWindows())
155                  {
156                      // One thread should be present on the thread local map. Trimming should remove it.
157                      Assert.AreEqual(1, CountThreads(ref state));
158                  }
159  
160                  shouldAccess = false;
161                  testThread.Join();
162  
163                  Assert.False(error);
164  
165                  string test = null;
166  
167                  try
168                  {
169                      test.IndexOf('1');
170                  }
171                  catch (NullReferenceException)
172                  {
173                      // This shouldn't freeze.
174                  }
175  
176                  if (OperatingSystem.IsWindows())
177                  {
178                      state.TrimThreads();
179  
180                      Assert.AreEqual(0, CountThreads(ref state));
181                  }
182  
183                  /*
184                  * Use this to test invalid access. Can't put this in the test suite unfortunately as invalid access crashes the test process.
185                  * memory.Reprotect(vaSize - 0x1000, 0x1000, MemoryPermission.None);
186                  * //memory.UnmapView(backing, vaSize - 0x1000, 0x1000);
187                  * memory.Read<int>(vaSize - 0x1000);
188                  */
189              }
190              finally
191              {
192                  // In case something failed, we want to ensure the test thread is dead before disposing of the memory.
193                  shouldAccess = false;
194                  testThread?.Join();
195  
196                  exceptionHandler.Dispose();
197                  unusedMainMemory.Dispose();
198                  memory.Dispose();
199                  backing.Dispose();
200              }
201          }
202  
203          [Test]
204          // Memory aliasing tests fail on CI at the moment.
205          [Platform(Exclude = "MacOsX")]
206          public unsafe void PartialUnmapNative()
207          {
208  
209              // Set up an address space to test partial unmapping.
210              // Should register the signal handler to deal with this on Windows.
211              ulong vaSize = 0x100000;
212  
213              // The first 0x100000 is mapped to start. It is replaced from the center with the 0x200000 mapping.
214              var backing = new MemoryBlock(vaSize * 2, MemoryAllocationFlags.Mirrorable);
215  
216              (MemoryBlock mainMemory, MemoryBlock unusedMirror, MemoryEhMeilleure exceptionHandler) = GetVirtual(vaSize * 2);
217  
218              EnsureTranslator();
219  
220              ref var state = ref PartialUnmapState.GetRef();
221  
222              // Create some state to be used for managing the native writing loop.
223              int stateSize = Unsafe.SizeOf<NativeWriteLoopState>();
224              var statePtr = Marshal.AllocHGlobal(stateSize);
225              Unsafe.InitBlockUnaligned((void*)statePtr, 0, (uint)stateSize);
226  
227              ref NativeWriteLoopState writeLoopState = ref Unsafe.AsRef<NativeWriteLoopState>((void*)statePtr);
228              writeLoopState.Running = 1;
229              writeLoopState.Error = 0;
230  
231              try
232              {
233                  // Globally reset the struct for handling partial unmap races.
234                  PartialUnmapState.Reset();
235  
236                  // Create a large mapping.
237                  mainMemory.MapView(backing, 0, 0, vaSize);
238  
239                  var writeFunc = TestMethods.GenerateDebugNativeWriteLoop();
240                  IntPtr writePtr = mainMemory.GetPointer(vaSize - 0x1000, 4);
241  
242                  Thread testThread = new(() =>
243                  {
244                      writeFunc(statePtr, writePtr);
245                  });
246  
247                  testThread.Start();
248  
249                  // Create a smaller mapping, covering the larger mapping.
250                  // Immediately try to write to the part of the larger mapping that did not change.
251                  // Do this a lot, with the smaller mapping gradually increasing in size. Should not crash, data should not be lost.
252  
253                  ulong pageSize = 0x1000;
254                  int mappingExpandCount = (int)(vaSize / (pageSize * 2)) - 1;
255                  ulong vaCenter = vaSize / 2;
256  
257                  for (int i = 1; i <= mappingExpandCount; i++)
258                  {
259                      ulong start = vaCenter - (pageSize * (ulong)i);
260                      ulong size = pageSize * (ulong)i * 2;
261  
262                      ulong startPa = start + vaSize;
263  
264                      mainMemory.MapView(backing, startPa, start, size);
265                  }
266  
267                  writeLoopState.Running = 0;
268                  testThread.Join();
269  
270                  Assert.False(writeLoopState.Error != 0);
271              }
272              finally
273              {
274                  Marshal.FreeHGlobal(statePtr);
275  
276                  exceptionHandler.Dispose();
277                  mainMemory.Dispose();
278                  unusedMirror.Dispose();
279                  backing.Dispose();
280              }
281          }
282  
283          [Test]
284          // Only test in Windows, as this is only used on Windows and uses Windows APIs for trimming.
285          [Platform("Win")]
286          [SuppressMessage("Interoperability", "CA1416: Validate platform compatibility")]
287          public void ThreadLocalMap()
288          {
289              PartialUnmapState.Reset();
290              ref var state = ref PartialUnmapState.GetRef();
291  
292              bool running = true;
293              var testThread = new Thread(() =>
294              {
295                  PartialUnmapState.GetRef().RetryFromAccessViolation();
296                  while (running)
297                  {
298                      Thread.Sleep(1);
299                  }
300              });
301  
302              testThread.Start();
303              Thread.Sleep(200);
304  
305              Assert.AreEqual(1, CountThreads(ref state));
306  
307              // Trimming should not remove the thread as it's still active.
308              state.TrimThreads();
309              Assert.AreEqual(1, CountThreads(ref state));
310  
311              running = false;
312  
313              testThread.Join();
314  
315              // Should trim now that it's inactive.
316              state.TrimThreads();
317              Assert.AreEqual(0, CountThreads(ref state));
318          }
319  
320          [Test]
321          // Only test in Windows, as this is only used on Windows and uses Windows APIs for trimming.
322          [Platform("Win")]
323          public unsafe void ThreadLocalMapNative()
324          {
325              EnsureTranslator();
326  
327              PartialUnmapState.Reset();
328  
329              ref var state = ref PartialUnmapState.GetRef();
330  
331              fixed (void* localMap = &state.LocalCounts)
332              {
333                  var getOrReserve = TestMethods.GenerateDebugThreadLocalMapGetOrReserve((IntPtr)localMap);
334  
335                  for (int i = 0; i < ThreadLocalMap<int>.MapSize; i++)
336                  {
337                      // Should obtain the index matching the call #.
338                      Assert.AreEqual(i, getOrReserve(i + 1, i));
339  
340                      // Check that this and all previously reserved thread IDs and struct contents are intact.
341                      for (int j = 0; j <= i; j++)
342                      {
343                          Assert.AreEqual(j + 1, state.LocalCounts.ThreadIds[j]);
344                          Assert.AreEqual(j, state.LocalCounts.Structs[j]);
345                      }
346                  }
347  
348                  // Trying to reserve again when the map is full should return -1.
349                  Assert.AreEqual(-1, getOrReserve(200, 0));
350  
351                  for (int i = 0; i < ThreadLocalMap<int>.MapSize; i++)
352                  {
353                      // Should obtain the index matching the call #, as it already exists.
354                      Assert.AreEqual(i, getOrReserve(i + 1, -1));
355  
356                      // The struct should not be reset to -1.
357                      Assert.AreEqual(i, state.LocalCounts.Structs[i]);
358                  }
359  
360                  // Clear one of the ids as if it were freed.
361                  state.LocalCounts.ThreadIds[13] = 0;
362  
363                  // GetOrReserve should now obtain and return 13.
364                  Assert.AreEqual(13, getOrReserve(300, 301));
365                  Assert.AreEqual(300, state.LocalCounts.ThreadIds[13]);
366                  Assert.AreEqual(301, state.LocalCounts.Structs[13]);
367              }
368          }
369  
370          [Test]
371          public void NativeReaderWriterLock()
372          {
373              var rwLock = new NativeReaderWriterLock();
374              var threads = new List<Thread>();
375  
376              int value = 0;
377  
378              bool running = true;
379              bool error = false;
380              int readersAllowed = 1;
381  
382              for (int i = 0; i < 5; i++)
383              {
384                  var readThread = new Thread(() =>
385                  {
386                      int count = 0;
387                      while (running)
388                      {
389                          rwLock.AcquireReaderLock();
390  
391                          int originalValue = Volatile.Read(ref value);
392  
393                          count++;
394  
395                          // Spin a bit.
396                          for (int i = 0; i < 100; i++)
397                          {
398                              if (Volatile.Read(ref readersAllowed) == 0)
399                              {
400                                  error = true;
401                                  running = false;
402                              }
403                          }
404  
405                          // Should not change while the lock is held.
406                          if (Volatile.Read(ref value) != originalValue)
407                          {
408                              error = true;
409                              running = false;
410                          }
411  
412                          rwLock.ReleaseReaderLock();
413                      }
414                  });
415  
416                  threads.Add(readThread);
417              }
418  
419              for (int i = 0; i < 2; i++)
420              {
421                  var writeThread = new Thread(() =>
422                  {
423                      int count = 0;
424                      while (running)
425                      {
426                          rwLock.AcquireReaderLock();
427                          rwLock.UpgradeToWriterLock();
428  
429                          Thread.Sleep(2);
430                          count++;
431  
432                          Interlocked.Exchange(ref readersAllowed, 0);
433  
434                          for (int i = 0; i < 10; i++)
435                          {
436                              Interlocked.Increment(ref value);
437                          }
438  
439                          Interlocked.Exchange(ref readersAllowed, 1);
440  
441                          rwLock.DowngradeFromWriterLock();
442                          rwLock.ReleaseReaderLock();
443  
444                          Thread.Sleep(1);
445                      }
446                  });
447  
448                  threads.Add(writeThread);
449              }
450  
451              foreach (var thread in threads)
452              {
453                  thread.Start();
454              }
455  
456              Thread.Sleep(1000);
457  
458              running = false;
459  
460              foreach (var thread in threads)
461              {
462                  thread.Join();
463              }
464  
465              Assert.False(error);
466          }
467      }
468  }