/ coreblocks / func_blocks / fu / fpu / far_path.py
far_path.py
  1  from amaranth import *
  2  from transactron import TModule, Method, def_method
  3  from coreblocks.func_blocks.fu.fpu.fpu_common import RoundingModes, FPUParams
  4  
  5  
  6  class FarPathMethodLayout:
  7      """Far path module layouts for methods
  8  
  9      Parameters
 10      ----------
 11      fpu_params; FPUParams
 12          FPU parameters
 13      """
 14  
 15      def __init__(self, *, fpu_params: FPUParams):
 16          """
 17          r_sign - result sign
 18          sig_a - significand of first operand (for effective subtraction in two's complement form)
 19          sig_b - significand of second operand (for effective subtraction in two's complement form)
 20          exp - exponent of result before shift
 21          sub_op - effective operation. 1 for subtraction 0 for addition
 22          rounding_mode - rounding mode
 23          guard_bit - guard bit (pth bit of second significand where p is precision)
 24          round_bit - round bit ((p+1)th bit of second significand where p is precision)
 25          sticky_bit - sticky_bit
 26          (OR of all bits with index >=p of second significand where p is precision)
 27          """
 28          self.far_path_in_layout = [
 29              ("r_sign", 1),
 30              ("sig_a", fpu_params.sig_width),
 31              ("sig_b", fpu_params.sig_width),
 32              ("exp", fpu_params.exp_width),
 33              ("sub_op", 1),
 34              ("rounding_mode", RoundingModes),
 35              ("guard_bit", 1),
 36              ("round_bit", 1),
 37              ("sticky_bit", 1),
 38          ]
 39          self.far_path_out_layout = [
 40              ("out_exp", fpu_params.exp_width),
 41              ("out_sig", fpu_params.sig_width),
 42              ("output_round", 1),
 43              ("output_sticky", 1),
 44          ]
 45  
 46  
 47  class FarPathModule(Elaboratable):
 48      """Far Path module
 49      Based on: https://userpages.cs.umbc.edu/phatak/645/supl/lza/lza-survey-arith01.pdf.
 50      This module implements far path of adder/subtractor.
 51      It performs subtraction for operands whose exponent differs by more than 1 and addition
 52      for all combinations of operands. Besides addition it also performs rounding at the same time
 53      as addition using two adders (one producing a+b and second one producing a+b+1). The correct
 54      output is chosen by flags that differ for each rounding mode. To deal with certain
 55      complication that may arise during addition in certain rounding modes the input of second
 56      may be either input operand or (a & b)<<1 and (a^b). This allows second adder to compute
 57      a+b+2 in special cases that are better explained in paper linked above.
 58  
 59      Parameters
 60      ----------
 61      fpu_params: FPUParams
 62          FPU rounding module parameters
 63  
 64      Attributes
 65      ----------
 66      far_path_request: Method
 67          Transactional method for initiating far path computation.
 68          Takes 'far_path_in_layout' as argument.
 69          Returns result as 'far_path_out_layout'.
 70      """
 71  
 72      def __init__(self, *, fpu_params: FPUParams):
 73  
 74          self.params = fpu_params
 75          self.method_layouts = FarPathMethodLayout(fpu_params=self.params)
 76          self.far_path_request = Method(
 77              i=self.method_layouts.far_path_in_layout,
 78              o=self.method_layouts.far_path_out_layout,
 79          )
 80  
 81      def elaborate(self, platform):
 82          m = TModule()
 83  
 84          input_sig_add_0_a = Signal(self.params.sig_width)
 85          input_sig_add_0_b = Signal(self.params.sig_width)
 86          input_sig_add_1_a = Signal(self.params.sig_width)
 87          input_sig_add_1_b = Signal(self.params.sig_width)
 88          output_sig_add_0 = Signal(self.params.sig_width + 1)
 89          output_sig_add_1 = Signal(self.params.sig_width + 1)
 90          output_sig = Signal(self.params.sig_width + 1)
 91          output_exp = Signal(self.params.exp_width + 1)
 92          output_final_exp = Signal(self.params.exp_width)
 93          output_final_sig = Signal(self.params.sig_width)
 94  
 95          output_round_bit = Signal()
 96          output_sticky_bit = Signal()
 97          final_guard_bit = Signal()
 98          final_round_bit = Signal()
 99          final_sticky_bit = Signal()
