close_path.py
1 from amaranth import * 2 from transactron import TModule, Method, def_method, Transaction 3 from transactron.utils.transactron_helpers import from_method_layout 4 from coreblocks.func_blocks.fu.fpu.fpu_common import RoundingModes, FPUParams 5 from coreblocks.func_blocks.fu.fpu.lza import LZAModule 6 7 8 class ClosePathMethodLayout: 9 """Close path module layouts for methods 10 11 Parameters 12 ---------- 13 fpu_params; FPUParams 14 FPU parameters 15 """ 16 17 def __init__(self, *, fpu_params: FPUParams): 18 """ 19 r_sign - sign of the result 20 sig_a - two's complement form of first significand 21 sig_2 - two's complement form of second significand 22 exp - exponent of result before shifts 23 rounding_mode - rounding mode 24 guard_bit - guard_bit (pth bit of second significand where p is precision) 25 """ 26 27 self.close_path_in_layout = [ 28 ("r_sign", 1), 29 ("sig_a", fpu_params.sig_width), 30 ("sig_b", fpu_params.sig_width), 31 ("exp", fpu_params.exp_width), 32 ("rounding_mode", RoundingModes), 33 ("guard_bit", 1), 34 ] 35 self.close_path_out_layout = [ 36 ("out_exp", fpu_params.exp_width), 37 ("out_sig", fpu_params.sig_width), 38 ("output_round", 1), 39 ("output_sticky", 1), 40 ] 41 42 43 class ClosePathModule(Elaboratable): 44 """Close path module 45 Based on http://i.stanford.edu/pub/cstr/reports/csl/tr/90/442/CSL-TR-90-442.pdf. 46 This module computes results for efficient subtraction, 47 whenever difference of exponents is lesser than 2. 48 Besides computing the result, this implementation also performs rounding at the same time 49 as subtraction by using two adders (one computing a+b and the other one computing a+b+1). 50 The correct output is chosen based on flags that are different for each rounding mode. 51 52 Parameters 53 ---------- 54 fpu_params: FPUParams 55 FPU rounding module parameters 56 57 Attributes 58 ---------- 59 close_path_request: Method 60 Transactional method for initiating close path computation. 61 Takes 'close_path_in_layout' as argument. 62 Returns result as 'close_path_out_layout'. 63 """ 64 65 def __init__(self, *, fpu_params: FPUParams): 66 67 self.params = fpu_params 68 self.method_layouts = ClosePathMethodLayout(fpu_params=self.params) 69 self.close_path_request = Method( 70 i=self.method_layouts.close_path_in_layout, 71 o=self.method_layouts.close_path_out_layout, 72 ) 73 74 def elaborate(self, platform): 75 m = TModule() 76 77 result_add_zero = Signal(self.params.sig_width) 78 result_add_one = Signal(self.params.sig_width) 79 final_result = Signal(self.params.sig_width) 80 shift_correction = Signal() 81 shift_amount = Signal(range(self.params.sig_width)) 82 bit_shift_amount = Signal(range(self.params.sig_width)) 83 check_shift_amount = Signal(range(self.params.sig_width)) 84 final_sig = Signal(self.params.sig_width) 85 final_exp = Signal(self.params.exp_width) 86 final_round = Signal() 87 88 shift_in_bit = Signal() 89 l_flag = Signal() 90 is_zero = Signal() 91 92 m.submodules.zero_lza = zero_lza = LZAModule(fpu_params=self.params) 93 m.submodules.one_lza = one_lza = LZAModule(fpu_params=self.params) 94 lza_resp = Signal(from_method_layout(zero_lza.method_layouts.predict_out_layout)) 95 96 @def_method(m, self.close_path_request) 97 def _( 98 r_sign, 99 sig_a, 100 sig_b, 101 exp, 102 rounding_mode, 103 guard_bit, 104 ): 105 m.d.av_comb += result_add_zero.eq(sig_a + sig_b) 106 m.d.av_comb += result_add_one.eq(sig_a + sig_b + 1) 107 m.d.av_comb += shift_in_bit.eq(guard_bit) 108 109 with m.Switch(rounding_mode): 110 with m.Case(RoundingModes.ROUND_UP): 111 m.d.av_comb += l_flag.eq((~(r_sign) & result_add_zero[-1] & guard_bit) | ~(guard_bit)) 112 with m.Case(RoundingModes.ROUND_DOWN): 113 m.d.av_comb += l_flag.eq((r_sign & result_add_zero[-1] & guard_bit) | ~(guard_bit)) 114 with m.Case(RoundingModes.ROUND_ZERO): 115 m.d.av_comb += l_flag.eq(~(guard_bit)) 116 with m.Case(RoundingModes.ROUND_NEAREST_EVEN): 117 m.d.av_comb += l_flag.eq((result_add_zero[-1] & guard_bit & result_add_zero[0]) | ~(guard_bit)) 118 with m.Case(RoundingModes.ROUND_NEAREST_AWAY): 119 m.d.av_comb += l_flag.eq((result_add_zero[-1] & guard_bit) | ~(guard_bit)) 120 121 with Transaction().body(m): 122 m.d.av_comb += final_result.eq(Mux(l_flag, result_add_one, result_add_zero)) 123 resp = Mux( 124 l_flag, 125 one_lza.predict_request(m, sig_a=sig_a, sig_b=sig_b, carry=1), 126 zero_lza.predict_request(m, sig_a=sig_a, sig_b=sig_b, carry=0), 127 ) 128 m.d.av_comb += lza_resp.eq(resp) 129 m.d.av_comb += is_zero.eq(lza_resp["is_zero"]) 130 131 with m.If(is_zero | (exp == 0)): 132 m.d.av_comb += final_sig.eq(final_result) 133 m.d.av_comb += final_exp.eq(0) 134 m.d.av_comb += final_round.eq(guard_bit) 135 with m.Elif(exp <= lza_resp["shift_amount"]): 136 with m.If(exp == 1): 137 m.d.av_comb += final_sig.eq(final_result) 138 m.d.av_comb += final_round.eq(guard_bit) 139 with m.Else(): 140 m.d.av_comb += shift_amount.eq(exp - 1) 141 m.d.av_comb += bit_shift_amount.eq(exp - 2) 142 m.d.av_comb += final_sig.eq((final_result << shift_amount) | (shift_in_bit << bit_shift_amount)) 143 m.d.av_comb += final_round.eq(0) 144 m.d.av_comb += final_exp.eq(0) 145 with m.Else(): 146 shifted_sig = Signal(self.params.sig_width) 147 shifted_exp = Signal(self.params.exp_width) 148 149 m.d.av_comb += shifted_sig.eq(final_result << lza_resp["shift_amount"]) 150 m.d.av_comb += shifted_exp.eq(exp - lza_resp["shift_amount"]) 151 m.d.av_comb += check_shift_amount.eq(lza_resp["shift_amount"] - 1) 152 m.d.av_comb += shift_correction.eq( 153 (shifted_sig | (guard_bit << check_shift_amount))[self.params.sig_width - 1] 154 ) 155 156 with m.If(shift_correction): 157 with m.If(lza_resp["shift_amount"] == 0): 158 m.d.av_comb += final_sig.eq(shifted_sig) 159 m.d.av_comb += final_round.eq(guard_bit) 160 with m.Else(): 161 m.d.av_comb += bit_shift_amount.eq(lza_resp["shift_amount"] - 1) 162 m.d.av_comb += final_sig.eq(shifted_sig | (shift_in_bit << bit_shift_amount)) 163 m.d.av_comb += final_round.eq(0) 164 m.d.av_comb += final_exp.eq(shifted_exp) 165 with m.Else(): 166 m.d.av_comb += final_round.eq(0) 167 with m.If(shifted_exp == 1): 168 with m.If(lza_resp["shift_amount"] > 0): 169 m.d.av_comb += bit_shift_amount.eq(lza_resp["shift_amount"] - 1) 170 m.d.av_comb += final_sig.eq(shifted_sig | (shift_in_bit << bit_shift_amount)) 171 with m.Else(): 172 m.d.av_comb += final_sig.eq(shifted_sig) 173 m.d.av_comb += final_exp.eq(0) 174 with m.Else(): 175 m.d.av_comb += final_sig.eq((shifted_sig << 1) | (shift_in_bit << (lza_resp["shift_amount"]))) 176 m.d.av_comb += final_exp.eq(shifted_exp - 1) 177 178 return {"out_exp": final_exp, "out_sig": final_sig, "output_round": final_round, "output_sticky": 0} 179 180 return m