/ test / frontend / test_fetch.py
test_fetch.py
  1  import pytest
  2  from typing import Optional
  3  from collections import deque
  4  from dataclasses import dataclass
  5  from parameterized import parameterized_class
  6  import random
  7  
  8  from amaranth import Elaboratable, Module
  9  
 10  from transactron.core import Method
 11  from transactron.lib import Adapter, BasicFifo
 12  from transactron.testing.method_mock import MethodMock
 13  from transactron.utils import ModuleConnector, DependencyContext
 14  from transactron.testing import (
 15      TestCaseWithSimulator,
 16      TestbenchIO,
 17      def_method_mock,
 18      SimpleTestCircuit,
 19      TestbenchContext,
 20      ProcessContext,
 21  )
 22  
 23  from coreblocks.frontend.fetch.fetch import FetchUnit, PredictionChecker
 24  from coreblocks.cache.iface import CacheInterface
 25  from coreblocks.arch import *
 26  from coreblocks.params import *
 27  from coreblocks.params.configurations import test_core_config
 28  from coreblocks.interface.layouts import ICacheLayouts, FetchLayouts
 29  from coreblocks.interface.keys import CSRInstancesKey
 30  from coreblocks.priv.csr.csr_instances import CSRInstances
 31  
 32  
 33  class MockedICache(Elaboratable, CacheInterface):
 34      def __init__(self, gen_params: GenParams):
 35          layouts = gen_params.get(ICacheLayouts)
 36  
 37          self.issue_req_io = TestbenchIO(Adapter(i=layouts.issue_req))
 38          self.accept_res_io = TestbenchIO(Adapter(o=layouts.accept_res))
 39  
 40          self.issue_req = self.issue_req_io.adapter.iface
 41          self.accept_res = self.accept_res_io.adapter.iface
 42          self.flush = Method()
 43  
 44      def elaborate(self, platform):
 45          m = Module()
 46  
 47          m.submodules.issue_req_io = self.issue_req_io
 48          m.submodules.accept_res_io = self.accept_res_io
 49  
 50          return m
 51  
 52  
 53  @pytest.mark.parametrize("fetch_block_log", [2, 3, 4])
 54  @pytest.mark.parametrize("with_rvc", [False, True])
 55  @pytest.mark.parametrize("superscalarity", [1, 2])
 56  class TestFetchUnit(TestCaseWithSimulator):
 57      @pytest.fixture(autouse=True)
 58      def setup(self, fixture_initialize_testing_env, fetch_block_log: int, with_rvc: bool, superscalarity: int):
 59          self.with_rvc = with_rvc
 60          self.pc = 0
 61          self.gen_params = GenParams(
 62              test_core_config.replace(
 63                  start_pc=self.pc,
 64                  compressed=with_rvc,
 65                  fetch_block_bytes_log=fetch_block_log,
 66                  frontend_superscalarity=superscalarity,
 67              )
 68          )
 69  
 70          self.csr_instances = CSRInstances(self.gen_params)
 71          DependencyContext.get().add_dependency(CSRInstancesKey(), self.csr_instances)
 72  
 73          self.icache = MockedICache(self.gen_params)
 74          fifo = BasicFifo(self.gen_params.get(FetchLayouts).fetch_result, depth=2)
 75          self.fifo = SimpleTestCircuit(fifo, exclude={"write"})
 76          self.fetch_resume_mock = TestbenchIO(Adapter())
 77  
 78          fetch_unit = FetchUnit(self.gen_params, self.icache)
 79          fetch_unit.cont.provide(fifo.write)
 80  
 81          self.fetch = SimpleTestCircuit(fetch_unit, exclude={"cont"})
 82  
 83          self.m = ModuleConnector(self.csr_instances, self.icache, self.fifo, self.fetch)
 84  
 85          self.instr_queue = deque()
 86          self.mem = {}
 87          self.memerr = set()
 88          self.input_q = deque()
 89          self.output_q = deque()
 90          self.stalled = False
 91  
 92          self.next_fetch_request = self.pc
 93          self.last_redirect = None
 94          self.backend_redirect = deque()
 95  
 96          random.seed(41)
 97  
 98      def add_instr(self, data: int, jumps: bool, jump_offset: int = 0, branch_taken: bool = False) -> int:
 99          rvc = (data & 0b11) != 0b11