100  
101          round_up_inc_1 = Signal()
102          round_down_inc_1 = Signal()
103          round_to_inf_special_case = Signal()
104          xor_sig = Signal(self.params.sig_width)
105          carry_sig = Signal(self.params.sig_width)
106          carry_add1 = Signal()
107          rgs_any = Signal()
108          rgs_all = Signal()
109  
110          # No right shift
111          nrs = Signal()
112          # One right shift
113          ors = Signal()
114          # No left shift
115          nls = Signal()
116          # One left shift
117          ols = Signal()
118          nxs = Signal()
119  
120          nxs_rtne = Signal()
121          nxs_rtna = Signal()
122          nxs_zero = Signal()
123          nxs_up = Signal()
124          nxs_down = Signal()
125  
126          ors_rtne = Signal()
127          ors_rtna = Signal()
128          ors_zero = Signal()
129          ors_up = Signal()
130          ors_down = Signal()
131  
132          ols_rtne = Signal()
133          ols_rtna = Signal()
134          ols_zero = Signal()
135          ols_up = Signal()
136          ols_down = Signal()
137  
138          shift_in_bit_rtne = Signal()
139          shift_in_bit_rtna = Signal()
140          shift_in_bit_zero = Signal()
141          shift_in_bit_up = Signal()
142          shift_in_bit_down = Signal()
143          shift_in_bit = Signal()
144  
145          g = Signal()
146  
147          @def_method(m, self.far_path_request)
148          def _(
149              r_sign,
150              sig_a,
151              sig_b,
152              exp,
153              sub_op,
154              rounding_mode,
155              guard_bit,
156              round_bit,
157              sticky_bit,
158          ):
159              m.d.av_comb += input_sig_add_0_a.eq(sig_a)
160              m.d.av_comb += input_sig_add_0_b.eq(sig_b)
161              m.d.av_comb += xor_sig.eq(sig_a ^ sig_b)
162              m.d.av_comb += carry_sig.eq(sig_a & sig_b)
163              m.d.av_comb += carry_add1.eq(carry_sig[-1])
164              m.d.av_comb += rgs_any.eq(guard_bit | round_bit | sticky_bit)
165              m.d.av_comb += rgs_all.eq(guard_bit & round_bit & sticky_bit)
166              m.d.av_comb += round_to_inf_special_case.eq(
167                  (~sub_op) & ((rounding_mode == RoundingModes.ROUND_DOWN) | (rounding_mode == RoundingModes.ROUND_UP))
168              )
169  
170              with m.If(round_to_inf_special_case):
171                  m.d.av_comb += input_sig_add_1_a.eq((carry_sig << 1) | (~xor_sig[0]))
172                  m.d.av_comb += input_sig_add_1_b.eq(xor_sig)
173              with m.Else():
174                  m.d.av_comb += input_sig_add_1_a.eq(sig_a)
175                  m.d.av_comb += input_sig_add_1_b.eq(sig_b)
176                  m.d.av_comb += carry_add1.eq(0)
177  
178              m.d.av_comb += output_sig_add_0.eq(input_sig_add_0_a + input_sig_add_0_b)
179              m.d.av_comb += output_sig_add_1.eq(
180                  (input_sig_add_1_a + input_sig_add_1_b + 1) | (carry_add1 << (self.params.sig_width - 1))
181              )
182  
183              m.d.av_comb += nrs.eq((~sub_op) & (~output_sig_add_0[-1]))
184              m.d.av_comb += ors.eq((~sub_op) & (output_sig_add_0[-1]))
185              m.d.av_comb += nls.eq(sub_op & (((~rgs_any) & output_sig_add_1[-2]) | (rgs_any & output_sig_add_0[-2])))
186              m.d.av_comb += ols.eq(
187                  sub_op & (((~rgs_any) & (~output_sig_add_1[-2])) | (rgs_any & (~output_sig_add_0[-2])))
188              )
189              m.d.av_comb += nxs.eq(nls | nrs)
190  
191              subtraction = sub_op & ((~r_sign) | (~rgs_any))
192              addition = (~sub_op) & ((sig_a[0] ^ sig_b[0]) & ((~r_sign) & (rgs_any)))
193              m.d.av_comb += nxs_up.eq(subtraction | addition)
194  
195              subtraction = sub_op & (r_sign | (~rgs_any))
196              addition = (~sub_op) & ((sig_a[0] ^ sig_b[0]) & (r_sign & (rgs_any)))
197              m.d.av_comb += nxs_down.eq(subtraction | addition)
198  
199              m.d.av_comb += nxs_zero.eq(sub_op & (~rgs_any))
200  
201              subtraction = sub_op & ((~guard_bit) | (guard_bit & (~round_bit) & (~sticky_bit) & (sig_a[0] ^ sig_b[0])))
202              addition = (~sub_op) & guard_bit & (round_bit | sticky_bit | (sig_a[0] ^ sig_b[0]))
203              m.d.av_comb += nxs_rtne.eq(subtraction | addition)
204  
205              subtraction = sub_op & (((~guard_bit) ^ ((~round_bit) & (~sticky_bit))) | (~rgs_any))
206              addition = (~sub_op) & guard_bit
207              m.d.av_comb += nxs_rtna.eq(subtraction | addition)
208  
209              m.d.av_comb += ors_up.eq((~r_sign) & ((sig_a[0] ^ sig_b[0]) | rgs_any))
210              m.d.av_comb += ors_down.eq(r_sign & ((sig_a[0] ^ sig_b[0]) | rgs_any))
211              m.d.av_comb += ors_zero.eq(sub_op & (~rgs_any))
212              m.d.av_comb += ors_rtne.eq((sig_a[0] ^ sig_b[0]) & (rgs_any | (sig_a[1] ^ sig_b[1])))
213              m.d.av_comb += ors_rtna.eq(sig_a[0] ^ sig_b[0])
214  
215              m.d.av_comb += ols_up.eq(((~r_sign) & (~guard_bit)) | (r_sign & (~rgs_any)))
216              m.d.av_comb += ols_down.eq((r_sign & (~guard_bit)) | ((~r_sign) & (~rgs_any)))
217              m.d.av_comb += ols_zero.eq(sub_op & (~rgs_any))
218              m.d.av_comb += ols_rtne.eq((~guard_bit) & ((~round_bit) | (~sticky_bit)))
219              m.d.av_comb += ols_rtna.eq((~guard_bit) & ((~round_bit) | (~sticky_bit)))
220              m.d.av_comb += shift_in_bit_up.eq(
221                  ((~r_sign) & guard_bit)
222                  | (r_sign & ((guard_bit & (~round_bit) & (~sticky_bit)) | ((~guard_bit) & (round_bit | sticky_bit))))
223              )
224              m.d.av_comb += shift_in_bit_down.eq(
225                  (r_sign & guard_bit)
226                  | ((~r_sign) & ((guard_bit & (~round_bit) & (~sticky_bit)) | ((~guard_bit) & (round_bit | sticky_bit))))
227              )
228              m.d.av_comb += shift_in_bit_zero.eq(
229                  ((~guard_bit) & (round_bit | sticky_bit)) | (guard_bit & (~round_bit) & (~sticky_bit))
230              )
231              m.d.av_comb += shift_in_bit_rtne.eq(((~guard_bit) & round_bit & sticky_bit) | (guard_bit & (~round_bit)))
232              m.d.av_comb += shift_in_bit_rtna.eq(
233                  ((~guard_bit) & round_bit & sticky_bit) | (guard_bit & (~(round_bit & sticky_bit)))
234              )
235  
236              with m.Switch(rounding_mode):
237                  with m.Case(RoundingModes.ROUND_UP):
238                      m.d.av_comb += g.eq((ors & ors_up) | (nxs & nxs_up) | (ols & ols_up))
239                      m.d.av_comb += shift_in_bit.eq(shift_in_bit_up)
240                  with m.Case(RoundingModes.ROUND_DOWN):
241                      m.d.av_comb += g.eq((ors & ors_down) | (nxs & nxs_down) | (ols & ols_down))
242                      m.d.av_comb += shift_in_bit.eq(shift_in_bit_down)
243                  with m.Case(RoundingModes.ROUND_ZERO):
244                      m.d.av_comb += g.eq((ors & ors_zero) | (nxs & nxs_zero) | (ols & ols_zero))
245                      m.d.av_comb += shift_in_bit.eq(shift_in_bit_zero)
246  
247                  with m.Case(RoundingModes.ROUND_NEAREST_EVEN):
248                      m.d.av_comb += g.eq((ors & ors_rtne) | (nxs & nxs_rtne) | (ols & ols_rtne))
249                      m.d.av_comb += shift_in_bit.eq(shift_in_bit_rtne)
250  
251                  with m.Case(RoundingModes.ROUND_NEAREST_AWAY):
252                      m.d.av_comb += g.eq((ors & ors_rtna) | (nxs & nxs_rtna) | (ols & ols_rtna))
253                      m.d.av_comb += shift_in_bit.eq(shift_in_bit_rtna)
254  
255              m.d.av_comb += round_up_inc_1.eq(
256                  (rounding_mode == RoundingModes.ROUND_UP)
257                  & nrs
258                  & (~g)
259                  & (~(sig_a[0] ^ sig_b[0]))
260                  & ((~r_sign) & (rgs_any))
261              )
262              m.d.av_comb += round_down_inc_1.eq(
263                  (rounding_mode == RoundingModes.ROUND_DOWN)
264                  & nrs
265                  & (~g)
266                  & (~(sig_a[0] ^ sig_b[0]))
267                  & (r_sign & (rgs_any))
268              )
269              with m.If(g):
270                  m.d.av_comb += output_sig.eq(output_sig_add_1)
271              with m.Else():
272                  with m.If(round_down_inc_1 | round_up_inc_1):
273                      m.d.av_comb += output_sig.eq(output_sig_add_0 | 1)
274                  with m.Else():
275                      m.d.av_comb += output_sig.eq(output_sig_add_0)
276              m.d.av_comb += output_exp.eq(exp)
277  
278              with m.If(sub_op):
279                  m.d.av_comb += final_guard_bit.eq((~guard_bit) ^ ((~round_bit) & (~sticky_bit)))
280                  m.d.av_comb += final_round_bit.eq((~round_bit) ^ (~sticky_bit))
281                  m.d.av_comb += final_sticky_bit.eq(sticky_bit)
282  
283              with m.Else():
284                  m.d.av_comb += final_guard_bit.eq(guard_bit)
285                  m.d.av_comb += final_round_bit.eq(round_bit)
286                  m.d.av_comb += final_sticky_bit.eq(sticky_bit)
287  
288              with m.If(ors):
289                  m.d.av_comb += output_sticky_bit.eq(final_guard_bit | final_round_bit | final_sticky_bit)
290                  m.d.av_comb += output_round_bit.eq(sig_a[0] ^ sig_b[0])
291              with m.Elif(ols):
292                  m.d.av_comb += output_sticky_bit.eq(final_sticky_bit)
293                  m.d.av_comb += output_round_bit.eq(final_round_bit)
294              with m.Else():
295                  m.d.av_comb += output_sticky_bit.eq(final_round_bit | final_sticky_bit)
296                  m.d.av_comb += output_round_bit.eq(final_guard_bit)
297  
298              with m.If((~sub_op) & (output_sig[-1])):
299                  m.d.av_comb += output_final_sig.eq(output_sig >> 1)
300                  m.d.av_comb += output_final_exp.eq(output_exp + 1)
301  
302              with m.Elif((sub_op & (~output_sig[-2])) & (output_exp > 0)):
303                  with m.If(output_exp == 1):
304                      m.d.av_comb += output_final_sig.eq(output_sig)
305                  with m.Else():
306                      m.d.av_comb += output_final_sig.eq((output_sig << 1) | shift_in_bit)
307                  m.d.av_comb += output_final_exp.eq(output_exp - 1)
308  
309              with m.Else():
310                  m.d.av_comb += output_final_sig.eq(output_sig)
311                  with m.If((output_exp == 0) & ((output_sig[-2]))):
312                      m.d.av_comb += output_final_exp.eq(1)
313                  with m.Else():
314                      m.d.av_comb += output_final_exp.eq(output_exp)
315  
316              return {
317                  "out_exp": output_final_exp,
318                  "out_sig": output_final_sig,
319                  "output_round": output_round_bit,
320                  "output_sticky": output_sticky_bit,
321              }
322  
323          return m