test_csr.py
1 from amaranth import * 2 import random 3 4 from transactron.lib import Adapter 5 from transactron.core.tmodule import TModule 6 from coreblocks.func_blocks.csr.csr_unit import CSRUnit 7 from coreblocks.priv.csr.csr_register import CSRRegister 8 from coreblocks.priv.csr.csr_instances import CSRInstances 9 from coreblocks.params import GenParams 10 from coreblocks.arch import Funct3, ExceptionCause, OpType, CSRAddress 11 from coreblocks.params.configurations import test_core_config 12 from coreblocks.interface.layouts import ExceptionRegisterLayouts, RetirementLayouts, FetchLayouts 13 from coreblocks.interface.keys import ( 14 AsyncInterruptInsertSignalKey, 15 UnsafeInstructionResolvedKey, 16 ExceptionReportKey, 17 InstructionPrecommitKey, 18 CSRInstancesKey, 19 ) 20 from coreblocks.arch.isa_consts import PrivilegeLevel 21 from transactron.lib.adapters import AdapterTrans 22 from transactron.utils.dependencies import DependencyContext 23 24 from transactron.testing import * 25 26 27 class CSRUnitTestCircuit(Elaboratable): 28 def __init__(self, gen_params: GenParams, csr_count: int, only_legal=True): 29 self.gen_params = gen_params 30 self.csr_count = csr_count 31 self.only_legal = only_legal 32 33 def elaborate(self, platform): 34 m = Module() 35 36 m.submodules.precommit = self.precommit = TestbenchIO( 37 Adapter( 38 i=self.gen_params.get(RetirementLayouts).precommit_in, 39 o=self.gen_params.get(RetirementLayouts).precommit_out, 40 nonexclusive=True, 41 combiner=lambda m, args, runs: args[0], 42 ).set(with_validate_arguments=True) 43 ) 44 m.submodules.exception_report = self.exception_report = TestbenchIO( 45 Adapter(i=self.gen_params.get(ExceptionRegisterLayouts).report) 46 ) 47 DependencyContext.get().add_dependency(InstructionPrecommitKey(), self.precommit.adapter.iface) 48 DependencyContext.get().add_dependency(ExceptionReportKey(), lambda: self.exception_report.adapter.iface) 49 50 m.submodules.dut = self.dut = CSRUnit(self.gen_params) 51 52 m.submodules.select = self.select = TestbenchIO(AdapterTrans.create(self.dut.select)) 53 m.submodules.insert = self.insert = TestbenchIO(AdapterTrans.create(self.dut.insert)) 54 m.submodules.update = self.update = TestbenchIO(AdapterTrans.create(self.dut.update[0])) 55 m.submodules.accept = self.accept = TestbenchIO(AdapterTrans.create(self.dut.get_result)) 56 m.submodules.fetch_resume = self.fetch_resume = TestbenchIO(Adapter(i=self.gen_params.get(FetchLayouts).resume)) 57 m.submodules.csr_instances = self.csr_instances = CSRInstances(self.gen_params) 58 m.submodules.priv_io = self.priv_io = TestbenchIO( 59 AdapterTrans.create(self.csr_instances.m_mode.priv_mode.write) 60 ) 61 m.submodules.mcounteren_io = self.mcounteren_io = TestbenchIO( 62 AdapterTrans.create(self.csr_instances.m_mode.mcounteren.write) 63 ) 64 if self.gen_params.supervisor_mode: 65 m.submodules.scounteren_io = self.scounteren_io = TestbenchIO( 66 AdapterTrans.create(self.csr_instances.s_mode.scounteren.write) 67 ) 68 DependencyContext.get().add_dependency(AsyncInterruptInsertSignalKey(), Signal()) 69 DependencyContext.get().add_dependency(CSRInstancesKey(), self.csr_instances) 70 DependencyContext.get().add_dependency(UnsafeInstructionResolvedKey(), self.fetch_resume.adapter.iface) 71 72 self.csr = {} 73 74 def make_csr(number: int): 75 csr = CSRRegister(csr_number=number, gen_params=self.gen_params) 76 self.csr[number] = csr 77 m.submodules += csr 78 79 # simple test not using external r/w functionality of csr 80 for i in range(self.csr_count): 81 make_csr(i) 82 83 if not self.only_legal: 84 make_csr(0xCC0) # read-only csr 85 make_csr(0x7FE) # machine mode only 86 87 return m 88 89 90 class TestCSRUnit(TestCaseWithSimulator): 91 def gen_expected_out(self, sim: TestbenchContext, op: Funct3, rd: int, rs1: int, operand_val: int, csr: int): 92 exp_read = {"rp_dst": rd, "result": sim.get(self.dut.csr[csr].value)} 93 rs1_val = {"rp_s1": rs1, "value": operand_val} 94 95 exp_write = {} 96 if op == Funct3.CSRRW or op == Funct3.CSRRWI: 97 exp_write = {"csr": csr, "value": operand_val} 98 elif (op == Funct3.CSRRC and rs1) or op == Funct3.CSRRCI: 99 exp_write = {"csr": csr, "value": exp_read["result"] & ~operand_val} 100 elif (op == Funct3.CSRRS and rs1) or op == Funct3.CSRRSI: 101 exp_write = {"csr": csr, "value": exp_read["result"] | operand_val} 102 else: 103 exp_write = {"csr": csr, "value": sim.get(self.dut.csr[csr].value)} 104 105 return {"exp_read": exp_read, "exp_write": exp_write, "rs1": rs1_val} 106 107 def generate_instruction(self, sim: TestbenchContext): 108 ops = [ 109 Funct3.CSRRW, 110 Funct3.CSRRC, 111 Funct3.CSRRS, 112 Funct3.CSRRWI, 113 Funct3.CSRRCI, 114 Funct3.CSRRSI, 115 ] 116 117 op = random.choice(ops) 118 imm_op = op == Funct3.CSRRWI or op == Funct3.CSRRCI or op == Funct3.CSRRSI 119 120 rd = random.randint(0, 15) 121 rs1 = 0 if imm_op else random.randint(0, 15) 122 imm = random.randint(0, 2**5 - 1) 123 rs1_val = random.randint(0, 2**self.gen_params.isa.xlen - 1) if rs1 else 0 124 operand_val = imm if imm_op else rs1_val 125 csr = random.choice(list(self.dut.csr.keys())) 126 127 exp = self.gen_expected_out(sim, op, rd, rs1, operand_val, csr) 128 129 value_available = random.random() < 0.2 130 131 return { 132 "instr": { 133 "exec_fn": {"op_type": OpType.CSR_IMM if imm_op else OpType.CSR_REG, "funct3": op, "funct7": 0}, 134 "rp_s1": 0 if value_available or imm_op else rs1, 135 "rp_s1_reg": rs1, 136 "s1_val": exp["rs1"]["value"] if value_available and not imm_op else 0, 137 "rp_dst": rd, 138 "imm": imm, 139 "csr": csr, 140 }, 141 "exp": exp, 142 } 143 144 async def process_test(self, sim: TestbenchContext): 145 self.dut.fetch_resume.enable(sim) 146 self.dut.exception_report.enable(sim) 147 for _ in range(self.cycles): 148 await self.random_wait_geom(sim) 149 150 op = self.generate_instruction(sim) 151 152 await self.dut.select.call(sim) 153 154 await self.dut.insert.call(sim, rs_data=op["instr"]) 155 156 await self.random_wait_geom(sim) 157 if op["exp"]["rs1"]["rp_s1"]: 158 await self.dut.update.call(sim, reg_id=op["exp"]["rs1"]["rp_s1"], reg_val=op["exp"]["rs1"]["value"]) 159 160 await self.random_wait_geom(sim) 161 # TODO: this is a hack, a real method mock should be used 162 for _, r in self.dut.precommit.adapter.validators: # type: ignore 163 sim.set(r, 1) 164 self.dut.precommit.call_init(sim, side_fx=1) # TODO: sensible precommit handling 165 166 await self.random_wait_geom(sim) 167 res, resume_res = await CallTrigger(sim).call(self.dut.accept).sample(self.dut.fetch_resume).until_done() 168 self.dut.precommit.disable(sim) 169 170 assert res is not None and resume_res is not None 171 assert res.rp_dst == op["exp"]["exp_read"]["rp_dst"] 172 if op["exp"]["exp_read"]["rp_dst"]: 173 assert res.result == op["exp"]["exp_read"]["result"] 174 assert sim.get(self.dut.csr[op["exp"]["exp_write"]["csr"]].value) == op["exp"]["exp_write"]["value"] 175 assert res.exception == 0 176 177 def test_randomized(self): 178 self.gen_params = GenParams(test_core_config) 179 random.seed(8) 180 181 self.cycles = 256 182 self.csr_count = 16 183 184 self.dut = CSRUnitTestCircuit(self.gen_params, self.csr_count) 185 186 with self.run_simulation(self.dut) as sim: 187 sim.add_testbench(self.process_test) 188 189 exception_csr_numbers = [ 190 0xCC0, # read_only 191 0xFFF, # nonexistent 192 0x7FE, # missing priv 193 ] 194 195 counteren_exception_cases = [ 196 { 197 "priv": PrivilegeLevel.SUPERVISOR, 198 "csr": CSRAddress.CYCLE, 199 "mcounteren": 0b000, 200 "scounteren": 0b111, 201 "expect_exception": True, 202 }, 203 { 204 "priv": PrivilegeLevel.USER, 205 "csr": CSRAddress.TIME, 206 "mcounteren": 0b010, 207 "scounteren": 0b000, 208 "expect_exception": True, 209 }, 210 { 211 "priv": PrivilegeLevel.USER, 212 "csr": CSRAddress.CYCLE, 213 "mcounteren": 0b001, 214 "scounteren": 0b001, 215 "expect_exception": False, 216 }, 217 { 218 "priv": PrivilegeLevel.MACHINE, 219 "csr": CSRAddress.CYCLE, 220 "mcounteren": 0b000, 221 "scounteren": 0b000, 222 "expect_exception": False, 223 }, 224 ] 225 226 async def process_exception_test(self, sim: TestbenchContext): 227 self.dut.fetch_resume.enable(sim) 228 self.dut.exception_report.enable(sim) 229 for csr in self.exception_csr_numbers: 230 if csr == 0x7FE: 231 await self.dut.priv_io.call(sim, data=PrivilegeLevel.USER) 232 else: 233 await self.dut.priv_io.call(sim, data=PrivilegeLevel.MACHINE) 234 235 await self.random_wait_geom(sim) 236 237 await self.dut.select.call(sim) 238 239 rob_id = random.randrange(2**self.gen_params.rob_entries_bits) 240 await self.dut.insert.call( 241 sim, 242 rs_data={ 243 "exec_fn": {"op_type": OpType.CSR_REG, "funct3": Funct3.CSRRW, "funct7": 0}, 244 "rp_s1": 0, 245 "rp_s1_reg": 1, 246 "s1_val": 1, 247 "rp_dst": 2, 248 "imm": 0, 249 "csr": csr, 250 "rob_id": rob_id, 251 }, 252 ) 253 254 await self.random_wait_geom(sim) 255 # TODO: this is a hack, a real method mock should be used 256 for _, r in self.dut.precommit.adapter.validators: # type: ignore 257 sim.set(r, 1) 258 self.dut.precommit.call_init(sim, side_fx=1) 259 260 await self.random_wait_geom(sim) 261 res, report = await CallTrigger(sim).call(self.dut.accept).sample(self.dut.exception_report).until_done() 262 self.dut.precommit.disable(sim) 263 264 assert res["exception"] == 1 265 assert report is not None 266 report_dict = data_const_to_dict(report) 267 report_dict.pop("mtval") # mtval tested in mtval.asm test 268 assert {"rob_id": rob_id, "cause": ExceptionCause.ILLEGAL_INSTRUCTION, "pc": 0} == report_dict 269 270 def test_exception(self): 271 self.gen_params = GenParams(test_core_config) 272 random.seed(9) 273 274 self.dut = CSRUnitTestCircuit(self.gen_params, 0, only_legal=False) 275 276 with self.run_simulation(self.dut) as sim: 277 sim.add_testbench(self.process_exception_test) 278 279 async def process_counteren_access_test(self, sim: TestbenchContext): 280 self.dut.fetch_resume.enable(sim) 281 self.dut.exception_report.enable(sim) 282 283 for idx, case in enumerate(self.counteren_exception_cases): 284 await self.dut.priv_io.call(sim, data=case["priv"]) 285 await self.dut.mcounteren_io.call(sim, data=case["mcounteren"]) 286 if self.gen_params.supervisor_mode: 287 await self.dut.scounteren_io.call(sim, data=case["scounteren"]) 288 289 await self.random_wait_geom(sim) 290 await self.dut.select.call(sim) 291 292 rob_id = idx + 100 293 await self.dut.insert.call( 294 sim, 295 rs_data={ 296 "exec_fn": { 297 "op_type": OpType.CSR_REG, 298 "funct3": Funct3.CSRRS, 299 "funct7": 0, 300 }, 301 "rp_s1": 0, 302 "rp_s1_reg": 0, 303 "s1_val": 0, 304 "rp_dst": 2, 305 "imm": 0, 306 "csr": case["csr"], 307 "rob_id": rob_id, 308 }, 309 ) 310 311 await self.random_wait_geom(sim) 312 for _, r in self.dut.precommit.adapter.validators: # type: ignore 313 sim.set(r, 1) 314 self.dut.precommit.call_init(sim, side_fx=1) 315 316 await self.random_wait_geom(sim) 317 res, report = await CallTrigger(sim).call(self.dut.accept).sample(self.dut.exception_report).until_done() 318 self.dut.precommit.disable(sim) 319 320 assert res is not None 321 assert res.exception == int(case["expect_exception"]) 322 323 if case["expect_exception"]: 324 assert report is not None 325 report_dict = data_const_to_dict(report) 326 report_dict.pop("mtval") 327 assert {"rob_id": rob_id, "cause": ExceptionCause.ILLEGAL_INSTRUCTION, "pc": 0} == report_dict 328 else: 329 assert report is None 330 331 def test_counteren_access(self): 332 self.gen_params = GenParams(test_core_config.replace(supervisor_mode=True, user_mode=True)) 333 random.seed(10) 334 335 self.dut = CSRUnitTestCircuit(self.gen_params, 0, only_legal=False) 336 337 with self.run_simulation(self.dut) as sim: 338 sim.add_testbench(self.process_counteren_access_test) 339 340 341 class TestCSRRegister(TestCaseWithSimulator): 342 async def randomized_process_test(self, sim: TestbenchContext): 343 # always enabled 344 self.dut.read.enable(sim) 345 346 previous_data = 0 347 for _ in range(self.cycles): 348 write = False 349 fu_write = False 350 fu_read = False 351 exp_write_data = None 352 353 if random.random() < 0.9: 354 write = True 355 exp_write_data = random.randint(0, 2**self.gen_params.isa.xlen - 1) 356 self.dut.write.call_init(sim, data=exp_write_data) 357 358 if random.random() < 0.3: 359 fu_write = True 360 # fu_write has priority over csr write, but it doesn't overwrite ro bits 361 write_arg = random.randint(0, 2**self.gen_params.isa.xlen - 1) 362 exp_write_data = (write_arg & ~self.ro_mask) | ( 363 (exp_write_data if exp_write_data is not None else previous_data) & self.ro_mask 364 ) 365 self.dut._fu_write.call_init(sim, data=write_arg) 366 367 if random.random() < 0.2: 368 fu_read = True 369 self.dut._fu_read.call_init(sim) 370 371 await sim.tick() 372 373 exp_read_data = exp_write_data if fu_write or write else previous_data 374 375 if fu_read: # in CSRUnit this call is called before write and returns previous result 376 assert data_const_to_dict(self.dut._fu_read.get_call_result(sim)) == {"data": exp_read_data} 377 378 assert data_const_to_dict(self.dut.read.get_call_result(sim)) == { 379 "data": exp_read_data, 380 "read": int(fu_read), 381 "written": int(fu_write), 382 } 383 384 read_result = self.dut.read.get_call_result(sim) 385 assert read_result is not None 386 previous_data = read_result.data 387 388 self.dut._fu_read.disable(sim) 389 self.dut._fu_write.disable(sim) 390 self.dut.write.disable(sim) 391 392 def test_randomized(self): 393 self.gen_params = GenParams(test_core_config) 394 random.seed(42) 395 396 self.cycles = 200 397 self.ro_mask = 0b101 398 399 self.dut = SimpleTestCircuit(CSRRegister(0, self.gen_params, ro_bits=self.ro_mask)) 400 401 with self.run_simulation(self.dut) as sim: 402 sim.add_testbench(self.randomized_process_test) 403 404 async def filtermap_process_test(self, sim: TestbenchContext): 405 prev_value = 0 406 for _ in range(50): 407 input = random.randrange(0, 2**34) 408 409 await self.dut._fu_write.call(sim, data=input) 410 output = (await self.dut._fu_read.call(sim))["data"] 411 412 expected = prev_value 413 if input & 1: 414 expected = input 415 if input & 2: 416 expected += 3 417 418 expected &= ~(2**32) 419 420 expected <<= 1 421 expected &= 2**34 - 1 422 423 assert output == expected 424 425 prev_value = output 426 427 def test_filtermap(self): 428 gen_params = GenParams(test_core_config) 429 430 def write_filtermap(m: TModule, v: Value): 431 res = Signal(34) 432 write = Signal() 433 m.d.comb += res.eq(v) 434 with m.If(v & 1): 435 m.d.comb += write.eq(1) 436 with m.If(v & 2): 437 m.d.comb += res.eq(v + 3) 438 return (write, res) 439 440 random.seed(4325) 441 442 self.dut = SimpleTestCircuit( 443 CSRRegister( 444 None, 445 gen_params, 446 width=34, 447 ro_bits=(1 << 32), 448 fu_read_map=lambda _, v: v << 1, 449 fu_write_filtermap=write_filtermap, 450 ), 451 ) 452 453 with self.run_simulation(self.dut) as sim: 454 sim.add_testbench(self.filtermap_process_test) 455 456 async def comb_process_test(self, sim: TestbenchContext): 457 self.dut.read.enable(sim) 458 self.dut.read_comb.enable(sim) 459 self.dut._fu_read.enable(sim) 460 461 self.dut._fu_write.call_init(sim, data=0xFFFF) 462 while self.dut._fu_write.get_call_result(sim) is None: 463 await sim.tick() 464 assert self.dut.read_comb.get_call_result(sim).data == 0xFFFF 465 assert self.dut._fu_read.get_call_result(sim).data == 0xAB 466 await sim.tick() 467 assert self.dut.read.get_call_result(sim)["data"] == 0xFFFB 468 assert self.dut._fu_read.get_call_result(sim)["data"] == 0xFFFB 469 await sim.tick() 470 471 self.dut._fu_write.call_init(sim, data=0x0FFF) 472 self.dut.write.call_init(sim, data=0xAAAA) 473 while self.dut._fu_write.get_call_result(sim) is None or self.dut.write.get_call_result(sim) is None: 474 await sim.tick() 475 assert data_const_to_dict(self.dut.read_comb.get_call_result(sim)) == {"data": 0x0FFF, "read": 1, "written": 1} 476 await sim.tick() 477 assert self.dut._fu_read.get_call_result(sim).data == 0xAAAA 478 await sim.tick() 479 480 # single cycle 481 self.dut._fu_write.call_init(sim, data=0x0BBB) 482 while self.dut._fu_write.get_call_result(sim) is None: 483 await sim.tick() 484 update_val = self.dut.read_comb.get_call_result(sim).data | 0xD000 485 self.dut.write.call_init(sim, data=update_val) 486 while self.dut.write.get_call_result(sim) is None: 487 await sim.tick() 488 await sim.tick() 489 assert self.dut._fu_read.get_call_result(sim).data == 0xDBBB 490 491 def test_comb(self): 492 gen_params = GenParams(test_core_config) 493 494 random.seed(4326) 495 496 self.dut = SimpleTestCircuit(CSRRegister(None, gen_params, ro_bits=0b1111, fu_write_priority=False, init=0xAB)) 497 498 with self.run_simulation(self.dut) as sim: 499 sim.add_testbench(self.comb_process_test)