/ src / Ryujinx.Memory / Tracking / MemoryTracking.cs
MemoryTracking.cs
  1  using Ryujinx.Common.Pools;
  2  using Ryujinx.Memory.Range;
  3  using System.Collections.Generic;
  4  
  5  namespace Ryujinx.Memory.Tracking
  6  {
  7      /// <summary>
  8      /// Manages memory tracking for a given virutal/physical memory block.
  9      /// </summary>
 10      public class MemoryTracking
 11      {
 12          private readonly IVirtualMemoryManager _memoryManager;
 13          private readonly InvalidAccessHandler _invalidAccessHandler;
 14  
 15          // Only use these from within the lock.
 16          private readonly NonOverlappingRangeList<VirtualRegion> _virtualRegions;
 17          // Guest virtual regions are a subset of the normal virtual regions, with potentially different protection
 18          // and expanded area of effect on platforms that don't support misaligned page protection.
 19          private readonly NonOverlappingRangeList<VirtualRegion> _guestVirtualRegions;
 20  
 21          private readonly int _pageSize;
 22  
 23          private readonly bool _singleByteGuestTracking;
 24  
 25          /// <summary>
 26          /// This lock must be obtained when traversing or updating the region-handle hierarchy.
 27          /// It is not required when reading dirty flags.
 28          /// </summary>
 29          internal object TrackingLock = new();
 30  
 31          /// <summary>
 32          /// Create a new tracking structure for the given "physical" memory block,
 33          /// with a given "virtual" memory manager that will provide mappings and virtual memory protection.
 34          /// </summary>
 35          /// <remarks>
 36          /// If <paramref name="singleByteGuestTracking" /> is true, the memory manager must also support protection on partially
 37          /// unmapped regions without throwing exceptions or dropping protection on the mapped portion.
 38          /// </remarks>
 39          /// <param name="memoryManager">Virtual memory manager</param>
 40          /// <param name="pageSize">Page size of the virtual memory space</param>
 41          /// <param name="invalidAccessHandler">Method to call for invalid memory accesses</param>
 42          /// <param name="singleByteGuestTracking">True if the guest only signals writes for the first byte</param>
 43          public MemoryTracking(
 44              IVirtualMemoryManager memoryManager,
 45              int pageSize,
 46              InvalidAccessHandler invalidAccessHandler = null,
 47              bool singleByteGuestTracking = false)
 48          {
 49              _memoryManager = memoryManager;
 50              _pageSize = pageSize;
 51              _invalidAccessHandler = invalidAccessHandler;
 52              _singleByteGuestTracking = singleByteGuestTracking;
 53  
 54              _virtualRegions = new NonOverlappingRangeList<VirtualRegion>();
 55              _guestVirtualRegions = new NonOverlappingRangeList<VirtualRegion>();
 56          }
 57  
 58          private (ulong address, ulong size) PageAlign(ulong address, ulong size)
 59          {
 60              ulong pageMask = (ulong)_pageSize - 1;
 61              ulong rA = address & ~pageMask;
 62              ulong rS = ((address + size + pageMask) & ~pageMask) - rA;
 63              return (rA, rS);
 64          }
 65  
 66          /// <summary>
 67          /// Indicate that a virtual region has been mapped, and which physical region it has been mapped to.
 68          /// Should be called after the mapping is complete.
 69          /// </summary>
 70          /// <param name="va">Virtual memory address</param>
 71          /// <param name="size">Size to be mapped</param>
 72          public void Map(ulong va, ulong size)
 73          {
 74              // A mapping may mean we need to re-evaluate each VirtualRegion's affected area.
 75              // Find all handles that overlap with the range, we need to recalculate their physical regions
 76  
 77              lock (TrackingLock)
 78              {
 79                  ref var overlaps = ref ThreadStaticArray<VirtualRegion>.Get();
 80  
 81                  for (int type = 0; type < 2; type++)
 82                  {
 83                      NonOverlappingRangeList<VirtualRegion> regions = type == 0 ? _virtualRegions : _guestVirtualRegions;
 84  
 85                      int count = regions.FindOverlapsNonOverlapping(va, size, ref overlaps);
 86  
 87                      for (int i = 0; i < count; i++)
 88                      {
 89                          VirtualRegion region = overlaps[i];
 90  
 91                          // If the region has been fully remapped, signal that it has been mapped again.
 92                          bool remapped = _memoryManager.IsRangeMapped(region.Address, region.Size);
 93                          if (remapped)
 94                          {
 95                              region.SignalMappingChanged(true);
 96                          }
 97  
 98                          region.UpdateProtection();
 99                      }
100                  }
101              }
102          }
103  
104          /// <summary>
105          /// Indicate that a virtual region has been unmapped.
106          /// Should be called before the unmapping is complete.
107          /// </summary>
108          /// <param name="va">Virtual memory address</param>
109          /// <param name="size">Size to be unmapped</param>
110          public void Unmap(ulong va, ulong size)
111          {
112              // An unmapping may mean we need to re-evaluate each VirtualRegion's affected area.
113              // Find all handles that overlap with the range, we need to notify them that the region was unmapped.
114  
115              lock (TrackingLock)
116              {
117                  ref var overlaps = ref ThreadStaticArray<VirtualRegion>.Get();
118  
119                  for (int type = 0; type < 2; type++)
120                  {
121                      NonOverlappingRangeList<VirtualRegion> regions = type == 0 ? _virtualRegions : _guestVirtualRegions;
122  
123                      int count = regions.FindOverlapsNonOverlapping(va, size, ref overlaps);
124  
125                      for (int i = 0; i < count; i++)
126                      {
127                          VirtualRegion region = overlaps[i];
128  
129                          region.SignalMappingChanged(false);
130                      }
131                  }
132              }
133          }
134  
135          /// <summary>
136          /// Alter a tracked memory region to properly capture unaligned accesses.
137          /// For most memory manager modes, this does nothing.
138          /// </summary>
139          /// <param name="address">Original region address</param>
140          /// <param name="size">Original region size</param>
141          /// <returns>A new address and size for tracking unaligned accesses</returns>
142          internal (ulong newAddress, ulong newSize) GetUnalignedSafeRegion(ulong address, ulong size)
143          {
144              if (_singleByteGuestTracking)
145              {
146                  // The guest only signals the first byte of each memory access with the current memory manager.
147                  // To catch unaligned access properly, we need to also protect the page before the address.
148  
149                  // Assume that the address and size are already aligned.
150  
151                  return (address - (ulong)_pageSize, size + (ulong)_pageSize);
152              }
153              else
154              {
155                  return (address, size);
156              }
157          }
158  
159          /// <summary>
160          /// Get a list of virtual regions that a handle covers.
161          /// </summary>
162          /// <param name="va">Starting virtual memory address of the handle</param>
163          /// <param name="size">Size of the handle's memory region</param>
164          /// <param name="guest">True if getting handles for guest protection, false otherwise</param>
165          /// <returns>A list of virtual regions within the given range</returns>
166          internal List<VirtualRegion> GetVirtualRegionsForHandle(ulong va, ulong size, bool guest)
167          {
168              List<VirtualRegion> result = new();
169              NonOverlappingRangeList<VirtualRegion> regions = guest ? _guestVirtualRegions : _virtualRegions;
170              regions.GetOrAddRegions(result, va, size, (va, size) => new VirtualRegion(this, va, size, guest));
171  
172              return result;
173          }
174  
175          /// <summary>
176          /// Remove a virtual region from the range list. This assumes that the lock has been acquired.
177          /// </summary>
178          /// <param name="region">Region to remove</param>
179          internal void RemoveVirtual(VirtualRegion region)
180          {
181              if (region.Guest)
182              {
183                  _guestVirtualRegions.Remove(region);
184              }
185              else
186              {
187                  _virtualRegions.Remove(region);
188              }
189          }
190  
191          /// <summary>
192          /// Obtains a memory tracking handle for the given virtual region, with a specified granularity. This should be disposed when finished with.
193          /// </summary>
194          /// <param name="address">CPU virtual address of the region</param>
195          /// <param name="size">Size of the region</param>
196          /// <param name="handles">Handles to inherit state from or reuse. When none are present, provide null</param>
197          /// <param name="granularity">Desired granularity of write tracking</param>
198          /// <param name="id">Handle ID</param>
199          /// <param name="flags">Region flags</param>
200          /// <returns>The memory tracking handle</returns>
201          public MultiRegionHandle BeginGranularTracking(ulong address, ulong size, IEnumerable<IRegionHandle> handles, ulong granularity, int id, RegionFlags flags = RegionFlags.None)
202          {
203              return new MultiRegionHandle(this, address, size, handles, granularity, id, flags);
204          }
205  
206          /// <summary>
207          /// Obtains a smart memory tracking handle for the given virtual region, with a specified granularity. This should be disposed when finished with.
208          /// </summary>
209          /// <param name="address">CPU virtual address of the region</param>
210          /// <param name="size">Size of the region</param>
211          /// <param name="granularity">Desired granularity of write tracking</param>
212          /// <param name="id">Handle ID</param>
213          /// <returns>The memory tracking handle</returns>
214          public SmartMultiRegionHandle BeginSmartGranularTracking(ulong address, ulong size, ulong granularity, int id)
215          {
216              (address, size) = PageAlign(address, size);
217  
218              return new SmartMultiRegionHandle(this, address, size, granularity, id);
219          }
220  
221          /// <summary>
222          /// Obtains a memory tracking handle for the given virtual region. This should be disposed when finished with.
223          /// </summary>
224          /// <param name="address">CPU virtual address of the region</param>
225          /// <param name="size">Size of the region</param>
226          /// <param name="id">Handle ID</param>
227          /// <param name="flags">Region flags</param>
228          /// <returns>The memory tracking handle</returns>
229          public RegionHandle BeginTracking(ulong address, ulong size, int id, RegionFlags flags = RegionFlags.None)
230          {
231              var (paAddress, paSize) = PageAlign(address, size);
232  
233              lock (TrackingLock)
234              {
235                  bool mapped = _memoryManager.IsRangeMapped(address, size);
236                  RegionHandle handle = new(this, paAddress, paSize, address, size, id, flags, mapped);
237  
238                  return handle;
239              }
240          }
241  
242          /// <summary>
243          /// Obtains a memory tracking handle for the given virtual region. This should be disposed when finished with.
244          /// </summary>
245          /// <param name="address">CPU virtual address of the region</param>
246          /// <param name="size">Size of the region</param>
247          /// <param name="bitmap">The bitmap owning the dirty flag for this handle</param>
248          /// <param name="bit">The bit of this handle within the dirty flag</param>
249          /// <param name="id">Handle ID</param>
250          /// <param name="flags">Region flags</param>
251          /// <returns>The memory tracking handle</returns>
252          internal RegionHandle BeginTrackingBitmap(ulong address, ulong size, ConcurrentBitmap bitmap, int bit, int id, RegionFlags flags = RegionFlags.None)
253          {
254              var (paAddress, paSize) = PageAlign(address, size);
255  
256              lock (TrackingLock)
257              {
258                  bool mapped = _memoryManager.IsRangeMapped(address, size);
259                  RegionHandle handle = new(this, paAddress, paSize, address, size, bitmap, bit, id, flags, mapped);
260  
261                  return handle;
262              }
263          }
264  
265          /// <summary>
266          /// Signal that a virtual memory event happened at the given location.
267          /// The memory event is assumed to be triggered by guest code.
268          /// </summary>
269          /// <param name="address">Virtual address accessed</param>
270          /// <param name="size">Size of the region affected in bytes</param>
271          /// <param name="write">Whether the region was written to or read</param>
272          /// <returns>True if the event triggered any tracking regions, false otherwise</returns>
273          public bool VirtualMemoryEvent(ulong address, ulong size, bool write)
274          {
275              return VirtualMemoryEvent(address, size, write, precise: false, exemptId: null, guest: true);
276          }
277  
278          /// <summary>
279          /// Signal that a virtual memory event happened at the given location.
280          /// This can be flagged as a precise event, which will avoid reprotection and call special handlers if possible.
281          /// A precise event has an exact address and size, rather than triggering on page granularity.
282          /// </summary>
283          /// <param name="address">Virtual address accessed</param>
284          /// <param name="size">Size of the region affected in bytes</param>
285          /// <param name="write">Whether the region was written to or read</param>
286          /// <param name="precise">True if the access is precise, false otherwise</param>
287          /// <param name="exemptId">Optional ID that of the handles that should not be signalled</param>
288          /// <param name="guest">True if the access is from the guest, false otherwise</param>
289          /// <returns>True if the event triggered any tracking regions, false otherwise</returns>
290          public bool VirtualMemoryEvent(ulong address, ulong size, bool write, bool precise, int? exemptId = null, bool guest = false)
291          {
292              // Look up the virtual region using the region list.
293              // Signal up the chain to relevant handles.
294  
295              bool shouldThrow = false;
296  
297              lock (TrackingLock)
298              {
299                  ref var overlaps = ref ThreadStaticArray<VirtualRegion>.Get();
300  
301                  NonOverlappingRangeList<VirtualRegion> regions = guest ? _guestVirtualRegions : _virtualRegions;
302  
303                  int count = regions.FindOverlapsNonOverlapping(address, size, ref overlaps);
304  
305                  if (count == 0 && !precise)
306                  {
307                      if (_memoryManager.IsRangeMapped(address, size))
308                      {
309                          // TODO: There is currently the possibility that a page can be protected after its virtual region is removed.
310                          // This code handles that case when it happens, but it would be better to find out how this happens.
311                          _memoryManager.TrackingReprotect(address & ~(ulong)(_pageSize - 1), (ulong)_pageSize, MemoryPermission.ReadAndWrite, guest);
312                          return true; // This memory _should_ be mapped, so we need to try again.
313                      }
314                      else
315                      {
316                          shouldThrow = true;
317                      }
318                  }
319                  else
320                  {
321                      if (guest && _singleByteGuestTracking)
322                      {
323                          // Increase the access size to trigger handles with misaligned accesses.
324                          size += (ulong)_pageSize;
325                      }
326  
327                      for (int i = 0; i < count; i++)
328                      {
329                          VirtualRegion region = overlaps[i];
330  
331                          if (precise)
332                          {
333                              region.SignalPrecise(address, size, write, exemptId);
334                          }
335                          else
336                          {
337                              region.Signal(address, size, write, exemptId);
338                          }
339                      }
340                  }
341              }
342  
343              if (shouldThrow)
344              {
345                  _invalidAccessHandler?.Invoke(address);
346  
347                  // We can't continue - it's impossible to remove protection from the page.
348                  // Even if the access handler wants us to continue, we wouldn't be able to.
349                  throw new InvalidMemoryRegionException();
350              }
351  
352              return true;
353          }
354  
355          /// <summary>
356          /// Reprotect a given virtual region. The virtual memory manager will handle this.
357          /// </summary>
358          /// <param name="region">Region to reprotect</param>
359          /// <param name="permission">Memory permission to protect with</param>
360          /// <param name="guest">True if the protection is for guest access, false otherwise</param>
361          internal void ProtectVirtualRegion(VirtualRegion region, MemoryPermission permission, bool guest)
362          {
363              _memoryManager.TrackingReprotect(region.Address, region.Size, permission, guest);
364          }
365  
366          /// <summary>
367          /// Returns the number of virtual regions currently being tracked.
368          /// Useful for tests and metrics.
369          /// </summary>
370          /// <returns>The number of virtual regions</returns>
371          public int GetRegionCount()
372          {
373              lock (TrackingLock)
374              {
375                  return _virtualRegions.Count;
376              }
377          }
378      }
379  }