/ coreblocks / func_blocks / fu / fpu / lza.py
lza.py
  1  from amaranth import *
  2  from transactron import TModule, Method, def_method
  3  from coreblocks.func_blocks.fu.fpu.fpu_common import FPUParams
  4  from transactron.utils.amaranth_ext import count_leading_zeros
  5  
  6  
  7  class LZAMethodLayout:
  8      """LZA module layouts for methods
  9  
 10      Parameters
 11      ----------
 12      fpu_params: FPUParams
 13          FPU parameters
 14      """
 15  
 16      def __init__(self, *, fpu_params: FPUParams):
 17          """
 18          sig_a - significand of a
 19          sig_b - significand of b
 20          carry - indicates if we want to predict result of a+b or a+b+1
 21          shift_amount - position to shift needed to normalize number
 22          is_zero - indicates if result is zero
 23          """
 24          self.predict_in_layout = [
 25              ("sig_a", fpu_params.sig_width),
 26              ("sig_b", fpu_params.sig_width),
 27              ("carry", 1),
 28          ]
 29          self.predict_out_layout = [
 30              ("shift_amount", range(fpu_params.sig_width)),
 31              ("is_zero", 1),
 32          ]
 33  
 34  
 35  class LZAModule(Elaboratable):
 36      """LZA module
 37      Based on: https://userpages.cs.umbc.edu/phatak/645/supl/lza/lza-survey-arith01.pdf
 38      After performing subtracion, we may have to normalize floating point numbers and
 39      For that, we have to know the number of leading zeros.
 40      The most basic approach includes using LZC (leading zero counter) after subtracion,
 41      a more advanced approach includes using LZA (Leading Zero Anticipator) to predict the number of
 42      leading zeroes. It is worth noting that this LZA module works under assumptions that
 43      significands are in two's complement and that before complementation sig_a was greater
 44      or equal to sig_b. Another thing worth noting is that LZA works with error = 1.
 45      That means that if 'n' is the result of the LZA module, in reality, to normalize
 46      number we may have to shift left by 'n' or 'n+1'. There are few techniques of
 47      dealing with that error like specially designed shifters or predicting the error
 48      but the most basic approach is to just use multiplexer after shifter to perform
 49      one more shift left if necessary.
 50  
 51      Parameters
 52      ----------
 53      fpu_params: FPUParams
 54          FPU rounding module parameters
 55  
 56      Attributes
 57      ----------
 58      predict_request: Method
 59          Transactional method for initiating leading zeros prediction.
 60          Takes 'predict_in_layout' as argument
 61          Returns shift amount as 'predict_out_layout'
 62      """
 63  
 64      def __init__(self, *, fpu_params: FPUParams):
 65  
 66          self.lza_params = fpu_params
 67          self.method_layouts = LZAMethodLayout(fpu_params=self.lza_params)
 68          self.predict_request = Method(
 69              i=self.method_layouts.predict_in_layout,
 70              o=self.method_layouts.predict_out_layout,
 71          )
 72  
 73      def elaborate(self, platform):
 74          m = TModule()
 75  
 76          @def_method(m, self.predict_request)
 77          def _(sig_a, sig_b, carry):
 78  
 79              t = Signal(self.lza_params.sig_width + 1)
 80              g = Signal(self.lza_params.sig_width + 1)
 81              z = Signal(self.lza_params.sig_width + 1)
 82              f = Signal(self.lza_params.sig_width)
 83              shift_amount = Signal(range(self.lza_params.sig_width))
 84              is_zero = Signal(1)
 85  
 86              m.d.av_comb += t.eq((sig_a ^ sig_b) << 1)
 87              m.d.av_comb += g.eq((sig_a & sig_b) << 1)
 88              m.d.av_comb += z.eq(((sig_a | sig_b) << 1))
 89              with m.If(carry):
 90                  m.d.av_comb += g[0].eq(1)
 91                  m.d.av_comb += z[0].eq(1)
 92  
 93              for i in reversed(range(1, self.lza_params.sig_width + 1)):
 94                  m.d.av_comb += f[i - 1].eq((t[i] ^ z[i - 1]))
 95  
 96              m.d.av_comb += shift_amount.eq(count_leading_zeros(f))
 97  
 98              m.d.av_comb += is_zero.eq((carry & t[1 : self.lza_params.sig_width].all()))
 99  
100              return {
101                  "shift_amount": shift_amount,
102                  "is_zero": is_zero,
103              }
104  
105          return m