test_fetch.py
1 import pytest 2 from typing import Optional 3 from collections import deque 4 from dataclasses import dataclass 5 from parameterized import parameterized_class 6 import random 7 8 from amaranth import Elaboratable, Module 9 10 from transactron.core import Method 11 from transactron.lib import Adapter, BasicFifo 12 from transactron.testing.method_mock import MethodMock 13 from transactron.utils import ModuleConnector, DependencyContext 14 from transactron.testing import ( 15 TestCaseWithSimulator, 16 TestbenchIO, 17 def_method_mock, 18 SimpleTestCircuit, 19 TestbenchContext, 20 ProcessContext, 21 ) 22 23 from coreblocks.frontend.fetch.fetch import FetchUnit, PredictionChecker 24 from coreblocks.cache.iface import CacheInterface 25 from coreblocks.arch import * 26 from coreblocks.params import * 27 from coreblocks.params.configurations import test_core_config 28 from coreblocks.interface.layouts import ICacheLayouts, FetchLayouts 29 from coreblocks.interface.keys import CSRInstancesKey 30 from coreblocks.priv.csr.csr_instances import CSRInstances 31 32 33 class MockedICache(Elaboratable, CacheInterface): 34 def __init__(self, gen_params: GenParams): 35 layouts = gen_params.get(ICacheLayouts) 36 37 self.issue_req_io = TestbenchIO(Adapter(i=layouts.issue_req)) 38 self.accept_res_io = TestbenchIO(Adapter(o=layouts.accept_res)) 39 40 self.issue_req = self.issue_req_io.adapter.iface 41 self.accept_res = self.accept_res_io.adapter.iface 42 self.flush = Method() 43 44 def elaborate(self, platform): 45 m = Module() 46 47 m.submodules.issue_req_io = self.issue_req_io 48 m.submodules.accept_res_io = self.accept_res_io 49 50 return m 51 52 53 @pytest.mark.parametrize("fetch_block_log", [2, 3, 4]) 54 @pytest.mark.parametrize("with_rvc", [False, True]) 55 @pytest.mark.parametrize("superscalarity", [1, 2]) 56 class TestFetchUnit(TestCaseWithSimulator): 57 @pytest.fixture(autouse=True) 58 def setup(self, fixture_initialize_testing_env, fetch_block_log: int, with_rvc: bool, superscalarity: int): 59 self.with_rvc = with_rvc 60 self.pc = 0 61 self.gen_params = GenParams( 62 test_core_config.replace( 63 start_pc=self.pc, 64 compressed=with_rvc, 65 fetch_block_bytes_log=fetch_block_log, 66 frontend_superscalarity=superscalarity, 67 ) 68 ) 69 70 self.csr_instances = CSRInstances(self.gen_params) 71 DependencyContext.get().add_dependency(CSRInstancesKey(), self.csr_instances) 72 73 self.icache = MockedICache(self.gen_params) 74 fifo = BasicFifo(self.gen_params.get(FetchLayouts).fetch_result, depth=2) 75 self.fifo = SimpleTestCircuit(fifo, exclude={"write"}) 76 self.fetch_resume_mock = TestbenchIO(Adapter()) 77 78 fetch_unit = FetchUnit(self.gen_params, self.icache) 79 fetch_unit.cont.provide(fifo.write) 80 81 self.fetch = SimpleTestCircuit(fetch_unit, exclude={"cont"}) 82 83 self.m = ModuleConnector(self.csr_instances, self.icache, self.fifo, self.fetch) 84 85 self.instr_queue = deque() 86 self.mem = {} 87 self.memerr = set() 88 self.input_q = deque() 89 self.output_q = deque() 90 self.stalled = False 91 92 self.next_fetch_request = self.pc 93 self.last_redirect = None 94 self.backend_redirect = deque() 95 96 random.seed(41) 97 98 def add_instr(self, data: int, jumps: bool, jump_offset: int = 0, branch_taken: bool = False) -> int: 99 rvc = (data & 0b11) != 0b11 100 if rvc: 101 self.mem[self.pc] = data 102 else: 103 self.mem[self.pc] = data & 0xFFFF 104 self.mem[self.pc + 2] = data >> 16 105 106 next_pc = self.pc + (2 if rvc else 4) 107 if jumps and branch_taken: 108 next_pc = self.pc + jump_offset 109 110 self.instr_queue.append( 111 { 112 "instr": data, 113 "pc": self.pc, 114 "jumps": jumps, 115 "branch_taken": branch_taken, 116 "next_pc": next_pc, 117 "rvc": rvc, 118 } 119 ) 120 121 instr_pc = self.pc 122 self.pc = next_pc 123 124 return instr_pc 125 126 def gen_non_branch_instr(self, rvc: bool) -> int: 127 if rvc: 128 data = (random.randrange(2**11) << 2) | 0b01 129 else: 130 data = random.randrange(2**32) & ~0b1111111 131 data |= 0b11 # 2 lowest bits must be set in 32-bit long instructions 132 133 return self.add_instr(data, False) 134 135 def gen_jal(self, offset: int) -> int: 136 data = JTypeInstr(opcode=Opcode.JAL, rd=0, imm=offset).encode() 137 138 return self.add_instr(data, True, jump_offset=offset, branch_taken=True) 139 140 def gen_branch(self, offset: int, taken: bool): 141 data = BTypeInstr(opcode=Opcode.BRANCH, imm=offset, funct3=Funct3.BEQ, rs1=0, rs2=0).encode() 142 143 return self.add_instr(data, True, jump_offset=offset, branch_taken=taken) 144 145 async def cache_process(self, sim: ProcessContext): 146 while True: 147 while len(self.input_q) == 0: 148 await sim.tick() 149 150 await self.random_wait_geom(sim, 0.5) 151 152 req_addr = self.input_q.popleft() & ~(self.gen_params.fetch_block_bytes - 1) 153 154 def load_or_gen_mem(addr): 155 if addr in self.mem: 156 return self.mem[addr] 157 158 # Make sure to generate a compressed instruction to avoid 159 # random cross boundary instructions. 160 return random.randrange(2**16) & ~(0b11) 161 162 fetch_block = 0 163 bad_addr = False 164 for i in range(0, self.gen_params.fetch_block_bytes, 2): 165 fetch_block |= load_or_gen_mem(req_addr + i) << (8 * i) 166 if req_addr + i in self.memerr: 167 if random.random() < 0.3: 168 fetch_block = 0 169 elif random.random() < 0.3: 170 fetch_block = random.randrange(1 << (self.gen_params.fetch_block_bytes * 8)) 171 bad_addr = True 172 173 self.output_q.append({"fetch_block": fetch_block, "error": bad_addr}) 174 175 @def_method_mock( 176 lambda self: self.icache.issue_req_io, enable=lambda self: len(self.input_q) < 2 177 ) # TODO had sched_prio 178 def issue_req_mock(self, paddr): 179 @MethodMock.effect 180 def eff(): 181 self.input_q.append(paddr) 182 183 @def_method_mock(lambda self: self.icache.accept_res_io, enable=lambda self: len(self.output_q) > 0) 184 def accept_res_mock(self): 185 @MethodMock.effect 186 def eff(): 187 self.output_q.popleft() 188 189 if self.output_q: 190 return self.output_q[0] 191 192 @def_method_mock(lambda self: self.fetch.stall_unsafe) 193 def stall_lock_unsafe(self): 194 pass 195 196 @def_method_mock(lambda self: self.fetch.fetch_writeback) 197 def fetch_writeback_mock(self, redirect, redirect_target): 198 @MethodMock.effect 199 def eff(): 200 if redirect: 201 self.last_redirect = redirect_target 202 else: 203 self.stalled = True 204 205 async def fetch_out_check(self, sim: TestbenchContext): 206 async def check_instr(instr, v): 207 access_fault = FetchLayouts.FaultFlag.ACCESS_FAULT if instr["pc"] in self.memerr else 0 208 if not instr["rvc"]: 209 if instr["pc"] + 2 in self.memerr: 210 access_fault = ( 211 FetchLayouts.FaultFlag.ACCESS_FAULT | FetchLayouts.FaultFlag.EXCEPTION_ON_SECOND_HALF 212 if not access_fault 213 else access_fault 214 ) 215 216 print(instr, v["pc"], v["access_fault"]) 217 assert v["pc"] == instr["pc"] 218 assert v["access_fault"] == access_fault 219 220 if not access_fault: 221 instr_data = instr["instr"] 222 if (instr_data & 0b11) == 0b11: 223 assert v["instr"] == instr_data 224 225 if (instr["jumps"] and (instr["branch_taken"] != v["predicted_taken"])) or access_fault: 226 await self.random_wait(sim, 5) 227 self.stalled = True 228 await self.fetch.flush.call(sim) 229 await self.random_wait(sim, 5) 230 231 # Empty the pipeline 232 await self.fifo.clear.call_try(sim) 233 await sim.tick() 234 235 resume_pc = instr["next_pc"] 236 if access_fault: 237 # Resume from the next fetch block 238 resume_pc = ( 239 instr["pc"] & ~(self.gen_params.fetch_block_bytes - 1) 240 ) + self.gen_params.fetch_block_bytes * ( 241 2 if FetchLayouts.FaultFlag.EXCEPTION_ON_SECOND_HALF in access_fault else 1 242 ) 243 244 self.backend_redirect.append(resume_pc) 245 246 return True 247 248 while self.instr_queue: 249 v = await self.fifo.read.call(sim) 250 251 for k in range(v.count): 252 instr = self.instr_queue.popleft() 253 # if fault happened, throw away rest of insns 254 if await check_instr(instr, v.data[k]): 255 break 256 # test ended, garbage insns ahead 257 if not self.instr_queue: 258 break 259 260 async def requester(self, sim: ProcessContext): 261 while True: 262 ret = None 263 if self.stalled: 264 await sim.tick() 265 else: 266 ret = await self.fetch.fetch_request.call_try(sim, pc=self.next_fetch_request) 267 268 await sim.delay(0) 269 if self.stalled: 270 while not self.backend_redirect: 271 await self.tick(sim) 272 273 self.stalled = False 274 self.next_fetch_request = self.backend_redirect[0] 275 self.backend_redirect.popleft() 276 277 elif self.last_redirect is not None: 278 self.next_fetch_request = self.last_redirect 279 elif ret is not None: 280 self.next_fetch_request = ( 281 1 + (self.next_fetch_request // self.gen_params.fetch_block_bytes) 282 ) * self.gen_params.fetch_block_bytes 283 284 self.last_redirect = None 285 286 def run_sim(self): 287 with self.run_simulation(self.m, max_cycles=1000) as sim: 288 sim.add_process(self.cache_process) 289 sim.add_process(self.requester) 290 sim.add_testbench(self.fetch_out_check) 291 292 def test_simple_no_jumps(self): 293 for _ in range(50): 294 self.gen_non_branch_instr(rvc=False) 295 296 self.run_sim() 297 298 def test_simple_no_jumps_rvc(self): 299 if not self.with_rvc: 300 self.run_sim() # Run simulation to avoid unused warnings. 301 return 302 303 # Try a fetch block full of non-RVC instructions 304 for _ in range(self.gen_params.fetch_width // 2): 305 self.gen_non_branch_instr(rvc=False) 306 307 # Try a fetch block full of RVC instructions 308 for _ in range(self.gen_params.fetch_width): 309 self.gen_non_branch_instr(rvc=True) 310 311 # Try what if an instruction crossed a boundary of a fetch block 312 self.gen_non_branch_instr(rvc=True) 313 for _ in range(self.gen_params.fetch_width - 1): 314 self.gen_non_branch_instr(rvc=False) 315 316 self.gen_non_branch_instr(rvc=True) 317 318 # We are now at the beginning of a fetch block again. 319 320 # RVC interleaved with non-RVC 321 for _ in range(self.gen_params.fetch_width): 322 self.gen_non_branch_instr(rvc=True) 323 self.gen_non_branch_instr(rvc=False) 324 325 # Random sequence 326 for _ in range(50): 327 self.gen_non_branch_instr(rvc=random.randrange(2) == 1) 328 329 self.run_sim() 330 331 def test_jumps(self): 332 # Jump to the next instruction 333 self.gen_jal(4) 334 for _ in range(self.gen_params.fetch_block_bytes // 4 - 1): 335 self.gen_non_branch_instr(rvc=False) 336 337 # Jump to the next fetch block 338 self.gen_jal(self.gen_params.fetch_block_bytes) 339 340 # Two fetch blocks-worth of instructions 341 for _ in range(self.gen_params.fetch_block_bytes // 2): 342 self.gen_non_branch_instr(rvc=False) 343 344 # Jump to the next fetch block, but fill the block with other jump instructions 345 block_pc = self.gen_jal(self.gen_params.fetch_block_bytes) 346 for i in range(self.gen_params.fetch_block_bytes // 4 - 1): 347 data = JTypeInstr(opcode=Opcode.JAL, rd=0, imm=-8).encode() 348 self.mem[block_pc + (i + 1) * 4] = data & 0xFFFF 349 self.mem[block_pc + (i + 1) * 4 + 2] = data >> 16 350 351 # Jump to the last instruction of a fetch block 352 self.gen_jal(2 * self.gen_params.fetch_block_bytes - 4) 353 354 self.gen_non_branch_instr(rvc=False) 355 356 # Jump as the last instruction of the fetch block 357 for _ in range(self.gen_params.fetch_block_bytes // 4 - 1): 358 self.gen_non_branch_instr(rvc=False) 359 self.gen_jal(20) 360 361 # A chain of jumps 362 for _ in range(10): 363 self.gen_jal(random.randrange(4, 100, 4)) 364 365 # A big jump 366 self.gen_jal(1000) 367 self.gen_non_branch_instr(rvc=False) 368 369 # And a jump backwards 370 self.gen_jal(-200) 371 for _ in range(5): 372 self.gen_non_branch_instr(rvc=False) 373 374 self.run_sim() 375 376 def test_jumps_rvc(self): 377 if not self.with_rvc: 378 self.run_sim() 379 return 380 381 # Jump to the last instruction of a fetch block 382 self.gen_jal(2 * self.gen_params.fetch_block_bytes - 2) 383 self.gen_non_branch_instr(rvc=True) 384 385 # Again, but the last instruction spans two fetch blocks 386 self.gen_jal(2 * self.gen_params.fetch_block_bytes - 2) 387 self.gen_non_branch_instr(rvc=False) 388 389 for _ in range(self.gen_params.fetch_width - 1): 390 self.gen_non_branch_instr(rvc=True) 391 392 # Make a jump instruction that spans two fetch blocks 393 for _ in range(self.gen_params.fetch_width - 1): 394 self.gen_non_branch_instr(rvc=True) 395 self.gen_jal(self.gen_params.fetch_block_bytes + 2) 396 397 self.gen_non_branch_instr(rvc=False) 398 399 self.run_sim() 400 401 def test_branches(self): 402 # Taken branch forward 403 self.gen_branch(offset=self.gen_params.fetch_block_bytes, taken=True) 404 405 for _ in range(self.gen_params.fetch_width): 406 self.gen_non_branch_instr(rvc=False) 407 408 # Not taken branch forward 409 self.gen_branch(offset=self.gen_params.fetch_block_bytes, taken=False) 410 411 for _ in range(self.gen_params.fetch_width): 412 self.gen_non_branch_instr(rvc=False) 413 414 # Jump somewhere far - biggest possible value 415 self.gen_branch(offset=4092, taken=True) 416 417 for _ in range(self.gen_params.fetch_width): 418 self.gen_non_branch_instr(rvc=False) 419 420 # Chain a few branches 421 for i in range(10): 422 self.gen_branch(offset=1028, taken=(i % 2 == 0)) 423 424 self.gen_non_branch_instr(rvc=False) 425 426 self.run_sim() 427 428 def test_access_fault(self): 429 for _ in range(self.gen_params.fetch_width): 430 self.gen_non_branch_instr(rvc=False) 431 432 # Access fault at the beginning of the fetch block 433 pc = self.gen_non_branch_instr(rvc=False) 434 self.memerr.add(pc) 435 436 # We will resume from the next fetch block 437 self.pc = pc + self.gen_params.fetch_block_bytes 438 439 for _ in range(self.gen_params.fetch_width): 440 self.gen_non_branch_instr(rvc=False) 441 442 # Access fault in a block with a jump 443 pc = self.gen_jal(2 * self.gen_params.fetch_block_bytes) 444 self.memerr.add(pc) 445 446 # We will resume from the next fetch block 447 self.pc = pc + self.gen_params.fetch_block_bytes 448 449 if self.with_rvc: 450 # Access fault on sencond half on instruction 451 for _ in range(self.gen_params.fetch_width - 1): 452 self.gen_non_branch_instr(rvc=True) 453 pc = self.gen_non_branch_instr(rvc=False) # 4 byte instruction crossing block 454 self.memerr.add(pc + 2) 455 456 # We will resume from next valid block 457 self.pc = pc + 2 + self.gen_params.fetch_block_bytes 458 459 self.gen_non_branch_instr(rvc=False) 460 461 self.run_sim() 462 463 def test_random(self): 464 for _ in range(500): 465 r = random.random() 466 if r < 0.6: 467 rvc = random.randrange(2) == 0 if self.with_rvc else False 468 self.gen_non_branch_instr(rvc=rvc) 469 else: 470 offset = random.randrange(0, 1000, 2) 471 if not self.with_rvc: 472 offset = offset & ~(0b11) 473 if r < 0.8: 474 self.gen_jal(offset) 475 else: 476 self.gen_branch(offset, taken=random.randrange(2) == 0) 477 478 with self.run_simulation(self.m) as sim: 479 sim.add_process(self.cache_process) 480 sim.add_testbench(self.fetch_out_check) 481 sim.add_process(self.requester) 482 483 484 @dataclass(frozen=True) 485 class CheckerResult: 486 mispredicted: bool 487 stall: bool 488 fb_instr_idx: int 489 redirect_target: int 490 491 492 @parameterized_class( 493 ("name", "fetch_block_log", "with_rvc"), 494 [ 495 ("block4B", 2, False), 496 ("block4B_rvc", 2, True), 497 ("block8B", 3, False), 498 ("block8B_rvc", 3, True), 499 ("block16B", 4, False), 500 ("block16B_rvc", 4, True), 501 ], 502 ) 503 class TestPredictionChecker(TestCaseWithSimulator): 504 fetch_block_log: int 505 with_rvc: bool 506 507 @pytest.fixture(autouse=True) 508 def setup(self, fixture_initialize_testing_env): 509 self.gen_params = GenParams( 510 test_core_config.replace(compressed=self.with_rvc, fetch_block_bytes_log=self.fetch_block_log) 511 ) 512 513 self.m = SimpleTestCircuit(PredictionChecker(self.gen_params)) 514 515 async def check( 516 self, 517 sim: TestbenchContext, 518 pc: int, 519 block_cross: bool, 520 predecoded: list[tuple[CfiType, int]], 521 branch_mask: int, 522 cfi_idx: int, 523 cfi_type: CfiType, 524 cfi_target: Optional[int], 525 valid_mask: int = -1, 526 ) -> CheckerResult: 527 # Fill the array with non-CFI instructions 528 for _ in range(self.gen_params.fetch_width - len(predecoded)): 529 predecoded.append((CfiType.INVALID, 0)) 530 predecoded_raw = [ 531 {"cfi_type": predecoded[i][0], "cfi_offset": predecoded[i][1], "unsafe": 0} 532 for i in range(self.gen_params.fetch_width) 533 ] 534 535 prediction = { 536 "branch_mask": branch_mask, 537 "cfi_idx": cfi_idx, 538 "cfi_type": cfi_type, 539 "cfi_target": cfi_target or 0, 540 "cfi_target_valid": 1 if cfi_target is not None else 0, 541 } 542 543 instr_start = ( 544 pc & ((1 << self.gen_params.fetch_block_bytes_log) - 1) 545 ) >> self.gen_params.min_instr_width_bytes_log 546 547 instr_valid = (((1 << self.gen_params.fetch_width) - 1) << instr_start) & valid_mask 548 549 res = await self.m.check.call( 550 sim, 551 fb_addr=pc >> self.gen_params.fetch_block_bytes_log, 552 instr_block_cross=block_cross, 553 instr_valid=instr_valid, 554 predecoded=predecoded_raw, 555 prediction=prediction, 556 ) 557 558 return CheckerResult( 559 mispredicted=bool(res["mispredicted"]), 560 stall=bool(res["stall"]), 561 fb_instr_idx=res["fb_instr_idx"], 562 redirect_target=res["redirect_target"], 563 ) 564 565 def assert_resp( 566 self, 567 res: CheckerResult, 568 mispredicted: Optional[bool] = None, 569 stall: Optional[bool] = None, 570 fb_instr_idx: Optional[int] = None, 571 redirect_target: Optional[int] = None, 572 ): 573 if mispredicted is not None: 574 assert res.mispredicted == mispredicted 575 if stall is not None: 576 assert res.stall == stall 577 if fb_instr_idx is not None: 578 assert res.fb_instr_idx == fb_instr_idx 579 if redirect_target is not None: 580 assert res.redirect_target == redirect_target 581 582 def test_no_misprediction(self): 583 instr_width = self.gen_params.min_instr_width_bytes 584 fetch_width = self.gen_params.fetch_width 585 586 async def proc(sim: TestbenchContext): 587 # No CFI at all 588 ret = await self.check(sim, 0x100, False, [], 0, 0, CfiType.INVALID, None) 589 self.assert_resp(ret, mispredicted=False) 590 591 # There is one forward branch that we didn't predict 592 ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0, 0, CfiType.INVALID, None) 593 self.assert_resp(ret, mispredicted=False) 594 595 # There are many forward branches that we didn't predict 596 ret = await self.check( 597 sim, 0x100, False, [(CfiType.BRANCH, 100)] * fetch_width, 0, 0, CfiType.INVALID, None 598 ) 599 self.assert_resp(ret, mispredicted=False) 600 601 # There is a predicted JAL instr 602 ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 0x100 + 100) 603 self.assert_resp(ret, mispredicted=False) 604 605 # There is a predicted JALR instr - the predecoded offset can now be anything 606 ret = await self.check(sim, 0x100, False, [(CfiType.JALR, 200)], 0, 0, CfiType.JALR, 0x100 + 100) 607 self.assert_resp(ret, mispredicted=False) 608 609 # There is a forward taken-predicted branch 610 ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 0x100 + 100) 611 self.assert_resp(ret, mispredicted=False) 612 613 # There is a backward taken-predicted branch 614 ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100) 615 self.assert_resp(ret, mispredicted=False) 616 617 # Branch located between two fetch blocks 618 if self.with_rvc: 619 ret = await self.check( 620 sim, 0x100, True, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.BRANCH, 0x100 - 100 - 2 621 ) 622 self.assert_resp(ret, mispredicted=False) 623 624 # One branch predicted as not taken 625 ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b1, 0, CfiType.INVALID, 0) 626 self.assert_resp(ret, mispredicted=False) 627 628 # Now tests for fetch blocks with multiple instructions 629 if fetch_width < 2: 630 return 631 632 # Predicted taken branch as the second instruction 633 ret = await self.check( 634 sim, 635 0x100, 636 False, 637 [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 638 0b10, 639 1, 640 CfiType.BRANCH, 641 0x100 + instr_width - 100, 642 ) 643 self.assert_resp(ret, mispredicted=False) 644 645 # Predicted, but not taken branch as the second instruction 646 ret = await self.check( 647 sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b10, 0, CfiType.INVALID, 0 648 ) 649 self.assert_resp(ret, mispredicted=False) 650 651 if self.with_rvc: 652 ret = await self.check( 653 sim, 654 0x100, 655 True, 656 [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 657 0b10, 658 1, 659 CfiType.BRANCH, 660 0x100 + instr_width - 100, 661 ) 662 self.assert_resp(ret, mispredicted=False) 663 664 ret = await self.check( 665 sim, 666 0x100, 667 True, 668 [(CfiType.JAL, 100), (CfiType.JAL, -100)], 669 0b00, 670 1, 671 CfiType.JAL, 672 0x100 + instr_width - 100, 673 valid_mask=0b10, 674 ) 675 self.assert_resp(ret, mispredicted=False) 676 677 # Two branches with all possible combintations taken/not-taken 678 ret = await self.check( 679 sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.INVALID, 0 680 ) 681 self.assert_resp(ret, mispredicted=False) 682 ret = await self.check( 683 sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 0b11, 0, CfiType.BRANCH, 0x100 - 100 684 ) 685 self.assert_resp(ret, mispredicted=False) 686 ret = await self.check( 687 sim, 688 0x100, 689 False, 690 [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 691 0b11, 692 1, 693 CfiType.BRANCH, 694 0x100 + instr_width + 100, 695 ) 696 self.assert_resp(ret, mispredicted=False) 697 698 # JAL at the beginning, but we start from the second instruction 699 ret = await self.check(sim, 0x100 + instr_width, False, [(CfiType.JAL, -100)], 0b0, 0, CfiType.INVALID, 0) 700 self.assert_resp(ret, mispredicted=False) 701 702 # JAL and a forward branch that we didn't predict 703 ret = await self.check( 704 sim, 705 0x100 + instr_width, 706 False, 707 [(CfiType.JAL, -100), (CfiType.BRANCH, 100)], 708 0b00, 709 0, 710 CfiType.INVALID, 711 0, 712 ) 713 self.assert_resp(ret, mispredicted=False) 714 715 # two JAL instructions, but we start from the second one 716 ret = await self.check( 717 sim, 718 0x100 + instr_width, 719 False, 720 [(CfiType.JAL, -100), (CfiType.JAL, 100)], 721 0b00, 722 1, 723 CfiType.JAL, 724 0x100 + instr_width + 100, 725 ) 726 self.assert_resp(ret, mispredicted=False) 727 728 # JAL and a branch, but we start from the second instruction 729 ret = await self.check( 730 sim, 731 0x100 + instr_width, 732 False, 733 [(CfiType.JAL, -100), (CfiType.BRANCH, 100)], 734 0b10, 735 1, 736 CfiType.BRANCH, 737 0x100 + instr_width + 100, 738 ) 739 self.assert_resp(ret, mispredicted=False) 740 741 with self.run_simulation(self.m) as sim: 742 sim.add_testbench(proc) 743 744 def test_preceding_redirection(self): 745 instr_width = self.gen_params.min_instr_width_bytes 746 fetch_width = self.gen_params.fetch_width 747 748 async def proc(sim: TestbenchContext): 749 # No prediction was made, but there is a JAL at the beginning 750 ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None) 751 self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 0x20) 752 753 # The same, but the jump is between two fetch blocks 754 if self.with_rvc: 755 ret = await self.check(sim, 0x100, True, [(CfiType.JAL, 0x20)], 0, 0, CfiType.INVALID, None) 756 self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 0x20 - 2) 757 758 # Not predicted backward branch 759 ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0b0, 0, CfiType.INVALID, 0) 760 self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) 761 762 # Now tests for fetch blocks with multiple instructions 763 if fetch_width < 2: 764 return 765 766 # We predicted the branch on the second instruction, but there's a JAL on the first one. 767 ret = await self.check( 768 sim, 769 0x100, 770 False, 771 [(CfiType.JAL, -100), (CfiType.BRANCH, 100)], 772 0b10, 773 1, 774 CfiType.BRANCH, 775 0x100 + instr_width + 100, 776 ) 777 self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) 778 779 # We predicted the branch on the second instruction, but there's a JALR on the first one. 780 ret = await self.check( 781 sim, 782 0x100, 783 False, 784 [(CfiType.JALR, -100), (CfiType.BRANCH, 100)], 785 0b10, 786 1, 787 CfiType.BRANCH, 788 0x100 + instr_width + 100, 789 ) 790 self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=0) 791 792 # We predicted the branch on the second instruction, but there's a backward on the first one. 793 ret = await self.check( 794 sim, 795 0x100, 796 False, 797 [(CfiType.BRANCH, -100), (CfiType.BRANCH, 100)], 798 0b10, 799 1, 800 CfiType.BRANCH, 801 0x100 + instr_width + 100, 802 ) 803 self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) 804 805 # Unpredicted backward branch as the second instruction 806 ret = await self.check( 807 sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, -100)], 0b00, 0, CfiType.INVALID, 0 808 ) 809 self.assert_resp( 810 ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width - 100 811 ) 812 813 # Unpredicted JAL as the second instruction 814 ret = await self.check( 815 sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b00, 0, CfiType.INVALID, 0 816 ) 817 self.assert_resp( 818 ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100 819 ) 820 821 # Unpredicted JALR as the second instruction 822 ret = await self.check( 823 sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JALR, 100)], 0b00, 0, CfiType.INVALID, 0 824 ) 825 self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=1) 826 827 if fetch_width < 3: 828 return 829 830 ret = await self.check( 831 sim, 832 0x100 + instr_width, 833 False, 834 [(CfiType.JAL, -100), (CfiType.INVALID, 100), (CfiType.JAL, 100)], 835 0b0, 836 0, 837 CfiType.INVALID, 838 None, 839 ) 840 self.assert_resp( 841 ret, mispredicted=True, stall=False, fb_instr_idx=2, redirect_target=0x100 + 2 * instr_width + 100 842 ) 843 844 with self.run_simulation(self.m) as sim: 845 sim.add_testbench(proc) 846 847 def test_mispredicted_cfi_type(self): 848 instr_width = self.gen_params.min_instr_width_bytes 849 fetch_width = self.gen_params.fetch_width 850 fb_bytes = self.gen_params.fetch_block_bytes 851 852 async def proc(sim: TestbenchContext): 853 # We predicted a JAL, but in fact there is a non-CFI instruction 854 ret = await self.check(sim, 0x100, False, [(CfiType.INVALID, 0)], 0, 0, CfiType.JAL, 100) 855 self.assert_resp( 856 ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes 857 ) 858 859 # We predicted a JAL, but in fact there is a branch 860 ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, -100)], 0, 0, CfiType.JAL, 100) 861 self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) 862 863 # We predicted a JAL, but in fact there is a JALR instruction 864 ret = await self.check(sim, 0x100, False, [(CfiType.JALR, -100)], 0, 0, CfiType.JAL, 100) 865 self.assert_resp(ret, mispredicted=True, stall=True, fb_instr_idx=0) 866 867 # We predicted a branch, but in fact there is a JAL 868 ret = await self.check(sim, 0x100, False, [(CfiType.JAL, -100)], 0b1, 0, CfiType.BRANCH, 100) 869 self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 - 100) 870 871 if fetch_width < 2: 872 return 873 874 # There is a branch and a non-CFI, but we predicted two branches 875 ret = await self.check( 876 sim, 0x100, False, [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)], 0b11, 1, CfiType.BRANCH, 100 877 ) 878 self.assert_resp( 879 ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes 880 ) 881 882 # The same as above, but we start from the second instruction 883 ret = await self.check( 884 sim, 885 0x100 + instr_width, 886 False, 887 [(CfiType.BRANCH, -100), (CfiType.INVALID, 0)], 888 0b11, 889 1, 890 CfiType.BRANCH, 891 100, 892 ) 893 self.assert_resp( 894 ret, mispredicted=True, stall=False, fb_instr_idx=fetch_width - 1, redirect_target=0x100 + fb_bytes 895 ) 896 897 with self.run_simulation(self.m) as sim: 898 sim.add_testbench(proc) 899 900 def test_mispredicted_cfi_target(self): 901 instr_width = self.gen_params.min_instr_width_bytes 902 fetch_width = self.gen_params.fetch_width 903 904 async def proc(sim: TestbenchContext): 905 # We predicted a wrong JAL target 906 ret = await self.check(sim, 0x100, False, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 200) 907 self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100) 908 909 # We predicted a wrong branch target 910 ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, 200) 911 self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100) 912 913 # We didn't provide the branch target 914 ret = await self.check(sim, 0x100, False, [(CfiType.BRANCH, 100)], 0b1, 0, CfiType.BRANCH, None) 915 self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100) 916 917 # We predicted a wrong JAL target that is between two fetch blocks 918 if self.with_rvc: 919 ret = await self.check(sim, 0x100, True, [(CfiType.JAL, 100)], 0, 0, CfiType.JAL, 300) 920 self.assert_resp(ret, mispredicted=True, stall=False, fb_instr_idx=0, redirect_target=0x100 + 100 - 2) 921 922 if fetch_width < 2: 923 return 924 925 # The second instruction is a branch without the target 926 ret = await self.check( 927 sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.BRANCH, 100)], 0b10, 1, CfiType.BRANCH, None 928 ) 929 self.assert_resp( 930 ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100 931 ) 932 933 # The second instruction is a JAL with a wrong target 934 ret = await self.check( 935 sim, 0x100, False, [(CfiType.INVALID, 0), (CfiType.JAL, 100)], 0b10, 1, CfiType.JAL, 200 936 ) 937 self.assert_resp( 938 ret, mispredicted=True, stall=False, fb_instr_idx=1, redirect_target=0x100 + instr_width + 100 939 ) 940 941 with self.run_simulation(self.m) as sim: 942 sim.add_testbench(proc)