/ coreblocks / func_blocks / fu / division / long_division.py
long_division.py
  1  """
  2  Algorithm - multi-cycle array divider
  3  Method described here: https://yuhei1-horibe.medium.com/designing-divider-213fbd32beb2
  4  """
  5  
  6  from amaranth import *
  7  
  8  from coreblocks.params import GenParams
  9  from transactron import *
 10  from transactron.core import def_method
 11  from coreblocks.func_blocks.fu.division.common import DividerBase
 12  
 13  
 14  class RecursiveDivison(Elaboratable):
 15      """
 16      Module that calculates n bits of quotient and
 17      yields remainder that can be used in next iteration
 18  
 19      If count == xlen the module is basically a one-cycle divider
 20  
 21      If count if not aligned to power of 2, then in last iteration we need to calculate
 22      different amount of bits.
 23      So to optimize resource usage, there is partial remainder
 24      that allows to reuse this module for a shorter calculation.
 25  
 26      Attributes
 27      ----------
 28      size: int
 29          Size of inputs
 30      n: int
 31          Number of steps
 32      partial_remainder_count: int
 33          Number of steps for last iteration
 34      divisor: Signal
 35          Input divisor
 36      dividend: Signal
 37          Input dividend
 38      input_remainder: Signal
 39          Remainder carried over from previous iteration
 40      quotient: Signal
 41          Calculated n bits of quotient
 42      remainder: Signal
 43          Calculated remainder
 44      partial_remainder: Signal
 45          Calculated partial remainder
 46      """
 47  
 48      def __init__(self, step_count: int, size: int, partial_remainder_count: int = 0):
 49          self.size = size
 50          self.step_count = step_count
 51          self.partial_remainder_count = partial_remainder_count
 52  
 53          self.divisor = Signal(unsigned(size))
 54          self.dividend = Signal(unsigned(size))
 55          self.input_remainder = Signal(unsigned(size))
 56  
 57          self.quotient = Signal(unsigned(size))
 58          self.remainder = Signal(unsigned(size))
 59          self.partial_remainder = Signal(unsigned(size))
 60  
 61      def elaborate(self, platform) -> TModule:
 62          if self.step_count == 0:
 63              # default case
 64              m = TModule()
 65  
 66              m.d.comb += self.quotient.eq(0)
 67              m.d.comb += self.remainder.eq(self.input_remainder)
 68              m.d.comb += self.partial_remainder.eq(self.input_remainder)
 69  
 70              return m
 71          else:
 72              return self.recursive_module()
 73  
 74      def recursive_module(self) -> TModule:
 75          m = TModule()
 76  
 77          # adding bit from dividend
 78          concat = Signal(self.size)
 79          m.d.comb += concat.eq(Cat(self.dividend[self.step_count - 1], self.input_remainder))
 80  
 81          # recursive module
 82          m.submodules.rec_div = rec_div = RecursiveDivison(
 83              self.step_count - 1, self.size, partial_remainder_count=self.partial_remainder_count - 1
 84          )
 85  
 86          m.d.comb += rec_div.dividend.eq(self.dividend)
 87          m.d.comb += rec_div.divisor.eq(self.divisor)
 88  
 89          # Single step as described in article
 90          with m.If(concat >= self.divisor):
 91              m.d.comb += self.quotient[self.step_count - 1].eq(1)
 92              m.d.comb += rec_div.input_remainder.eq(concat - self.divisor)
 93          with m.Else():
 94              m.d.comb += self.quotient[self.step_count - 1].eq(0)
 95              m.d.comb += rec_div.input_remainder.eq(concat)
 96  
 97          # wiring up rest of result from recursive module
 98          m.d.comb += self.quotient[: (self.step_count - 1)].eq(rec_div.quotient)
 99          m.d.comb += self.remainder.eq(rec_div.remainder)
100  
101          # partial remainder
102          if self.partial_remainder_count == 0:
103              m.d.comb += self.partial_remainder.eq(self.input_remainder)
104          else:
105              m.d.comb += self.partial_remainder.eq(rec_div.partial_remainder)
106  
107          return m
108  
109  
110  class LongDivider(DividerBase):
111      """
112      Module that handles iterative calculation
113  
114      Attributes
115      ----------
116      gen_params: GenParams
117          Gen Params
118      ipc: int
119          Number of steps per cycle
120      partial_remainder_count: int
121          Depth of last iteration
122      stages: int
123          Number of required iterations
124      odd_iteration: bool
125          flag whether last iteration requires partial calculation
126      """
127  
128      def __init__(self, gen_params: GenParams, ipc=4):
129          super().__init__(gen_params)
130          xlen = self.gen_params.isa.xlen
131  
132          self.ipc = ipc
133          self.partial_remainder_count = xlen % ipc
134  
135          self.stages = xlen // ipc + (1 if self.partial_remainder_count > 0 else 0)
136          self.odd_iteration = self.partial_remainder_count != 0
137  
138      def elaborate(self, platform):
139          m = TModule()
140          xlen = self.gen_params.isa.xlen
141          xlen_log = self.gen_params.isa.xlen_log
142  
143          m.submodules.divider = divider = RecursiveDivison(
144              self.ipc, xlen, partial_remainder_count=self.partial_remainder_count
145          )
146  
147          ready = Signal(1, init=1)
148  
149          dividend = Signal(unsigned(xlen))
150          divisor = Signal(unsigned(xlen))
151  
152          quotient = Signal(unsigned(xlen))
153          remainder = Signal(unsigned(xlen))
154  
155          stage = Signal(unsigned(xlen_log + 1))
156  
157          # starting calculations
158          @def_method(m, self.issue, ready=ready)
159          def _(arg):
160              m.d.sync += dividend.eq(arg.dividend)
161              m.d.sync += divisor.eq(arg.divisor)
162              m.d.sync += remainder.eq(0)
163              m.d.sync += quotient.eq(0)
164              m.d.sync += stage.eq(0)
165  
166              m.d.sync += ready.eq(0)
167  
168          # returning results
169          @def_method(m, self.accept, ready=(~ready & (stage == self.stages)))
170          def _(arg):
171              m.d.sync += ready.eq(1)
172              return {"quotient": quotient, "remainder": remainder}
173  
174          # clearing the unit
175          @def_method(m, self.clear)
176          def _():
177              m.d.sync += stage.eq(0)
178              m.d.sync += ready.eq(1)
179  
180          # performing calculations
181          with m.If(~ready & (stage != self.stages)):
182              special_stage = (self.stages == stage + 1) & self.odd_iteration
183  
184              # assigning inputs to recursive divider
185              m.d.comb += divider.divisor.eq(divisor)
186              m.d.comb += divider.dividend.eq(dividend[xlen - self.ipc :])
187              m.d.comb += divider.input_remainder.eq(remainder)
188  
189              # dividend is a shift register
190              # so in each iteration upper bits are fed into recursive divider
191              m.d.sync += dividend.eq(dividend << self.ipc)
192  
193              # if we are in the last stage and uneven amount of bits needs to be handled
194              with m.If(special_stage):
195                  m.d.sync += remainder.eq(divider.partial_remainder)
196                  m.d.sync += quotient.eq(
197                      Cat(divider.quotient[self.ipc - self.partial_remainder_count : self.ipc], quotient)
198                  )
199              # normal iteration
200              with m.Else():
201                  m.d.sync += remainder.eq(divider.remainder)
202                  m.d.sync += quotient.eq(Cat(divider.quotient[: self.ipc], quotient))
203  
204              m.d.sync += stage.eq(stage + 1)
205  
206          return m