/ coreblocks / func_blocks / fu / fpu / fpu_rounding_module.py
fpu_rounding_module.py
  1  from amaranth import *
  2  from transactron import TModule, Method, def_method
  3  from coreblocks.func_blocks.fu.fpu.fpu_common import (
  4      RoundingModes,
  5      FPUParams,
  6  )
  7  
  8  
  9  class FPURoudningMethodLayout:
 10      """FPU Rounding module layouts for methods
 11  
 12      Parameters
 13      ----------
 14      fpu_params: FPUParams
 15          FPU parameters
 16      """
 17  
 18      def __init__(self, *, fpu_params: FPUParams):
 19          self.rounding_in_layout = [
 20              ("sign", 1),
 21              ("sig", fpu_params.sig_width),
 22              ("exp", fpu_params.exp_width),
 23              ("round_bit", 1),
 24              ("sticky_bit", 1),
 25              ("rounding_mode", RoundingModes),
 26          ]
 27          self.rounding_out_layout = [
 28              ("sig", fpu_params.sig_width),
 29              ("exp", fpu_params.exp_width),
 30              ("inexact", 1),
 31          ]
 32  
 33  
 34  class FPURounding(Elaboratable):
 35      """FPU Rounding module
 36  
 37      Parameters
 38      ----------
 39      fpu_params: FPUParams
 40          FPU parameters
 41  
 42      Attributes
 43      ----------
 44      rounding_request: Method
 45          Transactional method for initiating rounding of a floating point number.
 46          Takes 'rounding_in_layout' as argument
 47          Returns rounded number and errors as 'rounding_out_layout'
 48      """
 49  
 50      def __init__(self, *, fpu_params: FPUParams):
 51  
 52          self.fpu_rounding_params = fpu_params
 53          self.method_layouts = FPURoudningMethodLayout(fpu_params=self.fpu_rounding_params)
 54          self.rounding_request = Method(
 55              i=self.method_layouts.rounding_in_layout,
 56              o=self.method_layouts.rounding_out_layout,
 57          )
 58  
 59      def elaborate(self, platform):
 60          m = TModule()
 61  
 62          add_one = Signal()
 63          inc_rtnte = Signal()
 64          inc_rtnta = Signal()
 65          inc_rtpi = Signal()
 66          inc_rtmi = Signal()
 67  
 68          rounded_sig = Signal(self.fpu_rounding_params.sig_width + 1)
 69          normalised_sig = Signal(self.fpu_rounding_params.sig_width)
 70          rounded_exp = Signal(self.fpu_rounding_params.exp_width)
 71  
 72          final_round_bit = Signal()
 73          final_sticky_bit = Signal()
 74  
 75          inexact = Signal()
 76  
 77          @def_method(m, self.rounding_request)
 78          def _(arg):
 79  
 80              m.d.av_comb += inc_rtnte.eq(
 81                  (arg.rounding_mode == RoundingModes.ROUND_NEAREST_EVEN)
 82                  & (arg.round_bit & (arg.sticky_bit | arg.sig[0]))
 83              )
 84              m.d.av_comb += inc_rtnta.eq((arg.rounding_mode == RoundingModes.ROUND_NEAREST_AWAY) & (arg.round_bit))
 85              m.d.av_comb += inc_rtpi.eq(
 86                  (arg.rounding_mode == RoundingModes.ROUND_UP) & (~arg.sign & (arg.round_bit | arg.sticky_bit))
 87              )
 88              m.d.av_comb += inc_rtmi.eq(
 89                  (arg.rounding_mode == RoundingModes.ROUND_DOWN) & (arg.sign & (arg.round_bit | arg.sticky_bit))
 90              )
 91  
 92              m.d.av_comb += add_one.eq(inc_rtmi | inc_rtnta | inc_rtnte | inc_rtpi)
 93  
 94              m.d.av_comb += rounded_sig.eq(arg.sig + add_one)
 95  
 96              with m.If(rounded_sig[-1]):
 97  
 98                  m.d.av_comb += normalised_sig.eq(rounded_sig >> 1)
 99                  m.d.av_comb += final_round_bit.eq(rounded_sig[0])
100                  m.d.av_comb += final_sticky_bit.eq(arg.round_bit | arg.sticky_bit)
101                  m.d.av_comb += rounded_exp.eq(arg.exp + 1)
102  
103              with m.Else():
104                  m.d.av_comb += normalised_sig.eq(rounded_sig)
105                  m.d.av_comb += final_round_bit.eq(arg.round_bit)
106                  m.d.av_comb += final_sticky_bit.eq(arg.sticky_bit)
107                  m.d.av_comb += rounded_exp.eq(arg.exp)
108  
109              m.d.av_comb += inexact.eq(final_round_bit | final_sticky_bit)
110  
111              return {
112                  "exp": rounded_exp,
113                  "sig": normalised_sig,
114                  "inexact": inexact,
115              }
116  
117          return m