test_lsu_atomic_wrapper.py
1 from collections import deque 2 import random 3 import pytest 4 from amaranth import * 5 from amaranth.utils import ceil_log2 6 from transactron import Method, TModule 7 from transactron.lib import Adapter, AdapterTrans 8 from transactron.testing import MethodMock, SimpleTestCircuit, TestCaseWithSimulator, TestbenchIO, def_method_mock 9 10 from coreblocks.arch.isa_consts import Funct3, Funct7 11 from coreblocks.arch.optypes import OpType 12 from coreblocks.func_blocks.fu.lsu.lsu_atomic_wrapper import LSUAtomicWrapper 13 from coreblocks.func_blocks.interface.func_protocols import FuncUnit 14 from coreblocks.interface.layouts import FuncUnitLayouts 15 from coreblocks.params.configurations import test_core_config 16 from coreblocks.params.genparams import GenParams 17 18 19 class FuncUnitMock(FuncUnit, Elaboratable): 20 def __init__(self, gen_params: GenParams): 21 layouts = gen_params.get(FuncUnitLayouts) 22 23 self.issue = Method(i=layouts.issue) 24 self.push_result = Method(i=layouts.push_result) 25 26 self.issue_tb = TestbenchIO(Adapter.create(self.issue)) 27 self.push_result_tb = TestbenchIO(AdapterTrans.create(self.push_result)) 28 29 def elaborate(self, platform): 30 m = TModule() 31 32 m.submodules.issue_tb = self.issue_tb 33 m.submodules.push_result_tb = self.push_result_tb 34 35 return m 36 37 38 class TestLSUAtomicWrapper(TestCaseWithSimulator): 39 @pytest.fixture(autouse=True) 40 def setup(self, fixture_initialize_testing_env): 41 random.seed(1258) 42 self.inst_cnt = 255 43 self.gen_params = GenParams(test_core_config.replace(rob_entries_bits=ceil_log2(self.inst_cnt))) 44 self.lsu = FuncUnitMock(self.gen_params) 45 self.dut = SimpleTestCircuit(LSUAtomicWrapper(self.gen_params, self.lsu)) 46 47 self.mem_cell = 0 48 self.instr_q = deque() 49 self.results = {} 50 self.lsu_res_q = deque() 51 self.lsu_except_q = deque() 52 53 self.generate_instrs(self.inst_cnt) 54 55 @def_method_mock(lambda self: self.lsu.issue_tb, enable=lambda _: random.random() < 0.9) 56 def lsu_issue_mock(self, arg): 57 @MethodMock.effect 58 def _(): 59 res = 0 60 addr = arg["s1_val"] + arg["imm"] 61 62 exc = self.lsu_except_q[0] 63 self.lsu_except_q.popleft() 64 assert addr == exc["addr"] 65 66 if not exc["exception"]: 67 match arg["exec_fn"]["op_type"]: 68 case OpType.STORE: 69 self.mem_cell = arg["s2_val"] 70 case OpType.LOAD: 71 res = self.mem_cell 72 case _: 73 assert False 74 75 self.lsu_res_q.append( 76 {"rob_id": arg["rob_id"], "result": res, "rp_dst": arg["rp_dst"], "exception": exc["exception"]} 77 ) 78 79 @def_method_mock( 80 lambda self: self.lsu.push_result_tb, enable=lambda self: random.random() < 0.9 and len(self.lsu_res_q) > 0 81 ) 82 def lsu_accept_mock(self, arg): 83 res = self.lsu_res_q[0] 84 85 @MethodMock.effect 86 def _(): 87 self.lsu_res_q.popleft() 88 89 return res 90 91 def generate_instrs(self, cnt): 92 generation_mem_cell = 0 93 generation_reservation_valid = 0 94 for i in range(cnt): 95 optype = random.choice([OpType.LOAD, OpType.STORE, OpType.ATOMIC_MEMORY_OP, OpType.ATOMIC_LR_SC]) 96 funct7 = 0 97 98 imm = random.randint(0, 1) 99 s1_val = random.randrange(0, 2**self.gen_params.isa.xlen - 1) 100 s2_val = random.randrange(0, 2**self.gen_params.isa.xlen) 101 rp_dst = random.randrange(0, 2**self.gen_params.phys_regs_bits) 102 103 exception = 0 104 result = 0 105 106 if optype == OpType.ATOMIC_MEMORY_OP: 107 funct7 = random.choice( 108 [ 109 Funct7.AMOSWAP, 110 Funct7.AMOADD, 111 Funct7.AMOMAXU, 112 Funct7.AMOMIN, 113 Funct7.AMOXOR, 114 Funct7.AMOOR, 115 Funct7.AMOAND, 116 Funct7.AMOMAX, 117 Funct7.AMOMINU, 118 ] 119 ) 120 121 exception = random.random() < 0.3 122 exception_on_load = exception and random.random() < 0.5 123 self.lsu_except_q.append({"addr": s1_val, "exception": exception_on_load}) 124 125 if not exception: 126 result = generation_mem_cell 127 128 def twos(x): 129 if x & (1 << (self.gen_params.isa.xlen - 1)): 130 x ^= (1 << self.gen_params.isa.xlen) - 1 131 x += 1 132 x *= -1 133 return x 134 135 match funct7: 136 case Funct7.AMOSWAP: 137 generation_mem_cell = s2_val 138 case Funct7.AMOADD: 139 generation_mem_cell += s2_val 140 generation_mem_cell %= 2**self.gen_params.isa.xlen 141 case Funct7.AMOAND: 142 generation_mem_cell &= s2_val 143 case Funct7.AMOOR: 144 generation_mem_cell |= s2_val 145 case Funct7.AMOXOR: 146 generation_mem_cell ^= s2_val 147 case Funct7.AMOMIN: 148 generation_mem_cell = ( 149 generation_mem_cell if twos(generation_mem_cell) < twos(s2_val) else s2_val 150 ) 151 case Funct7.AMOMAX: 152 generation_mem_cell = ( 153 generation_mem_cell if twos(generation_mem_cell) > twos(s2_val) else s2_val 154 ) 155 case Funct7.AMOMINU: 156 generation_mem_cell = min(generation_mem_cell, s2_val) 157 case Funct7.AMOMAXU: 158 generation_mem_cell = max(generation_mem_cell, s2_val) 159 160 if not exception_on_load: 161 self.lsu_except_q.append({"addr": s1_val, "exception": exception}) 162 elif optype == OpType.ATOMIC_LR_SC: 163 is_load = random.random() < 0.5 164 exception = random.random() < 0.3 165 sc_fail = False 166 if is_load: 167 funct7 = Funct7.LR 168 generation_reservation_valid = not exception 169 result = generation_mem_cell if not exception else 0 170 else: 171 funct7 = Funct7.SC 172 if generation_reservation_valid: 173 if not exception: 174 generation_mem_cell = s2_val 175 else: 176 sc_fail = True 177 exception = 0 178 179 result = 0 if generation_reservation_valid or exception else 1 180 generation_reservation_valid = 0 181 182 if not sc_fail: 183 self.lsu_except_q.append({"addr": s1_val, "exception": exception}) 184 185 elif optype == OpType.LOAD: 186 result = generation_mem_cell 187 self.lsu_except_q.append({"addr": s1_val + imm, "exception": 0}) 188 elif optype == OpType.STORE: 189 generation_mem_cell = s2_val 190 result = 0 191 self.lsu_except_q.append({"addr": s1_val + imm, "exception": 0}) 192 193 exec_fn = {"op_type": optype, "funct3": Funct3.W, "funct7": funct7} 194 rob_id = i 195 instr = { 196 "rp_dst": rp_dst, 197 "rob_id": rob_id, 198 "exec_fn": exec_fn, 199 "s1_val": s1_val, 200 "s2_val": s2_val, 201 "imm": imm, 202 "pc": 0, 203 } 204 self.instr_q.append(instr) 205 self.results[rob_id] = {"rob_id": rob_id, "rp_dst": rp_dst, "result": result, "exception": exception} 206 207 async def issue_process(self, sim): 208 while self.instr_q: 209 await self.dut.issue.call(sim, self.instr_q[0]) 210 self.instr_q.popleft() 211 await self.random_wait_geom(sim, 0.9) 212 213 async def accept_process(self, sim): 214 for _ in range(self.inst_cnt): 215 res = await self.dut.push_result.call(sim) 216 expected = self.results[res["rob_id"]] 217 assert res["rp_dst"] == expected["rp_dst"] 218 assert res["exception"] == expected["exception"] 219 assert res["result"] == expected["result"] 220 221 await self.random_wait_geom(sim, 0.9) 222 223 def test_randomized(self): 224 with self.run_simulation(self.dut, max_cycles=700) as sim: 225 sim.add_testbench(self.issue_process) 226 sim.add_testbench(self.accept_process)