/ coreblocks / func_blocks / fu / fpu / close_path.py
close_path.py
  1  from amaranth import *
  2  from transactron import TModule, Method, def_method, Transaction
  3  from transactron.utils.transactron_helpers import from_method_layout
  4  from coreblocks.func_blocks.fu.fpu.fpu_common import RoundingModes, FPUParams
  5  from coreblocks.func_blocks.fu.fpu.lza import LZAModule
  6  
  7  
  8  class ClosePathMethodLayout:
  9      """Close path module layouts for methods
 10  
 11      Parameters
 12      ----------
 13      fpu_params; FPUParams
 14          FPU parameters
 15      """
 16  
 17      def __init__(self, *, fpu_params: FPUParams):
 18          """
 19          r_sign - sign of the result
 20          sig_a - two's complement form of first significand
 21          sig_2 - two's complement form of second significand
 22          exp - exponent of result before shifts
 23          rounding_mode - rounding mode
 24          guard_bit - guard_bit (pth bit of second significand where p is precision)
 25          """
 26  
 27          self.close_path_in_layout = [
 28              ("r_sign", 1),
 29              ("sig_a", fpu_params.sig_width),
 30              ("sig_b", fpu_params.sig_width),
 31              ("exp", fpu_params.exp_width),
 32              ("rounding_mode", RoundingModes),
 33              ("guard_bit", 1),
 34          ]
 35          self.close_path_out_layout = [
 36              ("out_exp", fpu_params.exp_width),
 37              ("out_sig", fpu_params.sig_width),
 38              ("output_round", 1),
 39              ("output_sticky", 1),
 40          ]
 41  
 42  
 43  class ClosePathModule(Elaboratable):
 44      """Close path module
 45      Based on http://i.stanford.edu/pub/cstr/reports/csl/tr/90/442/CSL-TR-90-442.pdf.
 46      This module computes results for efficient subtraction,
 47      whenever difference of exponents is lesser than 2.
 48      Besides computing the result, this implementation also performs rounding at the same time
 49      as subtraction by using two adders (one computing a+b and the other one computing a+b+1).
 50      The correct output is chosen based on flags that are different for each rounding mode.
 51  
 52      Parameters
 53      ----------
 54      fpu_params: FPUParams
 55          FPU rounding module parameters
 56  
 57      Attributes
 58      ----------
 59      close_path_request: Method
 60          Transactional method for initiating close path computation.
 61          Takes 'close_path_in_layout' as argument.
 62          Returns result as 'close_path_out_layout'.
 63      """
 64  
 65      def __init__(self, *, fpu_params: FPUParams):
 66  
 67          self.params = fpu_params
 68          self.method_layouts = ClosePathMethodLayout(fpu_params=self.params)
 69          self.close_path_request = Method(
 70              i=self.method_layouts.close_path_in_layout,
 71              o=self.method_layouts.close_path_out_layout,
 72          )
 73  
 74      def elaborate(self, platform):
 75          m = TModule()
 76  
 77          result_add_zero = Signal(self.params.sig_width)
 78          result_add_one = Signal(self.params.sig_width)
 79          final_result = Signal(self.params.sig_width)
 80          shift_correction = Signal()
 81          shift_amount = Signal(range(self.params.sig_width))
 82          bit_shift_amount = Signal(range(self.params.sig_width))
 83          check_shift_amount = Signal(range(self.params.sig_width))
 84          final_sig = Signal(self.params.sig_width)
 85          final_exp = Signal(self.params.exp_width)
 86          final_round = Signal()
 87  
 88          shift_in_bit = Signal()
 89          l_flag = Signal()
 90          is_zero = Signal()
 91  
 92          m.submodules.zero_lza = zero_lza = LZAModule(fpu_params=self.params)
 93          m.submodules.one_lza = one_lza = LZAModule(fpu_params=self.params)
 94          lza_resp = Signal(from_method_layout(zero_lza.method_layouts.predict_out_layout))
 95  
 96          @def_method(m, self.close_path_request)
 97          def _(
 98              r_sign,
 99              sig_a,
100              sig_b,
101              exp,
102              rounding_mode,
103              guard_bit,
104          ):
105              m.d.av_comb += result_add_zero.eq(sig_a + sig_b)
106              m.d.av_comb += result_add_one.eq(sig_a + sig_b + 1)
107              m.d.av_comb += shift_in_bit.eq(guard_bit)
108  
109              with m.Switch(rounding_mode):
110                  with m.Case(RoundingModes.ROUND_UP):
111                      m.d.av_comb += l_flag.eq((~(r_sign) & result_add_zero[-1] & guard_bit) | ~(guard_bit))
112                  with m.Case(RoundingModes.ROUND_DOWN):
113                      m.d.av_comb += l_flag.eq((r_sign & result_add_zero[-1] & guard_bit) | ~(guard_bit))
114                  with m.Case(RoundingModes.ROUND_ZERO):
115                      m.d.av_comb += l_flag.eq(~(guard_bit))
116                  with m.Case(RoundingModes.ROUND_NEAREST_EVEN):
117                      m.d.av_comb += l_flag.eq((result_add_zero[-1] & guard_bit & result_add_zero[0]) | ~(guard_bit))
118                  with m.Case(RoundingModes.ROUND_NEAREST_AWAY):
119                      m.d.av_comb += l_flag.eq((result_add_zero[-1] & guard_bit) | ~(guard_bit))
120  
121              with Transaction().body(m):
122                  m.d.av_comb += final_result.eq(Mux(l_flag, result_add_one, result_add_zero))
123                  resp = Mux(
124                      l_flag,
125                      one_lza.predict_request(m, sig_a=sig_a, sig_b=sig_b, carry=1),
126                      zero_lza.predict_request(m, sig_a=sig_a, sig_b=sig_b, carry=0),
127                  )
128                  m.d.av_comb += lza_resp.eq(resp)
129                  m.d.av_comb += is_zero.eq(lza_resp["is_zero"])
130  
131              with m.If(is_zero | (exp == 0)):
132                  m.d.av_comb += final_sig.eq(final_result)
133                  m.d.av_comb += final_exp.eq(0)
134                  m.d.av_comb += final_round.eq(guard_bit)
135              with m.Elif(exp <= lza_resp["shift_amount"]):
136                  with m.If(exp == 1):
137                      m.d.av_comb += final_sig.eq(final_result)
138                      m.d.av_comb += final_round.eq(guard_bit)
139                  with m.Else():
140                      m.d.av_comb += shift_amount.eq(exp - 1)
141                      m.d.av_comb += bit_shift_amount.eq(exp - 2)
142                      m.d.av_comb += final_sig.eq((final_result << shift_amount) | (shift_in_bit << bit_shift_amount))
143                      m.d.av_comb += final_round.eq(0)
144                  m.d.av_comb += final_exp.eq(0)
145              with m.Else():
146                  shifted_sig = Signal(self.params.sig_width)
147                  shifted_exp = Signal(self.params.exp_width)
148  
149                  m.d.av_comb += shifted_sig.eq(final_result << lza_resp["shift_amount"])
150                  m.d.av_comb += shifted_exp.eq(exp - lza_resp["shift_amount"])
151                  m.d.av_comb += check_shift_amount.eq(lza_resp["shift_amount"] - 1)
152                  m.d.av_comb += shift_correction.eq(
153                      (shifted_sig | (guard_bit << check_shift_amount))[self.params.sig_width - 1]
154                  )
155  
156                  with m.If(shift_correction):
157                      with m.If(lza_resp["shift_amount"] == 0):
158                          m.d.av_comb += final_sig.eq(shifted_sig)
159                          m.d.av_comb += final_round.eq(guard_bit)
160                      with m.Else():
161                          m.d.av_comb += bit_shift_amount.eq(lza_resp["shift_amount"] - 1)
162                          m.d.av_comb += final_sig.eq(shifted_sig | (shift_in_bit << bit_shift_amount))
163                          m.d.av_comb += final_round.eq(0)
164                      m.d.av_comb += final_exp.eq(shifted_exp)
165                  with m.Else():
166                      m.d.av_comb += final_round.eq(0)
167                      with m.If(shifted_exp == 1):
168                          with m.If(lza_resp["shift_amount"] > 0):
169                              m.d.av_comb += bit_shift_amount.eq(lza_resp["shift_amount"] - 1)
170                              m.d.av_comb += final_sig.eq(shifted_sig | (shift_in_bit << bit_shift_amount))
171                          with m.Else():
172                              m.d.av_comb += final_sig.eq(shifted_sig)
173                          m.d.av_comb += final_exp.eq(0)
174                      with m.Else():
175                          m.d.av_comb += final_sig.eq((shifted_sig << 1) | (shift_in_bit << (lza_resp["shift_amount"])))
176                          m.d.av_comb += final_exp.eq(shifted_exp - 1)
177  
178              return {"out_exp": final_exp, "out_sig": final_sig, "output_round": final_round, "output_sticky": 0}
179  
180          return m