/ coreblocks / func_blocks / fu / jumpbranch.py
jumpbranch.py
  1  from dataclasses import dataclass
  2  from amaranth import *
  3  
  4  from enum import IntFlag, auto
  5  
  6  from typing import Sequence
  7  
  8  from transactron import *
  9  from transactron.core import def_method
 10  from transactron.lib import *
 11  from transactron.lib import logging
 12  from transactron.utils import DependencyContext, from_method_layout
 13  from coreblocks.params import GenParams, FunctionalComponentParams
 14  from coreblocks.arch import Funct3, OpType, ExceptionCause, Extension
 15  from coreblocks.interface.layouts import JumpBranchLayouts, CommonLayoutFields
 16  from coreblocks.interface.keys import (
 17      AsyncInterruptInsertSignalKey,
 18      BranchVerifyKey,
 19      ExceptionReportKey,
 20      PredictedJumpTargetKey,
 21  )
 22  from transactron.utils import OneHotSwitch
 23  from transactron.utils.transactron_helpers import make_layout
 24  from coreblocks.func_blocks.interface.func_protocols import FuncUnit
 25  from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase
 26  
 27  __all__ = ["JumpBranchFuncUnit", "JumpComponent"]
 28  
 29  
 30  log = logging.HardwareLogger("backend.fu.jumpbranch")
 31  
 32  
 33  class JumpBranchFn(DecoderManager):
 34      class Fn(IntFlag):
 35          JAL = auto()
 36          JALR = auto()
 37          AUIPC = auto()
 38          BEQ = auto()
 39          BNE = auto()
 40          BLT = auto()
 41          BLTU = auto()
 42          BGE = auto()
 43          BGEU = auto()
 44  
 45      def get_instructions(self) -> Sequence[tuple]:
 46          return [
 47              (self.Fn.BEQ, OpType.BRANCH, Funct3.BEQ),
 48              (self.Fn.BNE, OpType.BRANCH, Funct3.BNE),
 49              (self.Fn.BLT, OpType.BRANCH, Funct3.BLT),
 50              (self.Fn.BLTU, OpType.BRANCH, Funct3.BLTU),
 51              (self.Fn.BGE, OpType.BRANCH, Funct3.BGE),
 52              (self.Fn.BGEU, OpType.BRANCH, Funct3.BGEU),
 53              (self.Fn.JAL, OpType.JAL),
 54              (self.Fn.JALR, OpType.JALR, Funct3.JALR),
 55              (self.Fn.AUIPC, OpType.AUIPC),
 56          ]
 57  
 58  
 59  class JumpBranch(Elaboratable):
 60      def __init__(self, gen_params: GenParams, fn=JumpBranchFn()):
 61          self.gen_params = gen_params
 62  
 63          xlen = gen_params.isa.xlen
 64          self.fn = fn.get_function()
 65          self.in1 = Signal(xlen)
 66          self.in2 = Signal(xlen)
 67          self.in_pc = Signal(xlen)
 68          self.in_imm = Signal(xlen)
 69          self.in_rvc = Signal()
 70          self.jmp_addr = Signal(xlen)
 71          self.reg_res = Signal(xlen)
 72          self.taken = Signal()
 73  
 74      def elaborate(self, platform):
 75          m = Module()
 76  
 77          m.d.comb += self.jmp_addr.eq(self.in_pc + self.in_imm)
 78          m.d.comb += self.reg_res.eq(self.in_pc + 4)
 79  
 80          if Extension.ZCA in self.gen_params.isa.extensions:
 81              with m.If(self.in_rvc):
 82                  m.d.comb += self.reg_res.eq(self.in_pc + 2)
 83  
 84          with OneHotSwitch(m, self.fn) as OneHotCase:
 85              with OneHotCase(JumpBranchFn.Fn.JAL):
 86                  m.d.comb += self.taken.eq(1)
 87              with OneHotCase(JumpBranchFn.Fn.JALR):
 88                  m.d.comb += self.jmp_addr.eq(self.in1 + self.in_imm)
 89                  m.d.comb += self.jmp_addr[0].eq(0)
 90                  m.d.comb += self.taken.eq(1)
 91              with OneHotCase(JumpBranchFn.Fn.AUIPC):
 92                  m.d.comb += self.reg_res.eq(self.jmp_addr)
 93              with OneHotCase(JumpBranchFn.Fn.BEQ):
 94                  m.d.comb += self.taken.eq(self.in1 == self.in2)
 95              with OneHotCase(JumpBranchFn.Fn.BNE):
 96                  m.d.comb += self.taken.eq(self.in1 != self.in2)
 97              with OneHotCase(JumpBranchFn.Fn.BLT):
 98                  m.d.comb += self.taken.eq(self.in1.as_signed() < self.in2.as_signed())
 99              with OneHotCase(JumpBranchFn.Fn.BLTU):
