test_checkpointing.py
1 import random 2 import pytest 3 from amaranth.lib.enum import auto 4 from collections import deque 5 from enum import Enum 6 7 from coreblocks.arch import OpType 8 from coreblocks.params import GenParams 9 from coreblocks.params.configurations import test_core_config 10 from transactron.testing import CallTrigger, MethodMock, TestCaseWithSimulator, def_method_mock 11 12 from test.scheduler.test_scheduler import SchedulerTestCircuit, MockedBlockComponent 13 14 15 class TestSchedulerCheckpointing(TestCaseWithSimulator): 16 @pytest.mark.parametrize("tag_bits, checkpoint_count", [(2, 3), (5, 8)]) 17 def test_randomized(self, tag_bits: int, checkpoint_count: int): 18 gen_params = GenParams( 19 test_core_config.replace( 20 func_units_config=( 21 MockedBlockComponent({OpType.ARITHMETIC}, rs_entries=4), 22 MockedBlockComponent({OpType.BRANCH}, rs_entries=4), 23 ), 24 tag_bits=tag_bits, 25 checkpoint_count=checkpoint_count, 26 allow_partial_extensions=True, 27 ) 28 ) 29 30 dut = SchedulerTestCircuit(gen_params) 31 32 branch_in_flight = set() 33 34 instr_cnt = 512 35 exp_rs_branch = deque() 36 exp_rs_arith = deque() 37 38 correct_path_id = 0 39 wrong_path_id = 0x8000 40 on_correct_path = True 41 free_rp = 1 42 frat_to_restore = [] 43 44 rollback_tag = 0 45 rollback_tag_v = False 46 47 random.seed(42) 48 49 end = False 50 51 class BranchEncoding(Enum): 52 CORRECT_PATH_OK = auto() 53 CORRECT_PATH_MISPRED_EXIT = auto() 54 WRONG_PATH_OK = auto() 55 WRONG_PATH_WITH_ROLLBACK = auto() 56 57 in_order_branch_encoding = deque() 58 in_order_arith_encoding = deque() 59 frat = [0 for _ in range(2**gen_params.isa.reg_cnt_log)] 60 61 def get_instr(): 62 nonlocal correct_path_id, wrong_path_id, on_correct_path, free_rp, frat, frat_to_restore, rollback_tag 63 nonlocal rollback_tag_v 64 is_branch = random.randint(0, 1) 65 66 rd = random.randrange(0, 4) 67 rs = random.randrange(0, 4) 68 instr = { 69 "exec_fn": {"op_type": OpType.BRANCH if is_branch else OpType.ARITHMETIC}, 70 "imm": correct_path_id if on_correct_path else wrong_path_id, 71 "regs_l": { 72 "rl_dst": rd, 73 "rl_s1": rs, 74 }, 75 "rollback_tag": rollback_tag, 76 "rollback_tag_v": rollback_tag_v, 77 "commit_checkpoint": is_branch, 78 } 79 rollback_tag_v = False 80 81 if on_correct_path: 82 if is_branch: 83 exp_rs_branch.append(frat[rs]) 84 else: 85 exp_rs_arith.append(frat[rs]) 86 correct_path_id += 1 87 else: 88 wrong_path_id += 1 89 90 if rd != 0: 91 frat[rd] = free_rp 92 free_rp += 1 93 if free_rp == gen_params.phys_regs: 94 free_rp = 1 95 96 if is_branch: 97 is_misprediction = random.randint(0, 1) 98 if on_correct_path: 99 in_order_branch_encoding.append( 100 BranchEncoding.CORRECT_PATH_MISPRED_EXIT if is_misprediction else BranchEncoding.CORRECT_PATH_OK 101 ) 102 if is_misprediction: 103 on_correct_path = False 104 frat_to_restore = frat.copy() 105 else: 106 in_order_branch_encoding.append( 107 BranchEncoding.WRONG_PATH_WITH_ROLLBACK if is_misprediction else BranchEncoding.WRONG_PATH_OK 108 ) 109 else: 110 in_order_arith_encoding.append(on_correct_path) 111 112 return instr 113 114 async def input_process(sim): 115 nonlocal end 116 for _ in range(instr_cnt): 117 data = get_instr() 118 await dut.instr_inp.call(sim, count=1, data=[data]) 119 await self.random_wait_geom(sim, 0.5) 120 end = True 121 122 rob_id_to_imm_id = {} 123 124 async def free_rf_process(sim): 125 free_rp_inp = 1 126 while True: 127 await dut.free_rf_inp.call(sim, {"ident": free_rp_inp}) 128 free_rp_inp += 1 129 if free_rp_inp == gen_params.phys_regs: 130 free_rp_inp = 1 131 132 retire_imm_ids = 0 133 current_tag = 0 134 135 async def rob_retire_process(sim): 136 nonlocal current_tag, retire_imm_ids, end 137 for _ in range(instr_cnt): 138 await self.random_wait_geom(sim, 0.4) 139 140 _, active_tags, peek_res, rob_idxs = ( 141 await CallTrigger(sim) 142 .call(dut.rob_retire, count=1) 143 .call(dut.get_active_tags) 144 .call(dut.rob_peek) 145 .call(dut.rob_get_indices) 146 .until_all_done() 147 ) 148 active_tags = active_tags["active_tags"] 149 entry = peek_res.entries[0]["rob_data"] 150 rob_id = rob_idxs["start"] 151 152 current_tag += entry["tag_increment"] 153 current_tag %= 2**gen_params.tag_bits 154 155 if active_tags[current_tag]: 156 # check for instructions on valid speculation path retiring in order 157 assert rob_id_to_imm_id[rob_id] == retire_imm_ids 158 retire_imm_ids += 1 159 160 if entry["tag_increment"]: 161 await dut.free_tag.call(sim) 162 163 @def_method_mock(lambda: dut.core_state) 164 def core_state_mock(): 165 return {"flushing": 0} 166 167 @def_method_mock(lambda: dut.rs_alloc[0], enable=lambda: random.random() < 0.9) 168 def rs_alloc_arith(): 169 return {"rs_entry_id": 0} 170 171 @def_method_mock(lambda: dut.rs_alloc[1], enable=lambda: random.random() < 0.9) 172 def rs_alloc_branch(): 173 return {"rs_entry_id": 0} 174 175 @def_method_mock(lambda: dut.rs_insert[1]) 176 def rs_insert_branch(arg): 177 nonlocal rob_id_to_imm_id 178 179 @MethodMock.effect 180 def _(): 181 nonlocal arg 182 arg = arg["rs_data"] 183 rob_id_to_imm_id[arg["rob_id"]] = arg["imm"] 184 185 br_on_correct_path = ( 186 in_order_branch_encoding[0] == BranchEncoding.CORRECT_PATH_OK 187 or in_order_branch_encoding[0] == BranchEncoding.CORRECT_PATH_MISPRED_EXIT 188 ) 189 if br_on_correct_path: 190 assert arg["rp_s1"] == exp_rs_branch[0] 191 exp_rs_branch.popleft() 192 193 br = { 194 "encoding": in_order_branch_encoding[0], 195 "rob_id": arg["rob_id"], 196 "tag": arg["tag"], 197 } 198 199 in_order_branch_encoding.popleft() 200 branch_in_flight.add(frozenset(br.items())) 201 202 rob_done_queue = deque() 203 204 async def rs_insert_arithmetic(sim): 205 while True: 206 nonlocal rob_id_to_imm_id 207 arg = None 208 while arg is None: 209 await self.random_wait_geom(sim, 0.5) 210 arg = await dut.rs_insert[0].call_try(sim) 211 arg = arg["rs_data"] 212 213 rob_id_to_imm_id[arg["rob_id"]] = arg["imm"] 214 215 if in_order_arith_encoding[0]: 216 assert arg["rp_s1"] == exp_rs_arith[0] 217 exp_rs_arith.popleft() 218 219 in_order_arith_encoding.popleft() 220 rob_done_queue.append(arg["rob_id"]) 221 222 async def active_tags_call_process(sim): 223 while True: 224 await dut.get_active_tags.call(sim) 225 226 async def branch_fu_process(sim): 227 nonlocal on_correct_path, frat, rollback_tag, rollback_tag_v, frat_to_restore 228 229 while True: 230 await self.random_wait_geom(sim, 0.5) 231 if not branch_in_flight: 232 continue 233 instr = random.choice(tuple(branch_in_flight)) 234 branch_in_flight.remove(instr) 235 instr = dict(instr) 236 237 await sim.delay(1e-9) 238 239 active_tags_val = dut.get_active_tags.get_outputs(sim)["active_tags"] 240 wrong_path_rollback_legal = instr["encoding"] == BranchEncoding.WRONG_PATH_WITH_ROLLBACK and ( 241 active_tags_val[instr["tag"]] 242 ) 243 244 if wrong_path_rollback_legal or instr["encoding"] == BranchEncoding.CORRECT_PATH_MISPRED_EXIT: 245 await dut.rollback.call(sim, tag=instr["tag"]) 246 rollback_tag = instr["tag"] 247 rollback_tag_v = True 248 if instr["encoding"] == BranchEncoding.CORRECT_PATH_MISPRED_EXIT: 249 frat = frat_to_restore.copy() 250 on_correct_path = True 251 252 rob_done_queue.append(instr["rob_id"]) 253 254 async def mark_done_process(sim): 255 while True: 256 while not rob_done_queue: 257 await sim.tick() 258 await dut.rob_done.call(sim, rob_id=rob_done_queue[0]) 259 rob_done_queue.popleft() 260 261 with self.run_simulation(dut, max_cycles=2000) as sim: 262 sim.add_testbench(input_process) 263 sim.add_testbench(free_rf_process, background=True) 264 sim.add_testbench(branch_fu_process, background=True) 265 sim.add_testbench(rs_insert_arithmetic, background=True) 266 sim.add_testbench(mark_done_process, background=True) 267 sim.add_testbench(active_tags_call_process, background=True) 268 sim.add_testbench(rob_retire_process)