/ src / Ryujinx.Graphics.Gpu / Shader / ShaderCacheHashTable.cs
ShaderCacheHashTable.cs
  1  using Ryujinx.Graphics.Gpu.Memory;
  2  using Ryujinx.Graphics.Gpu.Shader.HashTable;
  3  using Ryujinx.Graphics.Shader;
  4  using System;
  5  using System.Collections.Generic;
  6  
  7  namespace Ryujinx.Graphics.Gpu.Shader
  8  {
  9      /// <summary>
 10      /// Holds already cached code for a guest shader.
 11      /// </summary>
 12      struct CachedGraphicsGuestCode
 13      {
 14          public byte[] VertexACode;
 15          public byte[] VertexBCode;
 16          public byte[] TessControlCode;
 17          public byte[] TessEvaluationCode;
 18          public byte[] GeometryCode;
 19          public byte[] FragmentCode;
 20  
 21          /// <summary>
 22          /// Gets the guest code of a shader stage by its index.
 23          /// </summary>
 24          /// <param name="stageIndex">Index of the shader stage</param>
 25          /// <returns>Guest code, or null if not present</returns>
 26          public readonly byte[] GetByIndex(int stageIndex)
 27          {
 28              return stageIndex switch
 29              {
 30                  1 => TessControlCode,
 31                  2 => TessEvaluationCode,
 32                  3 => GeometryCode,
 33                  4 => FragmentCode,
 34                  _ => VertexBCode,
 35              };
 36          }
 37      }
 38  
 39      /// <summary>
 40      /// Graphics shader cache hash table.
 41      /// </summary>
 42      class ShaderCacheHashTable
 43      {
 44          /// <summary>
 45          /// Shader ID cache.
 46          /// </summary>
 47          private struct IdCache
 48          {
 49              private PartitionedHashTable<int> _cache;
 50              private int _id;
 51  
 52              /// <summary>
 53              /// Initializes the state.
 54              /// </summary>
 55              public void Initialize()
 56              {
 57                  _cache = new PartitionedHashTable<int>();
 58                  _id = 0;
 59              }
 60  
 61              /// <summary>
 62              /// Adds guest code to the cache.
 63              /// </summary>
 64              /// <remarks>
 65              /// If the code was already cached, it will just return the existing ID.
 66              /// </remarks>
 67              /// <param name="code">Code to add</param>
 68              /// <returns>Unique ID for the guest code</returns>
 69              public int Add(byte[] code)
 70              {
 71                  int id = ++_id;
 72                  int cachedId = _cache.GetOrAdd(code, id);
 73                  if (cachedId != id)
 74                  {
 75                      --_id;
 76                  }
 77  
 78                  return cachedId;
 79              }
 80  
 81              /// <summary>
 82              /// Tries to find cached guest code.
 83              /// </summary>
 84              /// <param name="dataAccessor">Code accessor used to read guest code to find a match on the hash table</param>
 85              /// <param name="id">ID of the guest code, if found</param>
 86              /// <param name="data">Cached guest code, if found</param>
 87              /// <returns>True if found, false otherwise</returns>
 88              public readonly bool TryFind(IDataAccessor dataAccessor, out int id, out byte[] data)
 89              {
 90                  return _cache.TryFindItem(dataAccessor, out id, out data);
 91              }
 92          }
 93  
 94          /// <summary>
 95          /// Guest code IDs of the guest shaders that when combined forms a single host program.
 96          /// </summary>
 97          private struct IdTable : IEquatable<IdTable>
 98          {
 99              public int VertexAId;
100              public int VertexBId;
101              public int TessControlId;
102              public int TessEvaluationId;
103              public int GeometryId;
104              public int FragmentId;
105  
106              public readonly override bool Equals(object obj)
107              {
108                  return obj is IdTable other && Equals(other);
109              }
110  
111              public readonly bool Equals(IdTable other)
112              {
113                  return other.VertexAId == VertexAId &&
114                         other.VertexBId == VertexBId &&
115                         other.TessControlId == TessControlId &&
116                         other.TessEvaluationId == TessEvaluationId &&
117                         other.GeometryId == GeometryId &&
118                         other.FragmentId == FragmentId;
119              }
120  
121              public readonly override int GetHashCode()
122              {
123                  return HashCode.Combine(VertexAId, VertexBId, TessControlId, TessEvaluationId, GeometryId, FragmentId);
124              }
125          }
126  
127          private IdCache _vertexACache;
128          private IdCache _vertexBCache;
129          private IdCache _tessControlCache;
130          private IdCache _tessEvaluationCache;
131          private IdCache _geometryCache;
132          private IdCache _fragmentCache;
133  
134          private readonly Dictionary<IdTable, ShaderSpecializationList> _shaderPrograms;
135  
136          /// <summary>
137          /// Creates a new graphics shader cache hash table.
138          /// </summary>
139          public ShaderCacheHashTable()
140          {
141              _vertexACache.Initialize();
142              _vertexBCache.Initialize();
143              _tessControlCache.Initialize();
144              _tessEvaluationCache.Initialize();
145              _geometryCache.Initialize();
146              _fragmentCache.Initialize();
147  
148              _shaderPrograms = new Dictionary<IdTable, ShaderSpecializationList>();
149          }
150  
151          /// <summary>
152          /// Adds a program to the cache.
153          /// </summary>
154          /// <param name="program">Program to be added</param>
155          public void Add(CachedShaderProgram program)
156          {
157              IdTable idTable = new();
158  
159              foreach (var shader in program.Shaders)
160              {
161                  if (shader == null)
162                  {
163                      continue;
164                  }
165  
166                  if (shader.Info != null)
167                  {
168                      switch (shader.Info.Stage)
169                      {
170                          case ShaderStage.Vertex:
171                              idTable.VertexBId = _vertexBCache.Add(shader.Code);
172                              break;
173                          case ShaderStage.TessellationControl:
174                              idTable.TessControlId = _tessControlCache.Add(shader.Code);
175                              break;
176                          case ShaderStage.TessellationEvaluation:
177                              idTable.TessEvaluationId = _tessEvaluationCache.Add(shader.Code);
178                              break;
179                          case ShaderStage.Geometry:
180                              idTable.GeometryId = _geometryCache.Add(shader.Code);
181                              break;
182                          case ShaderStage.Fragment:
183                              idTable.FragmentId = _fragmentCache.Add(shader.Code);
184                              break;
185                      }
186                  }
187                  else
188                  {
189                      idTable.VertexAId = _vertexACache.Add(shader.Code);
190                  }
191              }
192  
193              if (!_shaderPrograms.TryGetValue(idTable, out ShaderSpecializationList specList))
194              {
195                  specList = new ShaderSpecializationList();
196                  _shaderPrograms.Add(idTable, specList);
197              }
198  
199              specList.Add(program);
200          }
201  
202          /// <summary>
203          /// Tries to find a cached program.
204          /// </summary>
205          /// <remarks>
206          /// Even if false is returned, <paramref name="guestCode"/> might still contain cached guest code.
207          /// This can be used to avoid additional allocations for guest code that was already cached.
208          /// </remarks>
209          /// <param name="channel">GPU channel</param>
210          /// <param name="poolState">Texture pool state</param>
211          /// <param name="graphicsState">Graphics state</param>
212          /// <param name="addresses">Guest addresses of the shaders to find</param>
213          /// <param name="program">Cached host program for the given state, if found</param>
214          /// <param name="guestCode">Cached guest code, if any found</param>
215          /// <returns>True if a cached host program was found, false otherwise</returns>
216          public bool TryFind(
217              GpuChannel channel,
218              ref GpuChannelPoolState poolState,
219              ref GpuChannelGraphicsState graphicsState,
220              ShaderAddresses addresses,
221              out CachedShaderProgram program,
222              out CachedGraphicsGuestCode guestCode)
223          {
224              var memoryManager = channel.MemoryManager;
225              IdTable idTable = new();
226              guestCode = new CachedGraphicsGuestCode();
227  
228              program = null;
229  
230              bool found = TryGetId(_vertexACache, memoryManager, addresses.VertexA, out idTable.VertexAId, out guestCode.VertexACode);
231              found &= TryGetId(_vertexBCache, memoryManager, addresses.VertexB, out idTable.VertexBId, out guestCode.VertexBCode);
232              found &= TryGetId(_tessControlCache, memoryManager, addresses.TessControl, out idTable.TessControlId, out guestCode.TessControlCode);
233              found &= TryGetId(_tessEvaluationCache, memoryManager, addresses.TessEvaluation, out idTable.TessEvaluationId, out guestCode.TessEvaluationCode);
234              found &= TryGetId(_geometryCache, memoryManager, addresses.Geometry, out idTable.GeometryId, out guestCode.GeometryCode);
235              found &= TryGetId(_fragmentCache, memoryManager, addresses.Fragment, out idTable.FragmentId, out guestCode.FragmentCode);
236  
237              if (found && _shaderPrograms.TryGetValue(idTable, out ShaderSpecializationList specList))
238              {
239                  return specList.TryFindForGraphics(channel, ref poolState, ref graphicsState, out program);
240              }
241  
242              return false;
243          }
244  
245          /// <summary>
246          /// Tries to get the ID of a single cached shader stage.
247          /// </summary>
248          /// <param name="idCache">ID cache of the stage</param>
249          /// <param name="memoryManager">GPU memory manager</param>
250          /// <param name="baseAddress">Base address of the shader</param>
251          /// <param name="id">ID, if found</param>
252          /// <param name="data">Cached guest code, if found</param>
253          /// <returns>True if a cached shader is found, false otherwise</returns>
254          private static bool TryGetId(IdCache idCache, MemoryManager memoryManager, ulong baseAddress, out int id, out byte[] data)
255          {
256              if (baseAddress == 0)
257              {
258                  id = 0;
259                  data = null;
260                  return true;
261              }
262  
263              ShaderCodeAccessor codeAccessor = new(memoryManager, baseAddress);
264              return idCache.TryFind(codeAccessor, out id, out data);
265          }
266  
267          /// <summary>
268          /// Gets all programs that have been added to the table.
269          /// </summary>
270          /// <returns>Programs added to the table</returns>
271          public IEnumerable<CachedShaderProgram> GetPrograms()
272          {
273              foreach (var specList in _shaderPrograms.Values)
274              {
275                  foreach (var program in specList)
276                  {
277                      yield return program;
278                  }
279              }
280          }
281      }
282  }