100                  m.d.comb += self.taken.eq(self.in1.as_unsigned() < self.in2.as_unsigned())
101              with OneHotCase(JumpBranchFn.Fn.BGE):
102                  m.d.comb += self.taken.eq(self.in1.as_signed() >= self.in2.as_signed())
103              with OneHotCase(JumpBranchFn.Fn.BGEU):
104                  m.d.comb += self.taken.eq(self.in1.as_unsigned() >= self.in2.as_unsigned())
105  
106          return m
107  
108  
109  class JumpBranchFuncUnit(FuncUnitBase[JumpBranchFn]):
110      def __init__(self, gen_params: GenParams, fn=JumpBranchFn()):
111          super().__init__(gen_params, fn)
112  
113          self.fifo_branch_resolved = FIFO(self.gen_params.get(JumpBranchLayouts).verify_branch, 2)
114  
115          self.dm = DependencyContext.get()
116          self.dm.add_dependency(BranchVerifyKey(), self.fifo_branch_resolved.read)
117  
118          self.perf_misaligned = HwCounter(
119              "backend.fu.jumpbranch.misaligned", "Number of instructions with misaligned target address"
120          )
121          self.perf_mispredictions = HwCounter("backend.fu.jumpbranch.mispredictions", "Number of branch mispredictions")
122  
123          self.exception_report = self.dm.get_dependency(ExceptionReportKey())()
124  
125      def elaborate(self, platform):
126          m = super().elaborate(platform)
127  
128          m.submodules += [
129              self.perf_misaligned,
130              self.perf_mispredictions,
131          ]
132  
133          jump_target_req, jump_target_resp = self.dm.get_dependency(PredictedJumpTargetKey())
134  
135          m.submodules.jb = jb = JumpBranch(self.gen_params, fn=self.fn)
136          m.submodules.fifo_branch_resolved = self.fifo_branch_resolved
137  
138          fields = self.gen_params.get(CommonLayoutFields)
139          instr_fifo_layout = make_layout(
140              fields.rob_id,
141              fields.pc,
142              fields.rp_dst,
143              ("type", JumpBranchFn.Fn),
144              ("jmp_addr", self.gen_params.isa.xlen),
145              ("reg_res", self.gen_params.isa.xlen),
146              ("taken", 1),
147              fields.predicted_taken,
148              fields.tag,
149          )
150          m.submodules.instr_fifo = instr_fifo = BasicFifo(instr_fifo_layout, 2)
151  
152          with Transaction().body(m):
153              instr = instr_fifo.read(m)
154              target_prediction = jump_target_resp(m)
155  
156              jump_result = Mux(instr.taken, instr.jmp_addr, instr.reg_res)
157              is_auipc = instr.type == JumpBranchFn.Fn.AUIPC
158  
159              predicted_addr_correctly = (instr.type != JumpBranchFn.Fn.JALR) | (
160                  target_prediction.valid & (target_prediction.cfi_target == instr.jmp_addr)
161              )
162  
163              misprediction = Signal()
164              m.d.av_comb += misprediction.eq(
165                  ~(is_auipc | (predicted_addr_correctly & (instr.taken == instr.predicted_taken)))
166              )
167              self.perf_mispredictions.incr(m, enable_call=misprediction)
168  
169              jmp_addr_misaligned = (
170                  instr.jmp_addr & (0b1 if Extension.ZCA in self.gen_params.isa.extensions else 0b11)
171              ) != 0
172  
173              async_interrupt_active = self.dm.get_dependency(AsyncInterruptInsertSignalKey())
174  
175              exception = Signal()
176  
177              with m.If(~is_auipc & instr.taken & jmp_addr_misaligned):
178                  self.perf_misaligned.incr(m)
179                  # Spec: "[...] if the target address is not four-byte aligned. This exception is reported on the branch
180                  # or jump instruction, not on the target instruction. No instruction-address-misaligned exception is
181                  # generated for a conditional branch that is not taken."
182                  m.d.comb += exception.eq(1)
183                  self.exception_report(
184                      m,
185                      rob_id=instr.rob_id,
186                      cause=ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED,
187                      pc=instr.pc,
188                      mtval=instr.jmp_addr,
189                  )
190  
191              with m.Elif(async_interrupt_active & ~is_auipc):
192                  # Jump instructions are entry points for async interrupts.
193                  # This way we can store known pc via report to global exception register and avoid it in ROB.
194                  # Exceptions have priority, because the instruction that reports async interrupt is commited
195                  # and exception would be lost.
196                  m.d.comb += exception.eq(1)
197                  self.exception_report(
198                      m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=jump_result, mtval=0
199                  )
200              with m.Elif(misprediction):
201                  # Async interrupts can have priority, because `jump_result` is handled in the same way.
202                  # No extra misprediction penalty will be introducted at interrupt return to `jump_result` address.
203                  m.d.comb += exception.eq(1)
204                  self.exception_report(
205                      m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_MISPREDICTION, pc=jump_result, mtval=0
206                  )
207  
208              with m.If(~is_auipc):
209                  self.fifo_branch_resolved.write(m, from_pc=instr.pc, next_pc=jump_result, misprediction=misprediction)
210                  log.debug(
211                      m,
212                      True,
213                      "branch resolved from 0x{:08x} to 0x{:08x}; misprediction: {}",
214                      instr.pc,
215                      jump_result,
216                      misprediction,
217                  )
218  
219              self.push_result(
220                  m,
221                  rob_id=instr.rob_id,
222                  result=instr.reg_res,
223                  rp_dst=instr.rp_dst,
224                  exception=exception,
225              )
226  
227          @def_method(m, self.issue_decoded)
228          def _(arg):
229              m.d.top_comb += jb.fn.eq(arg.decode_fn)
230              m.d.top_comb += jb.in1.eq(arg.s1_val)
231              m.d.top_comb += jb.in2.eq(arg.s2_val)
232              m.d.top_comb += jb.in_pc.eq(arg.pc)
233              m.d.top_comb += jb.in_imm.eq(arg.imm)
234  
235              funct7_info = Signal(from_method_layout(self.gen_params.get(JumpBranchLayouts).funct7_info))
236              m.d.top_comb += funct7_info.eq(arg.exec_fn.funct7)
237              m.d.top_comb += jb.in_rvc.eq(funct7_info.rvc)
238  
239              jump_target_req(m)
240  
241              instr_fifo.write(
242                  m,
243                  rob_id=arg.rob_id,
244                  pc=arg.pc,
245                  rp_dst=arg.rp_dst,
246                  type=arg.decode_fn,
247                  jmp_addr=jb.jmp_addr,
248                  reg_res=jb.reg_res,
249                  taken=jb.taken,
250                  predicted_taken=funct7_info.predicted_taken,
251                  tag=arg.tag,
252              )
253  
254          return m
255  
256  
257  @dataclass(frozen=True)
258  class JumpComponent(FunctionalComponentParams):
259      decoder_manager: JumpBranchFn = JumpBranchFn()
260  
261      def get_module(self, gen_params: GenParams) -> FuncUnit:
262          return JumpBranchFuncUnit(gen_params, self.decoder_manager)