jumpbranch.py
1 from dataclasses import dataclass 2 from amaranth import * 3 4 from enum import IntFlag, auto 5 6 from typing import Sequence 7 8 from transactron import * 9 from transactron.core import def_method 10 from transactron.lib import * 11 from transactron.lib import logging 12 from transactron.utils import DependencyContext, from_method_layout 13 from coreblocks.params import GenParams, FunctionalComponentParams 14 from coreblocks.arch import Funct3, OpType, ExceptionCause, Extension 15 from coreblocks.interface.layouts import JumpBranchLayouts, CommonLayoutFields 16 from coreblocks.interface.keys import ( 17 AsyncInterruptInsertSignalKey, 18 BranchVerifyKey, 19 ExceptionReportKey, 20 PredictedJumpTargetKey, 21 ) 22 from transactron.utils import OneHotSwitch 23 from transactron.utils.transactron_helpers import make_layout 24 from coreblocks.func_blocks.interface.func_protocols import FuncUnit 25 from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase 26 27 __all__ = ["JumpBranchFuncUnit", "JumpComponent"] 28 29 30 log = logging.HardwareLogger("backend.fu.jumpbranch") 31 32 33 class JumpBranchFn(DecoderManager): 34 class Fn(IntFlag): 35 JAL = auto() 36 JALR = auto() 37 AUIPC = auto() 38 BEQ = auto() 39 BNE = auto() 40 BLT = auto() 41 BLTU = auto() 42 BGE = auto() 43 BGEU = auto() 44 45 def get_instructions(self) -> Sequence[tuple]: 46 return [ 47 (self.Fn.BEQ, OpType.BRANCH, Funct3.BEQ), 48 (self.Fn.BNE, OpType.BRANCH, Funct3.BNE), 49 (self.Fn.BLT, OpType.BRANCH, Funct3.BLT), 50 (self.Fn.BLTU, OpType.BRANCH, Funct3.BLTU), 51 (self.Fn.BGE, OpType.BRANCH, Funct3.BGE), 52 (self.Fn.BGEU, OpType.BRANCH, Funct3.BGEU), 53 (self.Fn.JAL, OpType.JAL), 54 (self.Fn.JALR, OpType.JALR, Funct3.JALR), 55 (self.Fn.AUIPC, OpType.AUIPC), 56 ] 57 58 59 class JumpBranch(Elaboratable): 60 def __init__(self, gen_params: GenParams, fn=JumpBranchFn()): 61 self.gen_params = gen_params 62 63 xlen = gen_params.isa.xlen 64 self.fn = fn.get_function() 65 self.in1 = Signal(xlen) 66 self.in2 = Signal(xlen) 67 self.in_pc = Signal(xlen) 68 self.in_imm = Signal(xlen) 69 self.in_rvc = Signal() 70 self.jmp_addr = Signal(xlen) 71 self.reg_res = Signal(xlen) 72 self.taken = Signal() 73 74 def elaborate(self, platform): 75 m = Module() 76 77 m.d.comb += self.jmp_addr.eq(self.in_pc + self.in_imm) 78 m.d.comb += self.reg_res.eq(self.in_pc + 4) 79 80 if Extension.ZCA in self.gen_params.isa.extensions: 81 with m.If(self.in_rvc): 82 m.d.comb += self.reg_res.eq(self.in_pc + 2) 83 84 with OneHotSwitch(m, self.fn) as OneHotCase: 85 with OneHotCase(JumpBranchFn.Fn.JAL): 86 m.d.comb += self.taken.eq(1) 87 with OneHotCase(JumpBranchFn.Fn.JALR): 88 m.d.comb += self.jmp_addr.eq(self.in1 + self.in_imm) 89 m.d.comb += self.jmp_addr[0].eq(0) 90 m.d.comb += self.taken.eq(1) 91 with OneHotCase(JumpBranchFn.Fn.AUIPC): 92 m.d.comb += self.reg_res.eq(self.jmp_addr) 93 with OneHotCase(JumpBranchFn.Fn.BEQ): 94 m.d.comb += self.taken.eq(self.in1 == self.in2) 95 with OneHotCase(JumpBranchFn.Fn.BNE): 96 m.d.comb += self.taken.eq(self.in1 != self.in2) 97 with OneHotCase(JumpBranchFn.Fn.BLT): 98 m.d.comb += self.taken.eq(self.in1.as_signed() < self.in2.as_signed()) 99 with OneHotCase(JumpBranchFn.Fn.BLTU): 100 m.d.comb += self.taken.eq(self.in1.as_unsigned() < self.in2.as_unsigned()) 101 with OneHotCase(JumpBranchFn.Fn.BGE): 102 m.d.comb += self.taken.eq(self.in1.as_signed() >= self.in2.as_signed()) 103 with OneHotCase(JumpBranchFn.Fn.BGEU): 104 m.d.comb += self.taken.eq(self.in1.as_unsigned() >= self.in2.as_unsigned()) 105 106 return m 107 108 109 class JumpBranchFuncUnit(FuncUnitBase[JumpBranchFn]): 110 def __init__(self, gen_params: GenParams, fn=JumpBranchFn()): 111 super().__init__(gen_params, fn) 112 113 self.fifo_branch_resolved = FIFO(self.gen_params.get(JumpBranchLayouts).verify_branch, 2) 114 115 self.dm = DependencyContext.get() 116 self.dm.add_dependency(BranchVerifyKey(), self.fifo_branch_resolved.read) 117 118 self.perf_misaligned = HwCounter( 119 "backend.fu.jumpbranch.misaligned", "Number of instructions with misaligned target address" 120 ) 121 self.perf_mispredictions = HwCounter("backend.fu.jumpbranch.mispredictions", "Number of branch mispredictions") 122 123 self.exception_report = self.dm.get_dependency(ExceptionReportKey())() 124 125 def elaborate(self, platform): 126 m = super().elaborate(platform) 127 128 m.submodules += [ 129 self.perf_misaligned, 130 self.perf_mispredictions, 131 ] 132 133 jump_target_req, jump_target_resp = self.dm.get_dependency(PredictedJumpTargetKey()) 134 135 m.submodules.jb = jb = JumpBranch(self.gen_params, fn=self.fn) 136 m.submodules.fifo_branch_resolved = self.fifo_branch_resolved 137 138 fields = self.gen_params.get(CommonLayoutFields) 139 instr_fifo_layout = make_layout( 140 fields.rob_id, 141 fields.pc, 142 fields.rp_dst, 143 ("type", JumpBranchFn.Fn), 144 ("jmp_addr", self.gen_params.isa.xlen), 145 ("reg_res", self.gen_params.isa.xlen), 146 ("taken", 1), 147 fields.predicted_taken, 148 fields.tag, 149 ) 150 m.submodules.instr_fifo = instr_fifo = BasicFifo(instr_fifo_layout, 2) 151 152 with Transaction().body(m): 153 instr = instr_fifo.read(m) 154 target_prediction = jump_target_resp(m) 155 156 jump_result = Mux(instr.taken, instr.jmp_addr, instr.reg_res) 157 is_auipc = instr.type == JumpBranchFn.Fn.AUIPC 158 159 predicted_addr_correctly = (instr.type != JumpBranchFn.Fn.JALR) | ( 160 target_prediction.valid & (target_prediction.cfi_target == instr.jmp_addr) 161 ) 162 163 misprediction = Signal() 164 m.d.av_comb += misprediction.eq( 165 ~(is_auipc | (predicted_addr_correctly & (instr.taken == instr.predicted_taken))) 166 ) 167 self.perf_mispredictions.incr(m, enable_call=misprediction) 168 169 jmp_addr_misaligned = ( 170 instr.jmp_addr & (0b1 if Extension.ZCA in self.gen_params.isa.extensions else 0b11) 171 ) != 0 172 173 async_interrupt_active = self.dm.get_dependency(AsyncInterruptInsertSignalKey()) 174 175 exception = Signal() 176 177 with m.If(~is_auipc & instr.taken & jmp_addr_misaligned): 178 self.perf_misaligned.incr(m) 179 # Spec: "[...] if the target address is not four-byte aligned. This exception is reported on the branch 180 # or jump instruction, not on the target instruction. No instruction-address-misaligned exception is 181 # generated for a conditional branch that is not taken." 182 m.d.comb += exception.eq(1) 183 self.exception_report( 184 m, 185 rob_id=instr.rob_id, 186 cause=ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED, 187 pc=instr.pc, 188 mtval=instr.jmp_addr, 189 ) 190 191 with m.Elif(async_interrupt_active & ~is_auipc): 192 # Jump instructions are entry points for async interrupts. 193 # This way we can store known pc via report to global exception register and avoid it in ROB. 194 # Exceptions have priority, because the instruction that reports async interrupt is commited 195 # and exception would be lost. 196 m.d.comb += exception.eq(1) 197 self.exception_report( 198 m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=jump_result, mtval=0 199 ) 200 with m.Elif(misprediction): 201 # Async interrupts can have priority, because `jump_result` is handled in the same way. 202 # No extra misprediction penalty will be introducted at interrupt return to `jump_result` address. 203 m.d.comb += exception.eq(1) 204 self.exception_report( 205 m, rob_id=instr.rob_id, cause=ExceptionCause._COREBLOCKS_MISPREDICTION, pc=jump_result, mtval=0 206 ) 207 208 with m.If(~is_auipc): 209 self.fifo_branch_resolved.write(m, from_pc=instr.pc, next_pc=jump_result, misprediction=misprediction) 210 log.debug( 211 m, 212 True, 213 "branch resolved from 0x{:08x} to 0x{:08x}; misprediction: {}", 214 instr.pc, 215 jump_result, 216 misprediction, 217 ) 218 219 self.push_result( 220 m, 221 rob_id=instr.rob_id, 222 result=instr.reg_res, 223 rp_dst=instr.rp_dst, 224 exception=exception, 225 ) 226 227 @def_method(m, self.issue_decoded) 228 def _(arg): 229 m.d.top_comb += jb.fn.eq(arg.decode_fn) 230 m.d.top_comb += jb.in1.eq(arg.s1_val) 231 m.d.top_comb += jb.in2.eq(arg.s2_val) 232 m.d.top_comb += jb.in_pc.eq(arg.pc) 233 m.d.top_comb += jb.in_imm.eq(arg.imm) 234 235 funct7_info = Signal(from_method_layout(self.gen_params.get(JumpBranchLayouts).funct7_info)) 236 m.d.top_comb += funct7_info.eq(arg.exec_fn.funct7) 237 m.d.top_comb += jb.in_rvc.eq(funct7_info.rvc) 238 239 jump_target_req(m) 240 241 instr_fifo.write( 242 m, 243 rob_id=arg.rob_id, 244 pc=arg.pc, 245 rp_dst=arg.rp_dst, 246 type=arg.decode_fn, 247 jmp_addr=jb.jmp_addr, 248 reg_res=jb.reg_res, 249 taken=jb.taken, 250 predicted_taken=funct7_info.predicted_taken, 251 tag=arg.tag, 252 ) 253 254 return m 255 256 257 @dataclass(frozen=True) 258 class JumpComponent(FunctionalComponentParams): 259 decoder_manager: JumpBranchFn = JumpBranchFn() 260 261 def get_module(self, gen_params: GenParams) -> FuncUnit: 262 return JumpBranchFuncUnit(gen_params, self.decoder_manager)