functional_common.py
1 from dataclasses import asdict, dataclass 2 from itertools import product 3 import random 4 import pytest 5 from collections import deque 6 from typing import Generic, TypeVar 7 8 from amaranth import Elaboratable, Signal 9 10 from coreblocks.params import GenParams 11 from coreblocks.params.configurations import test_core_config 12 from coreblocks.priv.csr.csr_instances import CSRInstances 13 from transactron.testing.functions import data_const_to_dict 14 from transactron.utils.dependencies import DependencyContext 15 from coreblocks.params.fu_params import FunctionalComponentParams 16 from coreblocks.arch import Funct3, Funct7 17 from coreblocks.interface.keys import AsyncInterruptInsertSignalKey, ExceptionReportKey, CSRInstancesKey 18 from coreblocks.interface.layouts import ExceptionRegisterLayouts 19 from coreblocks.arch.optypes import OpType 20 from transactron.lib import Adapter 21 from transactron.testing import ( 22 RecordIntDict, 23 TestbenchIO, 24 TestCaseWithSimulator, 25 SimpleTestCircuit, 26 ProcessContext, 27 TestbenchContext, 28 ) 29 from transactron.utils import ModuleConnector 30 31 32 class FunctionalTestCircuit(Elaboratable): 33 """ 34 Common circuit for testing functional modules which are using @see{FuncUnitLayouts}. 35 36 Parameters 37 ---------- 38 gen: GenParams 39 Core generation parameters. 40 func_unit : FunctionalComponentParams 41 Class of functional unit to be tested. 42 """ 43 44 45 @dataclass 46 class ExecFn: 47 op_type: OpType 48 funct3: Funct3 = Funct3.ADD 49 funct7: Funct7 = Funct7.ADD 50 51 52 _T = TypeVar("_T") 53 54 55 class FunctionalUnitTestCase(TestCaseWithSimulator, Generic[_T]): 56 """ 57 Common test unit for testing functional modules which are using @see{FuncUnitLayouts}. 58 For example of usage see @see{MultiplierUnitTest}. 59 60 Attributes 61 ---------- 62 operations: dict[_T, ExecFn] 63 List of operations performed by this unit. 64 func_unit: FunctionalComponentParams 65 Unit parameters for the unit instantiated. 66 number_of_tests: int 67 Number of random tests to be performed per operation. 68 seed: int 69 Seed for generating random tests. 70 zero_imm: bool 71 Whether to set 'imm' to 0 or not in case 2nd operand comes from 's2_val' 72 core_config: CoreConfiguration 73 Core generation parameters. 74 """ 75 76 ops: dict[_T, ExecFn] 77 func_unit: FunctionalComponentParams 78 number_of_tests = 50 79 seed = 40 80 zero_imm = True 81 core_config = test_core_config 82 83 @staticmethod 84 def compute_result(i1: int, i2: int, i_imm: int, pc: int, fn: _T, xlen: int) -> dict[str, int]: 85 """ 86 Computes expected results. 87 88 Parameters 89 ---------- 90 i1: int 91 First argument value. 92 i2: int 93 Second argument value. 94 i_imm: int 95 Immediate value. 96 pc: int 97 Program counter value. 98 fn: _T 99 Function to execute. 100 xlen: int 101 Architecture bit width. 102 """ 103 raise NotImplementedError 104 105 @pytest.fixture(autouse=True) 106 def setup(self, fixture_initialize_testing_env): 107 self.gen_params = GenParams(test_core_config) 108 109 self.report_mock = TestbenchIO(Adapter(i=self.gen_params.get(ExceptionRegisterLayouts).report)) 110 self.csrs = CSRInstances(self.gen_params) 111 112 DependencyContext.get().add_dependency(ExceptionReportKey(), lambda: self.report_mock.adapter.iface) 113 DependencyContext.get().add_dependency(AsyncInterruptInsertSignalKey(), Signal()) 114 DependencyContext.get().add_dependency(CSRInstancesKey(), self.csrs) 115 116 self.m = SimpleTestCircuit(self.func_unit.get_module(self.gen_params), exclude=["increment_counter"]) 117 self.circ = ModuleConnector(dut=self.m, report_mock=self.report_mock, csrs=self.csrs) 118 119 random.seed(self.seed) 120 self.requests = deque[RecordIntDict]() 121 self.responses = deque[RecordIntDict]() 122 self.exceptions = deque[RecordIntDict]() 123 124 max_int = 2**self.gen_params.isa.xlen - 1 125 functions = list(self.ops.keys()) 126 127 for op, _ in product(functions, range(self.number_of_tests)): 128 data1 = random.randint(0, max_int) 129 data2 = random.randint(0, max_int) 130 data_imm = random.randint(0, max_int) 131 data2_is_imm = random.randint(0, 1) 132 rob_id = random.randint(0, 2**self.gen_params.rob_entries_bits - 1) 133 rp_dst = random.randint(0, 2**self.gen_params.phys_regs_bits - 1) 134 exec_fn = self.ops[op] 135 pc = random.randint(0, max_int) & ~0b11 136 results = self.compute_result(data1, data2, data_imm, pc, op, self.gen_params.isa.xlen) 137 138 self.requests.append( 139 { 140 "s1_val": data1, 141 "s2_val": 0 if data2_is_imm and self.zero_imm else data2, 142 "rob_id": rob_id, 143 "exec_fn": asdict(exec_fn), 144 "rp_dst": rp_dst, 145 "imm": data_imm if not self.zero_imm else data2 if data2_is_imm else 0, 146 "pc": pc, 147 } 148 ) 149 150 cause = None 151 if "exception" in results: 152 cause = results["exception"] 153 self.exceptions.append( 154 { 155 "rob_id": rob_id, 156 "cause": cause, 157 "pc": results.setdefault("exception_pc", pc), 158 "mtval": results.setdefault("mtval", 0), 159 } 160 ) 161 162 results.pop("exception") 163 results.pop("exception_pc") 164 results.pop("mtval") 165 166 self.responses.append({"rob_id": rob_id, "rp_dst": rp_dst, "exception": int(cause is not None)} | results) 167 168 async def consumer(self, sim: TestbenchContext): 169 while self.responses: 170 expected = self.responses.pop() 171 result = await self.m.push_result.call(sim) 172 assert expected == data_const_to_dict(result) 173 await self.random_wait(sim, self.max_wait) 174 175 async def producer(self, sim: TestbenchContext): 176 while self.requests: 177 req = self.requests.pop() 178 await self.m.issue.call(sim, req) 179 await self.random_wait(sim, self.max_wait) 180 181 async def exception_consumer(self, sim: TestbenchContext): 182 # This is a background testbench so that extra calls can be detected reliably 183 with sim.critical(): 184 while self.exceptions: 185 expected = self.exceptions.pop() 186 result = await self.report_mock.call(sim) 187 assert expected == data_const_to_dict(result) 188 await self.random_wait(sim, self.max_wait) 189 190 # keep partialy dependent tests from hanging up and detect extra calls 191 result = await self.report_mock.call(sim) 192 assert not True, "unexpected report call" 193 194 async def pipeline_verifier(self, sim: ProcessContext): 195 async for *_, ready, en, done in sim.tick().sample( 196 self.m.issue.adapter.iface.ready, self.m.issue.adapter.en, self.m.issue.adapter.done 197 ): 198 assert ready 199 assert en == done 200 201 def run_standard_fu_test(self, pipeline_test=False): 202 if pipeline_test: 203 self.max_wait = 0 204 else: 205 self.max_wait = 10 206 207 with self.run_simulation(self.circ) as sim: 208 sim.add_testbench(self.producer) 209 sim.add_testbench(self.consumer) 210 sim.add_testbench(self.exception_consumer, background=True) 211 if pipeline_test: 212 sim.add_process(self.pipeline_verifier)