100          if rvc:
101              self.mem[self.pc] = data
102          else:
103              self.mem[self.pc] = data & 0xFFFF
104              self.mem[self.pc + 2] = data >> 16
105  
106          next_pc = self.pc + (2 if rvc else 4)
107          if jumps and branch_taken:
108              next_pc = self.pc + jump_offset
109  
110          self.instr_queue.append(
111              {
112                  "instr": data,
113                  "pc": self.pc,
114                  "jumps": jumps,
115                  "branch_taken": branch_taken,
116                  "next_pc": next_pc,
117                  "rvc": rvc,
118              }
119          )
120  
121          instr_pc = self.pc
122          self.pc = next_pc
123  
124          return instr_pc
125  
126      def gen_non_branch_instr(self, rvc: bool) -> int:
127          if rvc:
128              data = (random.randrange(2**11) << 2) | 0b01
129          else:
130              data = random.randrange(2**32) & ~0b1111111
131              data |= 0b11  # 2 lowest bits must be set in 32-bit long instructions
132  
133          return self.add_instr(data, False)
134  
135      def gen_jal(self, offset: int) -> int:
136          data = JTypeInstr(opcode=Opcode.JAL, rd=0, imm=offset).encode()
137  
138          return self.add_instr(data, True, jump_offset=offset, branch_taken=True)
139  
140      def gen_branch(self, offset: int, taken: bool):
141          data = BTypeInstr(opcode=Opcode.BRANCH, imm=offset, funct3=Funct3.BEQ, rs1=0, rs2=0).encode()
142  
143          return self.add_instr(data, True, jump_offset=offset, branch_taken=taken)
144  
145      async def cache_process(self, sim: ProcessContext):
146          while True:
147              while len(self.input_q) == 0:
148                  await sim.tick()
149  
150              await self.random_wait_geom(sim, 0.5)
151  
152              req_addr = self.input_q.popleft() & ~(self.gen_params.fetch_block_bytes - 1)
153  
154              def load_or_gen_mem(addr):
155                  if addr in self.mem:
156                      return self.mem[addr]
157  
158                  # Make sure to generate a compressed instruction to avoid
159                  # random cross boundary instructions.
160                  return random.randrange(2**16) & ~(0b11)
161  
162              fetch_block = 0
163              bad_addr = False
164              for i in range(0, self.gen_params.fetch_block_bytes, 2):
165                  fetch_block |= load_or_gen_mem(req_addr + i) << (8 * i)
166                  if req_addr + i in self.memerr:
167                      if random.random() < 0.3:
168                          fetch_block = 0
169                      elif random.random() < 0.3:
170                          fetch_block = random.randrange(1 << (self.gen_params.fetch_block_bytes * 8))
171                      bad_addr = True
172  
173              self.output_q.append({"fetch_block": fetch_block, "error": bad_addr})
174  
175      @def_method_mock(
176          lambda self: self.icache.issue_req_io, enable=lambda self: len(self.input_q) < 2
177      )  # TODO had sched_prio
178      def issue_req_mock(self, paddr):
179          @MethodMock.effect
180          def eff():
181              self.input_q.append(paddr)
182  
183      @def_method_mock(lambda self: self.icache.accept_res_io, enable=lambda self: len(self.output_q) > 0)
184      def accept_res_mock(self):
185          @MethodMock.effect
186          def eff():
187              self.output_q.popleft()
188  
189          if self.output_q:
190              return self.output_q[0]
191  
192      @def_method_mock(lambda self: self.fetch.stall_unsafe)
193      def stall_lock_unsafe(self):
194          pass
195  
196      @def_method_mock(lambda self: self.fetch.fetch_writeback)
197      def fetch_writeback_mock(self, redirect, redirect_target):
198          @MethodMock.effect
199          def eff():
200              if redirect:
201                  self.last_redirect = redirect_target
202              else:
203                  self.stalled = True
204  
205      async def fetch_out_check(self, sim: TestbenchContext):
206          async def check_instr(instr, v):
207              access_fault = FetchLayouts.FaultFlag.ACCESS_FAULT if instr["pc"] in self.memerr else 0
208              if not instr["rvc"]:
209                  if instr["pc"] + 2 in self.memerr:
210                      access_fault = (
211                          FetchLayouts.FaultFlag.ACCESS_FAULT | FetchLayouts.FaultFlag.EXCEPTION_ON_SECOND_HALF
212                          if not access_fault
213                          else access_fault
214                      )
215  
216              print(instr, v["pc"], v["access_fault"])
217              assert v["pc"] == instr["pc"]
218              assert v["access_fault"] == access_fault
219  
220              if not access_fault:
221                  instr_data = instr["instr"]
222                  if (instr_data & 0b11) == 0b11:
223                      assert v["instr"] == instr_data
224  
225              if (instr["jumps"] and (instr["branch_taken"] != v["predicted_taken"])) or access_fault:
226                  await self.random_wait(sim, 5)
227                  self.stalled = True
228                  await self.fetch.flush.call(sim)
229                  await self.random_wait(sim, 5)
230  
231                  # Empty the pipeline
232                  await self.fifo.clear.call_try(sim)
233                  await sim.tick()
234  
235                  resume_pc = instr["next_pc"]
236                  if access_fault:
237                      # Resume from the next fetch block
238                      resume_pc = (
239                          instr["pc"] & ~(self.gen_params.fetch_block_bytes - 1)
240                      ) + self.gen_params.fetch_block_bytes * (
241                          2 if FetchLayouts.FaultFlag.EXCEPTION_ON_SECOND_HALF in access_fault else 1
242                      )
243  
244                  self.backend_redirect.append(resume_pc)
245  
246                  return True
247  
248          while self.instr_queue:
249              v = await self.fifo.read.call(sim)
250  
251              for k in range(v.count):
252                  instr = self.instr_queue.popleft()
253                  # if fault happened, throw away rest of insns
254                  if await check_instr(instr, v.data[k]):
255                      break
256                  # test ended, garbage insns ahead
257                  if not self.instr_queue:
258                      break
259  
260      async def requester(self, sim: ProcessContext):
261          while True:
262              ret = None
263              if self.stalled:
264                  await sim.tick()
265              else:
266                  ret = await self.fetch.fetch_request.call_try(sim, pc=self.next_fetch_request)
267  
268              await sim.delay(0)
269              if self.stalled:
270                  while not self.backend_redirect:
271                      await self.tick(sim)
272  
273                  self.stalled = False
274                  self.next_fetch_request = self.backend_redirect[0]
275                  self.backend_redirect.popleft()
276  
277              elif self.last_redirect is not None:
278                  self.next_fetch_request = self.last_redirect
279              elif ret is not None:
280                  self.next_fetch_request = (
281                      1 + (self.next_fetch_request // self.gen_params.fetch_block_bytes)
282                  ) * self.gen_params.fetch_block_bytes
283  
284              self.last_redirect = None
285  
286      def run_sim(self):
287          with self.run_simulation(self.m, max_cycles=1000) as sim:
288              sim.add_process(self.cache_process)
289              sim.add_process(self.requester)
290              sim.add_testbench(self.fetch_out_check)
291  
292      def test_simple_no_jumps(self):
293          for _ in range(50):
294              self.gen_non_branch_instr(rvc=False)
295  
296          self.run_sim()
297  
298      def test_simple_no_jumps_rvc(self):
299          if not self.with_rvc:
300              self.run_sim()  # Run simulation to avoid unused warnings.
301              return
302  
303          # Try a fetch block full of non-RVC instructions
304          for _ in range(self.gen_params.fetch_width // 2):
305              self.gen_non_branch_instr(rvc=False)
306  
307          # Try a fetch block full of RVC instructions
308          for _ in range(self.gen_params.fetch_width):
309              self.gen_non_branch_instr(rvc=True)
310  
311          # Try what if an instruction crossed a boundary of a fetch block
312          self.gen_non_branch_instr(rvc=True)
313          for _ in range(self.gen_params.fetch_width - 1):
314              self.gen_non_branch_instr(rvc=False)
315  
316          self.gen_non_branch_instr(rvc=True)
317  
318          # We are now at the beginning of a fetch block again.
319  
320          # RVC interleaved with non-RVC
321          for _ in range(self.gen_params.fetch_width):
322              self.gen_non_branch_instr(rvc=True)
323              self.gen_non_branch_instr(rvc=False)
324  
325          # Random sequence
326          for _ in range(50):
327              self.gen_non_branch_instr(rvc=random.randrange(2) == 1)
328  
329          self.run_sim()
330  
331      def test_jumps(self):
332          # Jump to the next instruction
333          self.gen_jal(4)
334          for _ in range(self.gen_params.fetch_block_bytes // 4 - 1):
335              self.gen_non_branch_instr(rvc=False)
336  
337          # Jump to the next fetch block
338          self.gen_jal(self.gen_params.fetch_block_bytes)
339  
340          # Two fetch blocks-worth of instructions
341          for _ in range(self.gen_params.fetch_block_bytes // 2):
342              self.gen_non_branch_instr(rvc=False)
343  
344          # Jump to the next fetch block, but fill the block with other jump instructions
345          block_pc = self.gen_jal(self.gen_params.fetch_block_bytes)
346          for i in range(self.gen_params.fetch_block_bytes // 4 - 1):
347              data = JTypeInstr(opcode=Opcode.JAL, rd=0, imm=-8).encode()
348              self.mem[block_pc + (i + 1) * 4] = data & 0xFFFF
349              self.mem[block_pc + (i + 1) * 4 + 2] = data >> 16
350  
351          # Jump to the last instruction of a fetch block
352          self.gen_jal(2 * self.gen_params.fetch_block_bytes - 4)
353  
354          self.gen_non_branch_instr(rvc=False)
355  
356          # Jump as the last instruction of the fetch block
357          for _ in range(self.gen_params.fetch_block_bytes // 4 - 1):
358              self.gen_non_branch_instr(rvc=False)
359          self.gen_jal(20)
360  
361          # A chain of jumps
362          for _ in range(10):
363              self.gen_jal(random.randrange(4, 100, 4))
364  
365          # A big jump
366          self.gen_jal(1000)
367          self.gen_non_branch_instr(rvc=False)
368  
369          # And a jump backwards
370          self.gen_jal(-200)
371          for _ in range(5):
372              self.gen_non_branch_instr(rvc=False)
373  
374          self.run_sim()
375  
376      def test_jumps_rvc(self):
377          if not self.with_rvc:
378              self.run_sim()
379              return
380  
381          # Jump to the last instruction of a fetch block
382          self.gen_jal(2 * self.gen_params.fetch_block_bytes - 2)
383          self.gen_non_branch_instr(rvc=True)
384  
385          # Again, but the last instruction spans two fetch blocks
386          self.gen_jal(2 * self.gen_params.fetch_block_bytes - 2)
387          self.gen_non_branch_instr(rvc=False)
388  
389          for _ in range(self.gen_params.fetch_width - 1):
390              self.gen_non_branch_instr(rvc=True)
391  
392          # Make a jump instruction that spans two fetch blocks
393          for _ in range(self.gen_params.fetch_width - 1):
394              self.gen_non_branch_instr(rvc=True)
395          self.gen_jal(self.gen_params.fetch_block_bytes + 2)
396  
397          self.gen_non_branch_instr(rvc=False)
398  
399          self.run_sim()
400  
401      def test_branches(self):
402          # Taken branch forward
403          self.gen_branch(offset=self.gen_params.fetch_block_bytes, taken=True)
404  
405          for _ in range(self.gen_params.fetch_width):
406              self.gen_non_branch_instr(rvc=False)
407  
408          # Not taken branch forward
409          self.gen_branch(offset=self.gen_params.fetch_block_bytes, taken=False)
410  
411          for _ in range(self.gen_params.fetch_width):
412              self.gen_non_branch_instr(rvc=False)
413  
414          # Jump somewhere far - biggest possible value
415          self.gen_branch(offset=4092, taken=True)
416  
417          for _ in range(self.gen_params.fetch_width):
418              self.gen_non_branch_instr(rvc=False)
419  
420          # Chain a few branches
421          for i in range(10):
422              self.gen_branch(offset=1028, taken=(i % 2 == 0))
423  
424          self.gen_non_branch_instr(rvc=False)
425  
426          self.run_sim()
427  
428      def test_access_fault(self):
429          for _ in range(self.gen_params.fetch_width):
430              self.gen_non_branch_instr(rvc=False)
431  
432          # Access fault at the beginning of the fetch block
433          pc = self.gen_non_branch_instr(rvc=False)
434          self.memerr.add(pc)
435  
436          # We will resume from the next fetch block
437          self.pc = pc + self.gen_params.fetch_block_bytes
438  
439          for _ in range(self.gen_params.fetch_width):
440              self.gen_non_branch_instr(rvc=False)
441  
442          # Access fault in a block with a jump
443          pc = self.gen_jal(2 * self.gen_params.fetch_block_bytes)
444          self.memerr.add(pc)
445  
446          # We will resume from the next fetch block
447          self.pc = pc + self.gen_params.fetch_block_bytes
448  
449          if self.with_rvc:
450              # Access fault on sencond half on instruction
451              for _ in range(self.gen_params.fetch_width - 1):
452                  self.gen_non_branch_instr(rvc=True)
453              pc = self.gen_non_branch_instr(rvc=False)  # 4 byte instruction crossing block
454              self.memerr.add(pc + 2)
455  
456              # We will resume from next valid block
457              self.pc = pc + 2 + self.gen_params.fetch_block_bytes
458  
459          self.gen_non_branch_instr(rvc=False)
460  
461          self.run_sim()
462  
463      def test_random(self):
464          for _ in range(500):
465              r = random.random()
466              if r < 0.6:
467                  rvc = random.randrange(2) == 0 if self.with_rvc else False
468                  self.gen_non_branch_instr(rvc=rvc)
469              else:
470                  offset = random.randrange(0, 1000, 2)
471                  if not self.with_rvc:
472                      offset = offset & ~(0b11)
473                  if r < 0.8:
474                      self.gen_jal(offset)
475                  else:
476                      self.gen_branch(offset, taken=random.randrange(2) == 0)
477  
478          with self.run_simulation(self.m) as sim:
479              sim.add_process(self.cache_process)
480              sim.add_testbench(self.fetch_out_check)
481              sim.add_process(self.requester)
482  
483  
484  @dataclass(frozen=True)
485  class CheckerResult:
486      mispredicted: bool
487      stall: bool
488      fb_instr_idx: int
489      redirect_target: int
490  
491  
492  @parameterized_class(
493      ("name", "fetch_block_log", "with_rvc"),
494      [
495          ("block4B", 2, False),
496          ("block4B_rvc", 2, True),
497          ("block8B", 3, False),
498          ("block8B_rvc", 3, True),
499          ("block16B", 4, False),
500          ("block16B_rvc", 4, True),
501      ],
502  )
503  class TestPredictionChecker(TestCaseWithSimulator):
504      fetch_block_log: int
505      with_rvc: bool
506  
507      @pytest.fixture(autouse=True)
508      def setup(self, fixture_initialize_testing_env):
509          self.gen_params = GenParams(
510              test_core_config.replace(compressed=self.with_rvc, fetch_block_bytes_log=self.fetch_block_log)
511          )
512  
513          self.m = SimpleTestCircuit(PredictionChecker(self.gen_params))
514  
515      async def check(
516          self,
517          sim: TestbenchContext,
518          pc: int,
519          block_cross: bool,
520          predecoded: list[tuple[CfiType, int]],
521          branch_mask: int,
522          cfi_idx: int,
523          cfi_type: CfiType,
524          cfi_target: Optional[int],
525          valid_mask: int = -1,
526      ) -> CheckerResult:
527          # Fill the array with non-CFI instructions
528          for _ in range(self.gen_params.fetch_width - len(predecoded)):
529              predecoded.append((CfiType.INVALID, 0))
530          predecoded_raw = [
531              {"cfi_type": predecoded[i][0], "cfi_offset": predecoded[i][1], "unsafe": 0}
532              for i in range(self.gen_params.fetch_width)
533          ]
534  
535          prediction = {
536              "branch_mask": branch_mask,
537              "cfi_idx": cfi_idx,
538              "cfi_type": cfi_type,
539              "cfi_target": cfi_target or 0,
540              "cfi_target_valid": 1 if cfi_target is not None else 0,
541          }
542  
543          instr_start = (
544              pc & ((1 << self.gen_params.fetch_block_bytes_log) - 1)
545          ) >> self.gen_params.min_instr_width_bytes_log
546  
547          instr_valid = (((1 << self.gen_params.fetch_width) - 1) << instr_start) & valid_mask
548  
549          res = await self.m.check.call(
550              sim,
551              fb_addr=pc >> self.gen_params.fetch_block_bytes_log,
552              instr_block_cross=block_cross,
553              instr_valid=instr_valid,
554              predecoded=predecoded_raw,
555              prediction=prediction,
556          )
557  
558          return CheckerResult(
559              mispredicted=bool(res["mispredicted"]),
560              stall=bool(res["stall"]),
561              fb_instr_idx=res["fb_instr_idx"],
562              redirect_target=res["redirect_target"],
563          )
564  
565      def assert_resp(
566          self,
567          res: CheckerResult,
568          mispredicted: Optional[bool] = None,
569          stall: Optional[bool] = None,
570          fb_instr_idx: Optional[int] = None,
571          redirect_target: Optional[int] = None,
572      ):
573          if mispredicted is not None:
574              assert res.mispredicted == mispredicted
575          if stall is not None:
576              assert res.stall == stall
577          if fb_instr_idx is not None:
578              assert res.fb_instr_idx == fb_instr_idx
579          if redirect_target is not None:
580              assert res.redirect_target == redirect_target
581  
582      def test_no_misprediction(self):
583          instr_width = self.gen_params.min_instr_width_bytes
584          fetch_width = self.gen_params.fetch_width
585  
586          async def proc(sim: TestbenchContext):
587              # No CFI at all
588              ret = await self.check(sim, 0x100, False, [], 0, 0, CfiType.INVALID, None)
589              self.assert_resp(ret, mispredicted=False)
590  
591              # There is one forward branch that we didn't predict
592              ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0, 0, CfiType.INVALID, None)
593              self.assert_resp(ret, mispredicted=False)
594  
595              # There are many forward branches that we didn't predict
596              ret = await self.check(
597                  sim, 0x100, False, [(CfiType.BRANCH, 100)] * fetch_width, 0, 0, CfiType.INVALID, None
598              )
599              self.assert_resp(ret, mispredicted=False)
600  
601              # There is a predicted JAL instr
602              ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 0x100 + 100)
603              self.assert_resp(ret, mispredicted=False)
604  
605              # There is a predicted JALR instr - the predecoded offset can now be anything
606              ret = await self.check(sim, 0x100, False, [(CfiType.JALR, 200)], 0, 0, CfiType.JALR, 0x100 + 100)
607              self.assert_resp(ret, mispredicted=False)
608  
609              # There is a forward taken-predicted branch
610              ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 0x100 + 100)
611              self.assert_resp(ret, mispredicted=False)
612  
613              # There is a backward taken-predicted branch
614              ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100)
615              self.assert_resp(ret, mispredicted=False)
616  
617              # Branch located between two fetch blocks
618              if self.with_rvc:
619                  ret = await self.check(
620                      sim, 0x100, True, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100 - 2
621                  )
622                  self.assert_resp(ret, mispredicted=False)
623  
624              # One branch predicted as not taken
625              ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.INVALID, 0)
626              self.assert_resp(ret, mispredicted=False)
627  
628              # Now tests for fetch blocks with multiple instructions
629              if fetch_width < 2:
630                  return
631  
632              # Predicted taken branch as the second instruction
633              ret = await self.check(
634                  sim,
635                  0x100,
636                  False,
637                  [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)],
638                  0b10,
639                  1,
640                  CfiType.BRANCH,
641                  0x100 + instr_width - 100,
642              )
643              self.assert_resp(ret, mispredicted=False)
644  
645              # Predicted, but not taken branch as the second instruction
646              ret = await self.check(
647                  sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b10, 0, CfiType.INVALID, 0
648              )
649              self.assert_resp(ret, mispredicted=False)
650  
651              if self.with_rvc:
652                  ret = await self.check(
653                      sim,
654                      0x100,
655                      True,
656                      [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)],
657                      0b10,
658                      1,
659                      CfiType.BRANCH,
660                      0x100 + instr_width - 100,
661                  )
662                  self.assert_resp(ret, mispredicted=False)
663  
664                  ret = await self.check(
665                      sim,
666                      0x100,
667                      True,
668                      [(CfiType.JAL, 100), (CfiType.JAL, -100)],
669                      0b00,
670                      1,
671                      CfiType.JAL,
672                      0x100 + instr_width - 100,
673                      valid_mask=0b10,
674                  )
675                  self.assert_resp(ret, mispredicted=False)
676  
677              # Two branches with all possible combintations taken/not-taken
678              ret = await self.check(
679                  sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.INVALID, 0
680              )
681              self.assert_resp(ret, mispredicted=False)
682              ret = await self.check(
683                  sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.BRANCH, 0x100 - 100
684              )
685              self.assert_resp(ret, mispredicted=False)
686              ret = await self.check(
687                  sim,
688                  0x100,
689                  False,
690                  [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)],
691                  0b11,
692                  1,
693                  CfiType.BRANCH,
694                  0x100 + instr_width + 100,
695              )
696              self.assert_resp(ret, mispredicted=False)
697  
698              # JAL at the beginning, but we start from the second instruction
699              ret = await self.check(sim, 0x100 + instr_width, False, [(CfiType.JAL, -100)], 0b0, 0, CfiType.INVALID, 0)
700              self.assert_resp(ret, mispredicted=False)
701  
702              # JAL and a forward branch that we didn't predict
703              ret = await self.check(
704                  sim,
705                  0x100 + instr_width,
706                  False,
707                  [(CfiType.JAL, -100), (CfiType.BRANCH, 100)],
708                  0b00,
709                  0,
710                  CfiType.INVALID,
711                  0,
712              )
713              self.assert_resp(ret, mispredicted=False)
714  
715              # two JAL instructions, but we start from the second one
716              ret = await self.check(
717                  sim,
718                  0x100 + instr_width,
719                  False,
720                  [(CfiType.JAL, -100), (CfiType.JAL, 100)],
721                  0b00,
722                  1,
723                  CfiType.JAL,
724                  0x100 + instr_width + 100,
725              )
726              self.assert_resp(ret, mispredicted=False)
727  
728              # JAL and a branch, but we start from the second instruction
729              ret = await self.check(
730                  sim,
731                  0x100 + instr_width,
732                  False,
733                  [(CfiType.JAL, -100), (CfiType.BRANCH, 100)],
734                  0b10,
735                  1,
736                  CfiType.BRANCH,
737                  0x100 + instr_width + 100,
738              )
739              self.assert_resp(ret, mispredicted=False)
740  
741          with self.run_simulation(self.m) as sim:
742              sim.add_testbench(proc)
743  
744      def test_preceding_redirection(self):
745          instr_width = self.gen_params.min_instr_width_bytes
746          fetch_width = self.gen_params.fetch_width
747  
748          async def proc(sim: TestbenchContext):
749              # No prediction was made, but there is a JAL at the beginning
750              ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None)
751              self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 0x20)
752  
753              # The same, but the jump is between two fetch blocks
754              if self.with_rvc:
755                  ret = await self.check(sim, 0x100, True, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None)
756                  self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 0x20 - 2)
757  
758              # Not predicted backward branch
759              ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b0, 0, CfiType.INVALID, 0)
760              self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100)
761  
762              # Now tests for fetch blocks with multiple instructions
763              if fetch_width < 2:
764                  return
765  
766              # We predicted the branch on the second instruction, but there's a JAL on the first one.
767              ret = await self.check(
768                  sim,
769                  0x100,
770                  False,
771                  [(CfiType.JAL, -100), (CfiType.BRANCH, 100)],
772                  0b10,
773                  1,
774                  CfiType.BRANCH,
775                  0x100 + instr_width + 100,
776              )
777              self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100)
778  
779              # We predicted the branch on the second instruction, but there's a JALR on the first one.
780              ret = await self.check(
781                  sim,
782                  0x100,
783                  False,
784                  [(CfiType.JALR, -100), (CfiType.BRANCH, 100)],
785                  0b10,
786                  1,
787                  CfiType.BRANCH,
788                  0x100 + instr_width + 100,
789              )
790              self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=0)
791  
792              # We predicted the branch on the second instruction, but there's a backward on the first one.
793              ret = await self.check(
794                  sim,
795                  0x100,
796                  False,
797                  [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)],
798                  0b10,
799                  1,
800                  CfiType.BRANCH,
801                  0x100 + instr_width + 100,
802              )
803              self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100)
804  
805              # Unpredicted backward branch as the second instruction
806              ret = await self.check(
807                  sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b00, 0, CfiType.INVALID, 0
808              )
809              self.assert_resp(
810                  ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width - 100
811              )
812  
813              # Unpredicted JAL as the second instruction
814              ret = await self.check(
815                  sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b00, 0, CfiType.INVALID, 0
816              )
817              self.assert_resp(
818                  ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100
819              )
820  
821              # Unpredicted JALR as the second instruction
822              ret = await self.check(
823                  sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JALR, 100)], 0b00, 0, CfiType.INVALID, 0
824              )
825              self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=1)
826  
827              if fetch_width < 3:
828                  return
829  
830              ret = await self.check(
831                  sim,
832                  0x100 + instr_width,
833                  False,
834                  [(CfiType.JAL, -100), (CfiType.INVALID, 100), (CfiType.JAL, 100)],
835                  0b0,
836                  0,
837                  CfiType.INVALID,
838                  None,
839              )
840              self.assert_resp(
841                  ret, mispredicted=True, stall=False, fb_instr_idx=2, redirect_target=0x100 + 2 * instr_width + 100
842              )
843  
844          with self.run_simulation(self.m) as sim:
845              sim.add_testbench(proc)
846  
847      def test_mispredicted_cfi_type(self):
848          instr_width = self.gen_params.min_instr_width_bytes
849          fetch_width = self.gen_params.fetch_width
850          fb_bytes = self.gen_params.fetch_block_bytes
851  
852          async def proc(sim: TestbenchContext):
853              # We predicted a JAL, but in fact there is a non-CFI instruction
854              ret = await self.check(sim, 0x100, False, [(CfiType.INVALID, 0)], 0, 0, CfiType.JAL, 100)
855              self.assert_resp(
856                  ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes
857              )
858  
859              # We predicted a JAL, but in fact there is a branch
860              ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0, 0, CfiType.JAL, 100)
861              self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100)
862  
863              # We predicted a JAL, but in fact there is a JALR instruction
864              ret = await self.check(sim, 0x100, False, [(CfiType.JALR, -100)], 0, 0, CfiType.JAL, 100)
865              self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=0)
866  
867              # We predicted a branch, but in fact there is a JAL
868              ret = await self.check(sim, 0x100, False, [(CfiType.JAL, -100)], 0b1, 0, CfiType.BRANCH, 100)
869              self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100)
870  
871              if fetch_width < 2:
872                  return
873  
874              # There is a branch and a non-CFI, but we predicted two branches
875              ret = await self.check(
876                  sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)], 0b11, 1, CfiType.BRANCH, 100
877              )
878              self.assert_resp(
879                  ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes
880              )
881  
882              # The same as above, but we start from the second instruction
883              ret = await self.check(
884                  sim,
885                  0x100 + instr_width,
886                  False,
887                  [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)],
888                  0b11,
889                  1,
890                  CfiType.BRANCH,
891                  100,
892              )
893              self.assert_resp(
894                  ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes
895              )
896  
897          with self.run_simulation(self.m) as sim:
898              sim.add_testbench(proc)
899  
900      def test_mispredicted_cfi_target(self):
901          instr_width = self.gen_params.min_instr_width_bytes
902          fetch_width = self.gen_params.fetch_width
903  
904          async def proc(sim: TestbenchContext):
905              # We predicted a wrong JAL target
906              ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 200)
907              self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100)
908  
909              # We predicted a wrong branch target
910              ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 200)
911              self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100)
912  
913              # We didn't provide the branch target
914              ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, None)
915              self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100)
916  
917              # We predicted a wrong JAL target that is between two fetch blocks
918              if self.with_rvc:
919                  ret = await self.check(sim, 0x100, True, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 300)
920                  self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100 - 2)
921  
922              if fetch_width < 2:
923                  return
924  
925              # The second instruction is a branch without the target
926              ret = await self.check(
927                  sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, 100)], 0b10, 1, CfiType.BRANCH, None
928              )
929              self.assert_resp(
930                  ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100
931              )
932  
933              # The second instruction is a JAL with a wrong target
934              ret = await self.check(
935                  sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b10, 1, CfiType.JAL, 200
936              )
937              self.assert_resp(
938                  ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100
939              )
940  
941          with self.run_simulation(self.m) as sim:
942              sim.add_testbench(proc)