/ test / func_blocks / csr / test_csr.py
test_csr.py
  1  from amaranth import *
  2  import random
  3  
  4  from transactron.lib import Adapter
  5  from transactron.core.tmodule import TModule
  6  from coreblocks.func_blocks.csr.csr_unit import CSRUnit
  7  from coreblocks.priv.csr.csr_register import CSRRegister
  8  from coreblocks.priv.csr.csr_instances import CSRInstances
  9  from coreblocks.params import GenParams
 10  from coreblocks.arch import Funct3, ExceptionCause, OpType, CSRAddress
 11  from coreblocks.params.configurations import test_core_config
 12  from coreblocks.interface.layouts import ExceptionRegisterLayouts, RetirementLayouts, FetchLayouts
 13  from coreblocks.interface.keys import (
 14      AsyncInterruptInsertSignalKey,
 15      UnsafeInstructionResolvedKey,
 16      ExceptionReportKey,
 17      InstructionPrecommitKey,
 18      CSRInstancesKey,
 19  )
 20  from coreblocks.arch.isa_consts import PrivilegeLevel
 21  from transactron.lib.adapters import AdapterTrans
 22  from transactron.utils.dependencies import DependencyContext
 23  
 24  from transactron.testing import *
 25  
 26  
 27  class CSRUnitTestCircuit(Elaboratable):
 28      def __init__(self, gen_params: GenParams, csr_count: int, only_legal=True):
 29          self.gen_params = gen_params
 30          self.csr_count = csr_count
 31          self.only_legal = only_legal
 32  
 33      def elaborate(self, platform):
 34          m = Module()
 35  
 36          m.submodules.precommit = self.precommit = TestbenchIO(
 37              Adapter(
 38                  i=self.gen_params.get(RetirementLayouts).precommit_in,
 39                  o=self.gen_params.get(RetirementLayouts).precommit_out,
 40                  nonexclusive=True,
 41                  combiner=lambda m, args, runs: args[0],
 42              ).set(with_validate_arguments=True)
 43          )
 44          m.submodules.exception_report = self.exception_report = TestbenchIO(
 45              Adapter(i=self.gen_params.get(ExceptionRegisterLayouts).report)
 46          )
 47          DependencyContext.get().add_dependency(InstructionPrecommitKey(), self.precommit.adapter.iface)
 48          DependencyContext.get().add_dependency(ExceptionReportKey(), lambda: self.exception_report.adapter.iface)
 49  
 50          m.submodules.dut = self.dut = CSRUnit(self.gen_params)
 51  
 52          m.submodules.select = self.select = TestbenchIO(AdapterTrans.create(self.dut.select))
 53          m.submodules.insert = self.insert = TestbenchIO(AdapterTrans.create(self.dut.insert))
 54          m.submodules.update = self.update = TestbenchIO(AdapterTrans.create(self.dut.update[0]))
 55          m.submodules.accept = self.accept = TestbenchIO(AdapterTrans.create(self.dut.get_result))
 56          m.submodules.fetch_resume = self.fetch_resume = TestbenchIO(Adapter(i=self.gen_params.get(FetchLayouts).resume))
 57          m.submodules.csr_instances = self.csr_instances = CSRInstances(self.gen_params)
 58          m.submodules.priv_io = self.priv_io = TestbenchIO(
 59              AdapterTrans.create(self.csr_instances.m_mode.priv_mode.write)
 60          )
 61          m.submodules.mcounteren_io = self.mcounteren_io = TestbenchIO(
 62              AdapterTrans.create(self.csr_instances.m_mode.mcounteren.write)
 63          )
 64          if self.gen_params.supervisor_mode:
 65              m.submodules.scounteren_io = self.scounteren_io = TestbenchIO(
 66                  AdapterTrans.create(self.csr_instances.s_mode.scounteren.write)
 67              )
 68          DependencyContext.get().add_dependency(AsyncInterruptInsertSignalKey(), Signal())
 69          DependencyContext.get().add_dependency(CSRInstancesKey(), self.csr_instances)
 70          DependencyContext.get().add_dependency(UnsafeInstructionResolvedKey(), self.fetch_resume.adapter.iface)
 71  
 72          self.csr = {}
 73  
 74          def make_csr(number: int):
 75              csr = CSRRegister(csr_number=number, gen_params=self.gen_params)
 76              self.csr[number] = csr
 77              m.submodules += csr
 78  
 79          # simple test not using external r/w functionality of csr
 80          for i in range(self.csr_count):
 81              make_csr(i)
 82  
 83          if not self.only_legal:
 84              make_csr(0xCC0)  # read-only csr
 85              make_csr(0x7FE)  # machine mode only
 86  
 87          return m
 88  
 89  
 90  class TestCSRUnit(TestCaseWithSimulator):
 91      def gen_expected_out(self, sim: TestbenchContext, op: Funct3, rd: int, rs1: int, operand_val: int, csr: int):
 92          exp_read = {"rp_dst": rd, "result": sim.get(self.dut.csr[csr].value)}
 93          rs1_val = {"rp_s1": rs1, "value": operand_val}
 94  
 95          exp_write = {}
 96          if op == Funct3.CSRRW or op == Funct3.CSRRWI:
 97              exp_write = {"csr": csr, "value": operand_val}
 98          elif (op == Funct3.CSRRC and rs1) or op == Funct3.CSRRCI:
 99              exp_write = {"csr": csr, "value": exp_read["result"] & ~operand_val}
