/ coreblocks / params / instr.py
instr.py
  1  """
  2  
  3  Based on riscv-python-model by Stefan Wallentowitz
  4  https://github.com/wallento/riscv-python-model
  5  """
  6  
  7  from dataclasses import dataclass
  8  from abc import ABC
  9  from enum import Enum
 10  from typing import Optional
 11  
 12  from amaranth.hdl import ValueCastable
 13  from amaranth import *
 14  from amaranth_types import ValueLike
 15  
 16  from coreblocks.arch import Opcode, Registers, Funct3, Funct12
 17  
 18  
 19  __all__ = [
 20      "RISCVInstr",
 21      "RTypeInstr",
 22      "ITypeInstr",
 23      "STypeInstr",
 24      "BTypeInstr",
 25      "UTypeInstr",
 26      "JTypeInstr",
 27      "IllegalInstr",
 28      "EBreakInstr",
 29  ]
 30  
 31  
 32  @dataclass(kw_only=True)
 33  class Field:
 34      """Information about a field in a RISC-V instruction.
 35  
 36      Attributes
 37      ----------
 38      base: int | list[int]
 39          A bit position (or a list of positions) where this field (or parts of the field)
 40          would map in the instruction.
 41      size: int | list[int]
 42          Size (or sizes of the parts) of the field
 43      signed: bool
 44          Whether this field encodes a signed value.
 45      offset: int
 46          How many bits of this field should be skipped when encoding the instruction.
 47          For example, the immediate of the jump instruction always skips the least
 48          significant bit. This only affects encoding procedures, so externally (for example
 49          when creating an instance of a instruction) full-size values should be always used.
 50      static_value: Optional[Value]
 51          Whether the field should have a static value for a given type of an instruction.
 52      """
 53  
 54      base: int | list[int]
 55      size: int | list[int]
 56  
 57      signed: bool = False
 58      offset: int = 0
 59      static_value: Optional[Value] = None
 60  
 61      _name: str = ""
 62  
 63      def bases(self) -> list[int]:
 64          return [self.base] if isinstance(self.base, int) else self.base
 65  
 66      def sizes(self) -> list[int]:
 67          return [self.size] if isinstance(self.size, int) else self.size
 68  
 69      def shape(self) -> Shape:
 70          return Shape(width=sum(self.sizes()) + self.offset, signed=self.signed)
 71  
 72      def __set_name__(self, owner, name):
 73          self._name = name
 74  
 75      def __get__(self, obj, objtype=None) -> Value:
 76          if self.static_value is not None:
 77              return self.static_value
 78  
 79          return obj.__dict__.get(self._name, C(0, self.shape()))
 80  
 81      def __set__(self, obj, value) -> None:
 82          if self.static_value is not None:
 83              raise AttributeError("Can't overwrite the static value of a field.")
 84  
 85          expected_shape = self.shape()
 86  
 87          field_val: Value = C(0)
 88          if isinstance(value, Enum):
 89              field_val = Const(value.value, expected_shape)
 90          elif isinstance(value, int):
 91              field_val = Const(value, expected_shape)
 92          else:
 93              field_val = Value.cast(value)
 94  
 95              if field_val.shape().width != expected_shape.width:
 96                  raise AttributeError(
 97                      f"Expected width of the value: {expected_shape.width}, given: {field_val.shape().width}"
 98                  )
 99              if field_val.shape().signed and not expected_shape.signed:
