far_path.py
1 from amaranth import * 2 from transactron import TModule, Method, def_method 3 from coreblocks.func_blocks.fu.fpu.fpu_common import RoundingModes, FPUParams 4 5 6 class FarPathMethodLayout: 7 """Far path module layouts for methods 8 9 Parameters 10 ---------- 11 fpu_params; FPUParams 12 FPU parameters 13 """ 14 15 def __init__(self, *, fpu_params: FPUParams): 16 """ 17 r_sign - result sign 18 sig_a - significand of first operand (for effective subtraction in two's complement form) 19 sig_b - significand of second operand (for effective subtraction in two's complement form) 20 exp - exponent of result before shift 21 sub_op - effective operation. 1 for subtraction 0 for addition 22 rounding_mode - rounding mode 23 guard_bit - guard bit (pth bit of second significand where p is precision) 24 round_bit - round bit ((p+1)th bit of second significand where p is precision) 25 sticky_bit - sticky_bit 26 (OR of all bits with index >=p of second significand where p is precision) 27 """ 28 self.far_path_in_layout = [ 29 ("r_sign", 1), 30 ("sig_a", fpu_params.sig_width), 31 ("sig_b", fpu_params.sig_width), 32 ("exp", fpu_params.exp_width), 33 ("sub_op", 1), 34 ("rounding_mode", RoundingModes), 35 ("guard_bit", 1), 36 ("round_bit", 1), 37 ("sticky_bit", 1), 38 ] 39 self.far_path_out_layout = [ 40 ("out_exp", fpu_params.exp_width), 41 ("out_sig", fpu_params.sig_width), 42 ("output_round", 1), 43 ("output_sticky", 1), 44 ] 45 46 47 class FarPathModule(Elaboratable): 48 """Far Path module 49 Based on: https://userpages.cs.umbc.edu/phatak/645/supl/lza/lza-survey-arith01.pdf. 50 This module implements far path of adder/subtractor. 51 It performs subtraction for operands whose exponent differs by more than 1 and addition 52 for all combinations of operands. Besides addition it also performs rounding at the same time 53 as addition using two adders (one producing a+b and second one producing a+b+1). The correct 54 output is chosen by flags that differ for each rounding mode. To deal with certain 55 complication that may arise during addition in certain rounding modes the input of second 56 may be either input operand or (a & b)<<1 and (a^b). This allows second adder to compute 57 a+b+2 in special cases that are better explained in paper linked above. 58 59 Parameters 60 ---------- 61 fpu_params: FPUParams 62 FPU rounding module parameters 63 64 Attributes 65 ---------- 66 far_path_request: Method 67 Transactional method for initiating far path computation. 68 Takes 'far_path_in_layout' as argument. 69 Returns result as 'far_path_out_layout'. 70 """ 71 72 def __init__(self, *, fpu_params: FPUParams): 73 74 self.params = fpu_params 75 self.method_layouts = FarPathMethodLayout(fpu_params=self.params) 76 self.far_path_request = Method( 77 i=self.method_layouts.far_path_in_layout, 78 o=self.method_layouts.far_path_out_layout, 79 ) 80 81 def elaborate(self, platform): 82 m = TModule() 83 84 input_sig_add_0_a = Signal(self.params.sig_width) 85 input_sig_add_0_b = Signal(self.params.sig_width) 86 input_sig_add_1_a = Signal(self.params.sig_width) 87 input_sig_add_1_b = Signal(self.params.sig_width) 88 output_sig_add_0 = Signal(self.params.sig_width + 1) 89 output_sig_add_1 = Signal(self.params.sig_width + 1) 90 output_sig = Signal(self.params.sig_width + 1) 91 output_exp = Signal(self.params.exp_width + 1) 92 output_final_exp = Signal(self.params.exp_width) 93 output_final_sig = Signal(self.params.sig_width) 94 95 output_round_bit = Signal() 96 output_sticky_bit = Signal() 97 final_guard_bit = Signal() 98 final_round_bit = Signal() 99 final_sticky_bit = Signal() 100 101 round_up_inc_1 = Signal() 102 round_down_inc_1 = Signal() 103 round_to_inf_special_case = Signal() 104 xor_sig = Signal(self.params.sig_width) 105 carry_sig = Signal(self.params.sig_width) 106 carry_add1 = Signal() 107 rgs_any = Signal() 108 rgs_all = Signal() 109 110 # No right shift 111 nrs = Signal() 112 # One right shift 113 ors = Signal() 114 # No left shift 115 nls = Signal() 116 # One left shift 117 ols = Signal() 118 nxs = Signal() 119 120 nxs_rtne = Signal() 121 nxs_rtna = Signal() 122 nxs_zero = Signal() 123 nxs_up = Signal() 124 nxs_down = Signal() 125 126 ors_rtne = Signal() 127 ors_rtna = Signal() 128 ors_zero = Signal() 129 ors_up = Signal() 130 ors_down = Signal() 131 132 ols_rtne = Signal() 133 ols_rtna = Signal() 134 ols_zero = Signal() 135 ols_up = Signal() 136 ols_down = Signal() 137 138 shift_in_bit_rtne = Signal() 139 shift_in_bit_rtna = Signal() 140 shift_in_bit_zero = Signal() 141 shift_in_bit_up = Signal() 142 shift_in_bit_down = Signal() 143 shift_in_bit = Signal() 144 145 g = Signal() 146 147 @def_method(m, self.far_path_request) 148 def _( 149 r_sign, 150 sig_a, 151 sig_b, 152 exp, 153 sub_op, 154 rounding_mode, 155 guard_bit, 156 round_bit, 157 sticky_bit, 158 ): 159 m.d.av_comb += input_sig_add_0_a.eq(sig_a) 160 m.d.av_comb += input_sig_add_0_b.eq(sig_b) 161 m.d.av_comb += xor_sig.eq(sig_a ^ sig_b) 162 m.d.av_comb += carry_sig.eq(sig_a & sig_b) 163 m.d.av_comb += carry_add1.eq(carry_sig[-1]) 164 m.d.av_comb += rgs_any.eq(guard_bit | round_bit | sticky_bit) 165 m.d.av_comb += rgs_all.eq(guard_bit & round_bit & sticky_bit) 166 m.d.av_comb += round_to_inf_special_case.eq( 167 (~sub_op) & ((rounding_mode == RoundingModes.ROUND_DOWN) | (rounding_mode == RoundingModes.ROUND_UP)) 168 ) 169 170 with m.If(round_to_inf_special_case): 171 m.d.av_comb += input_sig_add_1_a.eq((carry_sig << 1) | (~xor_sig[0])) 172 m.d.av_comb += input_sig_add_1_b.eq(xor_sig) 173 with m.Else(): 174 m.d.av_comb += input_sig_add_1_a.eq(sig_a) 175 m.d.av_comb += input_sig_add_1_b.eq(sig_b) 176 m.d.av_comb += carry_add1.eq(0) 177 178 m.d.av_comb += output_sig_add_0.eq(input_sig_add_0_a + input_sig_add_0_b) 179 m.d.av_comb += output_sig_add_1.eq( 180 (input_sig_add_1_a + input_sig_add_1_b + 1) | (carry_add1 << (self.params.sig_width - 1)) 181 ) 182 183 m.d.av_comb += nrs.eq((~sub_op) & (~output_sig_add_0[-1])) 184 m.d.av_comb += ors.eq((~sub_op) & (output_sig_add_0[-1])) 185 m.d.av_comb += nls.eq(sub_op & (((~rgs_any) & output_sig_add_1[-2]) | (rgs_any & output_sig_add_0[-2]))) 186 m.d.av_comb += ols.eq( 187 sub_op & (((~rgs_any) & (~output_sig_add_1[-2])) | (rgs_any & (~output_sig_add_0[-2]))) 188 ) 189 m.d.av_comb += nxs.eq(nls | nrs) 190 191 subtraction = sub_op & ((~r_sign) | (~rgs_any)) 192 addition = (~sub_op) & ((sig_a[0] ^ sig_b[0]) & ((~r_sign) & (rgs_any))) 193 m.d.av_comb += nxs_up.eq(subtraction | addition) 194 195 subtraction = sub_op & (r_sign | (~rgs_any)) 196 addition = (~sub_op) & ((sig_a[0] ^ sig_b[0]) & (r_sign & (rgs_any))) 197 m.d.av_comb += nxs_down.eq(subtraction | addition) 198 199 m.d.av_comb += nxs_zero.eq(sub_op & (~rgs_any)) 200 201 subtraction = sub_op & ((~guard_bit) | (guard_bit & (~round_bit) & (~sticky_bit) & (sig_a[0] ^ sig_b[0]))) 202 addition = (~sub_op) & guard_bit & (round_bit | sticky_bit | (sig_a[0] ^ sig_b[0])) 203 m.d.av_comb += nxs_rtne.eq(subtraction | addition) 204 205 subtraction = sub_op & (((~guard_bit) ^ ((~round_bit) & (~sticky_bit))) | (~rgs_any)) 206 addition = (~sub_op) & guard_bit 207 m.d.av_comb += nxs_rtna.eq(subtraction | addition) 208 209 m.d.av_comb += ors_up.eq((~r_sign) & ((sig_a[0] ^ sig_b[0]) | rgs_any)) 210 m.d.av_comb += ors_down.eq(r_sign & ((sig_a[0] ^ sig_b[0]) | rgs_any)) 211 m.d.av_comb += ors_zero.eq(sub_op & (~rgs_any)) 212 m.d.av_comb += ors_rtne.eq((sig_a[0] ^ sig_b[0]) & (rgs_any | (sig_a[1] ^ sig_b[1]))) 213 m.d.av_comb += ors_rtna.eq(sig_a[0] ^ sig_b[0]) 214 215 m.d.av_comb += ols_up.eq(((~r_sign) & (~guard_bit)) | (r_sign & (~rgs_any))) 216 m.d.av_comb += ols_down.eq((r_sign & (~guard_bit)) | ((~r_sign) & (~rgs_any))) 217 m.d.av_comb += ols_zero.eq(sub_op & (~rgs_any)) 218 m.d.av_comb += ols_rtne.eq((~guard_bit) & ((~round_bit) | (~sticky_bit))) 219 m.d.av_comb += ols_rtna.eq((~guard_bit) & ((~round_bit) | (~sticky_bit))) 220 m.d.av_comb += shift_in_bit_up.eq( 221 ((~r_sign) & guard_bit) 222 | (r_sign & ((guard_bit & (~round_bit) & (~sticky_bit)) | ((~guard_bit) & (round_bit | sticky_bit)))) 223 ) 224 m.d.av_comb += shift_in_bit_down.eq( 225 (r_sign & guard_bit) 226 | ((~r_sign) & ((guard_bit & (~round_bit) & (~sticky_bit)) | ((~guard_bit) & (round_bit | sticky_bit)))) 227 ) 228 m.d.av_comb += shift_in_bit_zero.eq( 229 ((~guard_bit) & (round_bit | sticky_bit)) | (guard_bit & (~round_bit) & (~sticky_bit)) 230 ) 231 m.d.av_comb += shift_in_bit_rtne.eq(((~guard_bit) & round_bit & sticky_bit) | (guard_bit & (~round_bit))) 232 m.d.av_comb += shift_in_bit_rtna.eq( 233 ((~guard_bit) & round_bit & sticky_bit) | (guard_bit & (~(round_bit & sticky_bit))) 234 ) 235 236 with m.Switch(rounding_mode): 237 with m.Case(RoundingModes.ROUND_UP): 238 m.d.av_comb += g.eq((ors & ors_up) | (nxs & nxs_up) | (ols & ols_up)) 239 m.d.av_comb += shift_in_bit.eq(shift_in_bit_up) 240 with m.Case(RoundingModes.ROUND_DOWN): 241 m.d.av_comb += g.eq((ors & ors_down) | (nxs & nxs_down) | (ols & ols_down)) 242 m.d.av_comb += shift_in_bit.eq(shift_in_bit_down) 243 with m.Case(RoundingModes.ROUND_ZERO): 244 m.d.av_comb += g.eq((ors & ors_zero) | (nxs & nxs_zero) | (ols & ols_zero)) 245 m.d.av_comb += shift_in_bit.eq(shift_in_bit_zero) 246 247 with m.Case(RoundingModes.ROUND_NEAREST_EVEN): 248 m.d.av_comb += g.eq((ors & ors_rtne) | (nxs & nxs_rtne) | (ols & ols_rtne)) 249 m.d.av_comb += shift_in_bit.eq(shift_in_bit_rtne) 250 251 with m.Case(RoundingModes.ROUND_NEAREST_AWAY): 252 m.d.av_comb += g.eq((ors & ors_rtna) | (nxs & nxs_rtna) | (ols & ols_rtna)) 253 m.d.av_comb += shift_in_bit.eq(shift_in_bit_rtna) 254 255 m.d.av_comb += round_up_inc_1.eq( 256 (rounding_mode == RoundingModes.ROUND_UP) 257 & nrs 258 & (~g) 259 & (~(sig_a[0] ^ sig_b[0])) 260 & ((~r_sign) & (rgs_any)) 261 ) 262 m.d.av_comb += round_down_inc_1.eq( 263 (rounding_mode == RoundingModes.ROUND_DOWN) 264 & nrs 265 & (~g) 266 & (~(sig_a[0] ^ sig_b[0])) 267 & (r_sign & (rgs_any)) 268 ) 269 with m.If(g): 270 m.d.av_comb += output_sig.eq(output_sig_add_1) 271 with m.Else(): 272 with m.If(round_down_inc_1 | round_up_inc_1): 273 m.d.av_comb += output_sig.eq(output_sig_add_0 | 1) 274 with m.Else(): 275 m.d.av_comb += output_sig.eq(output_sig_add_0) 276 m.d.av_comb += output_exp.eq(exp) 277 278 with m.If(sub_op): 279 m.d.av_comb += final_guard_bit.eq((~guard_bit) ^ ((~round_bit) & (~sticky_bit))) 280 m.d.av_comb += final_round_bit.eq((~round_bit) ^ (~sticky_bit)) 281 m.d.av_comb += final_sticky_bit.eq(sticky_bit) 282 283 with m.Else(): 284 m.d.av_comb += final_guard_bit.eq(guard_bit) 285 m.d.av_comb += final_round_bit.eq(round_bit) 286 m.d.av_comb += final_sticky_bit.eq(sticky_bit) 287 288 with m.If(ors): 289 m.d.av_comb += output_sticky_bit.eq(final_guard_bit | final_round_bit | final_sticky_bit) 290 m.d.av_comb += output_round_bit.eq(sig_a[0] ^ sig_b[0]) 291 with m.Elif(ols): 292 m.d.av_comb += output_sticky_bit.eq(final_sticky_bit) 293 m.d.av_comb += output_round_bit.eq(final_round_bit) 294 with m.Else(): 295 m.d.av_comb += output_sticky_bit.eq(final_round_bit | final_sticky_bit) 296 m.d.av_comb += output_round_bit.eq(final_guard_bit) 297 298 with m.If((~sub_op) & (output_sig[-1])): 299 m.d.av_comb += output_final_sig.eq(output_sig >> 1) 300 m.d.av_comb += output_final_exp.eq(output_exp + 1) 301 302 with m.Elif((sub_op & (~output_sig[-2])) & (output_exp > 0)): 303 with m.If(output_exp == 1): 304 m.d.av_comb += output_final_sig.eq(output_sig) 305 with m.Else(): 306 m.d.av_comb += output_final_sig.eq((output_sig << 1) | shift_in_bit) 307 m.d.av_comb += output_final_exp.eq(output_exp - 1) 308 309 with m.Else(): 310 m.d.av_comb += output_final_sig.eq(output_sig) 311 with m.If((output_exp == 0) & ((output_sig[-2]))): 312 m.d.av_comb += output_final_exp.eq(1) 313 with m.Else(): 314 m.d.av_comb += output_final_exp.eq(output_exp) 315 316 return { 317 "out_exp": output_final_exp, 318 "out_sig": output_final_sig, 319 "output_round": output_round_bit, 320 "output_sticky": output_sticky_bit, 321 } 322 323 return m