sequence.py
  1  from amaranth import *
  2  
  3  from coreblocks.func_blocks.fu.unsigned_multiplication.common import DSPMulUnit, MulBaseUnsigned
  4  from coreblocks.params import GenParams
  5  from transactron import *
  6  from transactron.core import def_method
  7  
  8  __all__ = ["SequentialUnsignedMul"]
  9  
 10  
 11  class RecursiveWithSingleDSPMul(Elaboratable):
 12      """
 13      Module with combinatorial connection for sequential multiplication using single DSP unit.
 14      It uses classic recursive multiplication algorithm.
 15  
 16      Attributes
 17      ----------
 18      i1: Signal(unsigned(n)), in
 19          First factor.
 20      i2: Signal(unsigned(n)), in
 21          Second factor.
 22      result: Signal(unsigned(n * 2)), out
 23          Product of inputted factors.
 24      confirm: Signal(1), out
 25          Signal providing information if computation is finished.
 26      reset: Signal(1), in
 27          Signal erasing previous result, and starting new computation of provided inputs.
 28      """
 29  
 30      def __init__(self, dsp: DSPMulUnit, n: int):
 31          """
 32          Parameters
 33          ----------
 34          dsp: DSPMulUnit
 35              Dsp unit performing multiplications in single clock cycle.
 36          n: int
 37              Bit width of multiplied numbers.
 38          """
 39          self.n = n
 40          self.dsp = dsp
 41  
 42          self.i1 = Signal(unsigned(n))
 43          self.i2 = Signal(unsigned(n))
 44          self.result = Signal(unsigned(n * 2))
 45          self.confirm = Signal()
 46          self.reset = Signal()
 47  
 48      def elaborate(self, platform) -> TModule:
 49          if self.n <= self.dsp.n:
 50              m = TModule()
 51              with m.If(self.reset):
 52                  m.d.sync += self.confirm.eq(0)
 53  
 54              with m.If(~self.confirm & ~self.reset):
 55                  with Transaction().body(m):
 56                      res = self.dsp.compute(m, i1=self.i1, i2=self.i2)
 57                      m.d.sync += self.result.eq(res)
 58                      m.d.sync += self.confirm.eq(1)
 59  
 60              return m
 61          else:
 62              return self.recursive_module()
 63  
 64      def recursive_module(self) -> TModule:
 65          # Classic Multiplying Algorithm
 66          #
 67          # bit: N       N/2      0
 68          #      +--------+-------+
 69          # i1 : | high_1 | low_1 |
 70          #      +--------+-------+
 71          # i2 : | high_2 | low_2 |
 72          #      +--------+-------+
 73          #
 74          #  result_ll = low_1 * low_2
 75          #  result_uu = high_1 * high_2
 76          #  result_lu = low_1 * high_2 +
 77          #  result_ul = high_1 * low_2
 78          #
 79          #  i1 * i2 = (high_1 << N/2 + low_1) * (high_2 << N/2 + low_2) =
 80          #          = (high_1 * high_2) << N + (high_1 * low_2 + high_2 * low_1) << N/2 + low_1 * low_2
 81          #          = result_uu << N + (result_lu + result_ul) << N/2 + result_ll
 82  
 83          m = TModule()
 84  
 85          m.submodules.low_mul = mul1 = RecursiveWithSingleDSPMul(self.dsp, self.n // 2)
 86          m.submodules.mid_mul = mul2 = RecursiveWithSingleDSPMul(self.dsp, self.n // 2)
 87          m.submodules.upper_mul = mul3 = RecursiveWithSingleDSPMul(self.dsp, self.n // 2)
 88          m.submodules.mul4 = mul4 = RecursiveWithSingleDSPMul(self.dsp, self.n // 2)
 89  
 90          m.d.comb += mul1.reset.eq(self.reset)
 91          m.d.comb += mul2.reset.eq(self.reset)
 92          m.d.comb += mul3.reset.eq(self.reset)
 93          m.d.comb += mul4.reset.eq(self.reset)
 94  
 95          m.d.comb += self.confirm.eq(mul1.confirm & mul2.confirm & mul3.confirm & mul4.confirm)
 96  
 97          result_ll = Signal(unsigned(self.n))
 98          result_ul = Signal(unsigned(self.n))
 99          result_lu = Signal(unsigned(self.n))
100          result_uu = Signal(unsigned(self.n))
101  
102          m.d.comb += mul1.i1.eq(self.i1[: self.n // 2])
103          m.d.comb += mul1.i2.eq(self.i2[: self.n // 2])
104          m.d.comb += result_ll.eq(mul1.result)
105  
106          m.d.comb += mul2.i1.eq(self.i1[self.n // 2 :])
107          m.d.comb += mul2.i2.eq(self.i2[: self.n // 2])
108          m.d.comb += result_ul.eq(mul2.result)
109  
110          m.d.comb += mul3.i1.eq(self.i1[: self.n // 2])
111          m.d.comb += mul3.i2.eq(self.i2[self.n // 2 :])
112          m.d.comb += result_lu.eq(mul3.result)
113  
114          m.d.comb += mul4.i1.eq(self.i1[self.n // 2 :])
115          m.d.comb += mul4.i2.eq(self.i2[self.n // 2 :])
116          m.d.comb += result_uu.eq(mul4.result)
117  
118          m.d.comb += self.result.eq(result_ll + ((result_ul + result_lu) << self.n // 2) + (result_uu << self.n))
119  
120          return m
121  
122  
123  class SequentialUnsignedMul(MulBaseUnsigned):
124      """
125      Module with @see{MulBaseUnsigned} interface performing sequential multiplication using single DSP unit.
126      It uses classic recursive multiplication algorithm.
127      """
128  
129      def __init__(self, gen_params: GenParams, dsp_width: int = 8):
130          super().__init__(gen_params)
131          self.dsp_width = dsp_width
132  
133      def elaborate(self, platform):
134          m = TModule()
135          m.submodules.dsp = dsp = DSPMulUnit(self.dsp_width)
136          m.submodules.multiplier = multiplier = RecursiveWithSingleDSPMul(dsp, self.gen_params.isa.xlen)
137  
138          accepted = Signal(1, init=1)
139          m.d.sync += multiplier.reset.eq(0)
140  
141          @def_method(m, self.issue, ready=accepted)
142          def _(arg):
143              m.d.sync += multiplier.i1.eq(arg.i1)
144              m.d.sync += multiplier.i2.eq(arg.i2)
145  
146              m.d.sync += multiplier.reset.eq(1)
147              m.d.sync += accepted.eq(0)
148  
149          @def_method(m, self.accept, ready=(~accepted) & multiplier.confirm & ~multiplier.reset)
150          def _(arg):
151              m.d.sync += accepted.eq(1)
152              return {"o": multiplier.result}
153  
154          return m