div_unit.py
1 from dataclasses import KW_ONLY, dataclass 2 from enum import IntFlag, auto 3 from collections.abc import Sequence 4 5 from amaranth import * 6 from amaranth.lib import data 7 8 from coreblocks.params.fu_params import FunctionalComponentParams 9 from coreblocks.params import GenParams 10 from coreblocks.arch import OpType, Funct3 11 from transactron import * 12 from transactron.core import def_method 13 from transactron.lib import * 14 15 from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase 16 17 from transactron.utils import OneHotSwitch 18 from coreblocks.func_blocks.interface.func_protocols import FuncUnit 19 from coreblocks.func_blocks.fu.division.long_division import LongDivider 20 21 22 class DivFn(DecoderManager): 23 class Fn(IntFlag): 24 DIV = auto() 25 DIVU = auto() 26 REM = auto() 27 REMU = auto() 28 29 def get_instructions(self) -> Sequence[tuple]: 30 return [ 31 (self.Fn.DIV, OpType.DIV_REM, Funct3.DIV), 32 (self.Fn.DIVU, OpType.DIV_REM, Funct3.DIVU), 33 (self.Fn.REM, OpType.DIV_REM, Funct3.REM), 34 (self.Fn.REMU, OpType.DIV_REM, Funct3.REMU), 35 ] 36 37 38 def get_input(arg: data.View) -> tuple[Value, Value]: 39 return arg.s1_val, Mux(arg.imm, arg.imm, arg.s2_val) 40 41 42 class DivUnit(FuncUnitBase[DivFn]): 43 def __init__(self, gen_params: GenParams, ipc: int = 4, fn=DivFn()): 44 super().__init__(gen_params, fn) 45 self.ipc = ipc 46 47 self.clear = Method() 48 49 def elaborate(self, platform): 50 m = super().elaborate(platform) 51 52 m.submodules.params_fifo = params_fifo = FIFO( 53 [ 54 ("rob_id", self.gen_params.rob_entries_bits), 55 ("rp_dst", self.gen_params.phys_regs_bits), 56 ("flip_sign", 1), 57 ("rem_res", 1), 58 ], 59 2, 60 ) 61 m.submodules.divider = divider = LongDivider(self.gen_params, self.ipc) 62 63 xlen = self.gen_params.isa.xlen 64 sign_bit = xlen - 1 # position of sign bit 65 66 @def_method(m, self.clear) 67 def _(): 68 divider.clear(m) 69 70 @def_method(m, self.issue_decoded) 71 def _(arg): 72 i1, i2 = get_input(arg) 73 74 flip_sign = Signal(1) # if result is negative number 75 rem_res = Signal(1) # flag whether we want quotient or remainder 76 77 dividend = Signal(xlen) 78 divisor = Signal(xlen) 79 80 def _abs(s: Value) -> Value: 81 return Mux(s.as_signed() < 0, -s, s) 82 83 with OneHotSwitch(m, arg.decode_fn) as OneHotCase: 84 with OneHotCase(DivFn.Fn.DIVU): 85 m.d.av_comb += flip_sign.eq(0) 86 m.d.av_comb += rem_res.eq(0) 87 m.d.av_comb += dividend.eq(i1) 88 m.d.av_comb += divisor.eq(i2) 89 with OneHotCase(DivFn.Fn.DIV): 90 # quotient is negative if divisor and dividend have different signs 91 m.d.av_comb += flip_sign.eq(i1[sign_bit] ^ i2[sign_bit]) 92 m.d.av_comb += rem_res.eq(0) 93 m.d.av_comb += dividend.eq(_abs(i1)) 94 m.d.av_comb += divisor.eq(_abs(i2)) 95 with OneHotCase(DivFn.Fn.REMU): 96 m.d.av_comb += flip_sign.eq(0) 97 m.d.av_comb += rem_res.eq(1) 98 m.d.av_comb += dividend.eq(i1) 99 m.d.av_comb += divisor.eq(i2) 100 with OneHotCase(DivFn.Fn.REM): 101 # sign of remainder is equal to sign of dividend 102 m.d.av_comb += flip_sign.eq(i1[sign_bit]) 103 m.d.av_comb += rem_res.eq(1) 104 m.d.av_comb += dividend.eq(_abs(i1)) 105 m.d.av_comb += divisor.eq(_abs(i2)) 106 107 params_fifo.write(m, rob_id=arg.rob_id, rp_dst=arg.rp_dst, flip_sign=flip_sign, rem_res=rem_res) 108 109 divider.issue(m, dividend=dividend, divisor=divisor) 110 111 with Transaction().body(m): 112 response = divider.accept(m) 113 params = params_fifo.read(m) 114 result = Mux(params.rem_res, response.remainder, response.quotient) 115 # change sign but only if it was requested and sign is not correct 116 flip_sig = Mux(params.flip_sign, ~result[sign_bit], 0) 117 sign_result = Mux(flip_sig, -result, result) 118 119 self.push_result(m, rob_id=params.rob_id, result=sign_result, rp_dst=params.rp_dst, exception=0) 120 121 return m 122 123 124 @dataclass(frozen=True) 125 class DivComponent(FunctionalComponentParams): 126 _: KW_ONLY 127 result_fifo: bool = False # last step is registered 128 ipc: int = 3 # iterations per cycle 129 decoder_manager: DivFn = DivFn() 130 131 def get_module(self, gen_params: GenParams) -> FuncUnit: 132 return DivUnit(gen_params, self.ipc, self.decoder_manager)