fpu_add_sub.py
1 from amaranth import * 2 from transactron import TModule, Method, def_method 3 from transactron.utils import assign 4 from transactron.utils.transactron_helpers import from_method_layout 5 from coreblocks.func_blocks.fu.fpu.fpu_common import ( 6 RoundingModes, 7 FPUParams, 8 create_data_input_layout, 9 create_output_layout, 10 create_raw_float_layout, 11 FPUCommonValues, 12 ) 13 from coreblocks.func_blocks.fu.fpu.far_path import FarPathModule 14 from coreblocks.func_blocks.fu.fpu.close_path import ClosePathModule 15 from coreblocks.func_blocks.fu.fpu.fpu_error_module import FPUErrorModule 16 17 18 class FPUAddSubMethodLayout: 19 """FPU addition/subtraction top module method layout 20 21 Parameters 22 ---------- 23 fpu_params; FPUParams 24 FPU parameters 25 """ 26 27 def __init__(self, *, fpu_params: FPUParams): 28 self.add_sub_in_layout = [ 29 ("op_1", create_data_input_layout(fpu_params)), 30 ("op_2", create_data_input_layout(fpu_params)), 31 ("rounding_mode", RoundingModes), 32 ("operation", 1), 33 ] 34 """ 35 | Input layout for addition/subtraction 36 | op_1 - layout containing data of the first operand 37 | op_2 - layout containing data of the second operand 38 | rounding_mode - selected rounding mode 39 | operation - selected operation; 1 - subtraction, 0 - addition 40 | op_1 and op_2 are created using 41 :meth:`create_data_input_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_data_input_layout>` 42 """ 43 self.add_sub_out_layout = create_output_layout(fpu_params) 44 """ 45 Output layout for addition/subtraction. Created using 46 :meth:`create_output_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_output_layout>` 47 """ 48 self.raw_float_layout = create_raw_float_layout(fpu_params) 49 """ 50 Output layout for raw float. Created using 51 :meth:`create_raw_float_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_raw_float_layout>` 52 """ 53 ext_paramas = FPUParams(sig_width=fpu_params.sig_width + 2, exp_width=fpu_params.exp_width) 54 self.ext_float_layout = create_raw_float_layout(ext_paramas) 55 """ 56 Output layout for raw float with significand larger by two bits from selected format. 57 Created using 58 :meth:`create_raw_float_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_raw_float_layout>` 59 """ 60 61 62 class FPUAddSubModule(Elaboratable): 63 """ 64 | FPU addition/subtraction top module 65 | This module implements addition and subtraction using 66 two path approach with rounding prediction. 67 | The module can be divided into two segments: 68 | 1. Receiving data and preparing it for one of the two path submodules 69 by calculating effective operation, swapping operands and aligning exponents. 70 | 2. Receiving data from one of the path submodules and preparing it for error checking 71 module by checking for various conditions 72 | For more info about close path and far path check 73 :class:`close path module <coreblocks.func_blocks.fu.fpu.close_path.ClosePathModule>` 74 and 75 :class:`far path module <coreblocks.func_blocks.fu.fpu.far_path.FarPathModule>` 76 77 Parameters 78 ---------- 79 fpu_params: FPUParams 80 FPU rounding module parameters 81 82 Attributes 83 ---------- 84 add_sub_request: Method 85 Transactional method for initiating addition or subtraction. 86 Takes 87 :meth:`add_sub_in_layout <coreblocks.func_blocks.fu.fpu.fpu_add_sub.FPUAddSubMethodLayout.add_sub_in_layout>` 88 as argument. 89 Returns result as 90 :meth:`add_sub_out_layout <coreblocks.func_blocks.fu.fpu.fpu_add_sub.FPUAddSubMethodLayout.add_sub_out_layout>`. 91 """ 92 93 def __init__(self, *, fpu_params: FPUParams): 94 self.fpu_params = fpu_params 95 self.method_layouts = FPUAddSubMethodLayout(fpu_params=self.fpu_params) 96 self.common_values = FPUCommonValues(self.fpu_params) 97 self.add_sub_request = Method( 98 i=self.method_layouts.add_sub_in_layout, 99 o=self.method_layouts.add_sub_out_layout, 100 ) 101 102 def elaborate(self, platform): 103 m = TModule() 104 105 def assign_values(lhs, exp, sig, sign): 106 m.d.av_comb += assign(lhs, {"sign": sign, "exp": exp, "sig": sig}) 107 108 m.submodules.close_path_module = close_path_module = ClosePathModule(fpu_params=self.fpu_params) 109 m.submodules.far_path_module = far_path_module = FarPathModule(fpu_params=self.fpu_params) 110 m.submodules.exception_module = exception_module = FPUErrorModule(fpu_params=self.fpu_params) 111 112 max_exp = (2 ** (self.fpu_params.exp_width)) - 1 113 114 final_sign = Signal(1) 115 exp_diff = Signal(range(-max_exp, max_exp + 1)) 116 norm_shift_amount = Signal(range(max_exp)) 117 sticky_bit = Signal(1) 118 true_operation = Signal(1) 119 exception_round_bit = Signal(1) 120 exception_sticky_bit = Signal(1) 121 invalid_operation = Signal(1) 122 123 path_response = Signal(from_method_layout(far_path_module.method_layouts.far_path_out_layout)) 124 exception_response = Signal(from_method_layout(exception_module.method_layouts.error_out_layout)) 125 126 @def_method(m, self.add_sub_request) 127 def _(op_1, op_2, rounding_mode, operation): 128 op_2_adjusted_sign = Signal() 129 m.d.av_comb += op_2_adjusted_sign.eq(operation ^ op_2.sign) 130 131 m.d.av_comb += exp_diff.eq(op_1.exp - op_2.exp) 132 133 # Swapping operands to ensure that abs(pre_shift_op1) >= abs(pre_shift_op2) 134 pre_shift_op1 = Signal(from_method_layout(self.method_layouts.ext_float_layout)) 135 pre_shift_op2 = Signal(from_method_layout(self.method_layouts.ext_float_layout)) 136 op_1_abs_float = Cat(op_1.sig, op_1.exp) 137 op_2_abs_float = Cat(op_2.sig, op_2.exp) 138 139 with m.If(op_1_abs_float > op_2_abs_float): 140 assign_values(pre_shift_op1, op_1.exp, op_1.sig << 2, op_1.sign) 141 assign_values(pre_shift_op2, op_2.exp, op_2.sig << 2, op_2_adjusted_sign) 142 with m.Else(): 143 assign_values(pre_shift_op1, op_2.exp, op_2.sig << 2, op_2_adjusted_sign) 144 assign_values(pre_shift_op2, op_1.exp, op_1.sig << 2, op_1.sign) 145 146 # Calculating true operation based on signs of operands 147 sign_xor = op_1.sign ^ op_2_adjusted_sign 148 149 m.d.av_comb += final_sign.eq(pre_shift_op1.sign) 150 151 with m.If(~sign_xor): 152 m.d.av_comb += true_operation.eq(0) 153 with m.Else(): 154 m.d.av_comb += true_operation.eq(1) 155 156 is_one_subnormal = (pre_shift_op1.exp > 0) & (pre_shift_op2.exp == 0) 157 m.d.av_comb += norm_shift_amount.eq(pre_shift_op1.exp - pre_shift_op2.exp - is_one_subnormal) 158 159 # Aligning exponents and calculating GRS bits 160 path_op1 = Signal(from_method_layout(self.method_layouts.raw_float_layout)) 161 far_path_op2_ext = Signal(from_method_layout(self.method_layouts.ext_float_layout)) 162 far_path_op2 = Signal(from_method_layout(self.method_layouts.raw_float_layout)) 163 close_path_op2 = Signal(from_method_layout(self.method_layouts.raw_float_layout)) 164 165 m.d.av_comb += path_op1.sig.eq(pre_shift_op1.sig) 166 with m.If(norm_shift_amount > (self.fpu_params.sig_width + 2)): 167 m.d.av_comb += sticky_bit.eq(pre_shift_op2.sig.any()) 168 m.d.av_comb += far_path_op2_ext.sig.eq(0) 169 with m.Else(): 170 sticky_bit_mask = Cat(Signal().replicate(self.fpu_params.sig_width), pre_shift_op2.sig).bit_select( 171 norm_shift_amount, self.fpu_params.sig_width 172 ) 173 m.d.av_comb += sticky_bit.eq(sticky_bit_mask.any()) 174 m.d.av_comb += far_path_op2_ext.sig.eq(pre_shift_op2.sig >> norm_shift_amount) 175 176 close_path_guard_bit = Signal() 177 178 with m.If(norm_shift_amount[0] == 0): 179 m.d.av_comb += close_path_op2.sig.eq(~(pre_shift_op2.sig >> 2)) 180 m.d.av_comb += close_path_guard_bit.eq(0) 181 with m.Else(): 182 m.d.av_comb += close_path_op2.sig.eq(~(pre_shift_op2.sig >> 3)) 183 m.d.av_comb += close_path_guard_bit.eq(pre_shift_op2.sig[2]) 184 185 guard_bit = far_path_op2_ext.sig[1] 186 round_bit = far_path_op2_ext.sig[0] 187 188 # Assigning operands for close path and far path 189 assign_values(path_op1, pre_shift_op1.exp, pre_shift_op1.sig >> 2, pre_shift_op1.sign) 190 assign_values( 191 far_path_op2, 192 (pre_shift_op2.exp + norm_shift_amount), 193 Mux(true_operation, ~(far_path_op2_ext.sig >> 2), far_path_op2_ext.sig >> 2), 194 pre_shift_op2.sign, 195 ) 196 197 close_path = (norm_shift_amount <= 1) & true_operation 198 resp = Mux( 199 close_path, 200 close_path_module.close_path_request( 201 m, 202 r_sign=final_sign, 203 sig_a=path_op1.sig, 204 sig_b=close_path_op2.sig, 205 exp=path_op1.exp, 206 rounding_mode=rounding_mode, 207 guard_bit=close_path_guard_bit, 208 ), 209 far_path_module.far_path_request( 210 m, 211 r_sign=final_sign, 212 sig_a=path_op1.sig, 213 sig_b=far_path_op2.sig, 214 exp=path_op1.exp, 215 sub_op=true_operation, 216 rounding_mode=rounding_mode, 217 guard_bit=guard_bit, 218 round_bit=round_bit, 219 sticky_bit=sticky_bit, 220 ), 221 ) 222 223 # Preparing data for error checking module 224 m.d.av_comb += path_response.eq(resp) 225 eq_signs = pre_shift_op2.sign == path_op1.sign 226 is_inf = op_1.is_inf | op_2.is_inf 227 wrong_inf = (op_1.is_inf & op_2.is_inf) & ~(eq_signs) 228 is_nan = (op_1.is_nan | op_2.is_nan) | wrong_inf 229 output_zero = (path_response["out_exp"] == 0) & (path_response["out_sig"] == 0) 230 output_exact = ~(path_response["output_round"] | path_response["output_sticky"]) 231 both_op_zero = op_1.is_zero & op_2.is_zero 232 is_zero = Signal() 233 m.d.av_comb += is_zero.eq(both_op_zero | (output_exact & output_zero)) 234 normal_case = ~(is_nan | is_inf | is_zero) 235 exception_op = Signal(from_method_layout(self.method_layouts.raw_float_layout)) 236 237 with m.If(~normal_case): 238 m.d.av_comb += exception_round_bit.eq(0) 239 m.d.av_comb += exception_sticky_bit.eq(0) 240 with m.If(is_nan): 241 is_any_snan = ((~op_1.sig[-2]) & op_1.is_nan) | ((~op_2.sig[-2]) & op_2.is_nan) 242 with m.If(is_any_snan | wrong_inf): 243 m.d.av_comb += invalid_operation.eq(1) 244 m.d.av_comb += exception_op.sign.eq(0) 245 m.d.av_comb += exception_op.exp.eq(max_exp) 246 m.d.av_comb += exception_op.sig.eq(self.common_values.canonical_nan_sig) 247 with m.Elif(is_inf & ~(wrong_inf)): 248 assign_values( 249 exception_op, 250 Mux(op_1.is_inf, op_1.exp, op_2.exp), 251 Mux(op_1.is_inf, op_1.sig, op_2.sig), 252 Mux(op_1.is_inf, op_1.sign, pre_shift_op2.sign), 253 ) 254 with m.Elif(is_zero): 255 with m.If(eq_signs): 256 m.d.av_comb += exception_op.sign.eq(op_1.sign) 257 with m.Else(): 258 m.d.av_comb += exception_op.sign.eq(rounding_mode == RoundingModes.ROUND_DOWN) 259 m.d.av_comb += exception_op.exp.eq(0) 260 m.d.av_comb += exception_op.sig.eq(0) 261 with m.Elif(normal_case): 262 m.d.av_comb += exception_op.sign.eq(final_sign) 263 m.d.av_comb += exception_op.exp.eq(path_response["out_exp"]) 264 m.d.av_comb += exception_round_bit.eq(path_response["output_round"]) 265 m.d.av_comb += exception_sticky_bit.eq(path_response["output_sticky"]) 266 with m.If(path_response["out_exp"] == max_exp): 267 m.d.av_comb += exception_op.sig.eq(2 ** (self.fpu_params.sig_width - 1)) 268 with m.Else(): 269 m.d.av_comb += exception_op.sig.eq(path_response["out_sig"]) 270 271 inexact = exception_sticky_bit | exception_round_bit 272 resp = exception_module.error_checking_request( 273 m, 274 sign=exception_op.sign, 275 sig=exception_op.sig, 276 exp=exception_op.exp, 277 rounding_mode=rounding_mode, 278 inexact=inexact, 279 invalid_operation=invalid_operation, 280 division_by_zero=0, 281 input_inf=is_inf, 282 ) 283 m.d.av_comb += exception_response.eq(resp) 284 285 return exception_response 286 287 return m