/ test / func_blocks / fu / test_jb_unit.py
test_jb_unit.py
  1  from amaranth import *
  2  from amaranth.lib.data import StructLayout
  3  from parameterized import parameterized_class
  4  
  5  from coreblocks.params import *
  6  from coreblocks.func_blocks.fu.jumpbranch import JumpBranchFuncUnit, JumpBranchFn
  7  from transactron import Method, def_method, TModule, Transaction
  8  
  9  from coreblocks.interface.layouts import FuncUnitLayouts, JumpBranchLayouts
 10  from coreblocks.func_blocks.interface.func_protocols import FuncUnit
 11  from coreblocks.arch import Funct3, OpType, ExceptionCause
 12  from coreblocks.interface.keys import PredictedJumpTargetKey
 13  
 14  from transactron.utils import signed_to_int, DependencyContext
 15  from transactron.lib import BasicFifo
 16  
 17  from test.func_blocks.fu.functional_common import ExecFn, FunctionalUnitTestCase
 18  
 19  
 20  class JumpBranchWrapper(FuncUnit, Elaboratable):
 21      def __init__(self, gen_params: GenParams, auipc_test: bool):
 22          self.gp = gen_params
 23          self.auipc_test = auipc_test
 24          layouts = gen_params.get(JumpBranchLayouts)
 25  
 26          self.target_pred_req = Method(i=layouts.predicted_jump_target_req)
 27          self.target_pred_resp = Method(o=layouts.predicted_jump_target_resp)
 28  
 29          DependencyContext.get().add_dependency(PredictedJumpTargetKey(), (self.target_pred_req, self.target_pred_resp))
 30  
 31          self.jb = JumpBranchFuncUnit(gen_params)
 32          self.issue = self.jb.issue
 33          self.push_result = Method(
 34              i=StructLayout(
 35                  gen_params.get(FuncUnitLayouts).push_result.members
 36                  | (gen_params.get(JumpBranchLayouts).verify_branch.members if not auipc_test else {})
 37              )
 38          )
 39  
 40      def elaborate(self, platform):
 41          m = TModule()
 42  
 43          m.submodules.jb_unit = self.jb
 44          m.submodules.res_fifo = res_fifo = BasicFifo(self.gp.get(FuncUnitLayouts).push_result, 2)
 45  
 46          self.jb.push_result.provide(res_fifo.write)
 47  
 48          @def_method(m, self.target_pred_req)
 49          def _():
 50              pass
 51  
 52          @def_method(m, self.target_pred_resp)
 53          def _(arg):
 54              return {"valid": 0, "cfi_target": 0}
 55  
 56          with Transaction().body(m):
 57              res = res_fifo.read(m)
 58              ret = {
 59                  "result": res.result,
 60                  "rob_id": res.rob_id,
 61                  "rp_dst": res.rp_dst,
 62                  "exception": res.exception,
 63              }
 64              if not self.auipc_test:
 65                  verify = self.jb.fifo_branch_resolved.read(m)
 66                  ret = ret | {
 67                      "next_pc": verify.next_pc,
 68                      "from_pc": verify.from_pc,
 69                      "misprediction": verify.misprediction,
 70                  }
 71  
 72              self.push_result(m, ret)
 73  
 74          return m
 75  
 76  
 77  class JumpBranchWrapperComponent(FunctionalComponentParams):
 78      def __init__(self, auipc_test: bool):
 79          self.auipc_test = auipc_test
 80  
 81      def get_module(self, gen_params: GenParams) -> FuncUnit:
 82          return JumpBranchWrapper(gen_params, self.auipc_test)
 83  
 84      def get_optypes(self) -> set[OpType]:
 85          return JumpBranchFn().get_op_types()
 86  
 87  
 88  @staticmethod
 89  def compute_result(i1: int, i2: int, i_imm: int, pc: int, fn: JumpBranchFn.Fn, xlen: int) -> dict[str, int]:
 90      max_int = 2**xlen - 1
 91      branch_target = pc + signed_to_int(i_imm, xlen)
 92      next_pc = 0
 93      res = pc + 4
 94  
 95      match fn:
 96          case JumpBranchFn.Fn.JAL:
 97              next_pc = pc + signed_to_int(i_imm, xlen)
 98          case JumpBranchFn.Fn.JALR:
 99              next_pc = (i1 + signed_to_int(i_imm, xlen)) & ~0x1
