pipelined.py
  1  from amaranth import *
  2  import math
  3  
  4  from coreblocks.func_blocks.fu.unsigned_multiplication.common import MulBaseUnsigned, DSPMulUnit
  5  from coreblocks.params import GenParams
  6  from transactron import *
  7  from transactron.core import def_method
  8  
  9  __all__ = ["PipelinedUnsignedMul"]
 10  
 11  
 12  class PipelinedMul(Elaboratable):
 13      """
 14      PipelinedMul uses pipelining mechanism in order to reduce the time needed to multiply multiple pairs of numbers.
 15      This class breaks down the multiplication of large bit-width operands into smaller chunks that
 16      fit into the DSP units,
 17      processes these chunks iteratively,
 18      and then combines the results using a multi-level pipeline to form the final product.
 19  
 20      Attributes
 21      ----------
 22          dsp_width (int): Width of each DSP unit.
 23          dsp_number (int): Number of DSP units available.
 24          n (int): Bit width of the numbers to be multiplied.
 25  
 26      Signals:
 27      --------
 28          ready (Signal): Indicates if the multiplier is ready to accept new inputs.
 29          issue (Signal): Signals the start of a new multiplication.
 30          i1 (Signal): First multiplicand, extended to align with the DSP width.
 31          i2 (Signal): Second multiplicand, extended to align with the DSP width.
 32          result (Signal): Final result of the multiplication.
 33          valid (Signal): Indicates if the result is valid.
 34          getting_result (Signal): Signals the process of retrieving the result.
 35      """
 36  
 37      def __init__(self, dsp_width: int, dsp_number: int, n: int):
 38          self.n = n
 39          self.n_padding = dsp_width * 2 ** (math.ceil(math.log2(n / dsp_width)))
 40          self.dsp_width = dsp_width
 41          self.dsp_number = dsp_number
 42          self.number_of_chunks = self.n_padding // self.dsp_width
 43          self.number_of_chunks_multiplication = math.ceil(n / dsp_width)
 44          self.number_of_multiplications = self.number_of_chunks_multiplication**2
 45          self.number_of_steps = math.ceil(self.number_of_multiplications / self.dsp_number)
 46          self.result_lvl = math.ceil(math.log2(self.number_of_chunks))
 47  
 48          self.ready = Signal()
 49          self.issue = Signal()
 50          self.step = Signal(range(self.number_of_steps + 1), init=self.number_of_steps)
 51          self.i1 = Signal(self.n_padding)
 52          self.i2 = Signal(self.n_padding)
 53          self.result = Signal(2 * n)
 54          self.valid = Signal()
 55          self.getting_result = Signal()
 56  
 57          self.values_array = [
 58              [
 59                  [Signal(self.dsp_width * (1 << i) * 2) for _ in range(self.number_of_chunks >> i)]
 60                  for _ in range(self.number_of_chunks >> i)
 61              ]
 62              for i in range(self.result_lvl + 1)
 63          ]
 64          self.valid_array = [Signal() for _ in range(self.result_lvl + 1)]
 65  
 66      def elaborate(self, platform=None):
 67          m = TModule()
 68  
 69          self.dsp_units = []
 70          for i in range(self.dsp_number):
 71              unit = DSPMulUnit(self.dsp_width)
 72              self.dsp_units.append(unit)
 73              setattr(m.submodules, f"dsp_unit_{i}", unit)
 74  
 75          with m.If((self.step < (self.number_of_steps - 1)) | (self.valid & ~self.getting_result)):
 76              m.d.comb += self.ready.eq(0)
 77          with m.Else():
 78              m.d.comb += self.ready.eq(1)
 79  
 80          m.d.comb += self.result.eq(self.values_array[self.result_lvl][0][0])
 81          m.d.comb += self.valid.eq(self.valid_array[self.result_lvl])
 82  
 83          with m.If(~self.valid | self.getting_result):
 84              for i in range(1, self.result_lvl + 1):
 85                  m.d.sync += self.valid_array[i].eq(self.valid_array[i - 1])
 86  
 87              with m.If(self.step == self.number_of_steps - 1):
 88                  m.d.sync += self.valid_array[0].eq(1)
 89              with m.Else():
 90                  m.d.sync += self.valid_array[0].eq(0)
 91  
 92              with m.If(self.step < self.number_of_steps):
 93                  m.d.sync += self.step.eq(self.step + 1)
 94  
 95              with m.If(self.issue):
 96                  m.d.sync += self.step.eq(0)
 97  
 98              for i in range(self.number_of_multiplications):
 99                  a = i // self.number_of_chunks_multiplication
100                  b = i % self.number_of_chunks_multiplication
101                  chunk_i1 = self.i1[a * self.dsp_width : (a + 1) * self.dsp_width]
102                  chunk_i2 = self.i2[b * self.dsp_width : (b + 1) * self.dsp_width]
103                  run_on_step = i // self.dsp_number
104                  dsp_idx = i % self.dsp_number
105                  with Transaction().body(m, ready=(self.step == run_on_step)):
106                      res = self.dsp_units[dsp_idx].compute(m, i1=chunk_i1, i2=chunk_i2)
107                      m.d.sync += self.values_array[0][a][b].eq(res)
108  
109              shift_size = self.dsp_width
110              for i in range(1, self.result_lvl + 1):
111                  shift_size = shift_size << 1
112                  for j in range(self.number_of_chunks >> i):
113                      for k in range(self.number_of_chunks >> i):
114                          ll = self.values_array[i - 1][2 * j][2 * k]
115                          lu = self.values_array[i - 1][2 * j][2 * k + 1]
116                          ul = self.values_array[i - 1][2 * j + 1][2 * k]
117                          uu = self.values_array[i - 1][2 * j + 1][2 * k + 1]
118                          m.d.sync += self.values_array[i][j][k].eq(
119                              ll + ((ul + lu) << (shift_size >> 1)) + (uu << shift_size)
120                          )
121          return m
122  
123  
124  class PipelinedUnsignedMul(MulBaseUnsigned):
125      def __init__(self, gen_params: GenParams, dsp_width: int = 18, dsp_number: int = 7):
126          super().__init__(gen_params)
127          self.dsp_width = dsp_width
128          self.dsp_number = dsp_number
129  
130      def elaborate(self, platform):
131          m = TModule()
132          m.submodules.mul = mul = PipelinedMul(self.dsp_width, self.dsp_number, self.gen_params.isa.xlen)
133  
134          @def_method(m, self.issue, ready=mul.ready)
135          def _(arg):
136              m.d.sync += mul.i1.eq(arg.i1)
137              m.d.sync += mul.i2.eq(arg.i2)
138              m.d.comb += mul.issue.eq(1)
139  
140          @def_method(m, self.accept, ready=mul.valid)
141          def _(arg):
142              m.d.comb += mul.getting_result.eq(1)
143              return mul.result
144  
145          return m