/ coreblocks / func_blocks / fu / zbc.py
zbc.py
  1  from dataclasses import dataclass, KW_ONLY
  2  from enum import IntFlag, auto, unique
  3  from typing import Sequence
  4  
  5  from amaranth import *
  6  
  7  from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase
  8  from coreblocks.func_blocks.interface.func_protocols import FuncUnit
  9  from coreblocks.params import GenParams, FunctionalComponentParams
 10  from coreblocks.arch import OpType, Funct3
 11  from transactron import Transaction, def_method
 12  from transactron.lib import FIFO
 13  from transactron.utils import OneHotSwitch
 14  
 15  
 16  class ZbcFn(DecoderManager):
 17      @unique
 18      class Fn(IntFlag):
 19          CLMUL = auto()
 20          CLMULH = auto()
 21          CLMULR = auto()
 22  
 23      def get_instructions(self) -> Sequence[tuple]:
 24          return [
 25              (ZbcFn.Fn.CLMUL, OpType.CLMUL, Funct3.CLMUL),
 26              (ZbcFn.Fn.CLMULH, OpType.CLMUL, Funct3.CLMULH),
 27              (ZbcFn.Fn.CLMULR, OpType.CLMUL, Funct3.CLMULR),
 28          ]
 29  
 30  
 31  class ClMultiplier(Elaboratable):
 32      """
 33      Module for computing carry-less product
 34  
 35      Attributes
 36      ----------
 37      i1: Signal(unsigned(n)), in
 38          First factor.
 39      i2: Signal(unsigned(n)), in
 40          Second factor.
 41      result: Signal(unsigned(n * 2)), out
 42          Result.
 43      reset: Signal(1), in
 44          Setting this signal to 1 will start a new computation with provided inputs
 45      busy: Signal(1), out
 46          Set to 1 while a computation is in progress
 47      """
 48  
 49      def __init__(self, bit_width: int, recursion_depth: int):
 50          """
 51          Parameters
 52          ----------
 53          bit_width: int
 54              Bit width of inputs
 55          recursion_depth: int
 56              Depth of recursive submodules for parallel computation (assumes bit_width to be a power of 2)
 57          """
 58          if bit_width.bit_count() != 1:
 59              raise ValueError("bit_width should be a power of 2")
 60          if bit_width.bit_length() <= recursion_depth:
 61              raise ValueError("Too large recursion depth")
 62  
 63          self.recursion_depth = recursion_depth
 64          self.bit_width = bit_width
 65  
 66          self.i1 = Signal(unsigned(bit_width))
 67          self.i2 = Signal(unsigned(bit_width))
 68          self.result = Signal(unsigned(bit_width * 2))
 69          self.reset = Signal()
 70          self.busy = Signal()
 71  
 72      def elaborate(self, platform):
 73          if self.recursion_depth == 0:
 74              return self.iterative_module()
 75          else:
 76              return self.recursive_module()
 77  
 78      def iterative_module(self):
 79          m = Module()
 80  
 81          m.d.sync += self.busy.eq(0)
 82  
 83          v1 = Signal(unsigned(self.bit_width * 2))
 84          v2 = Signal(unsigned(self.bit_width))
 85          with m.If(self.reset):
 86              m.d.sync += self.result.eq(0)
 87              m.d.sync += [
 88                  v1.eq(self.i1),
 89                  v2.eq(self.i2),
 90              ]
 91              m.d.sync += self.busy.eq(1)
 92  
 93          with m.Elif(v2.bool()):
 94              with m.If(v2[0]):
 95                  m.d.sync += self.result.eq(self.result ^ v1)
 96              m.d.sync += [
 97                  v1.eq(v1 << 1),
 98                  v2.eq(v2 >> 1),
 99              ]
