fpu_mul.py
1 from amaranth import * 2 from transactron import TModule, Method, def_method 3 from transactron.utils.transactron_helpers import from_method_layout 4 from coreblocks.func_blocks.fu.fpu.fpu_common import ( 5 RoundingModes, 6 FPUParams, 7 create_data_input_layout, 8 create_output_layout, 9 FPUCommonValues, 10 ) 11 from coreblocks.func_blocks.fu.fpu.fpu_error_module import FPUErrorModule 12 from coreblocks.func_blocks.fu.fpu.fpu_rounding_module import FPURounding 13 from transactron.utils.amaranth_ext import count_leading_zeros 14 from coreblocks.func_blocks.fu.unsigned_multiplication.fast_recursive import FastRecursiveMul 15 16 17 class FPUMulMethodLayout: 18 """FPU multiplication module method layout 19 20 Parameters 21 ---------- 22 fpu_params; FPUParams 23 FPU parameters 24 """ 25 26 def __init__(self, *, fpu_params: FPUParams): 27 self.mul_in_layout = [ 28 ("op_1", create_data_input_layout(fpu_params)), 29 ("op_2", create_data_input_layout(fpu_params)), 30 ("rounding_mode", RoundingModes), 31 ] 32 """ 33 | Input layout for multiplication 34 | op_1 - layout containing data of the first operand 35 | op_2 - layout containing data of the second operand 36 | rounding_mode - selected rounding mode 37 | op_1 and op_2 are created using 38 :meth:`create_data_input_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_data_input_layout>` 39 """ 40 self.mul_out_layout = create_output_layout(fpu_params) 41 """ 42 Output layout for multiplication. Created using 43 :meth:`create_output_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_output_layout>` 44 """ 45 46 47 class FPUMulModule(Elaboratable): 48 """ 49 | FPU multiplication top module 50 | The floating point multiplication consists of two parts: 51 | 1. Exponent calcuation - turning both exponents from biased form into un-biased form 52 | and then adding them together and turing result back into biased form 53 | 2. Significand multiplication - This is essentialy fixed-point multiplication with 54 | two bits for integer part and 2*n - 2 bits for fractional part. 55 | We deal with with subnormal number by extending exponents range and turning subnormal 56 | numbers into normalised numbers. 57 58 Parameters 59 ---------- 60 fpu_params: FPUParams 61 FPU rounding module parameters 62 63 Attributes 64 ---------- 65 mul_request: Method 66 Transactional method for initiating multiplication. 67 Takes 68 :meth:`mul_in_layout <coreblocks.func_blocks.fu.fpu.fpu_add_sub.FPUMulMethodLayout.add_sub_in_layout>` 69 as argument. 70 Returns result as 71 :meth:`mul_out_layout <coreblocks.func_blocks.fu.fpu.fpu_add_sub.FPUMulMethodLayout.add_sub_out_layout>`. 72 """ 73 74 def __init__(self, *, fpu_params: FPUParams): 75 self.fpu_params = fpu_params 76 self.method_layouts = FPUMulMethodLayout(fpu_params=self.fpu_params) 77 self.common_values = FPUCommonValues(self.fpu_params) 78 self.mul_request = Method( 79 i=self.method_layouts.mul_in_layout, 80 o=self.method_layouts.mul_out_layout, 81 ) 82 self.mul_params = {"isa": {"xlen": self.fpu_params.sig_width}} 83 84 def elaborate(self, platform): 85 m = TModule() 86 87 m.submodules.rounding_module = rounding_module = FPURounding(fpu_params=self.fpu_params) 88 m.submodules.exception_module = exception_module = FPUErrorModule(fpu_params=self.fpu_params) 89 m.submodules.multiplier = multiplier = FastRecursiveMul( 90 self.fpu_params.sig_width, self.fpu_params.sig_width // 2 91 ) 92 93 rounding_response = Signal(from_method_layout(rounding_module.method_layouts.rounding_out_layout)) 94 exception_response = Signal(from_method_layout(exception_module.method_layouts.error_out_layout)) 95 96 bias = self.common_values.bias 97 min_real_exp = Const(1 - bias) 98 max_real_exp = Const(self.common_values.max_exp - bias) 99 100 @def_method(m, self.mul_request) 101 def _(op_1, op_2, rounding_mode): 102 103 final_sign = Signal() 104 m.d.av_comb += final_sign.eq(op_1.sign ^ op_2.sign) 105 op_1_subn = ~op_1.sig[-1] 106 op_2_subn = ~op_2.sig[-1] 107 108 # exponent before potential operand normalization 109 pre_op_norm_exp = Signal(signed(self.fpu_params.exp_width + 1)) 110 sum_of_exp = op_1.exp + op_2.exp - 2 * bias 111 # Because exp = 0 and exp = 1 represent the same exponent emin, 112 # to properly calcuate pre_op_norm_exp we have to adjust those exponents 113 # in case they are 0 and represent subnormal number 114 subn_correction = op_1_subn + op_2_subn 115 m.d.av_comb += pre_op_norm_exp.eq(sum_of_exp + subn_correction) 116 117 # One of the ways to deal with subnormal values is to normalise them, 118 # record additional shifts in exponent and adjust for this during normalization 119 op_1_norm_shift = Signal(range(0, self.fpu_params.sig_width + 1)) 120 op_2_norm_shift = Signal(range(0, self.fpu_params.sig_width + 1)) 121 m.d.av_comb += op_1_norm_shift.eq(count_leading_zeros(op_1.sig)) 122 m.d.av_comb += op_2_norm_shift.eq(count_leading_zeros(op_2.sig)) 123 124 norm_op_1_sig = Signal(self.fpu_params.sig_width) 125 norm_op_2_sig = Signal(self.fpu_params.sig_width) 126 127 m.d.av_comb += norm_op_1_sig.eq(op_1.sig << op_1_norm_shift) 128 m.d.av_comb += norm_op_2_sig.eq(op_2.sig << op_2_norm_shift) 129 130 shifted_out_bit = Signal() 131 132 sig_product = Signal(2 * self.fpu_params.sig_width) 133 m.d.av_comb += multiplier.i1.eq(norm_op_1_sig) 134 m.d.av_comb += multiplier.i2.eq(norm_op_2_sig) 135 m.d.av_comb += sig_product.eq(multiplier.r) 136 137 # First step of normalization 138 # if sig is between [1,2) leave it alone 139 # if sig is between [2,4) shift right by one 140 fixed_sig_product_norm = Signal((2 * self.fpu_params.sig_width) - 1) 141 m.d.av_comb += fixed_sig_product_norm.eq(sig_product) 142 with m.If(sig_product[-1]): 143 m.d.av_comb += fixed_sig_product_norm.eq(sig_product >> 1) 144 m.d.av_comb += shifted_out_bit.eq(sig_product[0]) 145 146 post_multiplication_exp = Signal(signed(self.fpu_params.exp_width + 1)) 147 m.d.av_comb += post_multiplication_exp.eq( 148 pre_op_norm_exp - (op_1_norm_shift + op_2_norm_shift) + sig_product[-1] 149 ) 150 151 sticky_bit = Signal() 152 round_bit = Signal() 153 mult_exp = Signal(self.fpu_params.exp_width) 154 mult_sig = Signal(self.fpu_params.sig_width + 1) 155 # One additional bit for round bit 156 normalised_ext_sig = Signal(self.fpu_params.sig_width + 1) 157 # The entire number consists of 2*(sig_width - 1) bits for fractional size 158 # and 2 bits for integer part so to turn this number into floating point number 159 # (1 bit for integer part and sig_width - 1 bits for fractional part) 160 # we have to shift number right by sig_width - 1 bits but because we 161 # want to keep one aditional bit for round bit we shift by sig_width - 2 162 m.d.av_comb += normalised_ext_sig.eq(fixed_sig_product_norm >> (self.fpu_params.sig_width - 2)) 163 164 any_shifted_out = Signal() 165 with m.If(post_multiplication_exp >= min_real_exp): 166 m.d.av_comb += mult_exp.eq( 167 Mux( 168 post_multiplication_exp >= max_real_exp, 169 self.common_values.max_exp, 170 post_multiplication_exp + bias, 171 ) 172 ) 173 m.d.av_comb += mult_sig.eq(normalised_ext_sig) 174 lost_bits_or_red = fixed_sig_product_norm.bit_select(0, self.fpu_params.sig_width - 2).any() 175 m.d.av_comb += any_shifted_out.eq(lost_bits_or_red) 176 with m.If(mult_sig[-1] == 0): 177 m.d.av_comb += mult_exp.eq(0) 178 with m.Elif(post_multiplication_exp < min_real_exp): 179 # In this case value always will be subnormal 180 m.d.av_comb += mult_exp.eq(0) 181 shift_needed = Signal(unsigned(self.fpu_params.exp_width)) 182 m.d.av_comb += shift_needed.eq(min_real_exp - post_multiplication_exp) 183 m.d.av_comb += mult_sig.eq(normalised_ext_sig >> shift_needed) 184 with m.If(shift_needed > (self.fpu_params.sig_width)): 185 m.d.av_comb += any_shifted_out.eq(fixed_sig_product_norm.any()) 186 with m.Else(): 187 # product has (2*p) - 1 bits, p ms bits represent the fp number 188 # p+1 bit is round bit and p - 2 ls bits for initial sticky bit 189 # For sticky bit we have to catch those p - 2 ls bits and shift_needed bits 190 # from p + 1 ms bits 191 padding = Signal().replicate(self.fpu_params.sig_width) 192 shifted_out = Cat(padding, fixed_sig_product_norm).bit_select( 193 shift_needed, 2 * self.fpu_params.sig_width - 2 194 ) 195 m.d.av_comb += any_shifted_out.eq(shifted_out.any()) 196 m.d.av_comb += sticky_bit.eq(any_shifted_out | shifted_out_bit) 197 m.d.av_comb += round_bit.eq(mult_sig[0]) 198 resp = rounding_module.rounding_request( 199 m, 200 sign=final_sign, 201 sig=mult_sig >> 1, 202 exp=mult_exp, 203 round_bit=round_bit, 204 sticky_bit=sticky_bit, 205 rounding_mode=rounding_mode, 206 ) 207 m.d.av_comb += rounding_response.eq(resp) 208 209 is_inf = Signal() 210 m.d.av_comb += is_inf.eq(op_1.is_inf | op_2.is_inf) 211 bad_inf = Signal() 212 m.d.av_comb += bad_inf.eq((op_1.is_inf & op_2.is_zero) | (op_2.is_inf & op_1.is_zero)) 213 is_zero = Signal() 214 m.d.av_comb += is_zero.eq(op_1.is_zero | op_2.is_zero) 215 is_nan = Signal() 216 m.d.av_comb += is_nan.eq(op_1.is_nan | op_2.is_nan | bad_inf) 217 218 exc_sig = Signal(self.fpu_params.sig_width) 219 exc_exp = Signal(self.fpu_params.exp_width) 220 exc_sign = Signal() 221 inexact = Signal() 222 invalid_operation = Signal() 223 with m.If(is_nan | is_inf | is_zero): 224 m.d.av_comb += inexact.eq(0) 225 with m.If(is_nan): 226 is_any_snan = ((~op_1.sig[-2]) & op_1.is_nan) | ((~op_2.sig[-2]) & op_2.is_nan) 227 m.d.av_comb += exc_sign.eq(0) 228 m.d.av_comb += exc_exp.eq(self.common_values.max_exp) 229 m.d.av_comb += exc_sig.eq(self.common_values.canonical_nan_sig) 230 m.d.av_comb += invalid_operation.eq(is_any_snan | bad_inf) 231 with m.Elif(is_inf & ~(bad_inf)): 232 m.d.av_comb += exc_sign.eq(Mux(op_1.is_inf, op_1.sign, op_2.sign)) 233 m.d.av_comb += (exc_exp.eq(Mux(op_1.is_inf, op_1.exp, op_2.exp)),) 234 m.d.av_comb += (exc_sig.eq(Mux(op_1.is_inf, op_1.sig, op_2.sig)),) 235 with m.Elif(is_zero): 236 m.d.av_comb += exc_sign.eq(final_sign) 237 m.d.av_comb += exc_exp.eq(0) 238 m.d.av_comb += exc_sig.eq(0) 239 with m.Else(): 240 m.d.av_comb += exc_sign.eq(final_sign) 241 m.d.av_comb += exc_exp.eq(rounding_response["exp"]) 242 m.d.av_comb += inexact.eq(rounding_response["inexact"]) 243 with m.If(rounding_response["exp"] == self.common_values.max_exp): 244 m.d.av_comb += exc_sig.eq(2 ** (self.fpu_params.sig_width - 1)) 245 with m.Else(): 246 m.d.av_comb += exc_sig.eq(rounding_response["sig"]) 247 248 resp = exception_module.error_checking_request( 249 m, 250 sign=exc_sign, 251 sig=exc_sig, 252 exp=exc_exp, 253 rounding_mode=rounding_mode, 254 inexact=inexact, 255 invalid_operation=invalid_operation, 256 division_by_zero=0, 257 input_inf=is_inf, 258 ) 259 m.d.av_comb += exception_response.eq(resp) 260 261 return exception_response 262 263 return m