100                  raise AttributeError(
101                      f"Expected signedness of the value: {expected_shape.signed}, given: {field_val.shape().signed}"
102                  )
103  
104          obj.__dict__[self._name] = field_val
105  
106      def get_parts(self, value: Value) -> list[Value]:
107          base = self.bases()
108          size = self.sizes()
109          offset = self.offset
110  
111          ret: list[Value] = []
112          for i in range(len(base)):
113              ret.append(value[offset : offset + size[i]])
114              offset += size[i]
115  
116          return ret
117  
118  
119  def _get_fields(cls: type) -> list[Field]:
120      fields = [cls.__dict__[member] for member in vars(cls) if isinstance(cls.__dict__[member], Field)]
121      field_ids = set([id(field) for field in fields])
122      for base in cls.__bases__:
123          for field in _get_fields(base):
124              if id(field) in field_ids:
125                  continue
126              fields.append(field)
127              field_ids.add(id(field))
128  
129      return fields
130  
131  
132  class RISCVInstr(ABC, ValueCastable):
133      opcode = Field(base=0, size=7)
134  
135      def __init__(self, opcode: Opcode):
136          self.opcode = Cat(C(0b11, 2), opcode)
137  
138      def encode(self) -> int:
139          const = Const.cast(self.as_value())
140          return const.value  # type: ignore
141  
142      def as_value(self) -> Value:
143          parts: list[tuple[int, Value]] = []
144  
145          for field in _get_fields(type(self)):
146              value = field.__get__(self, type(self))
147              parts += zip(field.bases(), field.get_parts(value))
148  
149          parts.sort()
150          return Cat([part[1] for part in parts])
151  
152      def shape(self) -> Shape:
153          return self.as_value().shape()
154  
155  
156  class InstructionFunct3Type(RISCVInstr):
157      funct3 = Field(base=12, size=3)
158  
159  
160  class InstructionFunct7Type(RISCVInstr):
161      funct7 = Field(base=25, size=7)
162  
163  
164  class RTypeInstr(InstructionFunct3Type, InstructionFunct7Type):
165      rd = Field(base=7, size=5)
166      rs1 = Field(base=15, size=5)
167      rs2 = Field(base=20, size=5)
168  
169      def __init__(
170          self, opcode: Opcode, funct3: ValueLike, funct7: ValueLike, rd: ValueLike, rs1: ValueLike, rs2: ValueLike
171      ):
172          super().__init__(opcode)
173          self.funct3 = funct3
174          self.funct7 = funct7
175          self.rd = rd
176          self.rs1 = rs1
177          self.rs2 = rs2
178  
179  
180  class ITypeInstr(InstructionFunct3Type):
181      rd = Field(base=7, size=5)
182      rs1 = Field(base=15, size=5)
183      imm = Field(base=20, size=12, signed=True)
184  
185      def __init__(self, opcode: Opcode, funct3: ValueLike, rd: ValueLike, rs1: ValueLike, imm: ValueLike):
186          super().__init__(opcode)
187          self.funct3 = funct3
188          self.rd = rd
189          self.rs1 = rs1
190          self.imm = imm
191  
192  
193  class STypeInstr(InstructionFunct3Type):
194      rs1 = Field(base=15, size=5)
195      rs2 = Field(base=20, size=5)
196      imm = Field(base=[7, 25], size=[5, 7], signed=True)
197  
198      def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike):
199          super().__init__(opcode)
200          self.funct3 = funct3
201          self.rs1 = rs1
202          self.rs2 = rs2
203          self.imm = imm
204  
205  
206  class BTypeInstr(InstructionFunct3Type):
207      rs1 = Field(base=15, size=5)
208      rs2 = Field(base=20, size=5)
209      imm = Field(base=[8, 25, 7, 31], size=[4, 6, 1, 1], offset=1, signed=True)
210  
211      def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike):
212          super().__init__(opcode)
213          self.funct3 = funct3
214          self.rs1 = rs1
215          self.rs2 = rs2
216          self.imm = imm
217  
218  
219  class UTypeInstr(RISCVInstr):
220      rd = Field(base=7, size=5)
221      imm = Field(base=12, size=20, offset=12, signed=True)
222  
223      def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike):
224          super().__init__(opcode)
225          self.rd = rd
226          self.imm = imm
227  
228  
229  class JTypeInstr(RISCVInstr):
230      rd = Field(base=7, size=5)
231      imm = Field(base=[21, 20, 12, 31], size=[10, 1, 8, 1], offset=1, signed=True)
232  
233      def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike):
234          super().__init__(opcode)
235          self.rd = rd
236          self.imm = imm
237  
238  
239  class IllegalInstr(RISCVInstr):
240      illegal = Field(base=7, size=25, static_value=Cat(1).replicate(25))
241  
242      def __init__(self):
243          super().__init__(opcode=Opcode.RESERVED)
244  
245  
246  class EBreakInstr(ITypeInstr):
247      def __init__(self):
248          super().__init__(
249              opcode=Opcode.SYSTEM, rd=Registers.ZERO, funct3=Funct3.PRIV, rs1=Registers.ZERO, imm=Funct12.EBREAK
250          )