/ test / scheduler / test_checkpointing.py
test_checkpointing.py
  1  import random
  2  import pytest
  3  from amaranth.lib.enum import auto
  4  from collections import deque
  5  from enum import Enum
  6  
  7  from coreblocks.arch import OpType
  8  from coreblocks.params import GenParams
  9  from coreblocks.params.configurations import test_core_config
 10  from transactron.testing import CallTrigger, MethodMock, TestCaseWithSimulator, def_method_mock
 11  
 12  from test.scheduler.test_scheduler import SchedulerTestCircuit, MockedBlockComponent
 13  
 14  
 15  class TestSchedulerCheckpointing(TestCaseWithSimulator):
 16      @pytest.mark.parametrize("tag_bits, checkpoint_count", [(2, 3), (5, 8)])
 17      def test_randomized(self, tag_bits: int, checkpoint_count: int):
 18          gen_params = GenParams(
 19              test_core_config.replace(
 20                  func_units_config=(
 21                      MockedBlockComponent({OpType.ARITHMETIC}, rs_entries=4),
 22                      MockedBlockComponent({OpType.BRANCH}, rs_entries=4),
 23                  ),
 24                  tag_bits=tag_bits,
 25                  checkpoint_count=checkpoint_count,
 26                  allow_partial_extensions=True,
 27              )
 28          )
 29  
 30          dut = SchedulerTestCircuit(gen_params)
 31  
 32          branch_in_flight = set()
 33  
 34          instr_cnt = 512
 35          exp_rs_branch = deque()
 36          exp_rs_arith = deque()
 37  
 38          correct_path_id = 0
 39          wrong_path_id = 0x8000
 40          on_correct_path = True
 41          free_rp = 1
 42          frat_to_restore = []
 43  
 44          rollback_tag = 0
 45          rollback_tag_v = False
 46  
 47          random.seed(42)
 48  
 49          end = False
 50  
 51          class BranchEncoding(Enum):
 52              CORRECT_PATH_OK = auto()
 53              CORRECT_PATH_MISPRED_EXIT = auto()
 54              WRONG_PATH_OK = auto()
 55              WRONG_PATH_WITH_ROLLBACK = auto()
 56  
 57          in_order_branch_encoding = deque()
 58          in_order_arith_encoding = deque()
 59          frat = [0 for _ in range(2**gen_params.isa.reg_cnt_log)]
 60  
 61          def get_instr():
 62              nonlocal correct_path_id, wrong_path_id, on_correct_path, free_rp, frat, frat_to_restore, rollback_tag
 63              nonlocal rollback_tag_v
 64              is_branch = random.randint(0, 1)
 65  
 66              rd = random.randrange(0, 4)
 67              rs = random.randrange(0, 4)
 68              instr = {
 69                  "exec_fn": {"op_type": OpType.BRANCH if is_branch else OpType.ARITHMETIC},
 70                  "imm": correct_path_id if on_correct_path else wrong_path_id,
 71                  "regs_l": {
 72                      "rl_dst": rd,
 73                      "rl_s1": rs,
 74                  },
 75                  "rollback_tag": rollback_tag,
 76                  "rollback_tag_v": rollback_tag_v,
 77                  "commit_checkpoint": is_branch,
 78              }
 79              rollback_tag_v = False
 80  
 81              if on_correct_path:
 82                  if is_branch:
 83                      exp_rs_branch.append(frat[rs])
 84                  else:
 85                      exp_rs_arith.append(frat[rs])
 86                  correct_path_id += 1
 87              else:
 88                  wrong_path_id += 1
 89  
 90              if rd != 0:
 91                  frat[rd] = free_rp
 92                  free_rp += 1
 93                  if free_rp == gen_params.phys_regs:
 94                      free_rp = 1
 95  
 96              if is_branch:
 97                  is_misprediction = random.randint(0, 1)
 98                  if on_correct_path:
 99                      in_order_branch_encoding.append(
100                          BranchEncoding.CORRECT_PATH_MISPRED_EXIT if is_misprediction else BranchEncoding.CORRECT_PATH_OK
101                      )
102                      if is_misprediction:
103                          on_correct_path = False
104                          frat_to_restore = frat.copy()
105                  else:
106                      in_order_branch_encoding.append(
107                          BranchEncoding.WRONG_PATH_WITH_ROLLBACK if is_misprediction else BranchEncoding.WRONG_PATH_OK
108                      )
109              else:
110                  in_order_arith_encoding.append(on_correct_path)
111  
112              return instr
113  
114          async def input_process(sim):
115              nonlocal end
116              for _ in range(instr_cnt):
117                  data = get_instr()
118                  await dut.instr_inp.call(sim, count=1, data=[data])
119                  await self.random_wait_geom(sim, 0.5)
120              end = True
121  
122          rob_id_to_imm_id = {}
123  
124          async def free_rf_process(sim):
125              free_rp_inp = 1
126              while True:
127                  await dut.free_rf_inp.call(sim, {"ident": free_rp_inp})
128                  free_rp_inp += 1
129                  if free_rp_inp == gen_params.phys_regs:
130                      free_rp_inp = 1
131  
132          retire_imm_ids = 0
133          current_tag = 0
134  
135          async def rob_retire_process(sim):
136              nonlocal current_tag, retire_imm_ids, end
137              for _ in range(instr_cnt):
138                  await self.random_wait_geom(sim, 0.4)
139  
140                  _, active_tags, peek_res, rob_idxs = (
141                      await CallTrigger(sim)
142                      .call(dut.rob_retire, count=1)
143                      .call(dut.get_active_tags)
144                      .call(dut.rob_peek)
145                      .call(dut.rob_get_indices)
146                      .until_all_done()
147                  )
148                  active_tags = active_tags["active_tags"]
149                  entry = peek_res.entries[0]["rob_data"]
150                  rob_id = rob_idxs["start"]
151  
152                  current_tag += entry["tag_increment"]
153                  current_tag %= 2**gen_params.tag_bits
154  
155                  if active_tags[current_tag]:
156                      # check for instructions on valid speculation path retiring in order
157                      assert rob_id_to_imm_id[rob_id] == retire_imm_ids
158                      retire_imm_ids += 1
159  
160                  if entry["tag_increment"]:
161                      await dut.free_tag.call(sim)
162  
163          @def_method_mock(lambda: dut.core_state)
164          def core_state_mock():
165              return {"flushing": 0}
166  
167          @def_method_mock(lambda: dut.rs_alloc[0], enable=lambda: random.random() < 0.9)
168          def rs_alloc_arith():
169              return {"rs_entry_id": 0}
170  
171          @def_method_mock(lambda: dut.rs_alloc[1], enable=lambda: random.random() < 0.9)
172          def rs_alloc_branch():
173              return {"rs_entry_id": 0}
174  
175          @def_method_mock(lambda: dut.rs_insert[1])
176          def rs_insert_branch(arg):
177              nonlocal rob_id_to_imm_id
178  
179              @MethodMock.effect
180              def _():
181                  nonlocal arg
182                  arg = arg["rs_data"]
183                  rob_id_to_imm_id[arg["rob_id"]] = arg["imm"]
184  
185                  br_on_correct_path = (
186                      in_order_branch_encoding[0] == BranchEncoding.CORRECT_PATH_OK
187                      or in_order_branch_encoding[0] == BranchEncoding.CORRECT_PATH_MISPRED_EXIT
188                  )
189                  if br_on_correct_path:
190                      assert arg["rp_s1"] == exp_rs_branch[0]
191                      exp_rs_branch.popleft()
192  
193                  br = {
194                      "encoding": in_order_branch_encoding[0],
195                      "rob_id": arg["rob_id"],
196                      "tag": arg["tag"],
197                  }
198  
199                  in_order_branch_encoding.popleft()
200                  branch_in_flight.add(frozenset(br.items()))
201  
202          rob_done_queue = deque()
203  
204          async def rs_insert_arithmetic(sim):
205              while True:
206                  nonlocal rob_id_to_imm_id
207                  arg = None
208                  while arg is None:
209                      await self.random_wait_geom(sim, 0.5)
210                      arg = await dut.rs_insert[0].call_try(sim)
211                  arg = arg["rs_data"]
212  
213                  rob_id_to_imm_id[arg["rob_id"]] = arg["imm"]
214  
215                  if in_order_arith_encoding[0]:
216                      assert arg["rp_s1"] == exp_rs_arith[0]
217                      exp_rs_arith.popleft()
218  
219                  in_order_arith_encoding.popleft()
220                  rob_done_queue.append(arg["rob_id"])
221  
222          async def active_tags_call_process(sim):
223              while True:
224                  await dut.get_active_tags.call(sim)
225  
226          async def branch_fu_process(sim):
227              nonlocal on_correct_path, frat, rollback_tag, rollback_tag_v, frat_to_restore
228  
229              while True:
230                  await self.random_wait_geom(sim, 0.5)
231                  if not branch_in_flight:
232                      continue
233                  instr = random.choice(tuple(branch_in_flight))
234                  branch_in_flight.remove(instr)
235                  instr = dict(instr)
236  
237                  await sim.delay(1e-9)
238  
239                  active_tags_val = dut.get_active_tags.get_outputs(sim)["active_tags"]
240                  wrong_path_rollback_legal = instr["encoding"] == BranchEncoding.WRONG_PATH_WITH_ROLLBACK and (
241                      active_tags_val[instr["tag"]]
242                  )
243  
244                  if wrong_path_rollback_legal or instr["encoding"] == BranchEncoding.CORRECT_PATH_MISPRED_EXIT:
245                      await dut.rollback.call(sim, tag=instr["tag"])
246                      rollback_tag = instr["tag"]
247                      rollback_tag_v = True
248                      if instr["encoding"] == BranchEncoding.CORRECT_PATH_MISPRED_EXIT:
249                          frat = frat_to_restore.copy()
250                          on_correct_path = True
251  
252                  rob_done_queue.append(instr["rob_id"])
253  
254          async def mark_done_process(sim):
255              while True:
256                  while not rob_done_queue:
257                      await sim.tick()
258                  await dut.rob_done.call(sim, rob_id=rob_done_queue[0])
259                  rob_done_queue.popleft()
260  
261          with self.run_simulation(dut, max_cycles=2000) as sim:
262              sim.add_testbench(input_process)
263              sim.add_testbench(free_rf_process, background=True)
264              sim.add_testbench(branch_fu_process, background=True)
265              sim.add_testbench(rs_insert_arithmetic, background=True)
266              sim.add_testbench(mark_done_process, background=True)
267              sim.add_testbench(active_tags_call_process, background=True)
268              sim.add_testbench(rob_retire_process)