/ coreblocks / func_blocks / fu / fpu / fpu_add_sub.py
fpu_add_sub.py
  1  from amaranth import *
  2  from transactron import TModule, Method, def_method
  3  from transactron.utils import assign
  4  from transactron.utils.transactron_helpers import from_method_layout
  5  from coreblocks.func_blocks.fu.fpu.fpu_common import (
  6      RoundingModes,
  7      FPUParams,
  8      create_data_input_layout,
  9      create_output_layout,
 10      create_raw_float_layout,
 11      FPUCommonValues,
 12  )
 13  from coreblocks.func_blocks.fu.fpu.far_path import FarPathModule
 14  from coreblocks.func_blocks.fu.fpu.close_path import ClosePathModule
 15  from coreblocks.func_blocks.fu.fpu.fpu_error_module import FPUErrorModule
 16  
 17  
 18  class FPUAddSubMethodLayout:
 19      """FPU addition/subtraction top module method layout
 20  
 21      Parameters
 22      ----------
 23      fpu_params; FPUParams
 24          FPU parameters
 25      """
 26  
 27      def __init__(self, *, fpu_params: FPUParams):
 28          self.add_sub_in_layout = [
 29              ("op_1", create_data_input_layout(fpu_params)),
 30              ("op_2", create_data_input_layout(fpu_params)),
 31              ("rounding_mode", RoundingModes),
 32              ("operation", 1),
 33          ]
 34          """
 35          | Input layout for addition/subtraction
 36          | op_1 - layout containing data of the first operand
 37          | op_2 - layout containing data of the second operand
 38          | rounding_mode - selected rounding mode
 39          | operation - selected operation; 1 - subtraction, 0 - addition
 40          | op_1 and op_2 are created using
 41            :meth:`create_data_input_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_data_input_layout>`
 42          """
 43          self.add_sub_out_layout = create_output_layout(fpu_params)
 44          """
 45          Output layout for addition/subtraction. Created using
 46          :meth:`create_output_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_output_layout>`
 47          """
 48          self.raw_float_layout = create_raw_float_layout(fpu_params)
 49          """
 50          Output layout for raw float. Created using
 51          :meth:`create_raw_float_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_raw_float_layout>`
 52          """
 53          ext_paramas = FPUParams(sig_width=fpu_params.sig_width + 2, exp_width=fpu_params.exp_width)
 54          self.ext_float_layout = create_raw_float_layout(ext_paramas)
 55          """
 56          Output layout for raw float with significand larger by two bits from selected format.
 57          Created using
 58          :meth:`create_raw_float_layout <coreblocks.func_blocks.fu.fpu.fpu_common.create_raw_float_layout>`
 59          """
 60  
 61  
 62  class FPUAddSubModule(Elaboratable):
 63      """
 64      | FPU addition/subtraction top module
 65      | This module implements addition and subtraction using
 66        two path approach with rounding prediction.
 67      | The module can be divided into two segments:
 68      | 1. Receiving data and preparing it for one of the two path submodules
 69        by calculating effective operation, swapping operands and aligning exponents.
 70      | 2. Receiving data from one of the path submodules and preparing it for error checking
 71        module by checking for various conditions
 72      | For more info about close path and far path check
 73        :class:`close path module <coreblocks.func_blocks.fu.fpu.close_path.ClosePathModule>`
 74         and
 75        :class:`far path module <coreblocks.func_blocks.fu.fpu.far_path.FarPathModule>`
 76  
 77      Parameters
 78      ----------
 79      fpu_params: FPUParams
 80          FPU rounding module parameters
 81  
 82      Attributes
 83      ----------
 84      add_sub_request: Method
 85          Transactional method for initiating addition or subtraction.
 86          Takes
 87          :meth:`add_sub_in_layout <coreblocks.func_blocks.fu.fpu.fpu_add_sub.FPUAddSubMethodLayout.add_sub_in_layout>`
 88          as argument.
 89          Returns result as
 90          :meth:`add_sub_out_layout <coreblocks.func_blocks.fu.fpu.fpu_add_sub.FPUAddSubMethodLayout.add_sub_out_layout>`.
 91      """
 92  
 93      def __init__(self, *, fpu_params: FPUParams):
 94          self.fpu_params = fpu_params
 95          self.method_layouts = FPUAddSubMethodLayout(fpu_params=self.fpu_params)
 96          self.common_values = FPUCommonValues(self.fpu_params)
 97          self.add_sub_request = Method(
 98              i=self.method_layouts.add_sub_in_layout,
 99              o=self.method_layouts.add_sub_out_layout,
100          )
101  
102      def elaborate(self, platform):
103          m = TModule()
104  
105          def assign_values(lhs, exp, sig, sign):
106              m.d.av_comb += assign(lhs, {"sign": sign, "exp": exp, "sig": sig})
107  
108          m.submodules.close_path_module = close_path_module = ClosePathModule(fpu_params=self.fpu_params)
109          m.submodules.far_path_module = far_path_module = FarPathModule(fpu_params=self.fpu_params)
110          m.submodules.exception_module = exception_module = FPUErrorModule(fpu_params=self.fpu_params)
111  
112          max_exp = (2 ** (self.fpu_params.exp_width)) - 1
113  
114          final_sign = Signal(1)
115          exp_diff = Signal(range(-max_exp, max_exp + 1))
116          norm_shift_amount = Signal(range(max_exp))
117          sticky_bit = Signal(1)
118          true_operation = Signal(1)
119          exception_round_bit = Signal(1)
120          exception_sticky_bit = Signal(1)
121          invalid_operation = Signal(1)
122  
123          path_response = Signal(from_method_layout(far_path_module.method_layouts.far_path_out_layout))
124          exception_response = Signal(from_method_layout(exception_module.method_layouts.error_out_layout))
125  
126          @def_method(m, self.add_sub_request)
127          def _(op_1, op_2, rounding_mode, operation):
128              op_2_adjusted_sign = Signal()
129              m.d.av_comb += op_2_adjusted_sign.eq(operation ^ op_2.sign)
130  
131              m.d.av_comb += exp_diff.eq(op_1.exp - op_2.exp)
132  
133              # Swapping operands to ensure that abs(pre_shift_op1) >= abs(pre_shift_op2)
134              pre_shift_op1 = Signal(from_method_layout(self.method_layouts.ext_float_layout))
135              pre_shift_op2 = Signal(from_method_layout(self.method_layouts.ext_float_layout))
136              op_1_abs_float = Cat(op_1.sig, op_1.exp)
137              op_2_abs_float = Cat(op_2.sig, op_2.exp)
138  
139              with m.If(op_1_abs_float > op_2_abs_float):
140                  assign_values(pre_shift_op1, op_1.exp, op_1.sig << 2, op_1.sign)
141                  assign_values(pre_shift_op2, op_2.exp, op_2.sig << 2, op_2_adjusted_sign)
142              with m.Else():
143                  assign_values(pre_shift_op1, op_2.exp, op_2.sig << 2, op_2_adjusted_sign)
144                  assign_values(pre_shift_op2, op_1.exp, op_1.sig << 2, op_1.sign)
145  
146              # Calculating true operation based on signs of operands
147              sign_xor = op_1.sign ^ op_2_adjusted_sign
148  
149              m.d.av_comb += final_sign.eq(pre_shift_op1.sign)
150  
151              with m.If(~sign_xor):
152                  m.d.av_comb += true_operation.eq(0)
153              with m.Else():
154                  m.d.av_comb += true_operation.eq(1)
155  
156              is_one_subnormal = (pre_shift_op1.exp > 0) & (pre_shift_op2.exp == 0)
157              m.d.av_comb += norm_shift_amount.eq(pre_shift_op1.exp - pre_shift_op2.exp - is_one_subnormal)
158  
159              # Aligning exponents and calculating GRS bits
160              path_op1 = Signal(from_method_layout(self.method_layouts.raw_float_layout))
161              far_path_op2_ext = Signal(from_method_layout(self.method_layouts.ext_float_layout))
162              far_path_op2 = Signal(from_method_layout(self.method_layouts.raw_float_layout))
163              close_path_op2 = Signal(from_method_layout(self.method_layouts.raw_float_layout))
164  
165              m.d.av_comb += path_op1.sig.eq(pre_shift_op1.sig)
166              with m.If(norm_shift_amount > (self.fpu_params.sig_width + 2)):
167                  m.d.av_comb += sticky_bit.eq(pre_shift_op2.sig.any())
168                  m.d.av_comb += far_path_op2_ext.sig.eq(0)
169              with m.Else():
170                  sticky_bit_mask = Cat(Signal().replicate(self.fpu_params.sig_width), pre_shift_op2.sig).bit_select(
171                      norm_shift_amount, self.fpu_params.sig_width
172                  )
173                  m.d.av_comb += sticky_bit.eq(sticky_bit_mask.any())
174                  m.d.av_comb += far_path_op2_ext.sig.eq(pre_shift_op2.sig >> norm_shift_amount)
175  
176              close_path_guard_bit = Signal()
177  
178              with m.If(norm_shift_amount[0] == 0):
179                  m.d.av_comb += close_path_op2.sig.eq(~(pre_shift_op2.sig >> 2))
180                  m.d.av_comb += close_path_guard_bit.eq(0)
181              with m.Else():
182                  m.d.av_comb += close_path_op2.sig.eq(~(pre_shift_op2.sig >> 3))
183                  m.d.av_comb += close_path_guard_bit.eq(pre_shift_op2.sig[2])
184  
185              guard_bit = far_path_op2_ext.sig[1]
186              round_bit = far_path_op2_ext.sig[0]
187  
188              # Assigning operands for close path and far path
189              assign_values(path_op1, pre_shift_op1.exp, pre_shift_op1.sig >> 2, pre_shift_op1.sign)
190              assign_values(
191                  far_path_op2,
192                  (pre_shift_op2.exp + norm_shift_amount),
193                  Mux(true_operation, ~(far_path_op2_ext.sig >> 2), far_path_op2_ext.sig >> 2),
194                  pre_shift_op2.sign,
195              )
196  
197              close_path = (norm_shift_amount <= 1) & true_operation
198              resp = Mux(
199                  close_path,
200                  close_path_module.close_path_request(
201                      m,
202                      r_sign=final_sign,
203                      sig_a=path_op1.sig,
204                      sig_b=close_path_op2.sig,
205                      exp=path_op1.exp,
206                      rounding_mode=rounding_mode,
207                      guard_bit=close_path_guard_bit,
208                  ),
209                  far_path_module.far_path_request(
210                      m,
211                      r_sign=final_sign,
212                      sig_a=path_op1.sig,
213                      sig_b=far_path_op2.sig,
214                      exp=path_op1.exp,
215                      sub_op=true_operation,
216                      rounding_mode=rounding_mode,
217                      guard_bit=guard_bit,
218                      round_bit=round_bit,
219                      sticky_bit=sticky_bit,
220                  ),
221              )
222  
223              # Preparing data for error checking module
224              m.d.av_comb += path_response.eq(resp)
225              eq_signs = pre_shift_op2.sign == path_op1.sign
226              is_inf = op_1.is_inf | op_2.is_inf
227              wrong_inf = (op_1.is_inf & op_2.is_inf) & ~(eq_signs)
228              is_nan = (op_1.is_nan | op_2.is_nan) | wrong_inf
229              output_zero = (path_response["out_exp"] == 0) & (path_response["out_sig"] == 0)
230              output_exact = ~(path_response["output_round"] | path_response["output_sticky"])
231              both_op_zero = op_1.is_zero & op_2.is_zero
232              is_zero = Signal()
233              m.d.av_comb += is_zero.eq(both_op_zero | (output_exact & output_zero))
234              normal_case = ~(is_nan | is_inf | is_zero)
235              exception_op = Signal(from_method_layout(self.method_layouts.raw_float_layout))
236  
237              with m.If(~normal_case):
238                  m.d.av_comb += exception_round_bit.eq(0)
239                  m.d.av_comb += exception_sticky_bit.eq(0)
240              with m.If(is_nan):
241                  is_any_snan = ((~op_1.sig[-2]) & op_1.is_nan) | ((~op_2.sig[-2]) & op_2.is_nan)
242                  with m.If(is_any_snan | wrong_inf):
243                      m.d.av_comb += invalid_operation.eq(1)
244                  m.d.av_comb += exception_op.sign.eq(0)
245                  m.d.av_comb += exception_op.exp.eq(max_exp)
246                  m.d.av_comb += exception_op.sig.eq(self.common_values.canonical_nan_sig)
247              with m.Elif(is_inf & ~(wrong_inf)):
248                  assign_values(
249                      exception_op,
250                      Mux(op_1.is_inf, op_1.exp, op_2.exp),
251                      Mux(op_1.is_inf, op_1.sig, op_2.sig),
252                      Mux(op_1.is_inf, op_1.sign, pre_shift_op2.sign),
253                  )
254              with m.Elif(is_zero):
255                  with m.If(eq_signs):
256                      m.d.av_comb += exception_op.sign.eq(op_1.sign)
257                  with m.Else():
258                      m.d.av_comb += exception_op.sign.eq(rounding_mode == RoundingModes.ROUND_DOWN)
259                  m.d.av_comb += exception_op.exp.eq(0)
260                  m.d.av_comb += exception_op.sig.eq(0)
261              with m.Elif(normal_case):
262                  m.d.av_comb += exception_op.sign.eq(final_sign)
263                  m.d.av_comb += exception_op.exp.eq(path_response["out_exp"])
264                  m.d.av_comb += exception_round_bit.eq(path_response["output_round"])
265                  m.d.av_comb += exception_sticky_bit.eq(path_response["output_sticky"])
266                  with m.If(path_response["out_exp"] == max_exp):
267                      m.d.av_comb += exception_op.sig.eq(2 ** (self.fpu_params.sig_width - 1))
268                  with m.Else():
269                      m.d.av_comb += exception_op.sig.eq(path_response["out_sig"])
270  
271              inexact = exception_sticky_bit | exception_round_bit
272              resp = exception_module.error_checking_request(
273                  m,
274                  sign=exception_op.sign,
275                  sig=exception_op.sig,
276                  exp=exception_op.exp,
277                  rounding_mode=rounding_mode,
278                  inexact=inexact,
279                  invalid_operation=invalid_operation,
280                  division_by_zero=0,
281                  input_inf=is_inf,
282              )
283              m.d.av_comb += exception_response.eq(resp)
284  
285              return exception_response
286  
287          return m