/ coreblocks / func_blocks / fu / fpu / fpu_error_module.py
fpu_error_module.py
  1  from amaranth import *
  2  from transactron import TModule, Method, def_method
  3  from coreblocks.func_blocks.fu.fpu.fpu_common import (
  4      RoundingModes,
  5      FPUParams,
  6      Errors,
  7  )
  8  
  9  
 10  class FPUErrorMethodLayout:
 11      """FPU error checking module layouts for methods
 12  
 13      Parameters
 14      ----------
 15      fpu_params: FPUParams
 16          FPU parameters
 17      """
 18  
 19      def __init__(self, *, fpu_params: FPUParams):
 20          """
 21          input_inf is a flag that comes from previous stage.
 22          Its purpose is to indicate that the infinity on input
 23          is a result of infinity arithmetic and not a result of overflow
 24          """
 25          self.error_in_layout = [
 26              ("sign", 1),
 27              ("sig", fpu_params.sig_width),
 28              ("exp", fpu_params.exp_width),
 29              ("rounding_mode", RoundingModes),
 30              ("inexact", 1),
 31              ("invalid_operation", 1),
 32              ("division_by_zero", 1),
 33              ("input_inf", 1),
 34          ]
 35          self.error_out_layout = [
 36              ("sign", 1),
 37              ("sig", fpu_params.sig_width),
 38              ("exp", fpu_params.exp_width),
 39              ("errors", Errors),
 40          ]
 41  
 42  
 43  class FPUErrorModule(Elaboratable):
 44      """FPU error checking module
 45  
 46      Parameters
 47      ----------
 48      fpu_params: FPUParams
 49          FPU rounding module parameters
 50  
 51      Attributes
 52      ----------
 53      error_checking_request: Method
 54          Transactional method for initiating error checking of a floating point number.
 55          Takes 'error_in_layout' as argument
 56          Returns final number and errors as 'error_out_layout'
 57      """
 58  
 59      def __init__(self, *, fpu_params: FPUParams):
 60  
 61          self.fpu_errors_params = fpu_params
 62          self.method_layouts = FPUErrorMethodLayout(fpu_params=self.fpu_errors_params)
 63          self.error_checking_request = Method(
 64              i=self.method_layouts.error_in_layout,
 65              o=self.method_layouts.error_out_layout,
 66          )
 67  
 68      def elaborate(self, platform):
 69          m = TModule()
 70  
 71          max_exp = C(
 72              2 ** (self.fpu_errors_params.exp_width) - 1,
 73              unsigned(self.fpu_errors_params.exp_width),
 74          )
 75          max_normal_exp = C(
 76              2 ** (self.fpu_errors_params.exp_width) - 2,
 77              unsigned(self.fpu_errors_params.exp_width),
 78          )
 79          max_sig = C(
 80              2 ** (self.fpu_errors_params.sig_width) - 1,
 81              unsigned(self.fpu_errors_params.sig_width),
 82          )
 83  
 84          implicit_bit = 2 ** (self.fpu_errors_params.sig_width - 1)
 85  
 86          overflow = Signal()
 87          underflow = Signal()
 88          inexact = Signal()
 89          tininess = Signal()
 90  
 91          final_exp = Signal(self.fpu_errors_params.exp_width)
 92          final_sig = Signal(self.fpu_errors_params.sig_width)
 93          final_sign = Signal()
 94          final_errors = Signal(5)
 95  
 96          @def_method(m, self.error_checking_request)
 97          def _(arg):
 98              is_nan = arg.invalid_operation | ((arg.exp == max_exp) & (((arg.sig & ~(implicit_bit)).any())))
 99              is_inf = arg.division_by_zero | arg.input_inf
100              input_not_special = ~(is_nan) & ~(is_inf)
101              m.d.av_comb += overflow.eq(input_not_special & (arg.exp == max_exp))
102              m.d.av_comb += tininess.eq((arg.exp == 0) & (~arg.sig[-1]))
103              m.d.av_comb += inexact.eq(overflow | (input_not_special & arg.inexact))
104              m.d.av_comb += underflow.eq(tininess & inexact)
105  
106              with m.If(is_nan | is_inf):
107  
108                  m.d.av_comb += final_exp.eq(arg.exp)
109                  m.d.av_comb += final_sig.eq(arg.sig)
110                  m.d.av_comb += final_sign.eq(arg.sign)
111  
112              with m.Elif(overflow):
113  
114                  with m.Switch(arg.rounding_mode):
115                      with m.Case(RoundingModes.ROUND_NEAREST_AWAY, RoundingModes.ROUND_NEAREST_EVEN):
116  
117                          m.d.av_comb += final_exp.eq(max_exp)
118                          m.d.av_comb += final_sig.eq(implicit_bit)
119                          m.d.av_comb += final_sign.eq(arg.sign)
120  
121                      with m.Case(RoundingModes.ROUND_ZERO):
122  
123                          m.d.av_comb += final_exp.eq(max_normal_exp)
124                          m.d.av_comb += final_sig.eq(max_sig)
125                          m.d.av_comb += final_sign.eq(arg.sign)
126  
127                      with m.Case(RoundingModes.ROUND_DOWN):
128  
129                          with m.If(arg.sign):
130  
131                              m.d.av_comb += final_exp.eq(max_exp)
132                              m.d.av_comb += final_sig.eq(implicit_bit)
133                              m.d.av_comb += final_sign.eq(arg.sign)
134  
135                          with m.Else():
136  
137                              m.d.av_comb += final_exp.eq(max_normal_exp)
138                              m.d.av_comb += final_sig.eq(max_sig)
139                              m.d.av_comb += final_sign.eq(arg.sign)
140  
141                      with m.Case(RoundingModes.ROUND_UP):
142  
143                          with m.If(arg.sign):
144  
145                              m.d.av_comb += final_exp.eq(max_normal_exp)
146                              m.d.av_comb += final_sig.eq(max_sig)
147                              m.d.av_comb += final_sign.eq(arg.sign)
148  
149                          with m.Else():
150  
151                              m.d.av_comb += final_exp.eq(max_exp)
152                              m.d.av_comb += final_sig.eq(implicit_bit)
153                              m.d.av_comb += final_sign.eq(arg.sign)
154  
155              with m.Else():
156                  with m.If((arg.exp == 0) & (arg.sig[-1] == 1)):
157                      m.d.av_comb += final_exp.eq(1)
158                  with m.Else():
159                      m.d.av_comb += final_exp.eq(arg.exp)
160                  m.d.av_comb += final_sig.eq(arg.sig)
161                  m.d.av_comb += final_sign.eq(arg.sign)
162  
163              m.d.av_comb += final_errors.eq(
164                  Mux(arg.invalid_operation, Errors.INVALID_OPERATION, 0)
165                  | Mux(arg.division_by_zero, Errors.DIVISION_BY_ZERO, 0)
166                  | Mux(overflow, Errors.OVERFLOW, 0)
167                  | Mux(underflow, Errors.UNDERFLOW, 0)
168                  | Mux(inexact, Errors.INEXACT, 0)
169              )
170  
171              return {
172                  "exp": final_exp,
173                  "sig": final_sig,
174                  "sign": final_sign,
175                  "errors": final_errors,
176              }
177  
178          return m