long_division.py
1 """ 2 Algorithm - multi-cycle array divider 3 Method described here: https://yuhei1-horibe.medium.com/designing-divider-213fbd32beb2 4 """ 5 6 from amaranth import * 7 8 from coreblocks.params import GenParams 9 from transactron import * 10 from transactron.core import def_method 11 from coreblocks.func_blocks.fu.division.common import DividerBase 12 13 14 class RecursiveDivison(Elaboratable): 15 """ 16 Module that calculates n bits of quotient and 17 yields remainder that can be used in next iteration 18 19 If count == xlen the module is basically a one-cycle divider 20 21 If count if not aligned to power of 2, then in last iteration we need to calculate 22 different amount of bits. 23 So to optimize resource usage, there is partial remainder 24 that allows to reuse this module for a shorter calculation. 25 26 Attributes 27 ---------- 28 size: int 29 Size of inputs 30 n: int 31 Number of steps 32 partial_remainder_count: int 33 Number of steps for last iteration 34 divisor: Signal 35 Input divisor 36 dividend: Signal 37 Input dividend 38 input_remainder: Signal 39 Remainder carried over from previous iteration 40 quotient: Signal 41 Calculated n bits of quotient 42 remainder: Signal 43 Calculated remainder 44 partial_remainder: Signal 45 Calculated partial remainder 46 """ 47 48 def __init__(self, step_count: int, size: int, partial_remainder_count: int = 0): 49 self.size = size 50 self.step_count = step_count 51 self.partial_remainder_count = partial_remainder_count 52 53 self.divisor = Signal(unsigned(size)) 54 self.dividend = Signal(unsigned(size)) 55 self.input_remainder = Signal(unsigned(size)) 56 57 self.quotient = Signal(unsigned(size)) 58 self.remainder = Signal(unsigned(size)) 59 self.partial_remainder = Signal(unsigned(size)) 60 61 def elaborate(self, platform) -> TModule: 62 if self.step_count == 0: 63 # default case 64 m = TModule() 65 66 m.d.comb += self.quotient.eq(0) 67 m.d.comb += self.remainder.eq(self.input_remainder) 68 m.d.comb += self.partial_remainder.eq(self.input_remainder) 69 70 return m 71 else: 72 return self.recursive_module() 73 74 def recursive_module(self) -> TModule: 75 m = TModule() 76 77 # adding bit from dividend 78 concat = Signal(self.size) 79 m.d.comb += concat.eq(Cat(self.dividend[self.step_count - 1], self.input_remainder)) 80 81 # recursive module 82 m.submodules.rec_div = rec_div = RecursiveDivison( 83 self.step_count - 1, self.size, partial_remainder_count=self.partial_remainder_count - 1 84 ) 85 86 m.d.comb += rec_div.dividend.eq(self.dividend) 87 m.d.comb += rec_div.divisor.eq(self.divisor) 88 89 # Single step as described in article 90 with m.If(concat >= self.divisor): 91 m.d.comb += self.quotient[self.step_count - 1].eq(1) 92 m.d.comb += rec_div.input_remainder.eq(concat - self.divisor) 93 with m.Else(): 94 m.d.comb += self.quotient[self.step_count - 1].eq(0) 95 m.d.comb += rec_div.input_remainder.eq(concat) 96 97 # wiring up rest of result from recursive module 98 m.d.comb += self.quotient[: (self.step_count - 1)].eq(rec_div.quotient) 99 m.d.comb += self.remainder.eq(rec_div.remainder) 100 101 # partial remainder 102 if self.partial_remainder_count == 0: 103 m.d.comb += self.partial_remainder.eq(self.input_remainder) 104 else: 105 m.d.comb += self.partial_remainder.eq(rec_div.partial_remainder) 106 107 return m 108 109 110 class LongDivider(DividerBase): 111 """ 112 Module that handles iterative calculation 113 114 Attributes 115 ---------- 116 gen_params: GenParams 117 Gen Params 118 ipc: int 119 Number of steps per cycle 120 partial_remainder_count: int 121 Depth of last iteration 122 stages: int 123 Number of required iterations 124 odd_iteration: bool 125 flag whether last iteration requires partial calculation 126 """ 127 128 def __init__(self, gen_params: GenParams, ipc=4): 129 super().__init__(gen_params) 130 xlen = self.gen_params.isa.xlen 131 132 self.ipc = ipc 133 self.partial_remainder_count = xlen % ipc 134 135 self.stages = xlen // ipc + (1 if self.partial_remainder_count > 0 else 0) 136 self.odd_iteration = self.partial_remainder_count != 0 137 138 def elaborate(self, platform): 139 m = TModule() 140 xlen = self.gen_params.isa.xlen 141 xlen_log = self.gen_params.isa.xlen_log 142 143 m.submodules.divider = divider = RecursiveDivison( 144 self.ipc, xlen, partial_remainder_count=self.partial_remainder_count 145 ) 146 147 ready = Signal(1, init=1) 148 149 dividend = Signal(unsigned(xlen)) 150 divisor = Signal(unsigned(xlen)) 151 152 quotient = Signal(unsigned(xlen)) 153 remainder = Signal(unsigned(xlen)) 154 155 stage = Signal(unsigned(xlen_log + 1)) 156 157 # starting calculations 158 @def_method(m, self.issue, ready=ready) 159 def _(arg): 160 m.d.sync += dividend.eq(arg.dividend) 161 m.d.sync += divisor.eq(arg.divisor) 162 m.d.sync += remainder.eq(0) 163 m.d.sync += quotient.eq(0) 164 m.d.sync += stage.eq(0) 165 166 m.d.sync += ready.eq(0) 167 168 # returning results 169 @def_method(m, self.accept, ready=(~ready & (stage == self.stages))) 170 def _(arg): 171 m.d.sync += ready.eq(1) 172 return {"quotient": quotient, "remainder": remainder} 173 174 # clearing the unit 175 @def_method(m, self.clear) 176 def _(): 177 m.d.sync += stage.eq(0) 178 m.d.sync += ready.eq(1) 179 180 # performing calculations 181 with m.If(~ready & (stage != self.stages)): 182 special_stage = (self.stages == stage + 1) & self.odd_iteration 183 184 # assigning inputs to recursive divider 185 m.d.comb += divider.divisor.eq(divisor) 186 m.d.comb += divider.dividend.eq(dividend[xlen - self.ipc :]) 187 m.d.comb += divider.input_remainder.eq(remainder) 188 189 # dividend is a shift register 190 # so in each iteration upper bits are fed into recursive divider 191 m.d.sync += dividend.eq(dividend << self.ipc) 192 193 # if we are in the last stage and uneven amount of bits needs to be handled 194 with m.If(special_stage): 195 m.d.sync += remainder.eq(divider.partial_remainder) 196 m.d.sync += quotient.eq( 197 Cat(divider.quotient[self.ipc - self.partial_remainder_count : self.ipc], quotient) 198 ) 199 # normal iteration 200 with m.Else(): 201 m.d.sync += remainder.eq(divider.remainder) 202 m.d.sync += quotient.eq(Cat(divider.quotient[: self.ipc], quotient)) 203 204 m.d.sync += stage.eq(stage + 1) 205 206 return m