alu.py
1 from dataclasses import dataclass, KW_ONLY, field 2 from typing import Sequence 3 from amaranth import * 4 from amaranth_types import HasElaborate 5 6 from transactron import * 7 from transactron.lib.metrics import * 8 9 from coreblocks.arch import OpType, Funct3, Funct7 10 from coreblocks.params import GenParams, FunctionalComponentParams 11 from transactron.utils import OneHotSwitch 12 13 from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase 14 from enum import IntFlag, auto 15 16 from coreblocks.func_blocks.interface.func_protocols import FuncUnit 17 18 from transactron.utils import popcount, count_leading_zeros 19 20 __all__ = ["AluFuncUnit", "ALUComponent"] 21 22 23 class AluFn(DecoderManager): 24 def __init__(self, zba_enable=False, zbb_enable=False, zicond_enable=False) -> None: 25 self.zba_enable = zba_enable 26 self.zbb_enable = zbb_enable 27 self.zicond_enable = zicond_enable 28 29 class Fn(IntFlag): 30 ADD = auto() # Addition 31 XOR = auto() # Bitwise xor 32 OR = auto() # Bitwise or 33 AND = auto() # Bitwise and 34 SUB = auto() # Subtraction 35 SLT = auto() # Set if less than (signed) 36 SLTU = auto() # Set if less than (unsigned) 37 38 # ZBA extension 39 SH1ADD = auto() # Logic left shift by 1 and add 40 SH2ADD = auto() # Logic left shift by 2 and add 41 SH3ADD = auto() # Logic left shift by 3 and add 42 43 # ZBB extension 44 ANDN = auto() # Bitwise ANDN 45 ORN = auto() # Bitwise ORN 46 XNOR = auto() # Bitwise XNOR 47 48 CLZ = auto() # Count leading zeros 49 CTZ = auto() # Count trailing zeros 50 CPOP = auto() # Count set bits 51 52 MAX = auto() # Maximum 53 MAXU = auto() # Unsigned maximum 54 MIN = auto() # Minimum 55 MINU = auto() # Unsigned minimum 56 57 SEXTB = auto() # Sign-extend byte 58 SEXTH = auto() # Sign-extend halfword 59 ZEXTH = auto() # Zero extend halfword 60 61 ORCB = auto() # Bitwise or combine 62 REV8 = auto() # Reverse byte ordering 63 64 # ZICOND extension 65 CZEROEQZ = auto() # Move zero if condition if equal to zero 66 CZERONEZ = auto() # Move zero if condition is nonzero 67 68 def get_instructions(self) -> Sequence[tuple]: 69 return ( 70 [ 71 (self.Fn.ADD, OpType.ARITHMETIC, Funct3.ADD, Funct7.ADD), 72 (self.Fn.SUB, OpType.ARITHMETIC, Funct3.ADD, Funct7.SUB), 73 (self.Fn.SLT, OpType.COMPARE, Funct3.SLT), 74 (self.Fn.SLTU, OpType.COMPARE, Funct3.SLTU), 75 (self.Fn.XOR, OpType.LOGIC, Funct3.XOR), 76 (self.Fn.OR, OpType.LOGIC, Funct3.OR), 77 (self.Fn.AND, OpType.LOGIC, Funct3.AND), 78 ] 79 + [ 80 (self.Fn.SH1ADD, OpType.ADDRESS_GENERATION, Funct3.SH1ADD, Funct7.SH1ADD), 81 (self.Fn.SH2ADD, OpType.ADDRESS_GENERATION, Funct3.SH2ADD, Funct7.SH2ADD), 82 (self.Fn.SH3ADD, OpType.ADDRESS_GENERATION, Funct3.SH3ADD, Funct7.SH3ADD), 83 ] 84 * self.zba_enable 85 + [ 86 (self.Fn.ANDN, OpType.BIT_MANIPULATION, Funct3.ANDN, Funct7.ANDN), 87 (self.Fn.XNOR, OpType.BIT_MANIPULATION, Funct3.XNOR, Funct7.XNOR), 88 (self.Fn.ORN, OpType.BIT_MANIPULATION, Funct3.ORN, Funct7.ORN), 89 (self.Fn.MAX, OpType.BIT_MANIPULATION, Funct3.MAX, Funct7.MAX), 90 (self.Fn.MAXU, OpType.BIT_MANIPULATION, Funct3.MAXU, Funct7.MAX), 91 (self.Fn.MIN, OpType.BIT_MANIPULATION, Funct3.MIN, Funct7.MIN), 92 (self.Fn.MINU, OpType.BIT_MANIPULATION, Funct3.MINU, Funct7.MIN), 93 (self.Fn.REV8, OpType.UNARY_BIT_MANIPULATION_1, Funct3.REV8), 94 (self.Fn.SEXTB, OpType.UNARY_BIT_MANIPULATION_1, Funct3.SEXTB), 95 (self.Fn.ZEXTH, OpType.UNARY_BIT_MANIPULATION_1, Funct3.ZEXTH), 96 (self.Fn.ORCB, OpType.UNARY_BIT_MANIPULATION_2, Funct3.ORCB), 97 (self.Fn.SEXTH, OpType.UNARY_BIT_MANIPULATION_2, Funct3.SEXTH), 98 (self.Fn.CLZ, OpType.UNARY_BIT_MANIPULATION_3, Funct3.CLZ), 99 (self.Fn.CTZ, OpType.UNARY_BIT_MANIPULATION_4, Funct3.CTZ), 100 (self.Fn.CPOP, OpType.UNARY_BIT_MANIPULATION_5, Funct3.CPOP), 101 ] 102 * self.zbb_enable 103 + [ 104 (self.Fn.CZEROEQZ, OpType.CZERO, Funct3.CZEROEQZ), 105 (self.Fn.CZERONEZ, OpType.CZERO, Funct3.CZERONEZ), 106 ] 107 * self.zicond_enable 108 ) 109 110 111 class CLZSubmodule(Elaboratable): 112 def __init__(self, gen_params: GenParams): 113 xlen = gen_params.isa.xlen 114 self.in_sig = Signal(xlen) 115 self.out_sig = Signal(xlen) 116 117 def elaborate(self, platform) -> HasElaborate: 118 m = Module() 119 m.d.comb += self.out_sig.eq(count_leading_zeros(self.in_sig)) 120 return m 121 122 123 class Alu(Elaboratable): 124 def __init__(self, gen_params: GenParams, alu_fn=AluFn()): 125 self.zba_enable = alu_fn.zba_enable 126 self.zbb_enable = alu_fn.zbb_enable 127 self.zicond_enable = alu_fn.zicond_enable 128 self.gen_params = gen_params 129 130 self.fn = alu_fn.get_function() 131 self.in1 = Signal(gen_params.isa.xlen) 132 self.in2 = Signal(gen_params.isa.xlen) 133 134 self.out = Signal(gen_params.isa.xlen) 135 136 def elaborate(self, platform): 137 m = TModule() 138 139 xlen = self.gen_params.isa.xlen 140 141 with OneHotSwitch(m, self.fn) as OneHotCase: 142 with OneHotCase(AluFn.Fn.ADD): 143 m.d.comb += self.out.eq(self.in1 + self.in2) 144 with OneHotCase(AluFn.Fn.XOR): 145 m.d.comb += self.out.eq(self.in1 ^ self.in2) 146 with OneHotCase(AluFn.Fn.OR): 147 m.d.comb += self.out.eq(self.in1 | self.in2) 148 with OneHotCase(AluFn.Fn.AND): 149 m.d.comb += self.out.eq(self.in1 & self.in2) 150 with OneHotCase(AluFn.Fn.SUB): 151 m.d.comb += self.out.eq(self.in1 - self.in2) 152 with OneHotCase(AluFn.Fn.SLT): 153 m.d.comb += self.out.eq(self.in1.as_signed() < self.in2.as_signed()) 154 with OneHotCase(AluFn.Fn.SLTU): 155 m.d.comb += self.out.eq(self.in1 < self.in2) 156 157 if self.zba_enable: 158 with OneHotCase(AluFn.Fn.SH1ADD): 159 m.d.comb += self.out.eq((self.in1 << 1) + self.in2) 160 with OneHotCase(AluFn.Fn.SH2ADD): 161 m.d.comb += self.out.eq((self.in1 << 2) + self.in2) 162 with OneHotCase(AluFn.Fn.SH3ADD): 163 m.d.comb += self.out.eq((self.in1 << 3) + self.in2) 164 165 if self.zbb_enable: 166 m.submodules.clz = clz = CLZSubmodule(self.gen_params) 167 168 with OneHotCase(AluFn.Fn.ANDN): 169 m.d.comb += self.out.eq(self.in1 & ~self.in2) 170 with OneHotCase(AluFn.Fn.XNOR): 171 m.d.comb += self.out.eq(~(self.in1 ^ self.in2)) 172 with OneHotCase(AluFn.Fn.ORN): 173 m.d.comb += self.out.eq(self.in1 | ~self.in2) 174 with OneHotCase(AluFn.Fn.MIN): 175 with m.If(self.in1.as_signed() < self.in2.as_signed()): 176 m.d.comb += self.out.eq(self.in1) 177 with m.Else(): 178 m.d.comb += self.out.eq(self.in2) 179 with OneHotCase(AluFn.Fn.MINU): 180 with m.If(self.in1 < self.in2): 181 m.d.comb += self.out.eq(self.in1) 182 with m.Else(): 183 m.d.comb += self.out.eq(self.in2) 184 with OneHotCase(AluFn.Fn.MAX): 185 with m.If(self.in1.as_signed() >= self.in2.as_signed()): 186 m.d.comb += self.out.eq(self.in1) 187 with m.Else(): 188 m.d.comb += self.out.eq(self.in2) 189 with OneHotCase(AluFn.Fn.MAXU): 190 with m.If(self.in1 >= self.in2): 191 m.d.comb += self.out.eq(self.in1) 192 with m.Else(): 193 m.d.comb += self.out.eq(self.in2) 194 with OneHotCase(AluFn.Fn.CPOP): 195 m.d.comb += self.out.eq(popcount(self.in1)) 196 with OneHotCase(AluFn.Fn.CLZ): 197 m.d.comb += clz.in_sig.eq(self.in1) 198 m.d.comb += self.out.eq(clz.out_sig) 199 with OneHotCase(AluFn.Fn.CTZ): 200 m.d.comb += clz.in_sig.eq(self.in1[::-1]) 201 m.d.comb += self.out.eq(clz.out_sig) 202 with OneHotCase(AluFn.Fn.SEXTH): 203 m.d.comb += self.out.eq(Cat(self.in1[0:16], self.in1[15].replicate(xlen - 16))) 204 with OneHotCase(AluFn.Fn.SEXTB): 205 m.d.comb += self.out.eq(Cat(self.in1[0:8], self.in1[7].replicate(xlen - 8))) 206 with OneHotCase(AluFn.Fn.ZEXTH): 207 m.d.comb += self.out.eq(Cat(self.in1[0:16], C(0, shape=unsigned(xlen - 16)))) 208 with OneHotCase(AluFn.Fn.ORCB): 209 210 def _or(s: Value) -> Value: 211 return s.any().replicate(8) 212 213 for i in range(xlen // 8): 214 m.d.comb += self.out[i * 8 : (i + 1) * 8].eq(_or(self.in1[i * 8 : (i + 1) * 8])) 215 with OneHotCase(AluFn.Fn.REV8): 216 en = xlen // 8 217 for i in range(en): 218 j = en - i - 1 219 m.d.comb += self.out[i * 8 : (i + 1) * 8].eq(self.in1[j * 8 : (j + 1) * 8]) 220 221 if self.zicond_enable: 222 czero_cases = [ 223 (AluFn.Fn.CZERONEZ, lambda is_zero: self.in1 if is_zero else 0), 224 (AluFn.Fn.CZEROEQZ, lambda is_zero: 0 if is_zero else self.in1), 225 ] 226 for fn, output_fn in czero_cases: 227 with OneHotCase(fn): 228 with m.If(self.in2.any()): 229 m.d.comb += self.out.eq(output_fn(False)) 230 with m.Else(): 231 m.d.comb += self.out.eq(output_fn(True)) 232 233 return m 234 235 236 class AluFuncUnit(FuncUnitBase[AluFn]): 237 def __init__(self, gen_params: GenParams, fn=AluFn()): 238 super().__init__(gen_params, fn) 239 240 def elaborate(self, platform): 241 m = super().elaborate(platform) 242 243 m.submodules.alu = alu = Alu(self.gen_params, alu_fn=self.fn) 244 245 @def_method(m, self.issue_decoded) 246 def _(arg): 247 m.d.av_comb += alu.fn.eq(arg.decode_fn) 248 m.d.av_comb += alu.in1.eq(arg.s1_val) 249 m.d.av_comb += alu.in2.eq(Mux(arg.imm, arg.imm, arg.s2_val)) 250 251 self.push_result(m, rob_id=arg.rob_id, result=alu.out, rp_dst=arg.rp_dst, exception=0) 252 253 return m 254 255 256 @dataclass(frozen=True) 257 class ALUComponent(FunctionalComponentParams): 258 _: KW_ONLY 259 result_fifo: bool = True 260 zba_enable: bool = False 261 zbb_enable: bool = False 262 zicond_enable: bool = False 263 decoder_manager: AluFn = field(init=False) 264 265 def get_decoder_manager(self): 266 return AluFn(zba_enable=self.zba_enable, zbb_enable=self.zbb_enable, zicond_enable=self.zicond_enable) 267 268 def get_module(self, gen_params: GenParams) -> FuncUnit: 269 return AluFuncUnit(gen_params, self.decoder_manager)