Decoder.cs
  1  using Ryujinx.Common.Memory;
  2  using Ryujinx.Graphics.Nvdec.Vp9.Common;
  3  using Ryujinx.Graphics.Nvdec.Vp9.Types;
  4  using Ryujinx.Graphics.Video;
  5  using System;
  6  
  7  namespace Ryujinx.Graphics.Nvdec.Vp9
  8  {
  9      public sealed class Decoder : IVp9Decoder
 10      {
 11          public bool IsHardwareAccelerated => false;
 12  
 13          private readonly MemoryAllocator _allocator = new();
 14  
 15          public ISurface CreateSurface(int width, int height) => new Surface(width, height);
 16  
 17          private static ReadOnlySpan<byte> LiteralToFilter => new byte[]
 18          {
 19              Constants.EightTapSmooth,
 20              Constants.EightTap,
 21              Constants.EightTapSharp,
 22              Constants.Bilinear,
 23          };
 24  
 25          public unsafe bool Decode(
 26              ref Vp9PictureInfo pictureInfo,
 27              ISurface output,
 28              ReadOnlySpan<byte> bitstream,
 29              ReadOnlySpan<Vp9MvRef> mvsIn,
 30              Span<Vp9MvRef> mvsOut)
 31          {
 32              Vp9Common cm = new()
 33              {
 34                  FrameType = pictureInfo.IsKeyFrame ? FrameType.KeyFrame : FrameType.InterFrame,
 35                  IntraOnly = pictureInfo.IntraOnly,
 36  
 37                  Width = output.Width,
 38                  Height = output.Height,
 39                  SubsamplingX = 1,
 40                  SubsamplingY = 1,
 41  
 42                  UsePrevFrameMvs = pictureInfo.UsePrevInFindMvRefs,
 43  
 44                  RefFrameSignBias = pictureInfo.RefFrameSignBias,
 45  
 46                  BaseQindex = pictureInfo.BaseQIndex,
 47                  YDcDeltaQ = pictureInfo.YDcDeltaQ,
 48                  UvAcDeltaQ = pictureInfo.UvAcDeltaQ,
 49                  UvDcDeltaQ = pictureInfo.UvDcDeltaQ,
 50              };
 51  
 52              cm.Mb.Lossless = pictureInfo.Lossless;
 53              cm.Mb.Bd = 8;
 54  
 55              cm.TxMode = (TxMode)pictureInfo.TransformMode;
 56  
 57              cm.AllowHighPrecisionMv = pictureInfo.AllowHighPrecisionMv;
 58  
 59              cm.InterpFilter = (byte)pictureInfo.InterpFilter;
 60  
 61              if (cm.InterpFilter != Constants.Switchable)
 62              {
 63                  cm.InterpFilter = LiteralToFilter[cm.InterpFilter];
 64              }
 65  
 66              cm.ReferenceMode = (ReferenceMode)pictureInfo.ReferenceMode;
 67  
 68              cm.CompFixedRef = pictureInfo.CompFixedRef;
 69              cm.CompVarRef = pictureInfo.CompVarRef;
 70  
 71              cm.Log2TileCols = pictureInfo.Log2TileCols;
 72              cm.Log2TileRows = pictureInfo.Log2TileRows;
 73  
 74              cm.Seg.Enabled = pictureInfo.SegmentEnabled;
 75              cm.Seg.UpdateMap = pictureInfo.SegmentMapUpdate;
 76              cm.Seg.TemporalUpdate = pictureInfo.SegmentMapTemporalUpdate;
 77              cm.Seg.AbsDelta = (byte)pictureInfo.SegmentAbsDelta;
 78              cm.Seg.FeatureMask = pictureInfo.SegmentFeatureEnable;
 79              cm.Seg.FeatureData = pictureInfo.SegmentFeatureData;
 80  
 81              cm.Lf.ModeRefDeltaEnabled = pictureInfo.ModeRefDeltaEnabled;
 82              cm.Lf.RefDeltas = pictureInfo.RefDeltas;
 83              cm.Lf.ModeDeltas = pictureInfo.ModeDeltas;
 84  
 85              cm.Fc = new Ptr<Vp9EntropyProbs>(ref pictureInfo.Entropy);
 86              cm.Counts = new Ptr<Vp9BackwardUpdates>(ref pictureInfo.BackwardUpdateCounts);
 87  
 88              cm.FrameRefs[0].Buf = (Surface)pictureInfo.LastReference;
 89              cm.FrameRefs[1].Buf = (Surface)pictureInfo.GoldenReference;
 90              cm.FrameRefs[2].Buf = (Surface)pictureInfo.AltReference;
 91              cm.Mb.CurBuf = (Surface)output;
 92  
 93              cm.Mb.SetupBlockPlanes(1, 1);
 94  
 95              int tileCols = 1 << pictureInfo.Log2TileCols;
 96              int tileRows = 1 << pictureInfo.Log2TileRows;
 97  
 98              // Video usually have only 4 columns, so more threads won't make a difference for those.
 99              // Try to not take all CPU cores for video decoding.
100              int maxThreads = Math.Min(4, Environment.ProcessorCount / 2);
101  
102              cm.AllocTileWorkerData(_allocator, tileCols, tileRows, maxThreads);
103              cm.AllocContextBuffers(_allocator, output.Width, output.Height);
104              cm.InitContextBuffers();
105              cm.SetupSegmentationDequant();
106              cm.SetupScaleFactors();
107  
108              SetMvs(ref cm, mvsIn);
109  
110              fixed (byte* dataPtr = bitstream)
111              {
112                  try
113                  {
114                      if (maxThreads > 1 && tileRows == 1 && tileCols > 1)
115                      {
116                          DecodeFrame.DecodeTilesMt(ref cm, new ArrayPtr<byte>(dataPtr, bitstream.Length), maxThreads);
117                      }
118                      else
119                      {
120                          DecodeFrame.DecodeTiles(ref cm, new ArrayPtr<byte>(dataPtr, bitstream.Length));
121                      }
122                  }
123                  catch (InternalErrorException)
124                  {
125                      return false;
126                  }
127              }
128  
129              GetMvs(ref cm, mvsOut);
130  
131              cm.FreeTileWorkerData(_allocator);
132              cm.FreeContextBuffers(_allocator);
133  
134              return true;
135          }
136  
137          private static void SetMvs(ref Vp9Common cm, ReadOnlySpan<Vp9MvRef> mvs)
138          {
139              if (mvs.Length > cm.PrevFrameMvs.Length)
140              {
141                  throw new ArgumentException($"Size mismatch, expected: {cm.PrevFrameMvs.Length}, but got: {mvs.Length}.");
142              }
143  
144              for (int i = 0; i < mvs.Length; i++)
145              {
146                  ref var mv = ref cm.PrevFrameMvs[i];
147  
148                  mv.Mv[0].Row = mvs[i].Mvs[0].Row;
149                  mv.Mv[0].Col = mvs[i].Mvs[0].Col;
150                  mv.Mv[1].Row = mvs[i].Mvs[1].Row;
151                  mv.Mv[1].Col = mvs[i].Mvs[1].Col;
152  
153                  mv.RefFrame[0] = (sbyte)mvs[i].RefFrames[0];
154                  mv.RefFrame[1] = (sbyte)mvs[i].RefFrames[1];
155              }
156          }
157  
158          private static void GetMvs(ref Vp9Common cm, Span<Vp9MvRef> mvs)
159          {
160              if (mvs.Length > cm.CurFrameMvs.Length)
161              {
162                  throw new ArgumentException($"Size mismatch, expected: {cm.CurFrameMvs.Length}, but got: {mvs.Length}.");
163              }
164  
165              for (int i = 0; i < mvs.Length; i++)
166              {
167                  ref var mv = ref cm.CurFrameMvs[i];
168  
169                  mvs[i].Mvs[0].Row = mv.Mv[0].Row;
170                  mvs[i].Mvs[0].Col = mv.Mv[0].Col;
171                  mvs[i].Mvs[1].Row = mv.Mv[1].Row;
172                  mvs[i].Mvs[1].Col = mv.Mv[1].Col;
173  
174                  mvs[i].RefFrames[0] = mv.RefFrame[0];
175                  mvs[i].RefFrames[1] = mv.RefFrame[1];
176              }
177          }
178  
179          public void Dispose() => _allocator.Dispose();
180      }
181  }