fpu_error_module.py
1 from amaranth import * 2 from transactron import TModule, Method, def_method 3 from coreblocks.func_blocks.fu.fpu.fpu_common import ( 4 RoundingModes, 5 FPUParams, 6 Errors, 7 ) 8 9 10 class FPUErrorMethodLayout: 11 """FPU error checking module layouts for methods 12 13 Parameters 14 ---------- 15 fpu_params: FPUParams 16 FPU parameters 17 """ 18 19 def __init__(self, *, fpu_params: FPUParams): 20 """ 21 input_inf is a flag that comes from previous stage. 22 Its purpose is to indicate that the infinity on input 23 is a result of infinity arithmetic and not a result of overflow 24 """ 25 self.error_in_layout = [ 26 ("sign", 1), 27 ("sig", fpu_params.sig_width), 28 ("exp", fpu_params.exp_width), 29 ("rounding_mode", RoundingModes), 30 ("inexact", 1), 31 ("invalid_operation", 1), 32 ("division_by_zero", 1), 33 ("input_inf", 1), 34 ] 35 self.error_out_layout = [ 36 ("sign", 1), 37 ("sig", fpu_params.sig_width), 38 ("exp", fpu_params.exp_width), 39 ("errors", Errors), 40 ] 41 42 43 class FPUErrorModule(Elaboratable): 44 """FPU error checking module 45 46 Parameters 47 ---------- 48 fpu_params: FPUParams 49 FPU rounding module parameters 50 51 Attributes 52 ---------- 53 error_checking_request: Method 54 Transactional method for initiating error checking of a floating point number. 55 Takes 'error_in_layout' as argument 56 Returns final number and errors as 'error_out_layout' 57 """ 58 59 def __init__(self, *, fpu_params: FPUParams): 60 61 self.fpu_errors_params = fpu_params 62 self.method_layouts = FPUErrorMethodLayout(fpu_params=self.fpu_errors_params) 63 self.error_checking_request = Method( 64 i=self.method_layouts.error_in_layout, 65 o=self.method_layouts.error_out_layout, 66 ) 67 68 def elaborate(self, platform): 69 m = TModule() 70 71 max_exp = C( 72 2 ** (self.fpu_errors_params.exp_width) - 1, 73 unsigned(self.fpu_errors_params.exp_width), 74 ) 75 max_normal_exp = C( 76 2 ** (self.fpu_errors_params.exp_width) - 2, 77 unsigned(self.fpu_errors_params.exp_width), 78 ) 79 max_sig = C( 80 2 ** (self.fpu_errors_params.sig_width) - 1, 81 unsigned(self.fpu_errors_params.sig_width), 82 ) 83 84 implicit_bit = 2 ** (self.fpu_errors_params.sig_width - 1) 85 86 overflow = Signal() 87 underflow = Signal() 88 inexact = Signal() 89 tininess = Signal() 90 91 final_exp = Signal(self.fpu_errors_params.exp_width) 92 final_sig = Signal(self.fpu_errors_params.sig_width) 93 final_sign = Signal() 94 final_errors = Signal(5) 95 96 @def_method(m, self.error_checking_request) 97 def _(arg): 98 is_nan = arg.invalid_operation | ((arg.exp == max_exp) & (((arg.sig & ~(implicit_bit)).any()))) 99 is_inf = arg.division_by_zero | arg.input_inf 100 input_not_special = ~(is_nan) & ~(is_inf) 101 m.d.av_comb += overflow.eq(input_not_special & (arg.exp == max_exp)) 102 m.d.av_comb += tininess.eq((arg.exp == 0) & (~arg.sig[-1])) 103 m.d.av_comb += inexact.eq(overflow | (input_not_special & arg.inexact)) 104 m.d.av_comb += underflow.eq(tininess & inexact) 105 106 with m.If(is_nan | is_inf): 107 108 m.d.av_comb += final_exp.eq(arg.exp) 109 m.d.av_comb += final_sig.eq(arg.sig) 110 m.d.av_comb += final_sign.eq(arg.sign) 111 112 with m.Elif(overflow): 113 114 with m.Switch(arg.rounding_mode): 115 with m.Case(RoundingModes.ROUND_NEAREST_AWAY, RoundingModes.ROUND_NEAREST_EVEN): 116 117 m.d.av_comb += final_exp.eq(max_exp) 118 m.d.av_comb += final_sig.eq(implicit_bit) 119 m.d.av_comb += final_sign.eq(arg.sign) 120 121 with m.Case(RoundingModes.ROUND_ZERO): 122 123 m.d.av_comb += final_exp.eq(max_normal_exp) 124 m.d.av_comb += final_sig.eq(max_sig) 125 m.d.av_comb += final_sign.eq(arg.sign) 126 127 with m.Case(RoundingModes.ROUND_DOWN): 128 129 with m.If(arg.sign): 130 131 m.d.av_comb += final_exp.eq(max_exp) 132 m.d.av_comb += final_sig.eq(implicit_bit) 133 m.d.av_comb += final_sign.eq(arg.sign) 134 135 with m.Else(): 136 137 m.d.av_comb += final_exp.eq(max_normal_exp) 138 m.d.av_comb += final_sig.eq(max_sig) 139 m.d.av_comb += final_sign.eq(arg.sign) 140 141 with m.Case(RoundingModes.ROUND_UP): 142 143 with m.If(arg.sign): 144 145 m.d.av_comb += final_exp.eq(max_normal_exp) 146 m.d.av_comb += final_sig.eq(max_sig) 147 m.d.av_comb += final_sign.eq(arg.sign) 148 149 with m.Else(): 150 151 m.d.av_comb += final_exp.eq(max_exp) 152 m.d.av_comb += final_sig.eq(implicit_bit) 153 m.d.av_comb += final_sign.eq(arg.sign) 154 155 with m.Else(): 156 with m.If((arg.exp == 0) & (arg.sig[-1] == 1)): 157 m.d.av_comb += final_exp.eq(1) 158 with m.Else(): 159 m.d.av_comb += final_exp.eq(arg.exp) 160 m.d.av_comb += final_sig.eq(arg.sig) 161 m.d.av_comb += final_sign.eq(arg.sign) 162 163 m.d.av_comb += final_errors.eq( 164 Mux(arg.invalid_operation, Errors.INVALID_OPERATION, 0) 165 | Mux(arg.division_by_zero, Errors.DIVISION_BY_ZERO, 0) 166 | Mux(overflow, Errors.OVERFLOW, 0) 167 | Mux(underflow, Errors.UNDERFLOW, 0) 168 | Mux(inexact, Errors.INEXACT, 0) 169 ) 170 171 return { 172 "exp": final_exp, 173 "sig": final_sig, 174 "sign": final_sign, 175 "errors": final_errors, 176 } 177 178 return m