fast_recursive.py
  1  from amaranth import *
  2  
  3  from coreblocks.func_blocks.fu.unsigned_multiplication.common import MulBaseUnsigned, DSPMulUnit
  4  from coreblocks.params import GenParams
  5  from transactron import *
  6  from transactron.core import def_method
  7  
  8  __all__ = ["RecursiveUnsignedMul"]
  9  
 10  from transactron.lib import FIFO
 11  
 12  
 13  class FastRecursiveMul(Elaboratable):
 14      """
 15      Module with combinatorial connection for fast recursive multiplication using as many DSPMulUnit as required for
 16      one clock multiplication.
 17  
 18      Attributes
 19      ----------
 20      i1: Signal(unsigned(n)), in
 21          First factor.
 22      i2: Signal(unsigned(n)), in
 23          Second factor.
 24      r: Signal(unsigned(n * 2)), out
 25          Product of inputted factors.
 26      """
 27  
 28      def __init__(self, n: int, dsp_width: int):
 29          """
 30          Parameters
 31          ----------
 32          n: int
 33              Bit width of multiplied numbers.
 34          dsp_width: int
 35              Bit width of number multiplied bu dsp unit.
 36          """
 37          self.n = n
 38          self.dsp_width = dsp_width
 39  
 40          self.i1 = Signal(unsigned(n))
 41          self.i2 = Signal(unsigned(n))
 42          self.r = Signal(unsigned(n * 2))
 43  
 44      def elaborate(self, platform):
 45          if self.n <= self.dsp_width:
 46              m = TModule()
 47              m.submodules.dsp = dsp = DSPMulUnit(self.dsp_width)
 48              with Transaction().body(m):
 49                  # The bit width of the `i1` and `i2` parameters of `dsp` is different than of `self.i1`
 50                  # and `self.i2`, which triggers an error. Using `| 0` silences it.
 51                  res = dsp.compute(m, i1=self.i1 | 0, i2=self.i2 | 0)
 52                  m.d.comb += self.r.eq(res)
 53  
 54              return m
 55          else:
 56              return self.recursive_module()
 57  
 58      def recursive_module(self) -> TModule:
 59          # Fast Recursive Multiplying Algorithm
 60          #
 61          # bit: N       N/2      0
 62          #      +--------+-------+
 63          # i1 : | high_1 | low_1 |
 64          #      +--------+-------+
 65          # i2 : | high_2 | low_2 |
 66          #      +--------+-------+
 67          #
 68          #  result_low   = low_1 * low_2
 69          #  result_upper = high_1 * high_2
 70          #  result_mid   = (low_1 + high_1) * (low_2 + high_2) =
 71          #               = low_1 * low_2 + high_1 * high_2 + low_1 * high_2 + low_2 * high_1
 72          #               = result_low + result_upper + low_1 * high_2 + low_2 * high_1
 73          #
 74          #  i1 * i2 = (high_1 << N/2 + low_1) * (high_2 << N/2 + low_2) =
 75          #          = (high_1 * high_2) << N + (high_1 * low_2 + high_2 * low_1) << N/2 + low_1 * low_2 =
 76          #          = result_upper << N + (result_mid - result_low - result_upper) << N/2 + result_low
 77  
 78          m = TModule()
 79  
 80          upper = self.n // 2
 81          lower = (self.n + 1) // 2
 82          m.submodules.low_mul = low_mul = FastRecursiveMul(lower, self.dsp_width)
 83          m.submodules.mid_mul = mid_mul = FastRecursiveMul(lower + 1, self.dsp_width)
 84          m.submodules.upper_mul = upper_mul = FastRecursiveMul(upper, self.dsp_width)
 85  
 86          result_low = Signal(unsigned(2 * lower))
 87          result_mid = Signal(unsigned(2 * lower + 2))
 88          result_upper = Signal(unsigned(2 * upper))
 89  
 90          m.d.comb += low_mul.i1.eq(self.i1[:lower])
 91          m.d.comb += low_mul.i2.eq(self.i2[:lower])
 92          m.d.comb += result_low.eq(low_mul.r)
 93  
 94          m.d.comb += mid_mul.i1.eq(self.i1[:lower] + self.i1[lower:])
 95          m.d.comb += mid_mul.i2.eq(self.i2[:lower] + self.i2[lower:])
 96          m.d.comb += result_mid.eq(mid_mul.r)
 97  
 98          m.d.comb += upper_mul.i1.eq(self.i1[lower:])
 99          m.d.comb += upper_mul.i2.eq(self.i2[lower:])
100          m.d.comb += result_upper.eq(upper_mul.r)
101  
102          m.d.comb += self.r.eq(
103              result_low + ((result_mid - result_low - result_upper) << lower) + (result_upper << 2 * lower)
104          )
105  
106          return m
107  
108  
109  class RecursiveUnsignedMul(MulBaseUnsigned):
110      """
111      Module with @see{MulBaseUnsigned} interface performing fast recursive multiplication within 1 clock cycle.
112      """
113  
114      def __init__(self, gen_params: GenParams, dsp_width: int = 8):
115          super().__init__(gen_params)
116          self.dsp_width = dsp_width
117  
118      def elaborate(self, platform):
119          m = TModule()
120          m.submodules.fifo = fifo = FIFO([("o", 2 * self.gen_params.isa.xlen)], 2)
121  
122          m.submodules.mul = mul = FastRecursiveMul(self.gen_params.isa.xlen, self.dsp_width)
123  
124          @def_method(m, self.issue)
125          def _(arg):
126              m.d.comb += mul.i1.eq(arg.i1)
127              m.d.comb += mul.i2.eq(arg.i2)
128              fifo.write(m, o=mul.r)
129  
130          @def_method(m, self.accept)
131          def _(arg):
132              return fifo.read(m)
133  
134          return m