test_shift_unit.py
1 from coreblocks.arch import Funct3, Funct7, OpType 2 from coreblocks.func_blocks.fu.shift_unit import ShiftUnitFn, ShiftUnitComponent 3 4 from test.func_blocks.fu.functional_common import ExecFn, FunctionalUnitTestCase 5 6 7 class TestShiftUnit(FunctionalUnitTestCase[ShiftUnitFn.Fn]): 8 func_unit = ShiftUnitComponent(zbb_enable=True) 9 zero_imm = False 10 11 ops = { 12 ShiftUnitFn.Fn.SLL: ExecFn(OpType.SHIFT, Funct3.SLL), 13 ShiftUnitFn.Fn.SRL: ExecFn(OpType.SHIFT, Funct3.SR, Funct7.SL), 14 ShiftUnitFn.Fn.SRA: ExecFn(OpType.SHIFT, Funct3.SR, Funct7.SA), 15 ShiftUnitFn.Fn.ROL: ExecFn(OpType.BIT_ROTATION, Funct3.ROL, Funct7.ROL), 16 ShiftUnitFn.Fn.ROR: ExecFn(OpType.BIT_ROTATION, Funct3.ROR, Funct7.ROR), 17 } 18 19 @staticmethod 20 def compute_result(i1: int, i2: int, i_imm: int, pc: int, fn: ShiftUnitFn.Fn, xlen: int) -> dict[str, int]: 21 val2 = i_imm if i_imm else i2 22 23 mask = (1 << xlen) - 1 24 res = 0 25 shamt = val2 & (xlen - 1) 26 27 match fn: 28 case ShiftUnitFn.Fn.SLL: 29 res = i1 << shamt 30 case ShiftUnitFn.Fn.SRA: 31 if i1 & 2 ** (xlen - 1) != 0: 32 res = (((1 << xlen) - 1) << xlen | i1) >> shamt 33 else: 34 res = i1 >> shamt 35 case ShiftUnitFn.Fn.SRL: 36 res = i1 >> shamt 37 case ShiftUnitFn.Fn.ROR: 38 res = (i1 >> shamt) | (i1 << (xlen - shamt)) 39 case ShiftUnitFn.Fn.ROL: 40 res = (i1 << shamt) | (i1 >> (xlen - shamt)) 41 return {"result": res & mask} 42 43 def test_fu(self): 44 self.run_standard_fu_test() 45 46 def test_pipeline(self): 47 self.run_standard_fu_test(pipeline_test=True)