/ src / ARMeilleure / Instructions / InstEmitSimdShift32.cs
InstEmitSimdShift32.cs
  1  using ARMeilleure.Decoders;
  2  using ARMeilleure.IntermediateRepresentation;
  3  using ARMeilleure.State;
  4  using ARMeilleure.Translation;
  5  using System;
  6  using System.Diagnostics;
  7  using System.Reflection;
  8  using static ARMeilleure.Instructions.InstEmitHelper;
  9  using static ARMeilleure.Instructions.InstEmitSimdHelper;
 10  using static ARMeilleure.Instructions.InstEmitSimdHelper32;
 11  using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
 12  
 13  namespace ARMeilleure.Instructions
 14  {
 15      static partial class InstEmit32
 16      {
 17          public static void Vqrshrn(ArmEmitterContext context)
 18          {
 19              OpCode32SimdShImm op = (OpCode32SimdShImm)context.CurrOp;
 20  
 21              EmitRoundShrImmSaturatingNarrowOp(context, op.U ? ShrImmSaturatingNarrowFlags.VectorZxZx : ShrImmSaturatingNarrowFlags.VectorSxSx);
 22          }
 23  
 24          public static void Vqrshrun(ArmEmitterContext context)
 25          {
 26              EmitRoundShrImmSaturatingNarrowOp(context, ShrImmSaturatingNarrowFlags.VectorSxZx);
 27          }
 28  
 29          public static void Vqshrn(ArmEmitterContext context)
 30          {
 31              OpCode32SimdShImm op = (OpCode32SimdShImm)context.CurrOp;
 32  
 33              EmitShrImmSaturatingNarrowOp(context, op.U ? ShrImmSaturatingNarrowFlags.VectorZxZx : ShrImmSaturatingNarrowFlags.VectorSxSx);
 34          }
 35  
 36          public static void Vqshrun(ArmEmitterContext context)
 37          {
 38              EmitShrImmSaturatingNarrowOp(context, ShrImmSaturatingNarrowFlags.VectorSxZx);
 39          }
 40  
 41          public static void Vrshr(ArmEmitterContext context)
 42          {
 43              EmitRoundShrImmOp(context, accumulate: false);
 44          }
 45  
 46          public static void Vrshrn(ArmEmitterContext context)
 47          {
 48              EmitRoundShrImmNarrowOp(context, signed: false);
 49          }
 50  
 51          public static void Vrsra(ArmEmitterContext context)
 52          {
 53              EmitRoundShrImmOp(context, accumulate: true);
 54          }
 55  
 56          public static void Vshl(ArmEmitterContext context)
 57          {
 58              OpCode32SimdShImm op = (OpCode32SimdShImm)context.CurrOp;
 59  
 60              EmitVectorUnaryOpZx32(context, (op1) => context.ShiftLeft(op1, Const(op.Shift)));
 61          }
 62  
 63          public static void Vshl_I(ArmEmitterContext context)
 64          {
 65              OpCode32SimdReg op = (OpCode32SimdReg)context.CurrOp;
 66  
 67              if (op.U)
 68              {
 69                  EmitVectorBinaryOpZx32(context, (op1, op2) => EmitShlRegOp(context, op2, op1, op.Size, true));
 70              }
 71              else
 72              {
 73                  EmitVectorBinaryOpSx32(context, (op1, op2) => EmitShlRegOp(context, op2, op1, op.Size, false));
 74              }
 75          }
 76  
 77          public static void Vshll(ArmEmitterContext context)
 78          {
 79              OpCode32SimdShImmLong op = (OpCode32SimdShImmLong)context.CurrOp;
 80  
 81              Operand res = context.VectorZero();
 82  
 83              int elems = op.GetBytesCount() >> op.Size;
 84  
 85              for (int index = 0; index < elems; index++)
 86              {
 87                  Operand me = EmitVectorExtract32(context, op.Qm, op.Im + index, op.Size, !op.U);
 88  
 89                  if (op.Size == 2)
 90                  {
 91                      if (op.U)
 92                      {
 93                          me = context.ZeroExtend32(OperandType.I64, me);
 94                      }
 95                      else
 96                      {
 97                          me = context.SignExtend32(OperandType.I64, me);
 98                      }
 99                  }
100  
101                  me = context.ShiftLeft(me, Const(op.Shift));
102  
103                  res = EmitVectorInsert(context, res, me, index, op.Size + 1);
104              }
105  
106              context.Copy(GetVecA32(op.Qd), res);
107          }
108  
109          public static void Vshll2(ArmEmitterContext context)
110          {
111              OpCode32Simd op = (OpCode32Simd)context.CurrOp;
112  
113              Operand res = context.VectorZero();
114  
115              int elems = op.GetBytesCount() >> op.Size;
116  
117              for (int index = 0; index < elems; index++)
118              {
119                  Operand me = EmitVectorExtract32(context, op.Qm, op.Im + index, op.Size, !op.U);
120  
121                  if (op.Size == 2)
122                  {
123                      if (op.U)
124                      {
125                          me = context.ZeroExtend32(OperandType.I64, me);
126                      }
127                      else
128                      {
129                          me = context.SignExtend32(OperandType.I64, me);
130                      }
131                  }
132  
133                  me = context.ShiftLeft(me, Const(8 << op.Size));
134  
135                  res = EmitVectorInsert(context, res, me, index, op.Size + 1);
136              }
137  
138              context.Copy(GetVecA32(op.Qd), res);
139          }
140  
141          public static void Vshr(ArmEmitterContext context)
142          {
143              OpCode32SimdShImm op = (OpCode32SimdShImm)context.CurrOp;
144              int shift = GetImmShr(op);
145              int maxShift = (8 << op.Size) - 1;
146  
147              if (op.U)
148              {
149                  EmitVectorUnaryOpZx32(context, (op1) => (shift > maxShift) ? Const(op1.Type, 0) : context.ShiftRightUI(op1, Const(shift)));
150              }
151              else
152              {
153                  EmitVectorUnaryOpSx32(context, (op1) => context.ShiftRightSI(op1, Const(Math.Min(maxShift, shift))));
154              }
155          }
156  
157          public static void Vshrn(ArmEmitterContext context)
158          {
159              OpCode32SimdShImm op = (OpCode32SimdShImm)context.CurrOp;
160              int shift = GetImmShr(op);
161  
162              EmitVectorUnaryNarrowOp32(context, (op1) => context.ShiftRightUI(op1, Const(shift)));
163          }
164  
165          public static void Vsli_I(ArmEmitterContext context)
166          {
167              OpCode32SimdShImm op = (OpCode32SimdShImm)context.CurrOp;
168              int shift = op.Shift;
169              int eSize = 8 << op.Size;
170  
171              ulong mask = shift != 0 ? ulong.MaxValue >> (64 - shift) : 0UL;
172  
173              Operand res = GetVec(op.Qd);
174  
175              int elems = op.GetBytesCount() >> op.Size;
176  
177              for (int index = 0; index < elems; index++)
178              {
179                  Operand me = EmitVectorExtractZx(context, op.Qm, op.Im + index, op.Size);
180  
181                  Operand neShifted = context.ShiftLeft(me, Const(shift));
182  
183                  Operand de = EmitVectorExtractZx(context, op.Qd, op.Id + index, op.Size);
184  
185                  Operand deMasked = context.BitwiseAnd(de, Const(mask));
186  
187                  Operand e = context.BitwiseOr(neShifted, deMasked);
188  
189                  res = EmitVectorInsert(context, res, e, op.Id + index, op.Size);
190              }
191  
192              context.Copy(GetVec(op.Qd), res);
193          }
194  
195          public static void Vsra(ArmEmitterContext context)
196          {
197              OpCode32SimdShImm op = (OpCode32SimdShImm)context.CurrOp;
198              int shift = GetImmShr(op);
199              int maxShift = (8 << op.Size) - 1;
200  
201              if (op.U)
202              {
203                  EmitVectorImmBinaryQdQmOpZx32(context, (op1, op2) =>
204                  {
205                      Operand shiftRes = shift > maxShift ? Const(op2.Type, 0) : context.ShiftRightUI(op2, Const(shift));
206  
207                      return context.Add(op1, shiftRes);
208                  });
209              }
210              else
211              {
212                  EmitVectorImmBinaryQdQmOpSx32(context, (op1, op2) => context.Add(op1, context.ShiftRightSI(op2, Const(Math.Min(maxShift, shift)))));
213              }
214          }
215  
216          public static void EmitRoundShrImmOp(ArmEmitterContext context, bool accumulate)
217          {
218              OpCode32SimdShImm op = (OpCode32SimdShImm)context.CurrOp;
219              int shift = GetImmShr(op);
220              long roundConst = 1L << (shift - 1);
221  
222              if (op.U)
223              {
224                  if (op.Size < 2)
225                  {
226                      EmitVectorUnaryOpZx32(context, (op1) =>
227                      {
228                          op1 = context.Add(op1, Const(op1.Type, roundConst));
229  
230                          return context.ShiftRightUI(op1, Const(shift));
231                      }, accumulate);
232                  }
233                  else if (op.Size == 2)
234                  {
235                      EmitVectorUnaryOpZx32(context, (op1) =>
236                      {
237                          op1 = context.ZeroExtend32(OperandType.I64, op1);
238                          op1 = context.Add(op1, Const(op1.Type, roundConst));
239  
240                          return context.ConvertI64ToI32(context.ShiftRightUI(op1, Const(shift)));
241                      }, accumulate);
242                  }
243                  else /* if (op.Size == 3) */
244                  {
245                      EmitVectorUnaryOpZx32(context, (op1) => EmitShrImm64(context, op1, signed: false, roundConst, shift), accumulate);
246                  }
247              }
248              else
249              {
250                  if (op.Size < 2)
251                  {
252                      EmitVectorUnaryOpSx32(context, (op1) =>
253                      {
254                          op1 = context.Add(op1, Const(op1.Type, roundConst));
255  
256                          return context.ShiftRightSI(op1, Const(shift));
257                      }, accumulate);
258                  }
259                  else if (op.Size == 2)
260                  {
261                      EmitVectorUnaryOpSx32(context, (op1) =>
262                      {
263                          op1 = context.SignExtend32(OperandType.I64, op1);
264                          op1 = context.Add(op1, Const(op1.Type, roundConst));
265  
266                          return context.ConvertI64ToI32(context.ShiftRightSI(op1, Const(shift)));
267                      }, accumulate);
268                  }
269                  else /* if (op.Size == 3) */
270                  {
271                      EmitVectorUnaryOpZx32(context, (op1) => EmitShrImm64(context, op1, signed: true, roundConst, shift), accumulate);
272                  }
273              }
274          }
275  
276          private static void EmitRoundShrImmNarrowOp(ArmEmitterContext context, bool signed)
277          {
278              OpCode32SimdShImm op = (OpCode32SimdShImm)context.CurrOp;
279  
280              int shift = GetImmShr(op);
281              long roundConst = 1L << (shift - 1);
282  
283              EmitVectorUnaryNarrowOp32(context, (op1) =>
284              {
285                  if (op.Size <= 1)
286                  {
287                      op1 = context.Add(op1, Const(op1.Type, roundConst));
288                      op1 = signed ? context.ShiftRightSI(op1, Const(shift)) : context.ShiftRightUI(op1, Const(shift));
289                  }
290                  else /* if (op.Size == 2 && round) */
291                  {
292                      op1 = EmitShrImm64(context, op1, signed, roundConst, shift); // shift <= 32
293                  }
294  
295                  return op1;
296              }, signed);
297          }
298  
299          private static Operand EmitShlRegOp(ArmEmitterContext context, Operand op, Operand shiftLsB, int size, bool unsigned)
300          {
301              if (shiftLsB.Type == OperandType.I64)
302              {
303                  shiftLsB = context.ConvertI64ToI32(shiftLsB);
304              }
305  
306              shiftLsB = context.SignExtend8(OperandType.I32, shiftLsB);
307              Debug.Assert((uint)size < 4u);
308  
309              Operand negShiftLsB = context.Negate(shiftLsB);
310  
311              Operand isPositive = context.ICompareGreaterOrEqual(shiftLsB, Const(0));
312  
313              Operand shl = context.ShiftLeft(op, shiftLsB);
314              Operand shr = unsigned ? context.ShiftRightUI(op, negShiftLsB) : context.ShiftRightSI(op, negShiftLsB);
315  
316              Operand res = context.ConditionalSelect(isPositive, shl, shr);
317  
318              if (unsigned)
319              {
320                  Operand isOutOfRange = context.BitwiseOr(
321                      context.ICompareGreaterOrEqual(shiftLsB, Const(8 << size)),
322                      context.ICompareGreaterOrEqual(negShiftLsB, Const(8 << size)));
323  
324                  return context.ConditionalSelect(isOutOfRange, Const(op.Type, 0), res);
325              }
326              else
327              {
328                  Operand isOutOfRange0 = context.ICompareGreaterOrEqual(shiftLsB, Const(8 << size));
329                  Operand isOutOfRangeN = context.ICompareGreaterOrEqual(negShiftLsB, Const(8 << size));
330  
331                  // Also zero if shift is too negative, but value was positive.
332                  isOutOfRange0 = context.BitwiseOr(isOutOfRange0, context.BitwiseAnd(isOutOfRangeN, context.ICompareGreaterOrEqual(op, Const(op.Type, 0))));
333  
334                  Operand min = (op.Type == OperandType.I64) ? Const(-1L) : Const(-1);
335  
336                  return context.ConditionalSelect(isOutOfRange0, Const(op.Type, 0), context.ConditionalSelect(isOutOfRangeN, min, res));
337              }
338          }
339  
340          [Flags]
341          private enum ShrImmSaturatingNarrowFlags
342          {
343              Scalar = 1 << 0,
344              SignedSrc = 1 << 1,
345              SignedDst = 1 << 2,
346  
347              Round = 1 << 3,
348  
349              ScalarSxSx = Scalar | SignedSrc | SignedDst,
350              ScalarSxZx = Scalar | SignedSrc,
351              ScalarZxZx = Scalar,
352  
353              VectorSxSx = SignedSrc | SignedDst,
354              VectorSxZx = SignedSrc,
355              VectorZxZx = 0,
356          }
357  
358          private static void EmitRoundShrImmSaturatingNarrowOp(ArmEmitterContext context, ShrImmSaturatingNarrowFlags flags)
359          {
360              EmitShrImmSaturatingNarrowOp(context, ShrImmSaturatingNarrowFlags.Round | flags);
361          }
362  
363          private static void EmitShrImmSaturatingNarrowOp(ArmEmitterContext context, ShrImmSaturatingNarrowFlags flags)
364          {
365              OpCode32SimdShImm op = (OpCode32SimdShImm)context.CurrOp;
366  
367              bool scalar = (flags & ShrImmSaturatingNarrowFlags.Scalar) != 0;
368              bool signedSrc = (flags & ShrImmSaturatingNarrowFlags.SignedSrc) != 0;
369              bool signedDst = (flags & ShrImmSaturatingNarrowFlags.SignedDst) != 0;
370              bool round = (flags & ShrImmSaturatingNarrowFlags.Round) != 0;
371  
372              if (scalar)
373              {
374                  // TODO: Support scalar operation.
375                  throw new NotImplementedException();
376              }
377  
378              int shift = GetImmShr(op);
379              long roundConst = 1L << (shift - 1);
380  
381              EmitVectorUnaryNarrowOp32(context, (op1) =>
382              {
383                  if (op.Size <= 1 || !round)
384                  {
385                      if (round)
386                      {
387                          op1 = context.Add(op1, Const(op1.Type, roundConst));
388                      }
389  
390                      op1 = signedSrc ? context.ShiftRightSI(op1, Const(shift)) : context.ShiftRightUI(op1, Const(shift));
391                  }
392                  else /* if (op.Size == 2 && round) */
393                  {
394                      op1 = EmitShrImm64(context, op1, signedSrc, roundConst, shift); // shift <= 32
395                  }
396  
397                  return EmitSatQ(context, op1, 8 << op.Size, signedSrc, signedDst);
398              }, signedSrc);
399          }
400  
401          private static int GetImmShr(OpCode32SimdShImm op)
402          {
403              return (8 << op.Size) - op.Shift; // Shr amount is flipped.
404          }
405  
406          // dst64 = (Int(src64, signed) + roundConst) >> shift;
407          private static Operand EmitShrImm64(
408              ArmEmitterContext context,
409              Operand value,
410              bool signed,
411              long roundConst,
412              int shift)
413          {
414              MethodInfo info = signed
415                  ? typeof(SoftFallback).GetMethod(nameof(SoftFallback.SignedShrImm64))
416                  : typeof(SoftFallback).GetMethod(nameof(SoftFallback.UnsignedShrImm64));
417  
418              return context.Call(info, value, Const(roundConst), Const(shift));
419          }
420  
421          private static Operand EmitSatQ(ArmEmitterContext context, Operand value, int eSize, bool signedSrc, bool signedDst)
422          {
423              Debug.Assert(eSize <= 32);
424  
425              long intMin = signedDst ? -(1L << (eSize - 1)) : 0;
426              long intMax = signedDst ? (1L << (eSize - 1)) - 1 : (1L << eSize) - 1;
427  
428              Operand gt = signedSrc
429                  ? context.ICompareGreater(value, Const(value.Type, intMax))
430                  : context.ICompareGreaterUI(value, Const(value.Type, intMax));
431  
432              Operand lt = signedSrc
433                  ? context.ICompareLess(value, Const(value.Type, intMin))
434                  : context.ICompareLessUI(value, Const(value.Type, intMin));
435  
436              value = context.ConditionalSelect(gt, Const(value.Type, intMax), value);
437              value = context.ConditionalSelect(lt, Const(value.Type, intMin), value);
438  
439              Operand lblNoSat = Label();
440  
441              context.BranchIfFalse(lblNoSat, context.BitwiseOr(gt, lt));
442  
443              SetFpFlag(context, FPState.QcFlag, Const(1));
444  
445              context.MarkLabel(lblNoSat);
446  
447              return value;
448          }
449      }
450  }