100          case JumpBranchFn.Fn.BEQ:
101              next_pc = branch_target if i1 == i2 else pc + 4
102          case JumpBranchFn.Fn.BNE:
103              next_pc = branch_target if i1 != i2 else pc + 4
104          case JumpBranchFn.Fn.BLT:
105              next_pc = branch_target if signed_to_int(i1, xlen) < signed_to_int(i2, xlen) else pc + 4
106          case JumpBranchFn.Fn.BLTU:
107              next_pc = branch_target if i1 < i2 else pc + 4
108          case JumpBranchFn.Fn.BGE:
109              next_pc = branch_target if signed_to_int(i1, xlen) >= signed_to_int(i2, xlen) else pc + 4
110          case JumpBranchFn.Fn.BGEU:
111              next_pc = branch_target if i1 >= i2 else pc + 4
112  
113      next_pc &= max_int
114      res &= max_int
115  
116      misprediction = next_pc != pc + 4
117  
118      exception = None
119      exception_pc = pc
120      mtval = 0
121      if next_pc & 0b11 != 0:
122          exception = ExceptionCause.INSTRUCTION_ADDRESS_MISALIGNED
123          mtval = next_pc
124      elif misprediction:
125          exception = ExceptionCause._COREBLOCKS_MISPREDICTION
126          exception_pc = next_pc
127  
128      return {"result": res, "from_pc": pc, "next_pc": next_pc, "misprediction": misprediction} | (
129          {"exception": exception, "exception_pc": exception_pc, "mtval": mtval} if exception is not None else {}
130      )
131  
132  
133  @staticmethod
134  def compute_result_auipc(i1: int, i2: int, i_imm: int, pc: int, fn: JumpBranchFn.Fn, xlen: int) -> dict[str, int]:
135      max_int = 2**xlen - 1
136      res = pc + 4
137  
138      if fn == JumpBranchFn.Fn.AUIPC:
139          res = pc + i_imm
140  
141      res &= max_int
142  
143      return {"result": res}
144  
145  
146  ops = {
147      JumpBranchFn.Fn.BEQ: ExecFn(OpType.BRANCH, Funct3.BEQ),
148      JumpBranchFn.Fn.BNE: ExecFn(OpType.BRANCH, Funct3.BNE),
149      JumpBranchFn.Fn.BLT: ExecFn(OpType.BRANCH, Funct3.BLT),
150      JumpBranchFn.Fn.BLTU: ExecFn(OpType.BRANCH, Funct3.BLTU),
151      JumpBranchFn.Fn.BGE: ExecFn(OpType.BRANCH, Funct3.BGE),
152      JumpBranchFn.Fn.BGEU: ExecFn(OpType.BRANCH, Funct3.BGEU),
153      JumpBranchFn.Fn.JAL: ExecFn(OpType.JAL),
154      JumpBranchFn.Fn.JALR: ExecFn(OpType.JALR),
155  }
156  
157  ops_auipc = {
158      JumpBranchFn.Fn.AUIPC: ExecFn(OpType.AUIPC),
159  }
160  
161  
162  @parameterized_class(
163      ("name", "ops", "func_unit", "compute_result"),
164      [
165          (
166              "branches_and_jumps",
167              ops,
168              JumpBranchWrapperComponent(auipc_test=False),
169              compute_result,
170          ),
171          (
172              "auipc",
173              ops_auipc,
174              JumpBranchWrapperComponent(auipc_test=True),
175              compute_result_auipc,
176          ),
177      ],
178  )
179  class TestJumpBranchUnit(FunctionalUnitTestCase[JumpBranchFn.Fn]):
180      compute_result = compute_result
181      zero_imm = False
182  
183      def test_fu(self):
184          self.run_standard_fu_test()