fpu_rounding_module.py
1 from amaranth import * 2 from transactron import TModule, Method, def_method 3 from coreblocks.func_blocks.fu.fpu.fpu_common import ( 4 RoundingModes, 5 FPUParams, 6 ) 7 8 9 class FPURoudningMethodLayout: 10 """FPU Rounding module layouts for methods 11 12 Parameters 13 ---------- 14 fpu_params: FPUParams 15 FPU parameters 16 """ 17 18 def __init__(self, *, fpu_params: FPUParams): 19 self.rounding_in_layout = [ 20 ("sign", 1), 21 ("sig", fpu_params.sig_width), 22 ("exp", fpu_params.exp_width), 23 ("round_bit", 1), 24 ("sticky_bit", 1), 25 ("rounding_mode", RoundingModes), 26 ] 27 self.rounding_out_layout = [ 28 ("sig", fpu_params.sig_width), 29 ("exp", fpu_params.exp_width), 30 ("inexact", 1), 31 ] 32 33 34 class FPURounding(Elaboratable): 35 """FPU Rounding module 36 37 Parameters 38 ---------- 39 fpu_params: FPUParams 40 FPU parameters 41 42 Attributes 43 ---------- 44 rounding_request: Method 45 Transactional method for initiating rounding of a floating point number. 46 Takes 'rounding_in_layout' as argument 47 Returns rounded number and errors as 'rounding_out_layout' 48 """ 49 50 def __init__(self, *, fpu_params: FPUParams): 51 52 self.fpu_rounding_params = fpu_params 53 self.method_layouts = FPURoudningMethodLayout(fpu_params=self.fpu_rounding_params) 54 self.rounding_request = Method( 55 i=self.method_layouts.rounding_in_layout, 56 o=self.method_layouts.rounding_out_layout, 57 ) 58 59 def elaborate(self, platform): 60 m = TModule() 61 62 add_one = Signal() 63 inc_rtnte = Signal() 64 inc_rtnta = Signal() 65 inc_rtpi = Signal() 66 inc_rtmi = Signal() 67 68 rounded_sig = Signal(self.fpu_rounding_params.sig_width + 1) 69 normalised_sig = Signal(self.fpu_rounding_params.sig_width) 70 rounded_exp = Signal(self.fpu_rounding_params.exp_width) 71 72 final_round_bit = Signal() 73 final_sticky_bit = Signal() 74 75 inexact = Signal() 76 77 @def_method(m, self.rounding_request) 78 def _(arg): 79 80 m.d.av_comb += inc_rtnte.eq( 81 (arg.rounding_mode == RoundingModes.ROUND_NEAREST_EVEN) 82 & (arg.round_bit & (arg.sticky_bit | arg.sig[0])) 83 ) 84 m.d.av_comb += inc_rtnta.eq((arg.rounding_mode == RoundingModes.ROUND_NEAREST_AWAY) & (arg.round_bit)) 85 m.d.av_comb += inc_rtpi.eq( 86 (arg.rounding_mode == RoundingModes.ROUND_UP) & (~arg.sign & (arg.round_bit | arg.sticky_bit)) 87 ) 88 m.d.av_comb += inc_rtmi.eq( 89 (arg.rounding_mode == RoundingModes.ROUND_DOWN) & (arg.sign & (arg.round_bit | arg.sticky_bit)) 90 ) 91 92 m.d.av_comb += add_one.eq(inc_rtmi | inc_rtnta | inc_rtnte | inc_rtpi) 93 94 m.d.av_comb += rounded_sig.eq(arg.sig + add_one) 95 96 with m.If(rounded_sig[-1]): 97 98 m.d.av_comb += normalised_sig.eq(rounded_sig >> 1) 99 m.d.av_comb += final_round_bit.eq(rounded_sig[0]) 100 m.d.av_comb += final_sticky_bit.eq(arg.round_bit | arg.sticky_bit) 101 m.d.av_comb += rounded_exp.eq(arg.exp + 1) 102 103 with m.Else(): 104 m.d.av_comb += normalised_sig.eq(rounded_sig) 105 m.d.av_comb += final_round_bit.eq(arg.round_bit) 106 m.d.av_comb += final_sticky_bit.eq(arg.sticky_bit) 107 m.d.av_comb += rounded_exp.eq(arg.exp) 108 109 m.d.av_comb += inexact.eq(final_round_bit | final_sticky_bit) 110 111 return { 112 "exp": rounded_exp, 113 "sig": normalised_sig, 114 "inexact": inexact, 115 } 116 117 return m