sequence.py
1 from amaranth import * 2 3 from coreblocks.func_blocks.fu.unsigned_multiplication.common import DSPMulUnit, MulBaseUnsigned 4 from coreblocks.params import GenParams 5 from transactron import * 6 from transactron.core import def_method 7 8 __all__ = ["SequentialUnsignedMul"] 9 10 11 class RecursiveWithSingleDSPMul(Elaboratable): 12 """ 13 Module with combinatorial connection for sequential multiplication using single DSP unit. 14 It uses classic recursive multiplication algorithm. 15 16 Attributes 17 ---------- 18 i1: Signal(unsigned(n)), in 19 First factor. 20 i2: Signal(unsigned(n)), in 21 Second factor. 22 result: Signal(unsigned(n * 2)), out 23 Product of inputted factors. 24 confirm: Signal(1), out 25 Signal providing information if computation is finished. 26 reset: Signal(1), in 27 Signal erasing previous result, and starting new computation of provided inputs. 28 """ 29 30 def __init__(self, dsp: DSPMulUnit, n: int): 31 """ 32 Parameters 33 ---------- 34 dsp: DSPMulUnit 35 Dsp unit performing multiplications in single clock cycle. 36 n: int 37 Bit width of multiplied numbers. 38 """ 39 self.n = n 40 self.dsp = dsp 41 42 self.i1 = Signal(unsigned(n)) 43 self.i2 = Signal(unsigned(n)) 44 self.result = Signal(unsigned(n * 2)) 45 self.confirm = Signal() 46 self.reset = Signal() 47 48 def elaborate(self, platform) -> TModule: 49 if self.n <= self.dsp.n: 50 m = TModule() 51 with m.If(self.reset): 52 m.d.sync += self.confirm.eq(0) 53 54 with m.If(~self.confirm & ~self.reset): 55 with Transaction().body(m): 56 res = self.dsp.compute(m, i1=self.i1, i2=self.i2) 57 m.d.sync += self.result.eq(res) 58 m.d.sync += self.confirm.eq(1) 59 60 return m 61 else: 62 return self.recursive_module() 63 64 def recursive_module(self) -> TModule: 65 # Classic Multiplying Algorithm 66 # 67 # bit: N N/2 0 68 # +--------+-------+ 69 # i1 : | high_1 | low_1 | 70 # +--------+-------+ 71 # i2 : | high_2 | low_2 | 72 # +--------+-------+ 73 # 74 # result_ll = low_1 * low_2 75 # result_uu = high_1 * high_2 76 # result_lu = low_1 * high_2 + 77 # result_ul = high_1 * low_2 78 # 79 # i1 * i2 = (high_1 << N/2 + low_1) * (high_2 << N/2 + low_2) = 80 # = (high_1 * high_2) << N + (high_1 * low_2 + high_2 * low_1) << N/2 + low_1 * low_2 81 # = result_uu << N + (result_lu + result_ul) << N/2 + result_ll 82 83 m = TModule() 84 85 m.submodules.low_mul = mul1 = RecursiveWithSingleDSPMul(self.dsp, self.n // 2) 86 m.submodules.mid_mul = mul2 = RecursiveWithSingleDSPMul(self.dsp, self.n // 2) 87 m.submodules.upper_mul = mul3 = RecursiveWithSingleDSPMul(self.dsp, self.n // 2) 88 m.submodules.mul4 = mul4 = RecursiveWithSingleDSPMul(self.dsp, self.n // 2) 89 90 m.d.comb += mul1.reset.eq(self.reset) 91 m.d.comb += mul2.reset.eq(self.reset) 92 m.d.comb += mul3.reset.eq(self.reset) 93 m.d.comb += mul4.reset.eq(self.reset) 94 95 m.d.comb += self.confirm.eq(mul1.confirm & mul2.confirm & mul3.confirm & mul4.confirm) 96 97 result_ll = Signal(unsigned(self.n)) 98 result_ul = Signal(unsigned(self.n)) 99 result_lu = Signal(unsigned(self.n)) 100 result_uu = Signal(unsigned(self.n)) 101 102 m.d.comb += mul1.i1.eq(self.i1[: self.n // 2]) 103 m.d.comb += mul1.i2.eq(self.i2[: self.n // 2]) 104 m.d.comb += result_ll.eq(mul1.result) 105 106 m.d.comb += mul2.i1.eq(self.i1[self.n // 2 :]) 107 m.d.comb += mul2.i2.eq(self.i2[: self.n // 2]) 108 m.d.comb += result_ul.eq(mul2.result) 109 110 m.d.comb += mul3.i1.eq(self.i1[: self.n // 2]) 111 m.d.comb += mul3.i2.eq(self.i2[self.n // 2 :]) 112 m.d.comb += result_lu.eq(mul3.result) 113 114 m.d.comb += mul4.i1.eq(self.i1[self.n // 2 :]) 115 m.d.comb += mul4.i2.eq(self.i2[self.n // 2 :]) 116 m.d.comb += result_uu.eq(mul4.result) 117 118 m.d.comb += self.result.eq(result_ll + ((result_ul + result_lu) << self.n // 2) + (result_uu << self.n)) 119 120 return m 121 122 123 class SequentialUnsignedMul(MulBaseUnsigned): 124 """ 125 Module with @see{MulBaseUnsigned} interface performing sequential multiplication using single DSP unit. 126 It uses classic recursive multiplication algorithm. 127 """ 128 129 def __init__(self, gen_params: GenParams, dsp_width: int = 8): 130 super().__init__(gen_params) 131 self.dsp_width = dsp_width 132 133 def elaborate(self, platform): 134 m = TModule() 135 m.submodules.dsp = dsp = DSPMulUnit(self.dsp_width) 136 m.submodules.multiplier = multiplier = RecursiveWithSingleDSPMul(dsp, self.gen_params.isa.xlen) 137 138 accepted = Signal(1, init=1) 139 m.d.sync += multiplier.reset.eq(0) 140 141 @def_method(m, self.issue, ready=accepted) 142 def _(arg): 143 m.d.sync += multiplier.i1.eq(arg.i1) 144 m.d.sync += multiplier.i2.eq(arg.i2) 145 146 m.d.sync += multiplier.reset.eq(1) 147 m.d.sync += accepted.eq(0) 148 149 @def_method(m, self.accept, ready=(~accepted) & multiplier.confirm & ~multiplier.reset) 150 def _(arg): 151 m.d.sync += accepted.eq(1) 152 return {"o": multiplier.result} 153 154 return m