fast_recursive.py
1 from amaranth import * 2 3 from coreblocks.func_blocks.fu.unsigned_multiplication.common import MulBaseUnsigned, DSPMulUnit 4 from coreblocks.params import GenParams 5 from transactron import * 6 from transactron.core import def_method 7 8 __all__ = ["RecursiveUnsignedMul"] 9 10 from transactron.lib import FIFO 11 12 13 class FastRecursiveMul(Elaboratable): 14 """ 15 Module with combinatorial connection for fast recursive multiplication using as many DSPMulUnit as required for 16 one clock multiplication. 17 18 Attributes 19 ---------- 20 i1: Signal(unsigned(n)), in 21 First factor. 22 i2: Signal(unsigned(n)), in 23 Second factor. 24 r: Signal(unsigned(n * 2)), out 25 Product of inputted factors. 26 """ 27 28 def __init__(self, n: int, dsp_width: int): 29 """ 30 Parameters 31 ---------- 32 n: int 33 Bit width of multiplied numbers. 34 dsp_width: int 35 Bit width of number multiplied bu dsp unit. 36 """ 37 self.n = n 38 self.dsp_width = dsp_width 39 40 self.i1 = Signal(unsigned(n)) 41 self.i2 = Signal(unsigned(n)) 42 self.r = Signal(unsigned(n * 2)) 43 44 def elaborate(self, platform): 45 if self.n <= self.dsp_width: 46 m = TModule() 47 m.submodules.dsp = dsp = DSPMulUnit(self.dsp_width) 48 with Transaction().body(m): 49 # The bit width of the `i1` and `i2` parameters of `dsp` is different than of `self.i1` 50 # and `self.i2`, which triggers an error. Using `| 0` silences it. 51 res = dsp.compute(m, i1=self.i1 | 0, i2=self.i2 | 0) 52 m.d.comb += self.r.eq(res) 53 54 return m 55 else: 56 return self.recursive_module() 57 58 def recursive_module(self) -> TModule: 59 # Fast Recursive Multiplying Algorithm 60 # 61 # bit: N N/2 0 62 # +--------+-------+ 63 # i1 : | high_1 | low_1 | 64 # +--------+-------+ 65 # i2 : | high_2 | low_2 | 66 # +--------+-------+ 67 # 68 # result_low = low_1 * low_2 69 # result_upper = high_1 * high_2 70 # result_mid = (low_1 + high_1) * (low_2 + high_2) = 71 # = low_1 * low_2 + high_1 * high_2 + low_1 * high_2 + low_2 * high_1 72 # = result_low + result_upper + low_1 * high_2 + low_2 * high_1 73 # 74 # i1 * i2 = (high_1 << N/2 + low_1) * (high_2 << N/2 + low_2) = 75 # = (high_1 * high_2) << N + (high_1 * low_2 + high_2 * low_1) << N/2 + low_1 * low_2 = 76 # = result_upper << N + (result_mid - result_low - result_upper) << N/2 + result_low 77 78 m = TModule() 79 80 upper = self.n // 2 81 lower = (self.n + 1) // 2 82 m.submodules.low_mul = low_mul = FastRecursiveMul(lower, self.dsp_width) 83 m.submodules.mid_mul = mid_mul = FastRecursiveMul(lower + 1, self.dsp_width) 84 m.submodules.upper_mul = upper_mul = FastRecursiveMul(upper, self.dsp_width) 85 86 result_low = Signal(unsigned(2 * lower)) 87 result_mid = Signal(unsigned(2 * lower + 2)) 88 result_upper = Signal(unsigned(2 * upper)) 89 90 m.d.comb += low_mul.i1.eq(self.i1[:lower]) 91 m.d.comb += low_mul.i2.eq(self.i2[:lower]) 92 m.d.comb += result_low.eq(low_mul.r) 93 94 m.d.comb += mid_mul.i1.eq(self.i1[:lower] + self.i1[lower:]) 95 m.d.comb += mid_mul.i2.eq(self.i2[:lower] + self.i2[lower:]) 96 m.d.comb += result_mid.eq(mid_mul.r) 97 98 m.d.comb += upper_mul.i1.eq(self.i1[lower:]) 99 m.d.comb += upper_mul.i2.eq(self.i2[lower:]) 100 m.d.comb += result_upper.eq(upper_mul.r) 101 102 m.d.comb += self.r.eq( 103 result_low + ((result_mid - result_low - result_upper) << lower) + (result_upper << 2 * lower) 104 ) 105 106 return m 107 108 109 class RecursiveUnsignedMul(MulBaseUnsigned): 110 """ 111 Module with @see{MulBaseUnsigned} interface performing fast recursive multiplication within 1 clock cycle. 112 """ 113 114 def __init__(self, gen_params: GenParams, dsp_width: int = 8): 115 super().__init__(gen_params) 116 self.dsp_width = dsp_width 117 118 def elaborate(self, platform): 119 m = TModule() 120 m.submodules.fifo = fifo = FIFO([("o", 2 * self.gen_params.isa.xlen)], 2) 121 122 m.submodules.mul = mul = FastRecursiveMul(self.gen_params.isa.xlen, self.dsp_width) 123 124 @def_method(m, self.issue) 125 def _(arg): 126 m.d.comb += mul.i1.eq(arg.i1) 127 m.d.comb += mul.i2.eq(arg.i2) 128 fifo.write(m, o=mul.r) 129 130 @def_method(m, self.accept) 131 def _(arg): 132 return fifo.read(m) 133 134 return m