pipelined.py
1 from amaranth import * 2 import math 3 4 from coreblocks.func_blocks.fu.unsigned_multiplication.common import MulBaseUnsigned, DSPMulUnit 5 from coreblocks.params import GenParams 6 from transactron import * 7 from transactron.core import def_method 8 9 __all__ = ["PipelinedUnsignedMul"] 10 11 12 class PipelinedMul(Elaboratable): 13 """ 14 PipelinedMul uses pipelining mechanism in order to reduce the time needed to multiply multiple pairs of numbers. 15 This class breaks down the multiplication of large bit-width operands into smaller chunks that 16 fit into the DSP units, 17 processes these chunks iteratively, 18 and then combines the results using a multi-level pipeline to form the final product. 19 20 Attributes 21 ---------- 22 dsp_width (int): Width of each DSP unit. 23 dsp_number (int): Number of DSP units available. 24 n (int): Bit width of the numbers to be multiplied. 25 26 Signals: 27 -------- 28 ready (Signal): Indicates if the multiplier is ready to accept new inputs. 29 issue (Signal): Signals the start of a new multiplication. 30 i1 (Signal): First multiplicand, extended to align with the DSP width. 31 i2 (Signal): Second multiplicand, extended to align with the DSP width. 32 result (Signal): Final result of the multiplication. 33 valid (Signal): Indicates if the result is valid. 34 getting_result (Signal): Signals the process of retrieving the result. 35 """ 36 37 def __init__(self, dsp_width: int, dsp_number: int, n: int): 38 self.n = n 39 self.n_padding = dsp_width * 2 ** (math.ceil(math.log2(n / dsp_width))) 40 self.dsp_width = dsp_width 41 self.dsp_number = dsp_number 42 self.number_of_chunks = self.n_padding // self.dsp_width 43 self.number_of_chunks_multiplication = math.ceil(n / dsp_width) 44 self.number_of_multiplications = self.number_of_chunks_multiplication**2 45 self.number_of_steps = math.ceil(self.number_of_multiplications / self.dsp_number) 46 self.result_lvl = math.ceil(math.log2(self.number_of_chunks)) 47 48 self.ready = Signal() 49 self.issue = Signal() 50 self.step = Signal(range(self.number_of_steps + 1), init=self.number_of_steps) 51 self.i1 = Signal(self.n_padding) 52 self.i2 = Signal(self.n_padding) 53 self.result = Signal(2 * n) 54 self.valid = Signal() 55 self.getting_result = Signal() 56 57 self.values_array = [ 58 [ 59 [Signal(self.dsp_width * (1 << i) * 2) for _ in range(self.number_of_chunks >> i)] 60 for _ in range(self.number_of_chunks >> i) 61 ] 62 for i in range(self.result_lvl + 1) 63 ] 64 self.valid_array = [Signal() for _ in range(self.result_lvl + 1)] 65 66 def elaborate(self, platform=None): 67 m = TModule() 68 69 self.dsp_units = [] 70 for i in range(self.dsp_number): 71 unit = DSPMulUnit(self.dsp_width) 72 self.dsp_units.append(unit) 73 setattr(m.submodules, f"dsp_unit_{i}", unit) 74 75 with m.If((self.step < (self.number_of_steps - 1)) | (self.valid & ~self.getting_result)): 76 m.d.comb += self.ready.eq(0) 77 with m.Else(): 78 m.d.comb += self.ready.eq(1) 79 80 m.d.comb += self.result.eq(self.values_array[self.result_lvl][0][0]) 81 m.d.comb += self.valid.eq(self.valid_array[self.result_lvl]) 82 83 with m.If(~self.valid | self.getting_result): 84 for i in range(1, self.result_lvl + 1): 85 m.d.sync += self.valid_array[i].eq(self.valid_array[i - 1]) 86 87 with m.If(self.step == self.number_of_steps - 1): 88 m.d.sync += self.valid_array[0].eq(1) 89 with m.Else(): 90 m.d.sync += self.valid_array[0].eq(0) 91 92 with m.If(self.step < self.number_of_steps): 93 m.d.sync += self.step.eq(self.step + 1) 94 95 with m.If(self.issue): 96 m.d.sync += self.step.eq(0) 97 98 for i in range(self.number_of_multiplications): 99 a = i // self.number_of_chunks_multiplication 100 b = i % self.number_of_chunks_multiplication 101 chunk_i1 = self.i1[a * self.dsp_width : (a + 1) * self.dsp_width] 102 chunk_i2 = self.i2[b * self.dsp_width : (b + 1) * self.dsp_width] 103 run_on_step = i // self.dsp_number 104 dsp_idx = i % self.dsp_number 105 with Transaction().body(m, ready=(self.step == run_on_step)): 106 res = self.dsp_units[dsp_idx].compute(m, i1=chunk_i1, i2=chunk_i2) 107 m.d.sync += self.values_array[0][a][b].eq(res) 108 109 shift_size = self.dsp_width 110 for i in range(1, self.result_lvl + 1): 111 shift_size = shift_size << 1 112 for j in range(self.number_of_chunks >> i): 113 for k in range(self.number_of_chunks >> i): 114 ll = self.values_array[i - 1][2 * j][2 * k] 115 lu = self.values_array[i - 1][2 * j][2 * k + 1] 116 ul = self.values_array[i - 1][2 * j + 1][2 * k] 117 uu = self.values_array[i - 1][2 * j + 1][2 * k + 1] 118 m.d.sync += self.values_array[i][j][k].eq( 119 ll + ((ul + lu) << (shift_size >> 1)) + (uu << shift_size) 120 ) 121 return m 122 123 124 class PipelinedUnsignedMul(MulBaseUnsigned): 125 def __init__(self, gen_params: GenParams, dsp_width: int = 18, dsp_number: int = 7): 126 super().__init__(gen_params) 127 self.dsp_width = dsp_width 128 self.dsp_number = dsp_number 129 130 def elaborate(self, platform): 131 m = TModule() 132 m.submodules.mul = mul = PipelinedMul(self.dsp_width, self.dsp_number, self.gen_params.isa.xlen) 133 134 @def_method(m, self.issue, ready=mul.ready) 135 def _(arg): 136 m.d.sync += mul.i1.eq(arg.i1) 137 m.d.sync += mul.i2.eq(arg.i2) 138 m.d.comb += mul.issue.eq(1) 139 140 @def_method(m, self.accept, ready=mul.valid) 141 def _(arg): 142 m.d.comb += mul.getting_result.eq(1) 143 return mul.result 144 145 return m