/ coreblocks / func_blocks / fu / div_unit.py
div_unit.py
  1  from dataclasses import KW_ONLY, dataclass
  2  from enum import IntFlag, auto
  3  from collections.abc import Sequence
  4  
  5  from amaranth import *
  6  from amaranth.lib import data
  7  
  8  from coreblocks.params.fu_params import FunctionalComponentParams
  9  from coreblocks.params import GenParams
 10  from coreblocks.arch import OpType, Funct3
 11  from transactron import *
 12  from transactron.core import def_method
 13  from transactron.lib import *
 14  
 15  from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase
 16  
 17  from transactron.utils import OneHotSwitch
 18  from coreblocks.func_blocks.interface.func_protocols import FuncUnit
 19  from coreblocks.func_blocks.fu.division.long_division import LongDivider
 20  
 21  
 22  class DivFn(DecoderManager):
 23      class Fn(IntFlag):
 24          DIV = auto()
 25          DIVU = auto()
 26          REM = auto()
 27          REMU = auto()
 28  
 29      def get_instructions(self) -> Sequence[tuple]:
 30          return [
 31              (self.Fn.DIV, OpType.DIV_REM, Funct3.DIV),
 32              (self.Fn.DIVU, OpType.DIV_REM, Funct3.DIVU),
 33              (self.Fn.REM, OpType.DIV_REM, Funct3.REM),
 34              (self.Fn.REMU, OpType.DIV_REM, Funct3.REMU),
 35          ]
 36  
 37  
 38  def get_input(arg: data.View) -> tuple[Value, Value]:
 39      return arg.s1_val, Mux(arg.imm, arg.imm, arg.s2_val)
 40  
 41  
 42  class DivUnit(FuncUnitBase[DivFn]):
 43      def __init__(self, gen_params: GenParams, ipc: int = 4, fn=DivFn()):
 44          super().__init__(gen_params, fn)
 45          self.ipc = ipc
 46  
 47          self.clear = Method()
 48  
 49      def elaborate(self, platform):
 50          m = super().elaborate(platform)
 51  
 52          m.submodules.params_fifo = params_fifo = FIFO(
 53              [
 54                  ("rob_id", self.gen_params.rob_entries_bits),
 55                  ("rp_dst", self.gen_params.phys_regs_bits),
 56                  ("flip_sign", 1),
 57                  ("rem_res", 1),
 58              ],
 59              2,
 60          )
 61          m.submodules.divider = divider = LongDivider(self.gen_params, self.ipc)
 62  
 63          xlen = self.gen_params.isa.xlen
 64          sign_bit = xlen - 1  # position of sign bit
 65  
 66          @def_method(m, self.clear)
 67          def _():
 68              divider.clear(m)
 69  
 70          @def_method(m, self.issue_decoded)
 71          def _(arg):
 72              i1, i2 = get_input(arg)
 73  
 74              flip_sign = Signal(1)  # if result is negative number
 75              rem_res = Signal(1)  # flag whether we want quotient or remainder
 76  
 77              dividend = Signal(xlen)
 78              divisor = Signal(xlen)
 79  
 80              def _abs(s: Value) -> Value:
 81                  return Mux(s.as_signed() < 0, -s, s)
 82  
 83              with OneHotSwitch(m, arg.decode_fn) as OneHotCase:
 84                  with OneHotCase(DivFn.Fn.DIVU):
 85                      m.d.av_comb += flip_sign.eq(0)
 86                      m.d.av_comb += rem_res.eq(0)
 87                      m.d.av_comb += dividend.eq(i1)
 88                      m.d.av_comb += divisor.eq(i2)
 89                  with OneHotCase(DivFn.Fn.DIV):
 90                      # quotient is negative if divisor and dividend have different signs
 91                      m.d.av_comb += flip_sign.eq(i1[sign_bit] ^ i2[sign_bit])
 92                      m.d.av_comb += rem_res.eq(0)
 93                      m.d.av_comb += dividend.eq(_abs(i1))
 94                      m.d.av_comb += divisor.eq(_abs(i2))
 95                  with OneHotCase(DivFn.Fn.REMU):
 96                      m.d.av_comb += flip_sign.eq(0)
 97                      m.d.av_comb += rem_res.eq(1)
 98                      m.d.av_comb += dividend.eq(i1)
 99                      m.d.av_comb += divisor.eq(i2)
100                  with OneHotCase(DivFn.Fn.REM):
101                      # sign of remainder is equal to sign of dividend
102                      m.d.av_comb += flip_sign.eq(i1[sign_bit])
103                      m.d.av_comb += rem_res.eq(1)
104                      m.d.av_comb += dividend.eq(_abs(i1))
105                      m.d.av_comb += divisor.eq(_abs(i2))
106  
107              params_fifo.write(m, rob_id=arg.rob_id, rp_dst=arg.rp_dst, flip_sign=flip_sign, rem_res=rem_res)
108  
109              divider.issue(m, dividend=dividend, divisor=divisor)
110  
111          with Transaction().body(m):
112              response = divider.accept(m)
113              params = params_fifo.read(m)
114              result = Mux(params.rem_res, response.remainder, response.quotient)
115              # change sign but only if it was requested and sign is not correct
116              flip_sig = Mux(params.flip_sign, ~result[sign_bit], 0)
117              sign_result = Mux(flip_sig, -result, result)
118  
119              self.push_result(m, rob_id=params.rob_id, result=sign_result, rp_dst=params.rp_dst, exception=0)
120  
121          return m
122  
123  
124  @dataclass(frozen=True)
125  class DivComponent(FunctionalComponentParams):
126      _: KW_ONLY
127      result_fifo: bool = False  # last step is registered
128      ipc: int = 3  # iterations per cycle
129      decoder_manager: DivFn = DivFn()
130  
131      def get_module(self, gen_params: GenParams) -> FuncUnit:
132          return DivUnit(gen_params, self.ipc, self.decoder_manager)