test_jb_unit.py
1 from amaranth import * 2 from amaranth.lib.data import StructLayout 3 from parameterized import parameterized_class 4 5 from coreblocks.params import * 6 from coreblocks.func_blocks.fu.jumpbranch import JumpBranchFuncUnit, JumpBranchFn 7 from transactron import Method, def_method, TModule, Transaction 8 9 from coreblocks.interface.layouts import FuncUnitLayouts, JumpBranchLayouts 10 from coreblocks.func_blocks.interface.func_protocols import FuncUnit 11 from coreblocks.arch import Funct3, OpType, ExceptionCause 12 from coreblocks.interface.keys import PredictedJumpTargetKey 13 14 from transactron.utils import signed_to_int, DependencyContext 15 from transactron.lib import BasicFifo 16 17 from test.func_blocks.fu.functional_common import ExecFn, FunctionalUnitTestCase 18 19 20 class JumpBranchWrapper(FuncUnit, Elaboratable): 21 def __init__(self, gen_params: GenParams, auipc_test: bool): 22 self.gp = gen_params 23 self.auipc_test = auipc_test 24 layouts = gen_params.get(JumpBranchLayouts) 25 26 self.target_pred_req = Method(i=layouts.predicted_jump_target_req) 27 self.target_pred_resp = Method(o=layouts.predicted_jump_target_resp) 28 29 DependencyContext.get().add_dependency(PredictedJumpTargetKey(), (self.target_pred_req, self.target_pred_resp)) 30 31 self.jb = JumpBranchFuncUnit(gen_params) 32 self.issue = self.jb.issue 33 self.push_result = Method( 34 i=StructLayout( 35 gen_params.get(FuncUnitLayouts).push_result.members 36 | (gen_params.get(JumpBranchLayouts).verify_branch.members if not auipc_test else {}) 37 ) 38 ) 39 40 def elaborate(self, platform): 41 m = TModule() 42 43 m.submodules.jb_unit = self.jb 44 m.submodules.res_fifo = res_fifo = BasicFifo(self.gp.get(FuncUnitLayouts).push_result, 2) 45 46 self.jb.push_result.provide(res_fifo.write) 47 48 @def_method(m, self.target_pred_req) 49 def _(): 50 pass 51 52 @def_method(m, self.target_pred_resp) 53 def _(arg): 54 return {"valid": 0, "cfi_target": 0} 55 56 with Transaction().body(m): 57 res = res_fifo.read(m) 58 ret = { 59 "result": res.result, 60 "rob_id": res.rob_id, 61 "rp_dst": res.rp_dst, 62 "exception": res.exception, 63 } 64 if not self.auipc_test: 65 verify = self.jb.fifo_branch_resolved.read(m) 66 ret = ret | { 67 "next_pc": verify.next_pc, 68 "from_pc": verify.from_pc, 69 "misprediction": verify.misprediction, 70 } 71 72 self.push_result(m, ret) 73 74 return m 75 76 77 class JumpBranchWrapperComponent(FunctionalComponentParams): 78 def __init__(self, auipc_test: bool): 79 self.auipc_test = auipc_test 80 81 def get_module(self, gen_params: GenParams) -> FuncUnit: 82 return JumpBranchWrapper(gen_params, self.auipc_test) 83 84 def get_optypes(self) -> set[OpType]: 85 return JumpBranchFn().get_op_types() 86 87 88 @staticmethod 89 def compute_result(i1: int, i2: int, i_imm: int, pc: int, fn: JumpBranchFn.Fn, xlen: int) -> dict[str, int]: 90 max_int = 2**xlen - 1 91 branch_target = pc + signed_to_int(i_imm, xlen) 92 next_pc = 0 93 res = pc + 4 94 95 match fn: 96 case JumpBranchFn.Fn.JAL: 97 next_pc = pc + signed_to_int(i_imm, xlen) 98 case JumpBranchFn.Fn.JALR: 99 next_pc = (i1 + signed_to_int(i_imm, xlen)) & ~0x1 100 case JumpBranchFn.Fn.BEQ: 101 next_pc = branch_target if i1 == i2 else pc + 4 102 case JumpBranchFn.Fn.BNE: 103 next_pc = branch_target if i1 != i2 else pc + 4 104 case JumpBranchFn.Fn.BLT: 105 next_pc = branch_target if signed_to_int(i1, xlen) < signed_to_int(i2, xlen) else pc + 4 106 case JumpBranchFn.Fn.BLTU: 107 next_pc = branch_target if i1 < i2 else pc + 4 108 case JumpBranchFn.Fn.BGE: 109 next_pc = branch_target if signed_to_int(i1, xlen) >= signed_to_int(i2, xlen) else pc + 4 110 case JumpBranchFn.Fn.BGEU: 111 next_pc = branch_target if i1 >= i2 else pc + 4 112 113 next_pc &= max_int 114 res &= max_int 115 116 misprediction = next_pc != pc + 4 117 118 exception = None 119 exception_pc = pc 120 mtval = 0 121 if next_pc & 0b11 != 0: 122 exception = ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED 123 mtval = next_pc 124 elif misprediction: 125 exception = ExceptionCause._COREBLOCKS_MISPREDICTION 126 exception_pc = next_pc 127 128 return {"result": res, "from_pc": pc, "next_pc": next_pc, "misprediction": misprediction} | ( 129 {"exception": exception, "exception_pc": exception_pc, "mtval": mtval} if exception is not None else {} 130 ) 131 132 133 @staticmethod 134 def compute_result_auipc(i1: int, i2: int, i_imm: int, pc: int, fn: JumpBranchFn.Fn, xlen: int) -> dict[str, int]: 135 max_int = 2**xlen - 1 136 res = pc + 4 137 138 if fn == JumpBranchFn.Fn.AUIPC: 139 res = pc + i_imm 140 141 res &= max_int 142 143 return {"result": res} 144 145 146 ops = { 147 JumpBranchFn.Fn.BEQ: ExecFn(OpType.BRANCH, Funct3.BEQ), 148 JumpBranchFn.Fn.BNE: ExecFn(OpType.BRANCH, Funct3.BNE), 149 JumpBranchFn.Fn.BLT: ExecFn(OpType.BRANCH, Funct3.BLT), 150 JumpBranchFn.Fn.BLTU: ExecFn(OpType.BRANCH, Funct3.BLTU), 151 JumpBranchFn.Fn.BGE: ExecFn(OpType.BRANCH, Funct3.BGE), 152 JumpBranchFn.Fn.BGEU: ExecFn(OpType.BRANCH, Funct3.BGEU), 153 JumpBranchFn.Fn.JAL: ExecFn(OpType.JAL), 154 JumpBranchFn.Fn.JALR: ExecFn(OpType.JALR), 155 } 156 157 ops_auipc = { 158 JumpBranchFn.Fn.AUIPC: ExecFn(OpType.AUIPC), 159 } 160 161 162 @parameterized_class( 163 ("name", "ops", "func_unit", "compute_result"), 164 [ 165 ( 166 "branches_and_jumps", 167 ops, 168 JumpBranchWrapperComponent(auipc_test=False), 169 compute_result, 170 ), 171 ( 172 "auipc", 173 ops_auipc, 174 JumpBranchWrapperComponent(auipc_test=True), 175 compute_result_auipc, 176 ), 177 ], 178 ) 179 class TestJumpBranchUnit(FunctionalUnitTestCase[JumpBranchFn.Fn]): 180 compute_result = compute_result 181 zero_imm = False 182 183 def test_fu(self): 184 self.run_standard_fu_test()