/ test / func_blocks / fu / functional_common.py
functional_common.py
  1  from dataclasses import asdict, dataclass
  2  from itertools import product
  3  import random
  4  import pytest
  5  from collections import deque
  6  from typing import Generic, TypeVar
  7  
  8  from amaranth import Elaboratable, Signal
  9  
 10  from coreblocks.params import GenParams
 11  from coreblocks.params.configurations import test_core_config
 12  from coreblocks.priv.csr.csr_instances import CSRInstances
 13  from transactron.testing.functions import data_const_to_dict
 14  from transactron.utils.dependencies import DependencyContext
 15  from coreblocks.params.fu_params import FunctionalComponentParams
 16  from coreblocks.arch import Funct3, Funct7
 17  from coreblocks.interface.keys import AsyncInterruptInsertSignalKey, ExceptionReportKey, CSRInstancesKey
 18  from coreblocks.interface.layouts import ExceptionRegisterLayouts
 19  from coreblocks.arch.optypes import OpType
 20  from transactron.lib import Adapter
 21  from transactron.testing import (
 22      RecordIntDict,
 23      TestbenchIO,
 24      TestCaseWithSimulator,
 25      SimpleTestCircuit,
 26      ProcessContext,
 27      TestbenchContext,
 28  )
 29  from transactron.utils import ModuleConnector
 30  
 31  
 32  class FunctionalTestCircuit(Elaboratable):
 33      """
 34      Common circuit for testing functional modules which are using @see{FuncUnitLayouts}.
 35  
 36      Parameters
 37      ----------
 38      gen: GenParams
 39          Core generation parameters.
 40      func_unit : FunctionalComponentParams
 41          Class of functional unit to be tested.
 42      """
 43  
 44  
 45  @dataclass
 46  class ExecFn:
 47      op_type: OpType
 48      funct3: Funct3 = Funct3.ADD
 49      funct7: Funct7 = Funct7.ADD
 50  
 51  
 52  _T = TypeVar("_T")
 53  
 54  
 55  class FunctionalUnitTestCase(TestCaseWithSimulator, Generic[_T]):
 56      """
 57      Common test unit for testing functional modules which are using @see{FuncUnitLayouts}.
 58      For example of usage see @see{MultiplierUnitTest}.
 59  
 60      Attributes
 61      ----------
 62      operations: dict[_T, ExecFn]
 63          List of operations performed by this unit.
 64      func_unit: FunctionalComponentParams
 65          Unit parameters for the unit instantiated.
 66      number_of_tests: int
 67          Number of random tests to be performed per operation.
 68      seed: int
 69          Seed for generating random tests.
 70      zero_imm: bool
 71          Whether to set 'imm' to 0 or not in case 2nd operand comes from 's2_val'
 72      core_config: CoreConfiguration
 73          Core generation parameters.
 74      """
 75  
 76      ops: dict[_T, ExecFn]
 77      func_unit: FunctionalComponentParams
 78      number_of_tests = 50
 79      seed = 40
 80      zero_imm = True
 81      core_config = test_core_config
 82  
 83      @staticmethod
 84      def compute_result(i1: int, i2: int, i_imm: int, pc: int, fn: _T, xlen: int) -> dict[str, int]:
 85          """
 86          Computes expected results.
 87  
 88          Parameters
 89          ----------
 90          i1: int
 91              First argument value.
 92          i2: int
 93              Second argument value.
 94          i_imm: int
 95              Immediate value.
 96          pc: int
 97              Program counter value.
 98          fn: _T
 99              Function to execute.
100          xlen: int
101              Architecture bit width.
102          """
103          raise NotImplementedError
104  
105      @pytest.fixture(autouse=True)
106      def setup(self, fixture_initialize_testing_env):
107          self.gen_params = GenParams(test_core_config)
108  
109          self.report_mock = TestbenchIO(Adapter(i=self.gen_params.get(ExceptionRegisterLayouts).report))
110          self.csrs = CSRInstances(self.gen_params)
111  
112          DependencyContext.get().add_dependency(ExceptionReportKey(), lambda: self.report_mock.adapter.iface)
113          DependencyContext.get().add_dependency(AsyncInterruptInsertSignalKey(), Signal())
114          DependencyContext.get().add_dependency(CSRInstancesKey(), self.csrs)
115  
116          self.m = SimpleTestCircuit(self.func_unit.get_module(self.gen_params), exclude=["increment_counter"])
117          self.circ = ModuleConnector(dut=self.m, report_mock=self.report_mock, csrs=self.csrs)
118  
119          random.seed(self.seed)
120          self.requests = deque[RecordIntDict]()
121          self.responses = deque[RecordIntDict]()
122          self.exceptions = deque[RecordIntDict]()
123  
124          max_int = 2**self.gen_params.isa.xlen - 1
125          functions = list(self.ops.keys())
126  
127          for op, _ in product(functions, range(self.number_of_tests)):
128              data1 = random.randint(0, max_int)
129              data2 = random.randint(0, max_int)
130              data_imm = random.randint(0, max_int)
131              data2_is_imm = random.randint(0, 1)
132              rob_id = random.randint(0, 2**self.gen_params.rob_entries_bits - 1)
133              rp_dst = random.randint(0, 2**self.gen_params.phys_regs_bits - 1)
134              exec_fn = self.ops[op]
135              pc = random.randint(0, max_int) & ~0b11
136              results = self.compute_result(data1, data2, data_imm, pc, op, self.gen_params.isa.xlen)
137  
138              self.requests.append(
139                  {
140                      "s1_val": data1,
141                      "s2_val": 0 if data2_is_imm and self.zero_imm else data2,
142                      "rob_id": rob_id,
143                      "exec_fn": asdict(exec_fn),
144                      "rp_dst": rp_dst,
145                      "imm": data_imm if not self.zero_imm else data2 if data2_is_imm else 0,
146                      "pc": pc,
147                  }
148              )
149  
150              cause = None
151              if "exception" in results:
152                  cause = results["exception"]
153                  self.exceptions.append(
154                      {
155                          "rob_id": rob_id,
156                          "cause": cause,
157                          "pc": results.setdefault("exception_pc", pc),
158                          "mtval": results.setdefault("mtval", 0),
159                      }
160                  )
161  
162                  results.pop("exception")
163                  results.pop("exception_pc")
164                  results.pop("mtval")
165  
166              self.responses.append({"rob_id": rob_id, "rp_dst": rp_dst, "exception": int(cause is not None)} | results)
167  
168      async def consumer(self, sim: TestbenchContext):
169          while self.responses:
170              expected = self.responses.pop()
171              result = await self.m.push_result.call(sim)
172              assert expected == data_const_to_dict(result)
173              await self.random_wait(sim, self.max_wait)
174  
175      async def producer(self, sim: TestbenchContext):
176          while self.requests:
177              req = self.requests.pop()
178              await self.m.issue.call(sim, req)
179              await self.random_wait(sim, self.max_wait)
180  
181      async def exception_consumer(self, sim: TestbenchContext):
182          # This is a background testbench so that extra calls can be detected reliably
183          with sim.critical():
184              while self.exceptions:
185                  expected = self.exceptions.pop()
186                  result = await self.report_mock.call(sim)
187                  assert expected == data_const_to_dict(result)
188                  await self.random_wait(sim, self.max_wait)
189  
190          # keep partialy dependent tests from hanging up and detect extra calls
191          result = await self.report_mock.call(sim)
192          assert not True, "unexpected report call"
193  
194      async def pipeline_verifier(self, sim: ProcessContext):
195          async for *_, ready, en, done in sim.tick().sample(
196              self.m.issue.adapter.iface.ready, self.m.issue.adapter.en, self.m.issue.adapter.done
197          ):
198              assert ready
199              assert en == done
200  
201      def run_standard_fu_test(self, pipeline_test=False):
202          if pipeline_test:
203              self.max_wait = 0
204          else:
205              self.max_wait = 10
206  
207          with self.run_simulation(self.circ) as sim:
208              sim.add_testbench(self.producer)
209              sim.add_testbench(self.consumer)
210              sim.add_testbench(self.exception_consumer, background=True)
211              if pipeline_test:
212                  sim.add_process(self.pipeline_verifier)