/ test / func_blocks / fu / test_shift_unit.py
test_shift_unit.py
 1  from coreblocks.arch import Funct3, Funct7, OpType
 2  from coreblocks.func_blocks.fu.shift_unit import ShiftUnitFn, ShiftUnitComponent
 3  
 4  from test.func_blocks.fu.functional_common import ExecFn, FunctionalUnitTestCase
 5  
 6  
 7  class TestShiftUnit(FunctionalUnitTestCase[ShiftUnitFn.Fn]):
 8      func_unit = ShiftUnitComponent(zbb_enable=True)
 9      zero_imm = False
10  
11      ops = {
12          ShiftUnitFn.Fn.SLL: ExecFn(OpType.SHIFT, Funct3.SLL),
13          ShiftUnitFn.Fn.SRL: ExecFn(OpType.SHIFT, Funct3.SR, Funct7.SL),
14          ShiftUnitFn.Fn.SRA: ExecFn(OpType.SHIFT, Funct3.SR, Funct7.SA),
15          ShiftUnitFn.Fn.ROL: ExecFn(OpType.BIT_ROTATION, Funct3.ROL, Funct7.ROL),
16          ShiftUnitFn.Fn.ROR: ExecFn(OpType.BIT_ROTATION, Funct3.ROR, Funct7.ROR),
17      }
18  
19      @staticmethod
20      def compute_result(i1: int, i2: int, i_imm: int, pc: int, fn: ShiftUnitFn.Fn, xlen: int) -> dict[str, int]:
21          val2 = i_imm if i_imm else i2
22  
23          mask = (1 << xlen) - 1
24          res = 0
25          shamt = val2 & (xlen - 1)
26  
27          match fn:
28              case ShiftUnitFn.Fn.SLL:
29                  res = i1 << shamt
30              case ShiftUnitFn.Fn.SRA:
31                  if i1 & 2 ** (xlen - 1) != 0:
32                      res = (((1 << xlen) - 1) << xlen | i1) >> shamt
33                  else:
34                      res = i1 >> shamt
35              case ShiftUnitFn.Fn.SRL:
36                  res = i1 >> shamt
37              case ShiftUnitFn.Fn.ROR:
38                  res = (i1 >> shamt) | (i1 << (xlen - shamt))
39              case ShiftUnitFn.Fn.ROL:
40                  res = (i1 << shamt) | (i1 >> (xlen - shamt))
41          return {"result": res & mask}
42  
43      def test_fu(self):
44          self.run_standard_fu_test()
45  
46      def test_pipeline(self):
47          self.run_standard_fu_test(pipeline_test=True)