100              m.d.sync += self.busy.eq(1)
101  
102          return m
103  
104      def recursive_module(self):
105          m = Module()
106  
107          half_width = self.bit_width // 2
108  
109          m.submodules.mul_ll = mul_ll = ClMultiplier(half_width, self.recursion_depth - 1)
110          m.submodules.mul_lu = mul_lu = ClMultiplier(half_width, self.recursion_depth - 1)
111          m.submodules.mul_ul = mul_ul = ClMultiplier(half_width, self.recursion_depth - 1)
112          m.submodules.mul_uu = mul_uu = ClMultiplier(half_width, self.recursion_depth - 1)
113  
114          m.d.comb += [
115              mul_ll.reset.eq(self.reset),
116              mul_ul.reset.eq(self.reset),
117              mul_lu.reset.eq(self.reset),
118              mul_uu.reset.eq(self.reset),
119          ]
120  
121          m.d.comb += self.busy.eq(mul_ll.busy | mul_lu.busy | mul_ul.busy | mul_uu.busy)
122  
123          m.d.comb += [
124              mul_ll.i1.eq(self.i1[:half_width]),
125              mul_ll.i2.eq(self.i2[:half_width]),
126          ]
127          m.d.comb += [
128              mul_lu.i1.eq(self.i1[half_width:]),
129              mul_lu.i2.eq(self.i2[:half_width]),
130          ]
131          m.d.comb += [
132              mul_ul.i1.eq(self.i1[:half_width]),
133              mul_ul.i2.eq(self.i2[half_width:]),
134          ]
135          m.d.comb += [
136              mul_uu.i1.eq(self.i1[half_width:]),
137              mul_uu.i2.eq(self.i2[half_width:]),
138          ]
139  
140          m.d.comb += self.result.eq(
141              (mul_uu.result << self.bit_width)
142              ^ (mul_ul.result << half_width)
143              ^ (mul_lu.result << half_width)
144              ^ mul_ll.result
145          )
146  
147          return m
148  
149  
150  class ZbcUnit(FuncUnitBase[ZbcFn]):
151      """
152      Executes Zbc instructions (carry-less multiplication).
153      """
154  
155      def __init__(self, gen_params: GenParams, recursion_depth: int, fn: ZbcFn):
156          super().__init__(gen_params, fn)
157  
158          self.recursion_depth = recursion_depth
159  
160      def elaborate(self, platform):
161          m = super().elaborate(platform)
162  
163          m.submodules.params_fifo = params_fifo = FIFO(
164              [
165                  ("rob_id", self.gen_params.rob_entries_bits),
166                  ("rp_dst", self.gen_params.phys_regs_bits),
167                  ("high_res", 1),
168                  ("rev_res", 1),
169              ],
170              1,
171          )
172          m.submodules.clmul = clmul = ClMultiplier(self.gen_params.isa.xlen, self.recursion_depth)
173  
174          m.d.comb += clmul.reset.eq(0)
175  
176          with Transaction().body(m, ready=~clmul.busy):
177              xlen = self.gen_params.isa.xlen
178  
179              output = clmul.result
180              params = params_fifo.read(m)
181  
182              result = Mux(params.high_res, output[xlen:], output[:xlen])
183              reversed_result = Mux(params.rev_res, result[::-1], result)
184  
185              self.push_result(m, rob_id=params.rob_id, rp_dst=params.rp_dst, result=reversed_result, exception=0)
186  
187          @def_method(m, self.issue_decoded)
188          def _(exec_fn, decode_fn, imm, s1_val, s2_val, rob_id, rp_dst, pc, tag):
189              i1 = s1_val
190              i2 = Mux(imm, imm, s2_val)
191  
192              value1 = Signal(self.gen_params.isa.xlen)
193              value2 = Signal(self.gen_params.isa.xlen)
194              high_res = Signal(1)
195              rev_res = Signal(1)
196  
197              with OneHotSwitch(m, decode_fn) as OneHotCase:
198                  with OneHotCase(ZbcFn.Fn.CLMUL):
199                      m.d.av_comb += high_res.eq(0)
200                      m.d.av_comb += rev_res.eq(0)
201                      m.d.av_comb += value1.eq(i1)
202                      m.d.av_comb += value2.eq(i2)
203                  with OneHotCase(ZbcFn.Fn.CLMULH):
204                      m.d.av_comb += high_res.eq(1)
205                      m.d.av_comb += rev_res.eq(0)
206                      m.d.av_comb += value1.eq(i1)
207                      m.d.av_comb += value2.eq(i2)
208                  with OneHotCase(ZbcFn.Fn.CLMULR):
209                      # clmulr is equivalent to bit-reversing the inputs,
210                      # performing a clmul,
211                      # then bit-reversing the output.
212                      m.d.av_comb += high_res.eq(0)
213                      m.d.av_comb += rev_res.eq(1)
214                      m.d.av_comb += value1.eq(i1[::-1])
215                      m.d.av_comb += value2.eq(i2[::-1])
216  
217              params_fifo.write(m, rob_id=rob_id, rp_dst=rp_dst, high_res=high_res, rev_res=rev_res)
218  
219              m.d.av_comb += clmul.i1.eq(value1)
220              m.d.av_comb += clmul.i2.eq(value2)
221              m.d.comb += clmul.reset.eq(1)
222  
223          return m
224  
225  
226  @dataclass(frozen=True)
227  class ZbcComponent(FunctionalComponentParams):
228      _: KW_ONLY
229      recursion_depth: int = 3
230      decoder_manager: ZbcFn = ZbcFn()
231  
232      def get_module(self, gen_params: GenParams) -> FuncUnit:
233          return ZbcUnit(gen_params, self.recursion_depth, self.decoder_manager)