/ coreblocks / func_blocks / fu / fpu / fpu_mul.py
fpu_mul.py
  1  from amaranth import *
  2  from transactron import TModule, Method, def_method
  3  from transactron.utils.transactron_helpers import from_method_layout
  4  from coreblocks.func_blocks.fu.fpu.fpu_common import (
  5      RoundingModes,
  6      FPUParams,
  7      create_data_input_layout,
  8      create_output_layout,
  9      FPUCommonValues,
 10  )
 11  from coreblocks.func_blocks.fu.fpu.fpu_error_module import FPUErrorModule
 12  from coreblocks.func_blocks.fu.fpu.fpu_rounding_module import FPURounding
 13  from transactron.utils.amaranth_ext import count_leading_zeros
 14  from coreblocks.func_blocks.fu.unsigned_multiplication.fast_recursive import FastRecursiveMul
 15  
 16  
 17  class FPUMulMethodLayout:
 18      """FPU multiplication module method layout
 19  
 20      Parameters
 21      ----------
 22      fpu_params; FPUParams
 23          FPU parameters
 24      """
 25  
 26      def __init__(self, *, fpu_params: FPUParams):
 27          self.mul_in_layout = [
 28              ("op_1", create_data_input_layout(fpu_params)),
 29              ("op_2", create_data_input_layout(fpu_params)),
 30              ("rounding_mode", RoundingModes),
 31          ]
 32          """
 33          | Input layout for multiplication
 34          | op_1 - layout containing data of the first operand
 35          | op_2 - layout containing data of the second operand
 36          | rounding_mode - selected rounding mode
 37          | op_1 and op_2 are created using
 38            :meth:`create_data_input_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_data_input_layout>`
 39          """
 40          self.mul_out_layout = create_output_layout(fpu_params)
 41          """
 42          Output layout for multiplication. Created using
 43          :meth:`create_output_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_output_layout>`
 44          """
 45  
 46  
 47  class FPUMulModule(Elaboratable):
 48      """
 49      | FPU multiplication top module
 50      | The floating point multiplication consists of two parts:
 51      | 1. Exponent calcuation - turning both exponents from biased form into un-biased form
 52      | and then adding them together and turing result back into biased form
 53      | 2. Significand multiplication - This is essentialy fixed-point multiplication with
 54      | two bits for integer part and 2*n - 2 bits for fractional part.
 55      | We deal with with subnormal number by extending exponents range and turning subnormal
 56      | numbers into normalised numbers.
 57  
 58      Parameters
 59      ----------
 60      fpu_params: FPUParams
 61          FPU rounding module parameters
 62  
 63      Attributes
 64      ----------
 65      mul_request: Method
 66          Transactional method for initiating multiplication.
 67          Takes
 68          :meth:`mul_in_layout <coreblocks.func_blocks.fu.fpu.fpu_add_sub.FPUMulMethodLayout.add_sub_in_layout>`
 69          as argument.
 70          Returns result as
 71          :meth:`mul_out_layout <coreblocks.func_blocks.fu.fpu.fpu_add_sub.FPUMulMethodLayout.add_sub_out_layout>`.
 72      """
 73  
 74      def __init__(self, *, fpu_params: FPUParams):
 75          self.fpu_params = fpu_params
 76          self.method_layouts = FPUMulMethodLayout(fpu_params=self.fpu_params)
 77          self.common_values = FPUCommonValues(self.fpu_params)
 78          self.mul_request = Method(
 79              i=self.method_layouts.mul_in_layout,
 80              o=self.method_layouts.mul_out_layout,
 81          )
 82          self.mul_params = {"isa": {"xlen": self.fpu_params.sig_width}}
 83  
 84      def elaborate(self, platform):
 85          m = TModule()
 86  
 87          m.submodules.rounding_module = rounding_module = FPURounding(fpu_params=self.fpu_params)
 88          m.submodules.exception_module = exception_module = FPUErrorModule(fpu_params=self.fpu_params)
 89          m.submodules.multiplier = multiplier = FastRecursiveMul(
 90              self.fpu_params.sig_width, self.fpu_params.sig_width // 2
 91          )
 92  
 93          rounding_response = Signal(from_method_layout(rounding_module.method_layouts.rounding_out_layout))
 94          exception_response = Signal(from_method_layout(exception_module.method_layouts.error_out_layout))
 95  
 96          bias = self.common_values.bias
 97          min_real_exp = Const(1 - bias)
 98          max_real_exp = Const(self.common_values.max_exp - bias)
 99  
