test_decode_stage.py
1 from typing import Optional 2 import pytest 3 from transactron.testing import ( 4 TestCaseWithSimulator, 5 SimpleTestCircuit, 6 TestbenchContext, 7 data_const_to_dict, 8 ) 9 10 from coreblocks.frontend.decoder.decode_stage import Decode 11 from coreblocks.params import GenParams 12 from coreblocks.arch import OpType, Funct3, Funct7 13 from coreblocks.params.configurations import test_core_config 14 15 16 def mk_test( 17 op_type: OpType, 18 funct3: Funct3 = Funct3(0), 19 funct7: Funct7 = Funct7(0), 20 rl_dst: int = 0, 21 rl_s1: int = 0, 22 rl_s2: int = 0, 23 imm: Optional[int] = None, 24 csr: Optional[int] = None, 25 ): 26 return { 27 "exec_fn": {"op_type": op_type, "funct3": funct3, "funct7": funct7}, 28 "regs_l": {"rl_dst": rl_dst, "rl_s1": rl_s1, "rl_s2": rl_s2}, 29 "imm": imm, 30 "pc": 0, 31 "csr": csr, 32 } 33 34 35 tests = [ 36 (0x02A28213, 0, mk_test(op_type=OpType.ARITHMETIC, funct3=Funct3.ADD, rl_dst=4, rl_s1=5, imm=42, csr=42)), 37 ( 38 0x003100B3, 39 0, 40 mk_test(op_type=OpType.ARITHMETIC, funct3=Funct3.ADD, funct7=Funct7.ADD, rl_dst=1, rl_s1=2, rl_s2=3), 41 ), 42 (0x00000000, 0, mk_test(op_type=OpType.EXCEPTION, funct3=Funct3._EILLEGALINSTR)), 43 (0x02A28213, 1, mk_test(op_type=OpType.EXCEPTION, funct3=Funct3._EINSTRACCESSFAULT)), 44 ] 45 46 47 class TestDecode(TestCaseWithSimulator): 48 @pytest.fixture(autouse=True) 49 def setup(self, fixture_initialize_testing_env): 50 self.gen_params = GenParams(test_core_config.replace(start_pc=24)) 51 52 self.decode = Decode(self.gen_params, lambda m: ()) 53 54 self.m = SimpleTestCircuit(self.decode) 55 56 async def decode_proc(self, sim: TestbenchContext): 57 for instr, access_fault, data in tests: 58 decoded = await self.m.decode.call(sim, instr=instr, access_fault=access_fault) 59 data = dict(data) 60 if data["csr"] is None: 61 data["csr"] = decoded.csr 62 if data["imm"] is None: 63 data["imm"] = decoded.imm 64 assert data_const_to_dict(decoded) == data 65 66 def test(self): 67 with self.run_simulation(self.m) as sim: 68 sim.add_testbench(self.decode_proc)