/ coreblocks / frontend / fetch / fetch.py
fetch.py
  1  from math import lcm
  2  from amaranth import *
  3  from amaranth.lib.data import ArrayLayout
  4  from transactron.lib import BasicFifo, WideFifo, Semaphore, logging, Pipe
  5  from transactron.lib.metrics import *
  6  from transactron.lib.simultaneous import condition
  7  from transactron.utils import count_trailing_zeros, popcount, assign, StableSelectingNetwork
  8  from transactron.utils.transactron_helpers import make_layout
  9  from transactron.utils.amaranth_ext.coding import PriorityEncoder
 10  from transactron import *
 11  
 12  from coreblocks.cache.iface import CacheInterface
 13  from coreblocks.frontend.decoder.rvc import InstrDecompress, is_instr_compressed
 14  from coreblocks.priv.pmp import PMPChecker, PMPOperationMode
 15  
 16  from coreblocks.arch import *
 17  from coreblocks.params import *
 18  from coreblocks.interface.layouts import *
 19  from coreblocks.frontend import FrontendParams
 20  from coreblocks.priv.vmem.translation import AddressTranslator, AddressTranslatorMode
 21  
 22  log = logging.HardwareLogger("frontend.fetch")
 23  
 24  
 25  class FetchUnit(Elaboratable):
 26      """Superscalar Fetch Unit
 27  
 28      This module is responsible for retrieving instructions from memory and forwarding them to the decode stage.
 29  
 30      It works with 'fetch blocks', chunks of data it handles at a time. The size of these blocks
 31      depends on GenParams.fetch_block_bytes and is related to how many instructions the unit can
 32      handle at once, which can vary if extension C is on.
 33  
 34      The unit also deals with expanding compressed instructions and managing instructions that aren't aligned to
 35      4-byte boundaries.
 36      """
 37  
 38      fetch_request: Provided[Method]
 39      """Requests a fetch of the instruction block at the given PC."""
 40  
 41      fetch_writeback: Required[Method]
 42      """Invoked to write back the status of the requested fetch block."""
 43  
 44      flush: Provided[Method]
 45      """Flushes the fetch unit from the currently processed fetch blocks, so it can be redirected or/and stalled."""
 46  
 47      cont: Required[Method]
 48      """Should be invoked to send fetched instruction to the next step."""
 49  
 50      stall_unsafe: Required[Method]
 51      """Called when an unsafe instruction is fetched."""
 52  
 53      def __init__(self, gen_params: GenParams, icache: CacheInterface) -> None:
 54          """
 55          Parameters
 56          ----------
 57          gen_params : GenParams
 58              Instance of GenParams with parameters which should be used to generate
 59              fetch unit.
 60          icache : CacheInterface
 61              Instruction Cache
 62          """
 63          self.gen_params = gen_params
 64          self.icache = icache
 65          self.stall_unsafe = Method()
 66  
 67          self.layouts = self.gen_params.get(FetchLayouts)
 68  
 69          self.cont = Method(i=self.layouts.fetch_result)
 70          self.fetch_request = Method(i=self.layouts.fetch_request)
 71          self.fetch_writeback = Method(i=self.layouts.fetch_writeback)
 72  
 73          self.flush = Method()
 74  
 75          self.perf_fetch_utilization = TaggedCounter(
 76              "frontend.fetch.fetch_block_util",
 77              "Number of valid instructions in fetch blocks",
 78              tags=range(self.gen_params.fetch_width + 1),
 79          )
 80  
 81      def elaborate(self, platform):
 82          m = TModule()
 83  
 84          m.submodules.addr_translator = addr_translator = AddressTranslator(
 85              self.gen_params, mode=AddressTranslatorMode.INSTRUCTION
 86          )
 87          m.submodules += [self.perf_fetch_utilization]
 88  
 89          fetch_width = self.gen_params.fetch_width
 90          fields = self.gen_params.get(CommonLayoutFields)
 91          params = self.gen_params.get(FrontendParams)
 92  
 93          # Serializer creates a continuous instruction stream from fetch
 94          # blocks, which can have holes in them.
 95          m.submodules.aligner = aligner = StableSelectingNetwork(fetch_width, self.layouts.raw_instr)
 96          serializer_depth = 2 * lcm(self.gen_params.frontend_superscalarity, fetch_width)
 97          m.submodules.serializer = serializer = WideFifo(
 98              self.layouts.raw_instr,
 99              depth=serializer_depth,
100              read_width=self.gen_params.frontend_superscalarity,
101              write_width=fetch_width,
102          )
103  
104          with Transaction(name="cont").body(m):
105              peek_result = serializer.peek(m)
106              count = Signal(range(self.gen_params.frontend_superscalarity + 1))
107              # we want at most one branch insn in scheduling group, and only at the end (for simplicity)
108              # some insts in peek_result.data might not be valid, but this is still correct
109              which_is_branch = [0] + [instr.cfi_type == CfiType.BRANCH for instr in peek_result.data][:-1]
110              m.d.comb += count.eq(count_trailing_zeros(Cat(which_is_branch)))
111              result = serializer.read(m, count=count)
112              for i in range(self.gen_params.frontend_superscalarity):
113                  log.info(
114                      m,
115                      i < result.count,
116                      "Sending an instr to the backend pc=0x{:x} instr=0x{:x}",
117                      result.data[i].pc,
118                      result.data[i].instr,
119                  )
120              self.cont(m, result)
121  
122          m.submodules.fetch_requests = fetch_requests = BasicFifo(
123              make_layout(fields.pc, ("access_fault", 1), ("page_fault", 1)),
124              depth=2,
125          )
126  
127          # This limits number of fetch blocks the fetch unit can process
128          # at a time. We start counting when sending a request to the cache and
129          # stop when pushing a fetch packet out of the fetch unit.
130          m.submodules.req_counter = req_counter = Semaphore(4)
131          flushing_counter = Signal.like(req_counter.count)
132  
133          flush_now = Signal()
134  
135          def flush():
136              m.d.comb += flush_now.eq(1)
137  
138          #
139          # Fetch - stage 0
140          # ================
141          # - send a request to the instruction cache
142          # - check PMP execute permission (if PMP is enabled)
143          #
144  
145          m.submodules.pmp_checker = pmp_checker = PMPChecker(
146              self.gen_params,
147              mode=PMPOperationMode.INSTRUCTION_FETCH,
148          )
149  
150          @def_method(m, self.fetch_request)
151          def _(pc):
152              log.info(m, True, "[IFU] request pc=0x{:x}", pc)
153              req_counter.acquire(m)
154  
155              addr_translator.request(m, addr=pc)
156  
157          with Transaction().body(m):
158              translated = addr_translator.accept(m)
159              access_fault = Signal()
160  
161              m.d.av_comb += pmp_checker.paddr.eq(translated.paddr)
162              m.d.av_comb += access_fault.eq(translated.access_fault | ~pmp_checker.result.x)
163  
164              with m.If(~translated.page_fault & ~access_fault):
165                  self.icache.issue_req(m, paddr=translated.paddr)
166  
167              fetch_requests.write(
168                  m,
169                  pc=translated.vaddr,
170                  access_fault=access_fault,
171                  page_fault=translated.page_fault,
172              )
173  
174          #
175          # State passed between stage 1 and stage 2
176          #
177          m.submodules.s1_s2_pipe = s1_s2_pipe = Pipe(
178              [
179                  fields.fb_addr,
180                  ("instr_valid", fetch_width),
181                  ("access_fault", FetchLayouts.FaultFlag),
182                  ("rvc", fetch_width),
183                  ("instrs", ArrayLayout(self.gen_params.isa.ilen, fetch_width)),
184                  ("instr_block_cross", 1),
185              ]
186          )
187  
188          #
189          # Fetch - stage 1
190          # ================
191          # - read the response from the cache
192          # - expand compressed instructions (if applicable)
193          # - find where each instruction begins
194          # - handle instructions that cross a fetch boundary
195          #
196          rvc_expanders = [InstrDecompress(self.gen_params) for _ in range(fetch_width)]
197          for n, module in enumerate(rvc_expanders):
198              m.submodules[f"rvc_expander_{n}"] = module
199  
200          # With the C extension enabled, a single instruction can
201          # be located on a boundary of two fetch blocks. Hence,
202          # this requires some statefulness of the stage 1.
203          prev_half = Signal(16)
204          prev_half_addr = Signal(make_layout(fields.fb_addr).size)
205          prev_half_v = Signal()
206          with Transaction(name="Fetch_Stage1").body(m):
207              fetch_request = fetch_requests.read(m)
208  
209              # The address of the fetch block.
210              fetch_block_addr = params.fb_addr(fetch_request.pc)
211              # The index (in instructions) of the first instruction that we should process.
212              fetch_block_offset = params.fb_instr_idx(fetch_request.pc)
213  
214              # Conditionally read from icache or mark fault
215              cache_resp = Signal(self.gen_params.get(ICacheLayouts).accept_res)
216              access_fault = Signal(FetchLayouts.FaultFlag)
217  
218              with condition(m) as branch:
219                  with branch(fetch_request.page_fault | fetch_request.access_fault):
220                      with m.If(fetch_request.page_fault):
221                          m.d.av_comb += access_fault.eq(FetchLayouts.FaultFlag.PAGE_FAULT)
222                      with m.Else():
223                          m.d.av_comb += access_fault.eq(FetchLayouts.FaultFlag.ACCESS_FAULT)
224                      m.d.av_comb += cache_resp.fetch_block.eq(0)
225                      m.d.av_comb += cache_resp.error.eq(0)
226                  with branch():
227                      m.d.av_comb += cache_resp.eq(self.icache.accept_res(m))
228                      m.d.av_comb += access_fault.eq(Mux(cache_resp.error, FetchLayouts.FaultFlag.ACCESS_FAULT, 0))
229  
230              #
231              # Expand compressed instructions from the fetch block.
232              #
233              expanded_instr = [Signal(self.gen_params.isa.ilen) for _ in range(fetch_width)]
234              is_rvc = Signal(fetch_width)
235  
236              # Whether in this cycle we have a fetch block that contains
237              # an instruction that crosses a fetch boundary
238              instr_block_cross = Signal()
239              m.d.av_comb += instr_block_cross.eq(prev_half_v & ((prev_half_addr + 1) == fetch_block_addr))
240  
241              for i in range(fetch_width):
242                  if Extension.ZCA in self.gen_params.isa.extensions:
243                      full_instr = Signal(self.gen_params.isa.ilen)
244                      if i == 0:
245                          # If we have a half of an instruction from the previous block - we need to use it now.
246                          with m.If(instr_block_cross):
247                              m.d.av_comb += full_instr.eq(Cat(prev_half, cache_resp.fetch_block[0:16]))
248                          with m.Else():
249                              m.d.av_comb += full_instr.eq(cache_resp.fetch_block[:32])
250                      elif i == fetch_width - 1:
251                          # We will have only 16 bits for the last instruction, so append 16 zeroes.
252                          m.d.av_comb += full_instr.eq(Cat(cache_resp.fetch_block[-16:], C(0, 16)))
253                      else:
254                          m.d.av_comb += full_instr.eq(cache_resp.fetch_block[i * 16 : i * 16 + 32])
255  
256                      m.d.av_comb += is_rvc[i].eq(is_instr_compressed(full_instr))
257                      m.d.av_comb += rvc_expanders[i].instr_in.eq(full_instr[:16])
258                      m.d.av_comb += expanded_instr[i].eq(Mux(is_rvc[i], rvc_expanders[i].instr_out, full_instr))
259                  else:
260                      m.d.av_comb += expanded_instr[i].eq(cache_resp.fetch_block[i * 32 : (i + 1) * 32])
261  
262              # Mask denoting at which offsets expected instructions start (depends on rvc indication and start address)
263              instr_start = [Signal() for _ in range(fetch_width)]
264              for i in range(fetch_width):
265                  if Extension.ZCA in self.gen_params.isa.extensions:
266                      if i == 0:
267                          m.d.av_comb += instr_start[i].eq(fetch_block_offset == 0)
268                      elif i == 1:
269                          m.d.av_comb += instr_start[i].eq(
270                              (fetch_block_offset <= i) & (~instr_start[0] | is_rvc[0] | instr_block_cross)
271                          )
272                      else:
273                          m.d.av_comb += instr_start[i].eq(
274                              (fetch_block_offset <= i) & (~instr_start[i - 1] | is_rvc[i - 1])
275                          )
276                  else:
277                      m.d.av_comb += instr_start[i].eq(fetch_block_offset <= i)
278  
279              if Extension.ZCA in self.gen_params.isa.extensions:
280                  instr_position_mask = Cat(instr_start[:-1], instr_start[-1] & is_rvc[-1])
281  
282                  m.d.sync += prev_half_v.eq(
283                      (flushing_counter <= 1) & ~access_fault.any() & ~is_rvc[-1] & instr_start[-1]
284                  )
285                  m.d.sync += prev_half.eq(cache_resp.fetch_block[-16:])
286                  m.d.sync += prev_half_addr.eq(fetch_block_addr)
287              else:
288                  instr_position_mask = Cat(instr_start)
289  
290              # Reported fault pc (signalled by emitting an instruction) must always match first requested instruction
291              access_fault_instr_position = 1 << fetch_block_offset
292  
293              s1_s2_pipe.write(
294                  m,
295                  fb_addr=fetch_block_addr,
296                  instr_valid=Mux(access_fault.any(), access_fault_instr_position, instr_position_mask),
297                  access_fault=access_fault,
298                  rvc=is_rvc,
299                  instrs=expanded_instr,
300                  instr_block_cross=instr_block_cross,
301              )
302  
303          # Make sure to clean the state
304          with m.If(flush_now):
305              m.d.sync += prev_half_v.eq(0)
306  
307          #
308          # Fetch - stage 2
309          # ================
310          # - predecode instructions
311          # - verify the branch prediction
312          # - redirect the frontend if mispredicted
313          # - check if any of instructions stalls the frontend
314          # - enqueue a packet of instructions
315          #
316  
317          predecoders = [Predecoder(self.gen_params) for _ in range(fetch_width)]
318          for n, module in enumerate(predecoders):
319              m.submodules[f"predecoder_{n}"] = module
320  
321          m.submodules.prediction_checker = prediction_checker = PredictionChecker(self.gen_params)
322  
323          with Transaction(name="Fetch_Stage2").body(m):
324              req_counter.release(m)
325              s1_data = s1_s2_pipe.read(m)
326  
327              instrs = s1_data.instrs
328              fetch_block_addr = s1_data.fb_addr
329              instr_valid = s1_data.instr_valid
330              access_fault = s1_data.access_fault
331              fault_any = Signal()
332              m.d.av_comb += fault_any.eq(access_fault.any())
333  
334              # Predecode instructions
335              predecoded_instr = [predecoders[i].predecode(m, instrs[i]) for i in range(fetch_width)]
336  
337              # No prediction for now
338              prediction = Signal(self.layouts.bpu_prediction)
339  
340              # The method is guarded by the If to make sure that the metrics
341              # are updated only if not flushing.
342              with m.If(flushing_counter == 0):
343                  predcheck_res = prediction_checker.check(
344                      m,
345                      fb_addr=fetch_block_addr,
346                      instr_block_cross=s1_data.instr_block_cross,
347                      instr_valid=instr_valid,
348                      predecoded=predecoded_instr,
349                      prediction=prediction,
350                  )
351  
352              # Is the instruction unsafe (i.e. stalls the frontend until the backend resumes it).
353              instr_unsafe = Signal(fetch_width)
354              for i in range(fetch_width):
355                  # If there was an access fault, mark every instruction as unsafe
356                  m.d.av_comb += instr_unsafe[i].eq((predecoded_instr[i].unsafe | fault_any) & instr_valid[i])
357  
358              m.submodules.unsafe_prio_encoder = unsafe_prio_encoder = PriorityEncoder(fetch_width)
359              m.d.av_comb += unsafe_prio_encoder.i.eq(instr_unsafe)
360  
361              unsafe_idx = unsafe_prio_encoder.o[: self.gen_params.fetch_width_log]
362              has_unsafe = Signal()
363              m.d.av_comb += has_unsafe.eq(~unsafe_prio_encoder.n)
364  
365              redirect_before_unsafe = Signal()
366              m.d.av_comb += redirect_before_unsafe.eq(predcheck_res.fb_instr_idx < unsafe_idx)
367  
368              redirect = Signal()
369              unsafe_stall = Signal()
370              redirect_or_unsafe_idx = Signal(range(fetch_width))
371  
372              with m.If(predcheck_res.mispredicted & (~has_unsafe | redirect_before_unsafe)):
373                  m.d.av_comb += [
374                      redirect.eq(~predcheck_res.stall),
375                      unsafe_stall.eq(predcheck_res.stall),
376                      redirect_or_unsafe_idx.eq(predcheck_res.fb_instr_idx),
377                  ]
378              with m.Elif(has_unsafe):
379                  m.d.av_comb += [
380                      unsafe_stall.eq(1),
381                      redirect_or_unsafe_idx.eq(unsafe_idx),
382                  ]
383  
384              # This mask denotes what prefix of instructions we should enqueue.
385              valid_instr_prefix = Signal(fetch_width)
386              with m.If(redirect | unsafe_stall):
387                  # If there is an instruction that redirects or stalls the frontend, enqueue
388                  # instructions only up to that instruction.
389                  m.d.av_comb += valid_instr_prefix.eq((1 << (redirect_or_unsafe_idx + 1)) - 1)
390              with m.Else():
391                  m.d.av_comb += valid_instr_prefix.eq(C(1).replicate(fetch_width))
392  
393              # The ultimate mask that tells which instructions should be sent to the backend.
394              fetch_mask = Signal(fetch_width)
395              m.d.av_comb += fetch_mask.eq(instr_valid & valid_instr_prefix)
396  
397              # Aggregate all signals that will be sent out of the fetch unit.
398              raw_instrs = Signal(ArrayLayout(self.layouts.raw_instr, fetch_width))
399              for i in range(fetch_width):
400                  m.d.av_comb += [
401                      raw_instrs[i].instr.eq(instrs[i]),
402                      raw_instrs[i].pc.eq(params.pc_from_fb(fetch_block_addr, i)),
403                      raw_instrs[i].rvc.eq(s1_data.rvc[i]),
404                      raw_instrs[i].predicted_taken.eq(redirect & (predcheck_res.fb_instr_idx == i)),
405                      raw_instrs[i].access_fault.eq(s1_data.access_fault),
406                      raw_instrs[i].cfi_type.eq(predecoded_instr[i].cfi_type),
407                  ]
408  
409              if Extension.ZCA in self.gen_params.isa.extensions:
410                  with m.If(s1_data.instr_block_cross):
411                      m.d.av_comb += raw_instrs[0].pc.eq(params.pc_from_fb(fetch_block_addr, 0) - 2)
412                      with m.If(s1_data.access_fault):
413                          # Mark that access/page fault happened only at second (current) half.
414                          # If fault happened on the first half `instr_block_cross` would be false
415                          m.d.av_comb += raw_instrs[0].access_fault.eq(
416                              s1_data.access_fault | FetchLayouts.FaultFlag.EXCEPTION_ON_SECOND_HALF
417                          )
418  
419              with condition(m) as branch:
420                  with branch(flushing_counter == 0):
421                      with m.If(fault_any | unsafe_stall):
422                          self.stall_unsafe(m)
423  
424                      with m.If(fault_any | unsafe_stall | redirect):
425                          self.fetch_writeback(
426                              m,
427                              redirect=redirect,
428                              redirect_target=predcheck_res.redirect_target,
429                          )
430                          flush()
431  
432                      self.perf_fetch_utilization.incr(m, popcount(fetch_mask))
433  
434                      # Make sure this is called only once to avoid a huge mux on arguments
435                      m.d.av_comb += [aligner.valids.eq(fetch_mask), aligner.inputs.eq(raw_instrs)]
436                      serializer.write(m, data=aligner.outputs, count=aligner.output_cnt)
437                  with branch():
438                      m.d.sync += flushing_counter.eq(flushing_counter - 1)
439  
440          with m.If(flush_now):
441              m.d.sync += flushing_counter.eq(req_counter.count_next)
442  
443          @def_method(m, self.flush)
444          def _():
445              flush()
446              serializer.clear(m)
447  
448          return m
449  
450  
451  class Predecoder(Elaboratable):
452      """Instruction predecoder
453  
454      The module performs basic analysis on instructions. It identifies if an instruction
455      is a jump instruction, determines the type of jump, and finds the jump's target.
456  
457      Its role is to give quick feedback to the fetch unit and potentially the branch predictor
458      about the fetched instruction. This helps in redirecting the fetch unit promptly if needed.
459      """
460  
461      def __init__(self, gen_params: GenParams) -> None:
462          """
463          Parameters
464          ----------
465          gen_params: GenParams
466              Core generation parameters.
467          """
468          self.gen_params = gen_params
469  
470          layouts = self.gen_params.get(FetchLayouts)
471          fields = self.gen_params.get(CommonLayoutFields)
472  
473          self.predecode = Method(i=make_layout(fields.instr), o=layouts.predecoded_instr)
474  
475      def elaborate(self, platform):
476          m = TModule()
477  
478          @def_method(m, self.predecode)
479          def _(instr):
480              quadrant = instr[0:2]
481              opcode = instr[2:7]
482              funct3 = instr[12:15]
483              rd = instr[7:12]
484              rs1 = instr[15:20]
485  
486              bimm = Signal(signed(13))
487              jimm = Signal(signed(21))
488              iimm = Signal(signed(12))
489  
490              m.d.av_comb += [
491                  iimm.eq(instr[20:]),
492                  bimm.eq(Cat(0, instr[8:12], instr[25:31], instr[7], instr[31])),
493                  jimm.eq(Cat(0, instr[21:31], instr[20], instr[12:20], instr[31])),
494              ]
495  
496              ret = Signal.like(self.predecode.data_out)
497  
498              with m.Switch(opcode):
499                  with m.Case(Opcode.BRANCH):
500                      m.d.av_comb += ret.cfi_type.eq(CfiType.BRANCH)
501                      m.d.av_comb += ret.cfi_offset.eq(bimm)
502                  with m.Case(Opcode.JAL):
503                      m.d.av_comb += ret.cfi_type.eq(
504                          Mux((rd == Registers.X1) | (rd == Registers.X5), CfiType.CALL, CfiType.JAL)
505                      )
506                      m.d.av_comb += ret.cfi_offset.eq(jimm)
507                  with m.Case(Opcode.JALR):
508                      m.d.av_comb += ret.cfi_type.eq(
509                          Mux((rs1 == Registers.X1) | (rs1 == Registers.X5), CfiType.RET, CfiType.JALR)
510                      )
511                      m.d.av_comb += ret.cfi_offset.eq(iimm)
512                  with m.Default():
513                      m.d.av_comb += ret.cfi_type.eq(CfiType.INVALID)
514  
515              with m.If(quadrant != 0b11):
516                  m.d.av_comb += ret.cfi_type.eq(CfiType.INVALID)
517  
518              m.d.av_comb += ret.unsafe.eq(
519                  (opcode == Opcode.SYSTEM) | ((opcode == Opcode.MISC_MEM) & (funct3 == Funct3.FENCEI))
520              )
521  
522              return ret
523  
524          return m
525  
526  
527  class PredictionChecker(Elaboratable):
528      """Branch prediction checker
529  
530      This module checks if branch predictions are correct by looking at predecoded data.
531      It checks for the following errors:
532  
533       - a JAL/JALR instruction was not predicted taken,
534       - mistaking non-control flow instructions (CFI) for control flow ones,
535       - getting the target of JAL/BRANCH instructions wrong.
536      """
537  
538      def __init__(self, gen_params: GenParams) -> None:
539          """
540          Parameters
541          ----------
542          gen_params: GenParams
543              Core generation parameters.
544          """
545          self.gen_params = gen_params
546  
547          layouts = gen_params.get(FetchLayouts)
548  
549          self.check = Method(i=layouts.pred_checker_i, o=layouts.pred_checker_o)
550  
551          self.perf_preceding_redirection = TaggedCounter(
552              "frontend.fetch.pred_checker.preceding_redirection",
553              "Number of redirections caused by undetected CFIs",
554              tags=CfiType,
555          )
556          self.perf_mispredicted_cfi_type = TaggedCounter(
557              "frontend.fetch.pred_checker.cfi_type_mispredict",
558              "Number of redirections caused by misprediction of the CFI type",
559              tags=CfiType,
560          )
561          self.perf_mispredicted_cfi_target = TaggedCounter(
562              "frontend.fetch.pred_checker.cfi_target_mispredict",
563              "Number of redirections caused by misprediction of the CFI target",
564              tags=CfiType,
565          )
566  
567      def elaborate(self, platform):
568          m = TModule()
569  
570          params = self.gen_params.get(FrontendParams)
571  
572          m.submodules += [
573              self.perf_mispredicted_cfi_type,
574              self.perf_preceding_redirection,
575              self.perf_mispredicted_cfi_target,
576          ]
577  
578          @def_method(m, self.check)
579          def _(fb_addr, instr_block_cross, instr_valid, predecoded, prediction):
580              decoded_cfi_types = Array([predecoded[i].cfi_type for i in range(self.gen_params.fetch_width)])
581              decoded_cfi_offsets = Array([predecoded[i].cfi_offset for i in range(self.gen_params.fetch_width)])
582  
583              # First find all the instructions that would redirect the fetch unit.
584              decoded_redirections = Signal(self.gen_params.fetch_width)
585              for i in range(self.gen_params.fetch_width):
586                  # Here we make a static prediction: forward branches not taken and backward
587                  # taken. This prediction will be used if the branch prediction unit
588                  # didn't detect the branch at all.
589                  m.d.av_comb += decoded_redirections[i].eq(
590                      CfiType.is_jal(decoded_cfi_types[i])
591                      | CfiType.is_jalr(decoded_cfi_types[i])
592                      | (
593                          CfiType.is_branch(decoded_cfi_types[i])
594                          & ~prediction.branch_mask[i]
595                          & (decoded_cfi_offsets[i] < 0)
596                      )
597                  )
598  
599              # Find the earliest one
600              m.submodules.pd_redirection_enc = pd_redirection_enc = PriorityEncoder(self.gen_params.fetch_width)
601              m.d.av_comb += pd_redirection_enc.i.eq(decoded_redirections & instr_valid)
602  
603              pd_redirect_idx = Signal(self.gen_params.fetch_width_log)
604              m.d.av_comb += pd_redirect_idx.eq(pd_redirection_enc.o[: self.gen_params.fetch_width_log])
605  
606              # For a given instruction index, returns a CFI target based on the predecode info
607              def get_decoded_target_for(idx: Value) -> Value:
608                  base = params.pc_from_fb(fb_addr, idx) + decoded_cfi_offsets[idx]
609                  if Extension.ZCA in self.gen_params.isa.extensions:
610                      return base - Mux(instr_block_cross & (idx == 0), 2, 0)
611                  return base
612  
613              # Target of a CFI that would redirect the frontend according to the prediction
614              decoded_target_for_predicted_cfi = Signal(self.gen_params.isa.xlen)
615              m.d.av_comb += decoded_target_for_predicted_cfi.eq(get_decoded_target_for(prediction.cfi_idx))
616  
617              # Target of a CFI that would redirect the frontend according to predecode info
618              decoded_target_for_decoded_cfi = Signal(self.gen_params.isa.xlen)
619              m.d.av_comb += decoded_target_for_decoded_cfi.eq(get_decoded_target_for(pd_redirect_idx))
620  
621              preceding_redirection = ~pd_redirection_enc.n & (
622                  ((CfiType.valid(prediction.cfi_type) & (pd_redirect_idx < prediction.cfi_idx)))
623                  | ~CfiType.valid(prediction.cfi_type)
624              )
625  
626              mispredicted_cfi_type = CfiType.valid(prediction.cfi_type) & (
627                  prediction.cfi_type != decoded_cfi_types[prediction.cfi_idx]
628              )
629  
630              mispredicted_cfi_target = (CfiType.is_branch(prediction.cfi_type) | CfiType.is_jal(prediction.cfi_type)) & (
631                  ~prediction.cfi_target_valid | (decoded_target_for_predicted_cfi != prediction.cfi_target)
632              )
633  
634              ret = Signal.like(self.check.data_out)
635  
636              with m.If(preceding_redirection):
637                  self.perf_preceding_redirection.incr(m, decoded_cfi_types[pd_redirect_idx])
638                  m.d.av_comb += assign(
639                      ret,
640                      {
641                          "mispredicted": 1,
642                          "stall": CfiType.is_jalr(decoded_cfi_types[pd_redirect_idx]),
643                          "fb_instr_idx": pd_redirect_idx,
644                          "redirect_target": decoded_target_for_decoded_cfi,
645                      },
646                  )
647              with m.Elif(mispredicted_cfi_type):
648                  self.perf_mispredicted_cfi_type.incr(m, prediction.cfi_type)
649                  fallthrough_addr = params.pc_from_fb(fb_addr + 1, 0)
650                  m.d.av_comb += assign(
651                      ret,
652                      {
653                          "mispredicted": 1,
654                          "stall": CfiType.is_jalr(decoded_cfi_types[pd_redirect_idx]),
655                          "fb_instr_idx": Mux(pd_redirection_enc.n, self.gen_params.fetch_width - 1, pd_redirect_idx),
656                          "redirect_target": Mux(pd_redirection_enc.n, fallthrough_addr, decoded_target_for_decoded_cfi),
657                      },
658                  )
659              with m.Elif(mispredicted_cfi_target):
660                  self.perf_mispredicted_cfi_target.incr(m, prediction.cfi_type)
661                  m.d.av_comb += assign(
662                      ret,
663                      {
664                          "mispredicted": 1,
665                          "fb_instr_idx": prediction.cfi_idx,
666                          "redirect_target": decoded_target_for_predicted_cfi,
667                      },
668                  )
669  
670              return ret
671  
672          return m