/ coreblocks / func_blocks / fu / alu.py
alu.py
  1  from dataclasses import dataclass, KW_ONLY, field
  2  from typing import Sequence
  3  from amaranth import *
  4  from amaranth_types import HasElaborate
  5  
  6  from transactron import *
  7  from transactron.lib.metrics import *
  8  
  9  from coreblocks.arch import OpType, Funct3, Funct7
 10  from coreblocks.params import GenParams, FunctionalComponentParams
 11  from transactron.utils import OneHotSwitch
 12  
 13  from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase
 14  from enum import IntFlag, auto
 15  
 16  from coreblocks.func_blocks.interface.func_protocols import FuncUnit
 17  
 18  from transactron.utils import popcount, count_leading_zeros
 19  
 20  __all__ = ["AluFuncUnit", "ALUComponent"]
 21  
 22  
 23  class AluFn(DecoderManager):
 24      def __init__(self, zba_enable=False, zbb_enable=False, zicond_enable=False) -> None:
 25          self.zba_enable = zba_enable
 26          self.zbb_enable = zbb_enable
 27          self.zicond_enable = zicond_enable
 28  
 29      class Fn(IntFlag):
 30          ADD = auto()  # Addition
 31          XOR = auto()  # Bitwise xor
 32          OR = auto()  # Bitwise or
 33          AND = auto()  # Bitwise and
 34          SUB = auto()  # Subtraction
 35          SLT = auto()  # Set if less than (signed)
 36          SLTU = auto()  # Set if less than (unsigned)
 37  
 38          # ZBA extension
 39          SH1ADD = auto()  # Logic left shift by 1 and add
 40          SH2ADD = auto()  # Logic left shift by 2 and add
 41          SH3ADD = auto()  # Logic left shift by 3 and add
 42  
 43          # ZBB extension
 44          ANDN = auto()  # Bitwise ANDN
 45          ORN = auto()  # Bitwise ORN
 46          XNOR = auto()  # Bitwise XNOR
 47  
 48          CLZ = auto()  # Count leading zeros
 49          CTZ = auto()  # Count trailing zeros
 50          CPOP = auto()  # Count set bits
 51  
 52          MAX = auto()  # Maximum
 53          MAXU = auto()  # Unsigned maximum
 54          MIN = auto()  # Minimum
 55          MINU = auto()  # Unsigned minimum
 56  
 57          SEXTB = auto()  # Sign-extend byte
 58          SEXTH = auto()  # Sign-extend halfword
 59          ZEXTH = auto()  # Zero extend halfword
 60  
 61          ORCB = auto()  # Bitwise or combine
 62          REV8 = auto()  # Reverse byte ordering
 63  
 64          # ZICOND extension
 65          CZEROEQZ = auto()  # Move zero if condition if equal to zero
 66          CZERONEZ = auto()  # Move zero if condition is nonzero
 67  
 68      def get_instructions(self) -> Sequence[tuple]:
 69          return (
 70              [
 71                  (self.Fn.ADD, OpType.ARITHMETIC, Funct3.ADD, Funct7.ADD),
 72                  (self.Fn.SUB, OpType.ARITHMETIC, Funct3.ADD, Funct7.SUB),
 73                  (self.Fn.SLT, OpType.COMPARE, Funct3.SLT),
 74                  (self.Fn.SLTU, OpType.COMPARE, Funct3.SLTU),
 75                  (self.Fn.XOR, OpType.LOGIC, Funct3.XOR),
 76                  (self.Fn.OR, OpType.LOGIC, Funct3.OR),
 77                  (self.Fn.AND, OpType.LOGIC, Funct3.AND),
 78              ]
 79              + [
 80                  (self.Fn.SH1ADD, OpType.ADDRESS_GENERATION, Funct3.SH1ADD, Funct7.SH1ADD),
 81                  (self.Fn.SH2ADD, OpType.ADDRESS_GENERATION, Funct3.SH2ADD, Funct7.SH2ADD),
 82                  (self.Fn.SH3ADD, OpType.ADDRESS_GENERATION, Funct3.SH3ADD, Funct7.SH3ADD),
 83              ]
 84              * self.zba_enable
 85              + [
 86                  (self.Fn.ANDN, OpType.BIT_MANIPULATION, Funct3.ANDN, Funct7.ANDN),
 87                  (self.Fn.XNOR, OpType.BIT_MANIPULATION, Funct3.XNOR, Funct7.XNOR),
 88                  (self.Fn.ORN, OpType.BIT_MANIPULATION, Funct3.ORN, Funct7.ORN),
 89                  (self.Fn.MAX, OpType.BIT_MANIPULATION, Funct3.MAX, Funct7.MAX),
 90                  (self.Fn.MAXU, OpType.BIT_MANIPULATION, Funct3.MAXU, Funct7.MAX),
 91                  (self.Fn.MIN, OpType.BIT_MANIPULATION, Funct3.MIN, Funct7.MIN),
 92                  (self.Fn.MINU, OpType.BIT_MANIPULATION, Funct3.MINU, Funct7.MIN),
 93                  (self.Fn.REV8, OpType.UNARY_BIT_MANIPULATION_1, Funct3.REV8),
 94                  (self.Fn.SEXTB, OpType.UNARY_BIT_MANIPULATION_1, Funct3.SEXTB),
 95                  (self.Fn.ZEXTH, OpType.UNARY_BIT_MANIPULATION_1, Funct3.ZEXTH),
 96                  (self.Fn.ORCB, OpType.UNARY_BIT_MANIPULATION_2, Funct3.ORCB),
 97                  (self.Fn.SEXTH, OpType.UNARY_BIT_MANIPULATION_2, Funct3.SEXTH),
 98                  (self.Fn.CLZ, OpType.UNARY_BIT_MANIPULATION_3, Funct3.CLZ),
 99                  (self.Fn.CTZ, OpType.UNARY_BIT_MANIPULATION_4, Funct3.CTZ),
100                  (self.Fn.CPOP, OpType.UNARY_BIT_MANIPULATION_5, Funct3.CPOP),
101              ]
102              * self.zbb_enable
103              + [
104                  (self.Fn.CZEROEQZ, OpType.CZERO, Funct3.CZEROEQZ),
105                  (self.Fn.CZERONEZ, OpType.CZERO, Funct3.CZERONEZ),
106              ]
107              * self.zicond_enable
108          )
109  
110  
111  class CLZSubmodule(Elaboratable):
112      def __init__(self, gen_params: GenParams):
113          xlen = gen_params.isa.xlen
114          self.in_sig = Signal(xlen)
115          self.out_sig = Signal(xlen)
116  
117      def elaborate(self, platform) -> HasElaborate:
118          m = Module()
119          m.d.comb += self.out_sig.eq(count_leading_zeros(self.in_sig))
120          return m
121  
122  
123  class Alu(Elaboratable):
124      def __init__(self, gen_params: GenParams, alu_fn=AluFn()):
125          self.zba_enable = alu_fn.zba_enable
126          self.zbb_enable = alu_fn.zbb_enable
127          self.zicond_enable = alu_fn.zicond_enable
128          self.gen_params = gen_params
129  
130          self.fn = alu_fn.get_function()
131          self.in1 = Signal(gen_params.isa.xlen)
132          self.in2 = Signal(gen_params.isa.xlen)
133  
134          self.out = Signal(gen_params.isa.xlen)
135  
136      def elaborate(self, platform):
137          m = TModule()
138  
139          xlen = self.gen_params.isa.xlen
140  
141          with OneHotSwitch(m, self.fn) as OneHotCase:
142              with OneHotCase(AluFn.Fn.ADD):
143                  m.d.comb += self.out.eq(self.in1 + self.in2)
144              with OneHotCase(AluFn.Fn.XOR):
145                  m.d.comb += self.out.eq(self.in1 ^ self.in2)
146              with OneHotCase(AluFn.Fn.OR):
147                  m.d.comb += self.out.eq(self.in1 | self.in2)
148              with OneHotCase(AluFn.Fn.AND):
149                  m.d.comb += self.out.eq(self.in1 & self.in2)
150              with OneHotCase(AluFn.Fn.SUB):
151                  m.d.comb += self.out.eq(self.in1 - self.in2)
152              with OneHotCase(AluFn.Fn.SLT):
153                  m.d.comb += self.out.eq(self.in1.as_signed() < self.in2.as_signed())
154              with OneHotCase(AluFn.Fn.SLTU):
155                  m.d.comb += self.out.eq(self.in1 < self.in2)
156  
157              if self.zba_enable:
158                  with OneHotCase(AluFn.Fn.SH1ADD):
159                      m.d.comb += self.out.eq((self.in1 << 1) + self.in2)
160                  with OneHotCase(AluFn.Fn.SH2ADD):
161                      m.d.comb += self.out.eq((self.in1 << 2) + self.in2)
162                  with OneHotCase(AluFn.Fn.SH3ADD):
163                      m.d.comb += self.out.eq((self.in1 << 3) + self.in2)
164  
165              if self.zbb_enable:
166                  m.submodules.clz = clz = CLZSubmodule(self.gen_params)
167  
168                  with OneHotCase(AluFn.Fn.ANDN):
169                      m.d.comb += self.out.eq(self.in1 & ~self.in2)
170                  with OneHotCase(AluFn.Fn.XNOR):
171                      m.d.comb += self.out.eq(~(self.in1 ^ self.in2))
172                  with OneHotCase(AluFn.Fn.ORN):
173                      m.d.comb += self.out.eq(self.in1 | ~self.in2)
174                  with OneHotCase(AluFn.Fn.MIN):
175                      with m.If(self.in1.as_signed() < self.in2.as_signed()):
176                          m.d.comb += self.out.eq(self.in1)
177                      with m.Else():
178                          m.d.comb += self.out.eq(self.in2)
179                  with OneHotCase(AluFn.Fn.MINU):
180                      with m.If(self.in1 < self.in2):
181                          m.d.comb += self.out.eq(self.in1)
182                      with m.Else():
183                          m.d.comb += self.out.eq(self.in2)
184                  with OneHotCase(AluFn.Fn.MAX):
185                      with m.If(self.in1.as_signed() >= self.in2.as_signed()):
186                          m.d.comb += self.out.eq(self.in1)
187                      with m.Else():
188                          m.d.comb += self.out.eq(self.in2)
189                  with OneHotCase(AluFn.Fn.MAXU):
190                      with m.If(self.in1 >= self.in2):
191                          m.d.comb += self.out.eq(self.in1)
192                      with m.Else():
193                          m.d.comb += self.out.eq(self.in2)
194                  with OneHotCase(AluFn.Fn.CPOP):
195                      m.d.comb += self.out.eq(popcount(self.in1))
196                  with OneHotCase(AluFn.Fn.CLZ):
197                      m.d.comb += clz.in_sig.eq(self.in1)
198                      m.d.comb += self.out.eq(clz.out_sig)
199                  with OneHotCase(AluFn.Fn.CTZ):
200                      m.d.comb += clz.in_sig.eq(self.in1[::-1])
201                      m.d.comb += self.out.eq(clz.out_sig)
202                  with OneHotCase(AluFn.Fn.SEXTH):
203                      m.d.comb += self.out.eq(Cat(self.in1[0:16], self.in1[15].replicate(xlen - 16)))
204                  with OneHotCase(AluFn.Fn.SEXTB):
205                      m.d.comb += self.out.eq(Cat(self.in1[0:8], self.in1[7].replicate(xlen - 8)))
206                  with OneHotCase(AluFn.Fn.ZEXTH):
207                      m.d.comb += self.out.eq(Cat(self.in1[0:16], C(0, shape=unsigned(xlen - 16))))
208                  with OneHotCase(AluFn.Fn.ORCB):
209  
210                      def _or(s: Value) -> Value:
211                          return s.any().replicate(8)
212  
213                      for i in range(xlen // 8):
214                          m.d.comb += self.out[i * 8 : (i + 1) * 8].eq(_or(self.in1[i * 8 : (i + 1) * 8]))
215                  with OneHotCase(AluFn.Fn.REV8):
216                      en = xlen // 8
217                      for i in range(en):
218                          j = en - i - 1
219                          m.d.comb += self.out[i * 8 : (i + 1) * 8].eq(self.in1[j * 8 : (j + 1) * 8])
220  
221              if self.zicond_enable:
222                  czero_cases = [
223                      (AluFn.Fn.CZERONEZ, lambda is_zero: self.in1 if is_zero else 0),
224                      (AluFn.Fn.CZEROEQZ, lambda is_zero: 0 if is_zero else self.in1),
225                  ]
226                  for fn, output_fn in czero_cases:
227                      with OneHotCase(fn):
228                          with m.If(self.in2.any()):
229                              m.d.comb += self.out.eq(output_fn(False))
230                          with m.Else():
231                              m.d.comb += self.out.eq(output_fn(True))
232  
233          return m
234  
235  
236  class AluFuncUnit(FuncUnitBase[AluFn]):
237      def __init__(self, gen_params: GenParams, fn=AluFn()):
238          super().__init__(gen_params, fn)
239  
240      def elaborate(self, platform):
241          m = super().elaborate(platform)
242  
243          m.submodules.alu = alu = Alu(self.gen_params, alu_fn=self.fn)
244  
245          @def_method(m, self.issue_decoded)
246          def _(arg):
247              m.d.av_comb += alu.fn.eq(arg.decode_fn)
248              m.d.av_comb += alu.in1.eq(arg.s1_val)
249              m.d.av_comb += alu.in2.eq(Mux(arg.imm, arg.imm, arg.s2_val))
250  
251              self.push_result(m, rob_id=arg.rob_id, result=alu.out, rp_dst=arg.rp_dst, exception=0)
252  
253          return m
254  
255  
256  @dataclass(frozen=True)
257  class ALUComponent(FunctionalComponentParams):
258      _: KW_ONLY
259      result_fifo: bool = True
260      zba_enable: bool = False
261      zbb_enable: bool = False
262      zicond_enable: bool = False
263      decoder_manager: AluFn = field(init=False)
264  
265      def get_decoder_manager(self):
266          return AluFn(zba_enable=self.zba_enable, zbb_enable=self.zbb_enable, zicond_enable=self.zicond_enable)
267  
268      def get_module(self, gen_params: GenParams) -> FuncUnit:
269          return AluFuncUnit(gen_params, self.decoder_manager)