100          elif (op == Funct3.CSRRS and rs1) or op == Funct3.CSRRSI:
101              exp_write = {"csr": csr, "value": exp_read["result"] | operand_val}
102          else:
103              exp_write = {"csr": csr, "value": sim.get(self.dut.csr[csr].value)}
104  
105          return {"exp_read": exp_read, "exp_write": exp_write, "rs1": rs1_val}
106  
107      def generate_instruction(self, sim: TestbenchContext):
108          ops = [
109              Funct3.CSRRW,
110              Funct3.CSRRC,
111              Funct3.CSRRS,
112              Funct3.CSRRWI,
113              Funct3.CSRRCI,
114              Funct3.CSRRSI,
115          ]
116  
117          op = random.choice(ops)
118          imm_op = op == Funct3.CSRRWI or op == Funct3.CSRRCI or op == Funct3.CSRRSI
119  
120          rd = random.randint(0, 15)
121          rs1 = 0 if imm_op else random.randint(0, 15)
122          imm = random.randint(0, 2**5 - 1)
123          rs1_val = random.randint(0, 2**self.gen_params.isa.xlen - 1) if rs1 else 0
124          operand_val = imm if imm_op else rs1_val
125          csr = random.choice(list(self.dut.csr.keys()))
126  
127          exp = self.gen_expected_out(sim, op, rd, rs1, operand_val, csr)
128  
129          value_available = random.random() < 0.2
130  
131          return {
132              "instr": {
133                  "exec_fn": {"op_type": OpType.CSR_IMM if imm_op else OpType.CSR_REG, "funct3": op, "funct7": 0},
134                  "rp_s1": 0 if value_available or imm_op else rs1,
135                  "rp_s1_reg": rs1,
136                  "s1_val": exp["rs1"]["value"] if value_available and not imm_op else 0,
137                  "rp_dst": rd,
138                  "imm": imm,
139                  "csr": csr,
140              },
141              "exp": exp,
142          }
143  
144      async def process_test(self, sim: TestbenchContext):
145          self.dut.fetch_resume.enable(sim)
146          self.dut.exception_report.enable(sim)
147          for _ in range(self.cycles):
148              await self.random_wait_geom(sim)
149  
150              op = self.generate_instruction(sim)
151  
152              await self.dut.select.call(sim)
153  
154              await self.dut.insert.call(sim, rs_data=op["instr"])
155  
156              await self.random_wait_geom(sim)
157              if op["exp"]["rs1"]["rp_s1"]:
158                  await self.dut.update.call(sim, reg_id=op["exp"]["rs1"]["rp_s1"], reg_val=op["exp"]["rs1"]["value"])
159  
160              await self.random_wait_geom(sim)
161              # TODO: this is a hack, a real method mock should be used
162              for _, r in self.dut.precommit.adapter.validators:  # type: ignore
163                  sim.set(r, 1)
164              self.dut.precommit.call_init(sim, side_fx=1)  # TODO: sensible precommit handling
165  
166              await self.random_wait_geom(sim)
167              res, resume_res = await CallTrigger(sim).call(self.dut.accept).sample(self.dut.fetch_resume).until_done()
168              self.dut.precommit.disable(sim)
169  
170              assert res is not None and resume_res is not None
171              assert res.rp_dst == op["exp"]["exp_read"]["rp_dst"]
172              if op["exp"]["exp_read"]["rp_dst"]:
173                  assert res.result == op["exp"]["exp_read"]["result"]
174              assert sim.get(self.dut.csr[op["exp"]["exp_write"]["csr"]].value) == op["exp"]["exp_write"]["value"]
175              assert res.exception == 0
176  
177      def test_randomized(self):
178          self.gen_params = GenParams(test_core_config)
179          random.seed(8)
180  
181          self.cycles = 256
182          self.csr_count = 16
183  
184          self.dut = CSRUnitTestCircuit(self.gen_params, self.csr_count)
185  
186          with self.run_simulation(self.dut) as sim:
187              sim.add_testbench(self.process_test)
188  
189      exception_csr_numbers = [
190          0xCC0,  # read_only
191          0xFFF,  # nonexistent
192          0x7FE,  # missing priv
193      ]
194  
195      counteren_exception_cases = [
196          {
197              "priv": PrivilegeLevel.SUPERVISOR,
198              "csr": CSRAddress.CYCLE,
199              "mcounteren": 0b000,
200              "scounteren": 0b111,
201              "expect_exception": True,
202          },
203          {
204              "priv": PrivilegeLevel.USER,
205              "csr": CSRAddress.TIME,
206              "mcounteren": 0b010,
207              "scounteren": 0b000,
208              "expect_exception": True,
209          },
210          {
211              "priv": PrivilegeLevel.USER,
212              "csr": CSRAddress.CYCLE,
213              "mcounteren": 0b001,
214              "scounteren": 0b001,
215              "expect_exception": False,
216          },
217          {
218              "priv": PrivilegeLevel.MACHINE,
219              "csr": CSRAddress.CYCLE,
220              "mcounteren": 0b000,
221              "scounteren": 0b000,
222              "expect_exception": False,
223          },
224      ]
225  
226      async def process_exception_test(self, sim: TestbenchContext):
227          self.dut.fetch_resume.enable(sim)
228          self.dut.exception_report.enable(sim)
229          for csr in self.exception_csr_numbers:
230              if csr == 0x7FE:
231                  await self.dut.priv_io.call(sim, data=PrivilegeLevel.USER)
232              else:
233                  await self.dut.priv_io.call(sim, data=PrivilegeLevel.MACHINE)
234  
235              await self.random_wait_geom(sim)
236  
237              await self.dut.select.call(sim)
238  
239              rob_id = random.randrange(2**self.gen_params.rob_entries_bits)
240              await self.dut.insert.call(
241                  sim,
242                  rs_data={
243                      "exec_fn": {"op_type": OpType.CSR_REG, "funct3": Funct3.CSRRW, "funct7": 0},
244                      "rp_s1": 0,
245                      "rp_s1_reg": 1,
246                      "s1_val": 1,
247                      "rp_dst": 2,
248                      "imm": 0,
249                      "csr": csr,
250                      "rob_id": rob_id,
251                  },
252              )
253  
254              await self.random_wait_geom(sim)
255              # TODO: this is a hack, a real method mock should be used
256              for _, r in self.dut.precommit.adapter.validators:  # type: ignore
257                  sim.set(r, 1)
258              self.dut.precommit.call_init(sim, side_fx=1)
259  
260              await self.random_wait_geom(sim)
261              res, report = await CallTrigger(sim).call(self.dut.accept).sample(self.dut.exception_report).until_done()
262              self.dut.precommit.disable(sim)
263  
264              assert res["exception"] == 1
265              assert report is not None
266              report_dict = data_const_to_dict(report)
267              report_dict.pop("mtval")  # mtval tested in mtval.asm test
268              assert {"rob_id": rob_id, "cause": ExceptionCause.ILLEGAL_INSTRUCTION, "pc": 0} == report_dict
269  
270      def test_exception(self):
271          self.gen_params = GenParams(test_core_config)
272          random.seed(9)
273  
274          self.dut = CSRUnitTestCircuit(self.gen_params, 0, only_legal=False)
275  
276          with self.run_simulation(self.dut) as sim:
277              sim.add_testbench(self.process_exception_test)
278  
279      async def process_counteren_access_test(self, sim: TestbenchContext):
280          self.dut.fetch_resume.enable(sim)
281          self.dut.exception_report.enable(sim)
282  
283          for idx, case in enumerate(self.counteren_exception_cases):
284              await self.dut.priv_io.call(sim, data=case["priv"])
285              await self.dut.mcounteren_io.call(sim, data=case["mcounteren"])
286              if self.gen_params.supervisor_mode:
287                  await self.dut.scounteren_io.call(sim, data=case["scounteren"])
288  
289              await self.random_wait_geom(sim)
290              await self.dut.select.call(sim)
291  
292              rob_id = idx + 100
293              await self.dut.insert.call(
294                  sim,
295                  rs_data={
296                      "exec_fn": {
297                          "op_type": OpType.CSR_REG,
298                          "funct3": Funct3.CSRRS,
299                          "funct7": 0,
300                      },
301                      "rp_s1": 0,
302                      "rp_s1_reg": 0,
303                      "s1_val": 0,
304                      "rp_dst": 2,
305                      "imm": 0,
306                      "csr": case["csr"],
307                      "rob_id": rob_id,
308                  },
309              )
310  
311              await self.random_wait_geom(sim)
312              for _, r in self.dut.precommit.adapter.validators:  # type: ignore
313                  sim.set(r, 1)
314              self.dut.precommit.call_init(sim, side_fx=1)
315  
316              await self.random_wait_geom(sim)
317              res, report = await CallTrigger(sim).call(self.dut.accept).sample(self.dut.exception_report).until_done()
318              self.dut.precommit.disable(sim)
319  
320              assert res is not None
321              assert res.exception == int(case["expect_exception"])
322  
323              if case["expect_exception"]:
324                  assert report is not None
325                  report_dict = data_const_to_dict(report)
326                  report_dict.pop("mtval")
327                  assert {"rob_id": rob_id, "cause": ExceptionCause.ILLEGAL_INSTRUCTION, "pc": 0} == report_dict
328              else:
329                  assert report is None
330  
331      def test_counteren_access(self):
332          self.gen_params = GenParams(test_core_config.replace(supervisor_mode=True, user_mode=True))
333          random.seed(10)
334  
335          self.dut = CSRUnitTestCircuit(self.gen_params, 0, only_legal=False)
336  
337          with self.run_simulation(self.dut) as sim:
338              sim.add_testbench(self.process_counteren_access_test)
339  
340  
341  class TestCSRRegister(TestCaseWithSimulator):
342      async def randomized_process_test(self, sim: TestbenchContext):
343          # always enabled
344          self.dut.read.enable(sim)
345  
346          previous_data = 0
347          for _ in range(self.cycles):
348              write = False
349              fu_write = False
350              fu_read = False
351              exp_write_data = None
352  
353              if random.random() < 0.9:
354                  write = True
355                  exp_write_data = random.randint(0, 2**self.gen_params.isa.xlen - 1)
356                  self.dut.write.call_init(sim, data=exp_write_data)
357  
358              if random.random() < 0.3:
359                  fu_write = True
360                  # fu_write has priority over csr write, but it doesn't overwrite ro bits
361                  write_arg = random.randint(0, 2**self.gen_params.isa.xlen - 1)
362                  exp_write_data = (write_arg & ~self.ro_mask) | (
363                      (exp_write_data if exp_write_data is not None else previous_data) & self.ro_mask
364                  )
365                  self.dut._fu_write.call_init(sim, data=write_arg)
366  
367              if random.random() < 0.2:
368                  fu_read = True
369                  self.dut._fu_read.call_init(sim)
370  
371              await sim.tick()
372  
373              exp_read_data = exp_write_data if fu_write or write else previous_data
374  
375              if fu_read:  # in CSRUnit this call is called before write and returns previous result
376                  assert data_const_to_dict(self.dut._fu_read.get_call_result(sim)) == {"data": exp_read_data}
377  
378              assert data_const_to_dict(self.dut.read.get_call_result(sim)) == {
379                  "data": exp_read_data,
380                  "read": int(fu_read),
381                  "written": int(fu_write),
382              }
383  
384              read_result = self.dut.read.get_call_result(sim)
385              assert read_result is not None
386              previous_data = read_result.data
387  
388              self.dut._fu_read.disable(sim)
389              self.dut._fu_write.disable(sim)
390              self.dut.write.disable(sim)
391  
392      def test_randomized(self):
393          self.gen_params = GenParams(test_core_config)
394          random.seed(42)
395  
396          self.cycles = 200
397          self.ro_mask = 0b101
398  
399          self.dut = SimpleTestCircuit(CSRRegister(0, self.gen_params, ro_bits=self.ro_mask))
400  
401          with self.run_simulation(self.dut) as sim:
402              sim.add_testbench(self.randomized_process_test)
403  
404      async def filtermap_process_test(self, sim: TestbenchContext):
405          prev_value = 0
406          for _ in range(50):
407              input = random.randrange(0, 2**34)
408  
409              await self.dut._fu_write.call(sim, data=input)
410              output = (await self.dut._fu_read.call(sim))["data"]
411  
412              expected = prev_value
413              if input & 1:
414                  expected = input
415                  if input & 2:
416                      expected += 3
417  
418                  expected &= ~(2**32)
419  
420                  expected <<= 1
421                  expected &= 2**34 - 1
422  
423              assert output == expected
424  
425              prev_value = output
426  
427      def test_filtermap(self):
428          gen_params = GenParams(test_core_config)
429  
430          def write_filtermap(m: TModule, v: Value):
431              res = Signal(34)
432              write = Signal()
433              m.d.comb += res.eq(v)
434              with m.If(v & 1):
435                  m.d.comb += write.eq(1)
436              with m.If(v & 2):
437                  m.d.comb += res.eq(v + 3)
438              return (write, res)
439  
440          random.seed(4325)
441  
442          self.dut = SimpleTestCircuit(
443              CSRRegister(
444                  None,
445                  gen_params,
446                  width=34,
447                  ro_bits=(1 << 32),
448                  fu_read_map=lambda _, v: v << 1,
449                  fu_write_filtermap=write_filtermap,
450              ),
451          )
452  
453          with self.run_simulation(self.dut) as sim:
454              sim.add_testbench(self.filtermap_process_test)
455  
456      async def comb_process_test(self, sim: TestbenchContext):
457          self.dut.read.enable(sim)
458          self.dut.read_comb.enable(sim)
459          self.dut._fu_read.enable(sim)
460  
461          self.dut._fu_write.call_init(sim, data=0xFFFF)
462          while self.dut._fu_write.get_call_result(sim) is None:
463              await sim.tick()
464          assert self.dut.read_comb.get_call_result(sim).data == 0xFFFF
465          assert self.dut._fu_read.get_call_result(sim).data == 0xAB
466          await sim.tick()
467          assert self.dut.read.get_call_result(sim)["data"] == 0xFFFB
468          assert self.dut._fu_read.get_call_result(sim)["data"] == 0xFFFB
469          await sim.tick()
470  
471          self.dut._fu_write.call_init(sim, data=0x0FFF)
472          self.dut.write.call_init(sim, data=0xAAAA)
473          while self.dut._fu_write.get_call_result(sim) is None or self.dut.write.get_call_result(sim) is None:
474              await sim.tick()
475          assert data_const_to_dict(self.dut.read_comb.get_call_result(sim)) == {"data": 0x0FFF, "read": 1, "written": 1}
476          await sim.tick()
477          assert self.dut._fu_read.get_call_result(sim).data == 0xAAAA
478          await sim.tick()
479  
480          # single cycle
481          self.dut._fu_write.call_init(sim, data=0x0BBB)
482          while self.dut._fu_write.get_call_result(sim) is None:
483              await sim.tick()
484          update_val = self.dut.read_comb.get_call_result(sim).data | 0xD000
485          self.dut.write.call_init(sim, data=update_val)
486          while self.dut.write.get_call_result(sim) is None:
487              await sim.tick()
488          await sim.tick()
489          assert self.dut._fu_read.get_call_result(sim).data == 0xDBBB
490  
491      def test_comb(self):
492          gen_params = GenParams(test_core_config)
493  
494          random.seed(4326)
495  
496          self.dut = SimpleTestCircuit(CSRRegister(None, gen_params, ro_bits=0b1111, fu_write_priority=False, init=0xAB))
497  
498          with self.run_simulation(self.dut) as sim:
499              sim.add_testbench(self.comb_process_test)