zbc.py
1 from dataclasses import dataclass, KW_ONLY 2 from enum import IntFlag, auto, unique 3 from typing import Sequence 4 5 from amaranth import * 6 7 from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase 8 from coreblocks.func_blocks.interface.func_protocols import FuncUnit 9 from coreblocks.params import GenParams, FunctionalComponentParams 10 from coreblocks.arch import OpType, Funct3 11 from transactron import Transaction, def_method 12 from transactron.lib import FIFO 13 from transactron.utils import OneHotSwitch 14 15 16 class ZbcFn(DecoderManager): 17 @unique 18 class Fn(IntFlag): 19 CLMUL = auto() 20 CLMULH = auto() 21 CLMULR = auto() 22 23 def get_instructions(self) -> Sequence[tuple]: 24 return [ 25 (ZbcFn.Fn.CLMUL, OpType.CLMUL, Funct3.CLMUL), 26 (ZbcFn.Fn.CLMULH, OpType.CLMUL, Funct3.CLMULH), 27 (ZbcFn.Fn.CLMULR, OpType.CLMUL, Funct3.CLMULR), 28 ] 29 30 31 class ClMultiplier(Elaboratable): 32 """ 33 Module for computing carry-less product 34 35 Attributes 36 ---------- 37 i1: Signal(unsigned(n)), in 38 First factor. 39 i2: Signal(unsigned(n)), in 40 Second factor. 41 result: Signal(unsigned(n * 2)), out 42 Result. 43 reset: Signal(1), in 44 Setting this signal to 1 will start a new computation with provided inputs 45 busy: Signal(1), out 46 Set to 1 while a computation is in progress 47 """ 48 49 def __init__(self, bit_width: int, recursion_depth: int): 50 """ 51 Parameters 52 ---------- 53 bit_width: int 54 Bit width of inputs 55 recursion_depth: int 56 Depth of recursive submodules for parallel computation (assumes bit_width to be a power of 2) 57 """ 58 if bit_width.bit_count() != 1: 59 raise ValueError("bit_width should be a power of 2") 60 if bit_width.bit_length() <= recursion_depth: 61 raise ValueError("Too large recursion depth") 62 63 self.recursion_depth = recursion_depth 64 self.bit_width = bit_width 65 66 self.i1 = Signal(unsigned(bit_width)) 67 self.i2 = Signal(unsigned(bit_width)) 68 self.result = Signal(unsigned(bit_width * 2)) 69 self.reset = Signal() 70 self.busy = Signal() 71 72 def elaborate(self, platform): 73 if self.recursion_depth == 0: 74 return self.iterative_module() 75 else: 76 return self.recursive_module() 77 78 def iterative_module(self): 79 m = Module() 80 81 m.d.sync += self.busy.eq(0) 82 83 v1 = Signal(unsigned(self.bit_width * 2)) 84 v2 = Signal(unsigned(self.bit_width)) 85 with m.If(self.reset): 86 m.d.sync += self.result.eq(0) 87 m.d.sync += [ 88 v1.eq(self.i1), 89 v2.eq(self.i2), 90 ] 91 m.d.sync += self.busy.eq(1) 92 93 with m.Elif(v2.bool()): 94 with m.If(v2[0]): 95 m.d.sync += self.result.eq(self.result ^ v1) 96 m.d.sync += [ 97 v1.eq(v1 << 1), 98 v2.eq(v2 >> 1), 99 ] 100 m.d.sync += self.busy.eq(1) 101 102 return m 103 104 def recursive_module(self): 105 m = Module() 106 107 half_width = self.bit_width // 2 108 109 m.submodules.mul_ll = mul_ll = ClMultiplier(half_width, self.recursion_depth - 1) 110 m.submodules.mul_lu = mul_lu = ClMultiplier(half_width, self.recursion_depth - 1) 111 m.submodules.mul_ul = mul_ul = ClMultiplier(half_width, self.recursion_depth - 1) 112 m.submodules.mul_uu = mul_uu = ClMultiplier(half_width, self.recursion_depth - 1) 113 114 m.d.comb += [ 115 mul_ll.reset.eq(self.reset), 116 mul_ul.reset.eq(self.reset), 117 mul_lu.reset.eq(self.reset), 118 mul_uu.reset.eq(self.reset), 119 ] 120 121 m.d.comb += self.busy.eq(mul_ll.busy | mul_lu.busy | mul_ul.busy | mul_uu.busy) 122 123 m.d.comb += [ 124 mul_ll.i1.eq(self.i1[:half_width]), 125 mul_ll.i2.eq(self.i2[:half_width]), 126 ] 127 m.d.comb += [ 128 mul_lu.i1.eq(self.i1[half_width:]), 129 mul_lu.i2.eq(self.i2[:half_width]), 130 ] 131 m.d.comb += [ 132 mul_ul.i1.eq(self.i1[:half_width]), 133 mul_ul.i2.eq(self.i2[half_width:]), 134 ] 135 m.d.comb += [ 136 mul_uu.i1.eq(self.i1[half_width:]), 137 mul_uu.i2.eq(self.i2[half_width:]), 138 ] 139 140 m.d.comb += self.result.eq( 141 (mul_uu.result << self.bit_width) 142 ^ (mul_ul.result << half_width) 143 ^ (mul_lu.result << half_width) 144 ^ mul_ll.result 145 ) 146 147 return m 148 149 150 class ZbcUnit(FuncUnitBase[ZbcFn]): 151 """ 152 Executes Zbc instructions (carry-less multiplication). 153 """ 154 155 def __init__(self, gen_params: GenParams, recursion_depth: int, fn: ZbcFn): 156 super().__init__(gen_params, fn) 157 158 self.recursion_depth = recursion_depth 159 160 def elaborate(self, platform): 161 m = super().elaborate(platform) 162 163 m.submodules.params_fifo = params_fifo = FIFO( 164 [ 165 ("rob_id", self.gen_params.rob_entries_bits), 166 ("rp_dst", self.gen_params.phys_regs_bits), 167 ("high_res", 1), 168 ("rev_res", 1), 169 ], 170 1, 171 ) 172 m.submodules.clmul = clmul = ClMultiplier(self.gen_params.isa.xlen, self.recursion_depth) 173 174 m.d.comb += clmul.reset.eq(0) 175 176 with Transaction().body(m, ready=~clmul.busy): 177 xlen = self.gen_params.isa.xlen 178 179 output = clmul.result 180 params = params_fifo.read(m) 181 182 result = Mux(params.high_res, output[xlen:], output[:xlen]) 183 reversed_result = Mux(params.rev_res, result[::-1], result) 184 185 self.push_result(m, rob_id=params.rob_id, rp_dst=params.rp_dst, result=reversed_result, exception=0) 186 187 @def_method(m, self.issue_decoded) 188 def _(exec_fn, decode_fn, imm, s1_val, s2_val, rob_id, rp_dst, pc, tag): 189 i1 = s1_val 190 i2 = Mux(imm, imm, s2_val) 191 192 value1 = Signal(self.gen_params.isa.xlen) 193 value2 = Signal(self.gen_params.isa.xlen) 194 high_res = Signal(1) 195 rev_res = Signal(1) 196 197 with OneHotSwitch(m, decode_fn) as OneHotCase: 198 with OneHotCase(ZbcFn.Fn.CLMUL): 199 m.d.av_comb += high_res.eq(0) 200 m.d.av_comb += rev_res.eq(0) 201 m.d.av_comb += value1.eq(i1) 202 m.d.av_comb += value2.eq(i2) 203 with OneHotCase(ZbcFn.Fn.CLMULH): 204 m.d.av_comb += high_res.eq(1) 205 m.d.av_comb += rev_res.eq(0) 206 m.d.av_comb += value1.eq(i1) 207 m.d.av_comb += value2.eq(i2) 208 with OneHotCase(ZbcFn.Fn.CLMULR): 209 # clmulr is equivalent to bit-reversing the inputs, 210 # performing a clmul, 211 # then bit-reversing the output. 212 m.d.av_comb += high_res.eq(0) 213 m.d.av_comb += rev_res.eq(1) 214 m.d.av_comb += value1.eq(i1[::-1]) 215 m.d.av_comb += value2.eq(i2[::-1]) 216 217 params_fifo.write(m, rob_id=rob_id, rp_dst=rp_dst, high_res=high_res, rev_res=rev_res) 218 219 m.d.av_comb += clmul.i1.eq(value1) 220 m.d.av_comb += clmul.i2.eq(value2) 221 m.d.comb += clmul.reset.eq(1) 222 223 return m 224 225 226 @dataclass(frozen=True) 227 class ZbcComponent(FunctionalComponentParams): 228 _: KW_ONLY 229 recursion_depth: int = 3 230 decoder_manager: ZbcFn = ZbcFn() 231 232 def get_module(self, gen_params: GenParams) -> FuncUnit: 233 return ZbcUnit(gen_params, self.recursion_depth, self.decoder_manager)