fetch.py
1 from math import lcm 2 from amaranth import * 3 from amaranth.lib.data import ArrayLayout 4 from transactron.lib import BasicFifo, WideFifo, Semaphore, logging, Pipe 5 from transactron.lib.metrics import * 6 from transactron.lib.simultaneous import condition 7 from transactron.utils import count_trailing_zeros, popcount, assign, StableSelectingNetwork 8 from transactron.utils.transactron_helpers import make_layout 9 from transactron.utils.amaranth_ext.coding import PriorityEncoder 10 from transactron import * 11 12 from coreblocks.cache.iface import CacheInterface 13 from coreblocks.frontend.decoder.rvc import InstrDecompress, is_instr_compressed 14 from coreblocks.priv.pmp import PMPChecker, PMPOperationMode 15 16 from coreblocks.arch import * 17 from coreblocks.params import * 18 from coreblocks.interface.layouts import * 19 from coreblocks.frontend import FrontendParams 20 from coreblocks.priv.vmem.translation import AddressTranslator, AddressTranslatorMode 21 22 log = logging.HardwareLogger("frontend.fetch") 23 24 25 class FetchUnit(Elaboratable): 26 """Superscalar Fetch Unit 27 28 This module is responsible for retrieving instructions from memory and forwarding them to the decode stage. 29 30 It works with 'fetch blocks', chunks of data it handles at a time. The size of these blocks 31 depends on GenParams.fetch_block_bytes and is related to how many instructions the unit can 32 handle at once, which can vary if extension C is on. 33 34 The unit also deals with expanding compressed instructions and managing instructions that aren't aligned to 35 4-byte boundaries. 36 """ 37 38 fetch_request: Provided[Method] 39 """Requests a fetch of the instruction block at the given PC.""" 40 41 fetch_writeback: Required[Method] 42 """Invoked to write back the status of the requested fetch block.""" 43 44 flush: Provided[Method] 45 """Flushes the fetch unit from the currently processed fetch blocks, so it can be redirected or/and stalled.""" 46 47 cont: Required[Method] 48 """Should be invoked to send fetched instruction to the next step.""" 49 50 stall_unsafe: Required[Method] 51 """Called when an unsafe instruction is fetched.""" 52 53 def __init__(self, gen_params: GenParams, icache: CacheInterface) -> None: 54 """ 55 Parameters 56 ---------- 57 gen_params : GenParams 58 Instance of GenParams with parameters which should be used to generate 59 fetch unit. 60 icache : CacheInterface 61 Instruction Cache 62 """ 63 self.gen_params = gen_params 64 self.icache = icache 65 self.stall_unsafe = Method() 66 67 self.layouts = self.gen_params.get(FetchLayouts) 68 69 self.cont = Method(i=self.layouts.fetch_result) 70 self.fetch_request = Method(i=self.layouts.fetch_request) 71 self.fetch_writeback = Method(i=self.layouts.fetch_writeback) 72 73 self.flush = Method() 74 75 self.perf_fetch_utilization = TaggedCounter( 76 "frontend.fetch.fetch_block_util", 77 "Number of valid instructions in fetch blocks", 78 tags=range(self.gen_params.fetch_width + 1), 79 ) 80 81 def elaborate(self, platform): 82 m = TModule() 83 84 m.submodules.addr_translator = addr_translator = AddressTranslator( 85 self.gen_params, mode=AddressTranslatorMode.INSTRUCTION 86 ) 87 m.submodules += [self.perf_fetch_utilization] 88 89 fetch_width = self.gen_params.fetch_width 90 fields = self.gen_params.get(CommonLayoutFields) 91 params = self.gen_params.get(FrontendParams) 92 93 # Serializer creates a continuous instruction stream from fetch 94 # blocks, which can have holes in them. 95 m.submodules.aligner = aligner = StableSelectingNetwork(fetch_width, self.layouts.raw_instr) 96 serializer_depth = 2 * lcm(self.gen_params.frontend_superscalarity, fetch_width) 97 m.submodules.serializer = serializer = WideFifo( 98 self.layouts.raw_instr, 99 depth=serializer_depth, 100 read_width=self.gen_params.frontend_superscalarity, 101 write_width=fetch_width, 102 ) 103 104 with Transaction(name="cont").body(m): 105 peek_result = serializer.peek(m) 106 count = Signal(range(self.gen_params.frontend_superscalarity + 1)) 107 # we want at most one branch insn in scheduling group, and only at the end (for simplicity) 108 # some insts in peek_result.data might not be valid, but this is still correct 109 which_is_branch = [0] + [instr.cfi_type == CfiType.BRANCH for instr in peek_result.data][:-1] 110 m.d.comb += count.eq(count_trailing_zeros(Cat(which_is_branch))) 111 result = serializer.read(m, count=count) 112 for i in range(self.gen_params.frontend_superscalarity): 113 log.info( 114 m, 115 i < result.count, 116 "Sending an instr to the backend pc=0x{:x} instr=0x{:x}", 117 result.data[i].pc, 118 result.data[i].instr, 119 ) 120 self.cont(m, result) 121 122 m.submodules.fetch_requests = fetch_requests = BasicFifo( 123 make_layout(fields.pc, ("access_fault", 1), ("page_fault", 1)), 124 depth=2, 125 ) 126 127 # This limits number of fetch blocks the fetch unit can process 128 # at a time. We start counting when sending a request to the cache and 129 # stop when pushing a fetch packet out of the fetch unit. 130 m.submodules.req_counter = req_counter = Semaphore(4) 131 flushing_counter = Signal.like(req_counter.count) 132 133 flush_now = Signal() 134 135 def flush(): 136 m.d.comb += flush_now.eq(1) 137 138 # 139 # Fetch - stage 0 140 # ================ 141 # - send a request to the instruction cache 142 # - check PMP execute permission (if PMP is enabled) 143 # 144 145 m.submodules.pmp_checker = pmp_checker = PMPChecker( 146 self.gen_params, 147 mode=PMPOperationMode.INSTRUCTION_FETCH, 148 ) 149 150 @def_method(m, self.fetch_request) 151 def _(pc): 152 log.info(m, True, "[IFU] request pc=0x{:x}", pc) 153 req_counter.acquire(m) 154 155 addr_translator.request(m, addr=pc) 156 157 with Transaction().body(m): 158 translated = addr_translator.accept(m) 159 access_fault = Signal() 160 161 m.d.av_comb += pmp_checker.paddr.eq(translated.paddr) 162 m.d.av_comb += access_fault.eq(translated.access_fault | ~pmp_checker.result.x) 163 164 with m.If(~translated.page_fault & ~access_fault): 165 self.icache.issue_req(m, paddr=translated.paddr) 166 167 fetch_requests.write( 168 m, 169 pc=translated.vaddr, 170 access_fault=access_fault, 171 page_fault=translated.page_fault, 172 ) 173 174 # 175 # State passed between stage 1 and stage 2 176 # 177 m.submodules.s1_s2_pipe = s1_s2_pipe = Pipe( 178 [ 179 fields.fb_addr, 180 ("instr_valid", fetch_width), 181 ("access_fault", FetchLayouts.FaultFlag), 182 ("rvc", fetch_width), 183 ("instrs", ArrayLayout(self.gen_params.isa.ilen, fetch_width)), 184 ("instr_block_cross", 1), 185 ] 186 ) 187 188 # 189 # Fetch - stage 1 190 # ================ 191 # - read the response from the cache 192 # - expand compressed instructions (if applicable) 193 # - find where each instruction begins 194 # - handle instructions that cross a fetch boundary 195 # 196 rvc_expanders = [InstrDecompress(self.gen_params) for _ in range(fetch_width)] 197 for n, module in enumerate(rvc_expanders): 198 m.submodules[f"rvc_expander_{n}"] = module 199 200 # With the C extension enabled, a single instruction can 201 # be located on a boundary of two fetch blocks. Hence, 202 # this requires some statefulness of the stage 1. 203 prev_half = Signal(16) 204 prev_half_addr = Signal(make_layout(fields.fb_addr).size) 205 prev_half_v = Signal() 206 with Transaction(name="Fetch_Stage1").body(m): 207 fetch_request = fetch_requests.read(m) 208 209 # The address of the fetch block. 210 fetch_block_addr = params.fb_addr(fetch_request.pc) 211 # The index (in instructions) of the first instruction that we should process. 212 fetch_block_offset = params.fb_instr_idx(fetch_request.pc) 213 214 # Conditionally read from icache or mark fault 215 cache_resp = Signal(self.gen_params.get(ICacheLayouts).accept_res) 216 access_fault = Signal(FetchLayouts.FaultFlag) 217 218 with condition(m) as branch: 219 with branch(fetch_request.page_fault | fetch_request.access_fault): 220 with m.If(fetch_request.page_fault): 221 m.d.av_comb += access_fault.eq(FetchLayouts.FaultFlag.PAGE_FAULT) 222 with m.Else(): 223 m.d.av_comb += access_fault.eq(FetchLayouts.FaultFlag.ACCESS_FAULT) 224 m.d.av_comb += cache_resp.fetch_block.eq(0) 225 m.d.av_comb += cache_resp.error.eq(0) 226 with branch(): 227 m.d.av_comb += cache_resp.eq(self.icache.accept_res(m)) 228 m.d.av_comb += access_fault.eq(Mux(cache_resp.error, FetchLayouts.FaultFlag.ACCESS_FAULT, 0)) 229 230 # 231 # Expand compressed instructions from the fetch block. 232 # 233 expanded_instr = [Signal(self.gen_params.isa.ilen) for _ in range(fetch_width)] 234 is_rvc = Signal(fetch_width) 235 236 # Whether in this cycle we have a fetch block that contains 237 # an instruction that crosses a fetch boundary 238 instr_block_cross = Signal() 239 m.d.av_comb += instr_block_cross.eq(prev_half_v & ((prev_half_addr + 1) == fetch_block_addr)) 240 241 for i in range(fetch_width): 242 if Extension.ZCA in self.gen_params.isa.extensions: 243 full_instr = Signal(self.gen_params.isa.ilen) 244 if i == 0: 245 # If we have a half of an instruction from the previous block - we need to use it now. 246 with m.If(instr_block_cross): 247 m.d.av_comb += full_instr.eq(Cat(prev_half, cache_resp.fetch_block[0:16])) 248 with m.Else(): 249 m.d.av_comb += full_instr.eq(cache_resp.fetch_block[:32]) 250 elif i == fetch_width - 1: 251 # We will have only 16 bits for the last instruction, so append 16 zeroes. 252 m.d.av_comb += full_instr.eq(Cat(cache_resp.fetch_block[-16:], C(0, 16))) 253 else: 254 m.d.av_comb += full_instr.eq(cache_resp.fetch_block[i * 16 : i * 16 + 32]) 255 256 m.d.av_comb += is_rvc[i].eq(is_instr_compressed(full_instr)) 257 m.d.av_comb += rvc_expanders[i].instr_in.eq(full_instr[:16]) 258 m.d.av_comb += expanded_instr[i].eq(Mux(is_rvc[i], rvc_expanders[i].instr_out, full_instr)) 259 else: 260 m.d.av_comb += expanded_instr[i].eq(cache_resp.fetch_block[i * 32 : (i + 1) * 32]) 261 262 # Mask denoting at which offsets expected instructions start (depends on rvc indication and start address) 263 instr_start = [Signal() for _ in range(fetch_width)] 264 for i in range(fetch_width): 265 if Extension.ZCA in self.gen_params.isa.extensions: 266 if i == 0: 267 m.d.av_comb += instr_start[i].eq(fetch_block_offset == 0) 268 elif i == 1: 269 m.d.av_comb += instr_start[i].eq( 270 (fetch_block_offset <= i) & (~instr_start[0] | is_rvc[0] | instr_block_cross) 271 ) 272 else: 273 m.d.av_comb += instr_start[i].eq( 274 (fetch_block_offset <= i) & (~instr_start[i - 1] | is_rvc[i - 1]) 275 ) 276 else: 277 m.d.av_comb += instr_start[i].eq(fetch_block_offset <= i) 278 279 if Extension.ZCA in self.gen_params.isa.extensions: 280 instr_position_mask = Cat(instr_start[:-1], instr_start[-1] & is_rvc[-1]) 281 282 m.d.sync += prev_half_v.eq( 283 (flushing_counter <= 1) & ~access_fault.any() & ~is_rvc[-1] & instr_start[-1] 284 ) 285 m.d.sync += prev_half.eq(cache_resp.fetch_block[-16:]) 286 m.d.sync += prev_half_addr.eq(fetch_block_addr) 287 else: 288 instr_position_mask = Cat(instr_start) 289 290 # Reported fault pc (signalled by emitting an instruction) must always match first requested instruction 291 access_fault_instr_position = 1 << fetch_block_offset 292 293 s1_s2_pipe.write( 294 m, 295 fb_addr=fetch_block_addr, 296 instr_valid=Mux(access_fault.any(), access_fault_instr_position, instr_position_mask), 297 access_fault=access_fault, 298 rvc=is_rvc, 299 instrs=expanded_instr, 300 instr_block_cross=instr_block_cross, 301 ) 302 303 # Make sure to clean the state 304 with m.If(flush_now): 305 m.d.sync += prev_half_v.eq(0) 306 307 # 308 # Fetch - stage 2 309 # ================ 310 # - predecode instructions 311 # - verify the branch prediction 312 # - redirect the frontend if mispredicted 313 # - check if any of instructions stalls the frontend 314 # - enqueue a packet of instructions 315 # 316 317 predecoders = [Predecoder(self.gen_params) for _ in range(fetch_width)] 318 for n, module in enumerate(predecoders): 319 m.submodules[f"predecoder_{n}"] = module 320 321 m.submodules.prediction_checker = prediction_checker = PredictionChecker(self.gen_params) 322 323 with Transaction(name="Fetch_Stage2").body(m): 324 req_counter.release(m) 325 s1_data = s1_s2_pipe.read(m) 326 327 instrs = s1_data.instrs 328 fetch_block_addr = s1_data.fb_addr 329 instr_valid = s1_data.instr_valid 330 access_fault = s1_data.access_fault 331 fault_any = Signal() 332 m.d.av_comb += fault_any.eq(access_fault.any()) 333 334 # Predecode instructions 335 predecoded_instr = [predecoders[i].predecode(m, instrs[i]) for i in range(fetch_width)] 336 337 # No prediction for now 338 prediction = Signal(self.layouts.bpu_prediction) 339 340 # The method is guarded by the If to make sure that the metrics 341 # are updated only if not flushing. 342 with m.If(flushing_counter == 0): 343 predcheck_res = prediction_checker.check( 344 m, 345 fb_addr=fetch_block_addr, 346 instr_block_cross=s1_data.instr_block_cross, 347 instr_valid=instr_valid, 348 predecoded=predecoded_instr, 349 prediction=prediction, 350 ) 351 352 # Is the instruction unsafe (i.e. stalls the frontend until the backend resumes it). 353 instr_unsafe = Signal(fetch_width) 354 for i in range(fetch_width): 355 # If there was an access fault, mark every instruction as unsafe 356 m.d.av_comb += instr_unsafe[i].eq((predecoded_instr[i].unsafe | fault_any) & instr_valid[i]) 357 358 m.submodules.unsafe_prio_encoder = unsafe_prio_encoder = PriorityEncoder(fetch_width) 359 m.d.av_comb += unsafe_prio_encoder.i.eq(instr_unsafe) 360 361 unsafe_idx = unsafe_prio_encoder.o[: self.gen_params.fetch_width_log] 362 has_unsafe = Signal() 363 m.d.av_comb += has_unsafe.eq(~unsafe_prio_encoder.n) 364 365 redirect_before_unsafe = Signal() 366 m.d.av_comb += redirect_before_unsafe.eq(predcheck_res.fb_instr_idx < unsafe_idx) 367 368 redirect = Signal() 369 unsafe_stall = Signal() 370 redirect_or_unsafe_idx = Signal(range(fetch_width)) 371 372 with m.If(predcheck_res.mispredicted & (~has_unsafe | redirect_before_unsafe)): 373 m.d.av_comb += [ 374 redirect.eq(~predcheck_res.stall), 375 unsafe_stall.eq(predcheck_res.stall), 376 redirect_or_unsafe_idx.eq(predcheck_res.fb_instr_idx), 377 ] 378 with m.Elif(has_unsafe): 379 m.d.av_comb += [ 380 unsafe_stall.eq(1), 381 redirect_or_unsafe_idx.eq(unsafe_idx), 382 ] 383 384 # This mask denotes what prefix of instructions we should enqueue. 385 valid_instr_prefix = Signal(fetch_width) 386 with m.If(redirect | unsafe_stall): 387 # If there is an instruction that redirects or stalls the frontend, enqueue 388 # instructions only up to that instruction. 389 m.d.av_comb += valid_instr_prefix.eq((1 << (redirect_or_unsafe_idx + 1)) - 1) 390 with m.Else(): 391 m.d.av_comb += valid_instr_prefix.eq(C(1).replicate(fetch_width)) 392 393 # The ultimate mask that tells which instructions should be sent to the backend. 394 fetch_mask = Signal(fetch_width) 395 m.d.av_comb += fetch_mask.eq(instr_valid & valid_instr_prefix) 396 397 # Aggregate all signals that will be sent out of the fetch unit. 398 raw_instrs = Signal(ArrayLayout(self.layouts.raw_instr, fetch_width)) 399 for i in range(fetch_width): 400 m.d.av_comb += [ 401 raw_instrs[i].instr.eq(instrs[i]), 402 raw_instrs[i].pc.eq(params.pc_from_fb(fetch_block_addr, i)), 403 raw_instrs[i].rvc.eq(s1_data.rvc[i]), 404 raw_instrs[i].predicted_taken.eq(redirect & (predcheck_res.fb_instr_idx == i)), 405 raw_instrs[i].access_fault.eq(s1_data.access_fault), 406 raw_instrs[i].cfi_type.eq(predecoded_instr[i].cfi_type), 407 ] 408 409 if Extension.ZCA in self.gen_params.isa.extensions: 410 with m.If(s1_data.instr_block_cross): 411 m.d.av_comb += raw_instrs[0].pc.eq(params.pc_from_fb(fetch_block_addr, 0) - 2) 412 with m.If(s1_data.access_fault): 413 # Mark that access/page fault happened only at second (current) half. 414 # If fault happened on the first half `instr_block_cross` would be false 415 m.d.av_comb += raw_instrs[0].access_fault.eq( 416 s1_data.access_fault | FetchLayouts.FaultFlag.EXCEPTION_ON_SECOND_HALF 417 ) 418 419 with condition(m) as branch: 420 with branch(flushing_counter == 0): 421 with m.If(fault_any | unsafe_stall): 422 self.stall_unsafe(m) 423 424 with m.If(fault_any | unsafe_stall | redirect): 425 self.fetch_writeback( 426 m, 427 redirect=redirect, 428 redirect_target=predcheck_res.redirect_target, 429 ) 430 flush() 431 432 self.perf_fetch_utilization.incr(m, popcount(fetch_mask)) 433 434 # Make sure this is called only once to avoid a huge mux on arguments 435 m.d.av_comb += [aligner.valids.eq(fetch_mask), aligner.inputs.eq(raw_instrs)] 436 serializer.write(m, data=aligner.outputs, count=aligner.output_cnt) 437 with branch(): 438 m.d.sync += flushing_counter.eq(flushing_counter - 1) 439 440 with m.If(flush_now): 441 m.d.sync += flushing_counter.eq(req_counter.count_next) 442 443 @def_method(m, self.flush) 444 def _(): 445 flush() 446 serializer.clear(m) 447 448 return m 449 450 451 class Predecoder(Elaboratable): 452 """Instruction predecoder 453 454 The module performs basic analysis on instructions. It identifies if an instruction 455 is a jump instruction, determines the type of jump, and finds the jump's target. 456 457 Its role is to give quick feedback to the fetch unit and potentially the branch predictor 458 about the fetched instruction. This helps in redirecting the fetch unit promptly if needed. 459 """ 460 461 def __init__(self, gen_params: GenParams) -> None: 462 """ 463 Parameters 464 ---------- 465 gen_params: GenParams 466 Core generation parameters. 467 """ 468 self.gen_params = gen_params 469 470 layouts = self.gen_params.get(FetchLayouts) 471 fields = self.gen_params.get(CommonLayoutFields) 472 473 self.predecode = Method(i=make_layout(fields.instr), o=layouts.predecoded_instr) 474 475 def elaborate(self, platform): 476 m = TModule() 477 478 @def_method(m, self.predecode) 479 def _(instr): 480 quadrant = instr[0:2] 481 opcode = instr[2:7] 482 funct3 = instr[12:15] 483 rd = instr[7:12] 484 rs1 = instr[15:20] 485 486 bimm = Signal(signed(13)) 487 jimm = Signal(signed(21)) 488 iimm = Signal(signed(12)) 489 490 m.d.av_comb += [ 491 iimm.eq(instr[20:]), 492 bimm.eq(Cat(0, instr[8:12], instr[25:31], instr[7], instr[31])), 493 jimm.eq(Cat(0, instr[21:31], instr[20], instr[12:20], instr[31])), 494 ] 495 496 ret = Signal.like(self.predecode.data_out) 497 498 with m.Switch(opcode): 499 with m.Case(Opcode.BRANCH): 500 m.d.av_comb += ret.cfi_type.eq(CfiType.BRANCH) 501 m.d.av_comb += ret.cfi_offset.eq(bimm) 502 with m.Case(Opcode.JAL): 503 m.d.av_comb += ret.cfi_type.eq( 504 Mux((rd == Registers.X1) | (rd == Registers.X5), CfiType.CALL, CfiType.JAL) 505 ) 506 m.d.av_comb += ret.cfi_offset.eq(jimm) 507 with m.Case(Opcode.JALR): 508 m.d.av_comb += ret.cfi_type.eq( 509 Mux((rs1 == Registers.X1) | (rs1 == Registers.X5), CfiType.RET, CfiType.JALR) 510 ) 511 m.d.av_comb += ret.cfi_offset.eq(iimm) 512 with m.Default(): 513 m.d.av_comb += ret.cfi_type.eq(CfiType.INVALID) 514 515 with m.If(quadrant != 0b11): 516 m.d.av_comb += ret.cfi_type.eq(CfiType.INVALID) 517 518 m.d.av_comb += ret.unsafe.eq( 519 (opcode == Opcode.SYSTEM) | ((opcode == Opcode.MISC_MEM) & (funct3 == Funct3.FENCEI)) 520 ) 521 522 return ret 523 524 return m 525 526 527 class PredictionChecker(Elaboratable): 528 """Branch prediction checker 529 530 This module checks if branch predictions are correct by looking at predecoded data. 531 It checks for the following errors: 532 533 - a JAL/JALR instruction was not predicted taken, 534 - mistaking non-control flow instructions (CFI) for control flow ones, 535 - getting the target of JAL/BRANCH instructions wrong. 536 """ 537 538 def __init__(self, gen_params: GenParams) -> None: 539 """ 540 Parameters 541 ---------- 542 gen_params: GenParams 543 Core generation parameters. 544 """ 545 self.gen_params = gen_params 546 547 layouts = gen_params.get(FetchLayouts) 548 549 self.check = Method(i=layouts.pred_checker_i, o=layouts.pred_checker_o) 550 551 self.perf_preceding_redirection = TaggedCounter( 552 "frontend.fetch.pred_checker.preceding_redirection", 553 "Number of redirections caused by undetected CFIs", 554 tags=CfiType, 555 ) 556 self.perf_mispredicted_cfi_type = TaggedCounter( 557 "frontend.fetch.pred_checker.cfi_type_mispredict", 558 "Number of redirections caused by misprediction of the CFI type", 559 tags=CfiType, 560 ) 561 self.perf_mispredicted_cfi_target = TaggedCounter( 562 "frontend.fetch.pred_checker.cfi_target_mispredict", 563 "Number of redirections caused by misprediction of the CFI target", 564 tags=CfiType, 565 ) 566 567 def elaborate(self, platform): 568 m = TModule() 569 570 params = self.gen_params.get(FrontendParams) 571 572 m.submodules += [ 573 self.perf_mispredicted_cfi_type, 574 self.perf_preceding_redirection, 575 self.perf_mispredicted_cfi_target, 576 ] 577 578 @def_method(m, self.check) 579 def _(fb_addr, instr_block_cross, instr_valid, predecoded, prediction): 580 decoded_cfi_types = Array([predecoded[i].cfi_type for i in range(self.gen_params.fetch_width)]) 581 decoded_cfi_offsets = Array([predecoded[i].cfi_offset for i in range(self.gen_params.fetch_width)]) 582 583 # First find all the instructions that would redirect the fetch unit. 584 decoded_redirections = Signal(self.gen_params.fetch_width) 585 for i in range(self.gen_params.fetch_width): 586 # Here we make a static prediction: forward branches not taken and backward 587 # taken. This prediction will be used if the branch prediction unit 588 # didn't detect the branch at all. 589 m.d.av_comb += decoded_redirections[i].eq( 590 CfiType.is_jal(decoded_cfi_types[i]) 591 | CfiType.is_jalr(decoded_cfi_types[i]) 592 | ( 593 CfiType.is_branch(decoded_cfi_types[i]) 594 & ~prediction.branch_mask[i] 595 & (decoded_cfi_offsets[i] < 0) 596 ) 597 ) 598 599 # Find the earliest one 600 m.submodules.pd_redirection_enc = pd_redirection_enc = PriorityEncoder(self.gen_params.fetch_width) 601 m.d.av_comb += pd_redirection_enc.i.eq(decoded_redirections & instr_valid) 602 603 pd_redirect_idx = Signal(self.gen_params.fetch_width_log) 604 m.d.av_comb += pd_redirect_idx.eq(pd_redirection_enc.o[: self.gen_params.fetch_width_log]) 605 606 # For a given instruction index, returns a CFI target based on the predecode info 607 def get_decoded_target_for(idx: Value) -> Value: 608 base = params.pc_from_fb(fb_addr, idx) + decoded_cfi_offsets[idx] 609 if Extension.ZCA in self.gen_params.isa.extensions: 610 return base - Mux(instr_block_cross & (idx == 0), 2, 0) 611 return base 612 613 # Target of a CFI that would redirect the frontend according to the prediction 614 decoded_target_for_predicted_cfi = Signal(self.gen_params.isa.xlen) 615 m.d.av_comb += decoded_target_for_predicted_cfi.eq(get_decoded_target_for(prediction.cfi_idx)) 616 617 # Target of a CFI that would redirect the frontend according to predecode info 618 decoded_target_for_decoded_cfi = Signal(self.gen_params.isa.xlen) 619 m.d.av_comb += decoded_target_for_decoded_cfi.eq(get_decoded_target_for(pd_redirect_idx)) 620 621 preceding_redirection = ~pd_redirection_enc.n & ( 622 ((CfiType.valid(prediction.cfi_type) & (pd_redirect_idx < prediction.cfi_idx))) 623 | ~CfiType.valid(prediction.cfi_type) 624 ) 625 626 mispredicted_cfi_type = CfiType.valid(prediction.cfi_type) & ( 627 prediction.cfi_type != decoded_cfi_types[prediction.cfi_idx] 628 ) 629 630 mispredicted_cfi_target = (CfiType.is_branch(prediction.cfi_type) | CfiType.is_jal(prediction.cfi_type)) & ( 631 ~prediction.cfi_target_valid | (decoded_target_for_predicted_cfi != prediction.cfi_target) 632 ) 633 634 ret = Signal.like(self.check.data_out) 635 636 with m.If(preceding_redirection): 637 self.perf_preceding_redirection.incr(m, decoded_cfi_types[pd_redirect_idx]) 638 m.d.av_comb += assign( 639 ret, 640 { 641 "mispredicted": 1, 642 "stall": CfiType.is_jalr(decoded_cfi_types[pd_redirect_idx]), 643 "fb_instr_idx": pd_redirect_idx, 644 "redirect_target": decoded_target_for_decoded_cfi, 645 }, 646 ) 647 with m.Elif(mispredicted_cfi_type): 648 self.perf_mispredicted_cfi_type.incr(m, prediction.cfi_type) 649 fallthrough_addr = params.pc_from_fb(fb_addr + 1, 0) 650 m.d.av_comb += assign( 651 ret, 652 { 653 "mispredicted": 1, 654 "stall": CfiType.is_jalr(decoded_cfi_types[pd_redirect_idx]), 655 "fb_instr_idx": Mux(pd_redirection_enc.n, self.gen_params.fetch_width - 1, pd_redirect_idx), 656 "redirect_target": Mux(pd_redirection_enc.n, fallthrough_addr, decoded_target_for_decoded_cfi), 657 }, 658 ) 659 with m.Elif(mispredicted_cfi_target): 660 self.perf_mispredicted_cfi_target.incr(m, prediction.cfi_type) 661 m.d.av_comb += assign( 662 ret, 663 { 664 "mispredicted": 1, 665 "fb_instr_idx": prediction.cfi_idx, 666 "redirect_target": decoded_target_for_predicted_cfi, 667 }, 668 ) 669 670 return ret 671 672 return m