/ src / Ryujinx.Graphics.Nvdec.Vp9 / Detokenize.cs
Detokenize.cs
  1  using Ryujinx.Common.Memory;
  2  using Ryujinx.Graphics.Nvdec.Vp9.Dsp;
  3  using Ryujinx.Graphics.Nvdec.Vp9.Types;
  4  using Ryujinx.Graphics.Video;
  5  using System;
  6  using System.Diagnostics;
  7  using System.Runtime.InteropServices;
  8  using static Ryujinx.Graphics.Nvdec.Vp9.Dsp.InvTxfm;
  9  
 10  namespace Ryujinx.Graphics.Nvdec.Vp9
 11  {
 12      internal static class Detokenize
 13      {
 14          private const int EobContextNode = 0;
 15          private const int ZeroContextNode = 1;
 16          private const int OneContextNode = 2;
 17  
 18          private static int GetCoefContext(ReadOnlySpan<short> neighbors, ReadOnlySpan<byte> tokenCache, int c)
 19          {
 20              const int MaxNeighbors = 2;
 21  
 22              return (1 + tokenCache[neighbors[MaxNeighbors * c + 0]] + tokenCache[neighbors[MaxNeighbors * c + 1]]) >> 1;
 23          }
 24  
 25          private static int ReadCoeff(
 26              ref Reader r,
 27              ReadOnlySpan<byte> probs,
 28              int n,
 29              ref ulong value,
 30              ref int count,
 31              ref uint range)
 32          {
 33              int i, val = 0;
 34              for (i = 0; i < n; ++i)
 35              {
 36                  val = (val << 1) | r.ReadBool(probs[i], ref value, ref count, ref range);
 37              }
 38  
 39              return val;
 40          }
 41  
 42          private static int DecodeCoefs(
 43              ref MacroBlockD xd,
 44              PlaneType type,
 45              Span<int> dqcoeff,
 46              TxSize txSize,
 47              ref Array2<short> dq,
 48              int ctx,
 49              ReadOnlySpan<short> scan,
 50              ReadOnlySpan<short> nb,
 51              ref Reader r)
 52          {
 53              ref Vp9BackwardUpdates counts = ref xd.Counts.Value;
 54              int maxEob = 16 << ((int)txSize << 1);
 55              ref Vp9EntropyProbs fc = ref xd.Fc.Value;
 56              int refr = xd.Mi[0].Value.IsInterBlock() ? 1 : 0;
 57              int band, c = 0;
 58              ref Array6<Array6<Array3<byte>>> coefProbs = ref fc.CoefProbs[(int)txSize][(int)type][refr];
 59              Span<byte> tokenCache = stackalloc byte[32 * 32];
 60              ReadOnlySpan<byte> bandTranslate = Luts.GetBandTranslate(txSize);
 61              int dqShift = (txSize == TxSize.Tx32x32) ? 1 : 0;
 62              int v;
 63              short dqv = dq[0];
 64              ReadOnlySpan<byte> cat6Prob = (xd.Bd == 12)
 65                  ? Luts.Vp9Cat6ProbHigh12
 66                  : (xd.Bd == 10) ? Luts.Vp9Cat6ProbHigh12[2..] : Luts.Vp9Cat6Prob;
 67              int cat6Bits = (xd.Bd == 12) ? 18 : (xd.Bd == 10) ? 16 : 14;
 68              // Keep value, range, and count as locals.  The compiler produces better
 69              // results with the locals than using r directly.
 70              ulong value = r.Value;
 71              uint range = r.Range;
 72              int count = r.Count;
 73  
 74              while (c < maxEob)
 75              {
 76                  int val = -1;
 77                  band = bandTranslate[0];
 78                  bandTranslate = bandTranslate[1..];
 79                  ref Array3<byte> prob = ref coefProbs[band][ctx];
 80                  if (!xd.Counts.IsNull)
 81                  {
 82                      ++counts.EobBranch[(int)txSize][(int)type][refr][band][ctx];
 83                  }
 84  
 85                  if (r.ReadBool(prob[EobContextNode], ref value, ref count, ref range) == 0)
 86                  {
 87                      if (!xd.Counts.IsNull)
 88                      {
 89                          ++counts.Coef[(int)txSize][(int)type][refr][band][ctx][Constants.EobModelToken];
 90                      }
 91  
 92                      break;
 93                  }
 94  
 95                  while (r.ReadBool(prob[ZeroContextNode], ref value, ref count, ref range) == 0)
 96                  {
 97                      if (!xd.Counts.IsNull)
 98                      {
 99                          ++counts.Coef[(int)txSize][(int)type][refr][band][ctx][Constants.ZeroToken];
100                      }
101  
102                      dqv = dq[1];
103                      tokenCache[scan[c]] = 0;
104                      ++c;
105                      if (c >= maxEob)
106                      {
107                          r.Value = value;
108                          r.Range = range;
109                          r.Count = count;
110  
111                          return c; // Zero tokens at the end (no eob token)
112                      }
113                      ctx = GetCoefContext(nb, tokenCache, c);
114                      band = bandTranslate[0];
115                      bandTranslate = bandTranslate[1..];
116                      prob = ref coefProbs[band][ctx];
117                  }
118  
119                  if (r.ReadBool(prob[OneContextNode], ref value, ref count, ref range) != 0)
120                  {
121                      ReadOnlySpan<byte> p = Luts.Vp9Pareto8Full[prob[Constants.PivotNode] - 1];
122                      if (!xd.Counts.IsNull)
123                      {
124                          ++counts.Coef[(int)txSize][(int)type][refr][band][ctx][Constants.TwoToken];
125                      }
126  
127                      if (r.ReadBool(p[0], ref value, ref count, ref range) != 0)
128                      {
129                          if (r.ReadBool(p[3], ref value, ref count, ref range) != 0)
130                          {
131                              tokenCache[scan[c]] = 5;
132                              if (r.ReadBool(p[5], ref value, ref count, ref range) != 0)
133                              {
134                                  if (r.ReadBool(p[7], ref value, ref count, ref range) != 0)
135                                  {
136                                      val = Constants.Cat6MinVal + ReadCoeff(ref r, cat6Prob, cat6Bits, ref value, ref count, ref range);
137                                  }
138                                  else
139                                  {
140                                      val = Constants.Cat5MinVal + ReadCoeff(ref r, Luts.Vp9Cat5Prob, 5, ref value, ref count, ref range);
141                                  }
142                              }
143                              else if (r.ReadBool(p[6], ref value, ref count, ref range) != 0)
144                              {
145                                  val = Constants.Cat4MinVal + ReadCoeff(ref r, Luts.Vp9Cat4Prob, 4, ref value, ref count, ref range);
146                              }
147                              else
148                              {
149                                  val = Constants.Cat3MinVal + ReadCoeff(ref r, Luts.Vp9Cat3Prob, 3, ref value, ref count, ref range);
150                              }
151                          }
152                          else
153                          {
154                              tokenCache[scan[c]] = 4;
155                              if (r.ReadBool(p[4], ref value, ref count, ref range) != 0)
156                              {
157                                  val = Constants.Cat2MinVal + ReadCoeff(ref r, Luts.Vp9Cat2Prob, 2, ref value, ref count, ref range);
158                              }
159                              else
160                              {
161                                  val = Constants.Cat1MinVal + ReadCoeff(ref r, Luts.Vp9Cat1Prob, 1, ref value, ref count, ref range);
162                              }
163                          }
164                          // Val may use 18-bits
165                          v = (int)(((long)val * dqv) >> dqShift);
166                      }
167                      else
168                      {
169                          if (r.ReadBool(p[1], ref value, ref count, ref range) != 0)
170                          {
171                              tokenCache[scan[c]] = 3;
172                              v = ((3 + r.ReadBool(p[2], ref value, ref count, ref range)) * dqv) >> dqShift;
173                          }
174                          else
175                          {
176                              tokenCache[scan[c]] = 2;
177                              v = (2 * dqv) >> dqShift;
178                          }
179                      }
180                  }
181                  else
182                  {
183                      if (!xd.Counts.IsNull)
184                      {
185                          ++counts.Coef[(int)txSize][(int)type][refr][band][ctx][Constants.OneToken];
186                      }
187  
188                      tokenCache[scan[c]] = 1;
189                      v = dqv >> dqShift;
190                  }
191                  dqcoeff[scan[c]] = (int)HighbdCheckRange(r.ReadBool(128, ref value, ref count, ref range) != 0 ? -v : v, xd.Bd);
192                  ++c;
193                  ctx = GetCoefContext(nb, tokenCache, c);
194                  dqv = dq[1];
195              }
196  
197              r.Value = value;
198              r.Range = range;
199              r.Count = count;
200  
201              return c;
202          }
203  
204          private static void GetCtxShift(ref MacroBlockD xd, ref int ctxShiftA, ref int ctxShiftL, int x, int y, uint txSizeInBlocks)
205          {
206              if (xd.MaxBlocksWide != 0)
207              {
208                  if (txSizeInBlocks + x > xd.MaxBlocksWide)
209                  {
210                      ctxShiftA = (int)(txSizeInBlocks - (xd.MaxBlocksWide - x)) * 8;
211                  }
212              }
213              if (xd.MaxBlocksHigh != 0)
214              {
215                  if (txSizeInBlocks + y > xd.MaxBlocksHigh)
216                  {
217                      ctxShiftL = (int)(txSizeInBlocks - (xd.MaxBlocksHigh - y)) * 8;
218                  }
219              }
220          }
221  
222          private static PlaneType GetPlaneType(int plane)
223          {
224              return (PlaneType)(plane > 0 ? 1 : 0);
225          }
226  
227          public static int DecodeBlockTokens(
228              ref TileWorkerData twd,
229              int plane,
230              Luts.ScanOrder sc,
231              int x,
232              int y,
233              TxSize txSize,
234              int segId)
235          {
236              ref Reader r = ref twd.BitReader;
237              ref MacroBlockD xd = ref twd.Xd;
238              ref MacroBlockDPlane pd = ref xd.Plane[plane];
239              ref Array2<short> dequant = ref pd.SegDequant[segId];
240              int eob;
241              Span<sbyte> a = pd.AboveContext.AsSpan()[x..];
242              Span<sbyte> l = pd.LeftContext.AsSpan()[y..];
243              int ctx;
244              int ctxShiftA = 0;
245              int ctxShiftL = 0;
246  
247              switch (txSize)
248              {
249                  case TxSize.Tx4x4:
250                      ctx = a[0] != 0 ? 1 : 0;
251                      ctx += l[0] != 0 ? 1 : 0;
252                      eob = DecodeCoefs(
253                          ref xd,
254                          GetPlaneType(plane),
255                          pd.DqCoeff.AsSpan(),
256                          txSize,
257                          ref dequant,
258                          ctx,
259                          sc.Scan,
260                          sc.Neighbors,
261                          ref r);
262                      a[0] = l[0] = (sbyte)(eob > 0 ? 1 : 0);
263                      break;
264                  case TxSize.Tx8x8:
265                      GetCtxShift(ref xd, ref ctxShiftA, ref ctxShiftL, x, y, 1 << (int)TxSize.Tx8x8);
266                      ctx = MemoryMarshal.Cast<sbyte, ushort>(a)[0] != 0 ? 1 : 0;
267                      ctx += MemoryMarshal.Cast<sbyte, ushort>(l)[0] != 0 ? 1 : 0;
268                      eob = DecodeCoefs(
269                          ref xd,
270                          GetPlaneType(plane),
271                          pd.DqCoeff.AsSpan(),
272                          txSize,
273                          ref dequant,
274                          ctx,
275                          sc.Scan,
276                          sc.Neighbors,
277                          ref r);
278                      MemoryMarshal.Cast<sbyte, ushort>(a)[0] = (ushort)((eob > 0 ? 0x0101 : 0) >> ctxShiftA);
279                      MemoryMarshal.Cast<sbyte, ushort>(l)[0] = (ushort)((eob > 0 ? 0x0101 : 0) >> ctxShiftL);
280                      break;
281                  case TxSize.Tx16x16:
282                      GetCtxShift(ref xd, ref ctxShiftA, ref ctxShiftL, x, y, 1 << (int)TxSize.Tx16x16);
283                      ctx = MemoryMarshal.Cast<sbyte, uint>(a)[0] != 0 ? 1 : 0;
284                      ctx += MemoryMarshal.Cast<sbyte, uint>(l)[0] != 0 ? 1 : 0;
285                      eob = DecodeCoefs(
286                          ref xd,
287                          GetPlaneType(plane),
288                          pd.DqCoeff.AsSpan(),
289                          txSize,
290                          ref dequant,
291                          ctx,
292                          sc.Scan,
293                          sc.Neighbors,
294                          ref r);
295                      MemoryMarshal.Cast<sbyte, uint>(a)[0] = (uint)((eob > 0 ? 0x01010101 : 0) >> ctxShiftA);
296                      MemoryMarshal.Cast<sbyte, uint>(l)[0] = (uint)((eob > 0 ? 0x01010101 : 0) >> ctxShiftL);
297                      break;
298                  case TxSize.Tx32x32:
299                      GetCtxShift(ref xd, ref ctxShiftA, ref ctxShiftL, x, y, 1 << (int)TxSize.Tx32x32);
300                      // NOTE: Casting to ulong here is safe because the default memory
301                      // alignment is at least 8 bytes and the Tx32x32 is aligned on 8 byte
302                      // boundaries.
303                      ctx = MemoryMarshal.Cast<sbyte, ulong>(a)[0] != 0 ? 1 : 0;
304                      ctx += MemoryMarshal.Cast<sbyte, ulong>(l)[0] != 0 ? 1 : 0;
305                      eob = DecodeCoefs(
306                          ref xd,
307                          GetPlaneType(plane),
308                          pd.DqCoeff.AsSpan(),
309                          txSize,
310                          ref dequant,
311                          ctx,
312                          sc.Scan,
313                          sc.Neighbors,
314                          ref r);
315                      MemoryMarshal.Cast<sbyte, ulong>(a)[0] = (eob > 0 ? 0x0101010101010101UL : 0) >> ctxShiftA;
316                      MemoryMarshal.Cast<sbyte, ulong>(l)[0] = (eob > 0 ? 0x0101010101010101UL : 0) >> ctxShiftL;
317                      break;
318                  default:
319                      Debug.Assert(false, "Invalid transform size.");
320                      eob = 0;
321                      break;
322              }
323  
324              return eob;
325          }
326      }
327  }