100          @def_method(m, self.mul_request)
101          def _(op_1, op_2, rounding_mode):
102  
103              final_sign = Signal()
104              m.d.av_comb += final_sign.eq(op_1.sign ^ op_2.sign)
105              op_1_subn = ~op_1.sig[-1]
106              op_2_subn = ~op_2.sig[-1]
107  
108              # exponent before potential operand normalization
109              pre_op_norm_exp = Signal(signed(self.fpu_params.exp_width + 1))
110              sum_of_exp = op_1.exp + op_2.exp - 2 * bias
111              # Because exp = 0 and exp = 1 represent the same exponent emin,
112              # to properly calcuate pre_op_norm_exp we have to adjust those exponents
113              # in case they are 0 and represent subnormal number
114              subn_correction = op_1_subn + op_2_subn
115              m.d.av_comb += pre_op_norm_exp.eq(sum_of_exp + subn_correction)
116  
117              # One of the ways to deal with subnormal values is to normalise them,
118              # record additional shifts in exponent and adjust for this during normalization
119              op_1_norm_shift = Signal(range(0, self.fpu_params.sig_width + 1))
120              op_2_norm_shift = Signal(range(0, self.fpu_params.sig_width + 1))
121              m.d.av_comb += op_1_norm_shift.eq(count_leading_zeros(op_1.sig))
122              m.d.av_comb += op_2_norm_shift.eq(count_leading_zeros(op_2.sig))
123  
124              norm_op_1_sig = Signal(self.fpu_params.sig_width)
125              norm_op_2_sig = Signal(self.fpu_params.sig_width)
126  
127              m.d.av_comb += norm_op_1_sig.eq(op_1.sig << op_1_norm_shift)
128              m.d.av_comb += norm_op_2_sig.eq(op_2.sig << op_2_norm_shift)
129  
130              shifted_out_bit = Signal()
131  
132              sig_product = Signal(2 * self.fpu_params.sig_width)
133              m.d.av_comb += multiplier.i1.eq(norm_op_1_sig)
134              m.d.av_comb += multiplier.i2.eq(norm_op_2_sig)
135              m.d.av_comb += sig_product.eq(multiplier.r)
136  
137              # First step of normalization
138              # if sig is between [1,2) leave it alone
139              # if sig is between [2,4) shift right by one
140              fixed_sig_product_norm = Signal((2 * self.fpu_params.sig_width) - 1)
141              m.d.av_comb += fixed_sig_product_norm.eq(sig_product)
142              with m.If(sig_product[-1]):
143                  m.d.av_comb += fixed_sig_product_norm.eq(sig_product >> 1)
144                  m.d.av_comb += shifted_out_bit.eq(sig_product[0])
145  
146              post_multiplication_exp = Signal(signed(self.fpu_params.exp_width + 1))
147              m.d.av_comb += post_multiplication_exp.eq(
148                  pre_op_norm_exp - (op_1_norm_shift + op_2_norm_shift) + sig_product[-1]
149              )
150  
151              sticky_bit = Signal()
152              round_bit = Signal()
153              mult_exp = Signal(self.fpu_params.exp_width)
154              mult_sig = Signal(self.fpu_params.sig_width + 1)
155              # One additional bit for round bit
156              normalised_ext_sig = Signal(self.fpu_params.sig_width + 1)
157              # The entire number consists of 2*(sig_width - 1) bits for fractional size
158              # and 2 bits for integer part so to turn this number into floating point number
159              # (1 bit for integer part and sig_width - 1 bits for fractional part)
160              # we have to shift number right by sig_width - 1 bits but because we
161              # want to keep one aditional bit for round bit we shift by sig_width - 2
162              m.d.av_comb += normalised_ext_sig.eq(fixed_sig_product_norm >> (self.fpu_params.sig_width - 2))
163  
164              any_shifted_out = Signal()
165              with m.If(post_multiplication_exp >= min_real_exp):
166                  m.d.av_comb += mult_exp.eq(
167                      Mux(
168                          post_multiplication_exp >= max_real_exp,
169                          self.common_values.max_exp,
170                          post_multiplication_exp + bias,
171                      )
172                  )
173                  m.d.av_comb += mult_sig.eq(normalised_ext_sig)
174                  lost_bits_or_red = fixed_sig_product_norm.bit_select(0, self.fpu_params.sig_width - 2).any()
175                  m.d.av_comb += any_shifted_out.eq(lost_bits_or_red)
176                  with m.If(mult_sig[-1] == 0):
177                      m.d.av_comb += mult_exp.eq(0)
178              with m.Elif(post_multiplication_exp < min_real_exp):
179                  # In this case value always will be subnormal
180                  m.d.av_comb += mult_exp.eq(0)
181                  shift_needed = Signal(unsigned(self.fpu_params.exp_width))
182                  m.d.av_comb += shift_needed.eq(min_real_exp - post_multiplication_exp)
183                  m.d.av_comb += mult_sig.eq(normalised_ext_sig >> shift_needed)
184                  with m.If(shift_needed > (self.fpu_params.sig_width)):
185                      m.d.av_comb += any_shifted_out.eq(fixed_sig_product_norm.any())
186                  with m.Else():
187                      # product has (2*p) - 1 bits, p ms bits represent the fp number
188                      # p+1 bit is round bit and p - 2 ls bits for initial sticky bit
189                      # For sticky bit we have to catch those p - 2 ls bits and shift_needed bits
190                      # from p + 1 ms bits
191                      padding = Signal().replicate(self.fpu_params.sig_width)
192                      shifted_out = Cat(padding, fixed_sig_product_norm).bit_select(
193                          shift_needed, 2 * self.fpu_params.sig_width - 2
194                      )
195                      m.d.av_comb += any_shifted_out.eq(shifted_out.any())
196              m.d.av_comb += sticky_bit.eq(any_shifted_out | shifted_out_bit)
197              m.d.av_comb += round_bit.eq(mult_sig[0])
198              resp = rounding_module.rounding_request(
199                  m,
200                  sign=final_sign,
201                  sig=mult_sig >> 1,
202                  exp=mult_exp,
203                  round_bit=round_bit,
204                  sticky_bit=sticky_bit,
205                  rounding_mode=rounding_mode,
206              )
207              m.d.av_comb += rounding_response.eq(resp)
208  
209              is_inf = Signal()
210              m.d.av_comb += is_inf.eq(op_1.is_inf | op_2.is_inf)
211              bad_inf = Signal()
212              m.d.av_comb += bad_inf.eq((op_1.is_inf & op_2.is_zero) | (op_2.is_inf & op_1.is_zero))
213              is_zero = Signal()
214              m.d.av_comb += is_zero.eq(op_1.is_zero | op_2.is_zero)
215              is_nan = Signal()
216              m.d.av_comb += is_nan.eq(op_1.is_nan | op_2.is_nan | bad_inf)
217  
218              exc_sig = Signal(self.fpu_params.sig_width)
219              exc_exp = Signal(self.fpu_params.exp_width)
220              exc_sign = Signal()
221              inexact = Signal()
222              invalid_operation = Signal()
223              with m.If(is_nan | is_inf | is_zero):
224                  m.d.av_comb += inexact.eq(0)
225                  with m.If(is_nan):
226                      is_any_snan = ((~op_1.sig[-2]) & op_1.is_nan) | ((~op_2.sig[-2]) & op_2.is_nan)
227                      m.d.av_comb += exc_sign.eq(0)
228                      m.d.av_comb += exc_exp.eq(self.common_values.max_exp)
229                      m.d.av_comb += exc_sig.eq(self.common_values.canonical_nan_sig)
230                      m.d.av_comb += invalid_operation.eq(is_any_snan | bad_inf)
231                  with m.Elif(is_inf & ~(bad_inf)):
232                      m.d.av_comb += exc_sign.eq(Mux(op_1.is_inf, op_1.sign, op_2.sign))
233                      m.d.av_comb += (exc_exp.eq(Mux(op_1.is_inf, op_1.exp, op_2.exp)),)
234                      m.d.av_comb += (exc_sig.eq(Mux(op_1.is_inf, op_1.sig, op_2.sig)),)
235                  with m.Elif(is_zero):
236                      m.d.av_comb += exc_sign.eq(final_sign)
237                      m.d.av_comb += exc_exp.eq(0)
238                      m.d.av_comb += exc_sig.eq(0)
239              with m.Else():
240                  m.d.av_comb += exc_sign.eq(final_sign)
241                  m.d.av_comb += exc_exp.eq(rounding_response["exp"])
242                  m.d.av_comb += inexact.eq(rounding_response["inexact"])
243                  with m.If(rounding_response["exp"] == self.common_values.max_exp):
244                      m.d.av_comb += exc_sig.eq(2 ** (self.fpu_params.sig_width - 1))
245                  with m.Else():
246                      m.d.av_comb += exc_sig.eq(rounding_response["sig"])
247  
248              resp = exception_module.error_checking_request(
249                  m,
250                  sign=exc_sign,
251                  sig=exc_sig,
252                  exp=exc_exp,
253                  rounding_mode=rounding_mode,
254                  inexact=inexact,
255                  invalid_operation=invalid_operation,
256                  division_by_zero=0,
257                  input_inf=is_inf,
258              )
259              m.d.av_comb += exception_response.eq(resp)
260  
261              return exception_response
262  
263          return m