instr.py
1 """ 2 3 Based on riscv-python-model by Stefan Wallentowitz 4 https://github.com/wallento/riscv-python-model 5 """ 6 7 from dataclasses import dataclass 8 from abc import ABC 9 from enum import Enum 10 from typing import Optional 11 12 from amaranth.hdl import ValueCastable 13 from amaranth import * 14 from amaranth_types import ValueLike 15 16 from coreblocks.arch import Opcode, Registers, Funct3, Funct12 17 18 19 __all__ = [ 20 "RISCVInstr", 21 "RTypeInstr", 22 "ITypeInstr", 23 "STypeInstr", 24 "BTypeInstr", 25 "UTypeInstr", 26 "JTypeInstr", 27 "IllegalInstr", 28 "EBreakInstr", 29 ] 30 31 32 @dataclass(kw_only=True) 33 class Field: 34 """Information about a field in a RISC-V instruction. 35 36 Attributes 37 ---------- 38 base: int | list[int] 39 A bit position (or a list of positions) where this field (or parts of the field) 40 would map in the instruction. 41 size: int | list[int] 42 Size (or sizes of the parts) of the field 43 signed: bool 44 Whether this field encodes a signed value. 45 offset: int 46 How many bits of this field should be skipped when encoding the instruction. 47 For example, the immediate of the jump instruction always skips the least 48 significant bit. This only affects encoding procedures, so externally (for example 49 when creating an instance of a instruction) full-size values should be always used. 50 static_value: Optional[Value] 51 Whether the field should have a static value for a given type of an instruction. 52 """ 53 54 base: int | list[int] 55 size: int | list[int] 56 57 signed: bool = False 58 offset: int = 0 59 static_value: Optional[Value] = None 60 61 _name: str = "" 62 63 def bases(self) -> list[int]: 64 return [self.base] if isinstance(self.base, int) else self.base 65 66 def sizes(self) -> list[int]: 67 return [self.size] if isinstance(self.size, int) else self.size 68 69 def shape(self) -> Shape: 70 return Shape(width=sum(self.sizes()) + self.offset, signed=self.signed) 71 72 def __set_name__(self, owner, name): 73 self._name = name 74 75 def __get__(self, obj, objtype=None) -> Value: 76 if self.static_value is not None: 77 return self.static_value 78 79 return obj.__dict__.get(self._name, C(0, self.shape())) 80 81 def __set__(self, obj, value) -> None: 82 if self.static_value is not None: 83 raise AttributeError("Can't overwrite the static value of a field.") 84 85 expected_shape = self.shape() 86 87 field_val: Value = C(0) 88 if isinstance(value, Enum): 89 field_val = Const(value.value, expected_shape) 90 elif isinstance(value, int): 91 field_val = Const(value, expected_shape) 92 else: 93 field_val = Value.cast(value) 94 95 if field_val.shape().width != expected_shape.width: 96 raise AttributeError( 97 f"Expected width of the value: {expected_shape.width}, given: {field_val.shape().width}" 98 ) 99 if field_val.shape().signed and not expected_shape.signed: 100 raise AttributeError( 101 f"Expected signedness of the value: {expected_shape.signed}, given: {field_val.shape().signed}" 102 ) 103 104 obj.__dict__[self._name] = field_val 105 106 def get_parts(self, value: Value) -> list[Value]: 107 base = self.bases() 108 size = self.sizes() 109 offset = self.offset 110 111 ret: list[Value] = [] 112 for i in range(len(base)): 113 ret.append(value[offset : offset + size[i]]) 114 offset += size[i] 115 116 return ret 117 118 119 def _get_fields(cls: type) -> list[Field]: 120 fields = [cls.__dict__[member] for member in vars(cls) if isinstance(cls.__dict__[member], Field)] 121 field_ids = set([id(field) for field in fields]) 122 for base in cls.__bases__: 123 for field in _get_fields(base): 124 if id(field) in field_ids: 125 continue 126 fields.append(field) 127 field_ids.add(id(field)) 128 129 return fields 130 131 132 class RISCVInstr(ABC, ValueCastable): 133 opcode = Field(base=0, size=7) 134 135 def __init__(self, opcode: Opcode): 136 self.opcode = Cat(C(0b11, 2), opcode) 137 138 def encode(self) -> int: 139 const = Const.cast(self.as_value()) 140 return const.value # type: ignore 141 142 def as_value(self) -> Value: 143 parts: list[tuple[int, Value]] = [] 144 145 for field in _get_fields(type(self)): 146 value = field.__get__(self, type(self)) 147 parts += zip(field.bases(), field.get_parts(value)) 148 149 parts.sort() 150 return Cat([part[1] for part in parts]) 151 152 def shape(self) -> Shape: 153 return self.as_value().shape() 154 155 156 class InstructionFunct3Type(RISCVInstr): 157 funct3 = Field(base=12, size=3) 158 159 160 class InstructionFunct7Type(RISCVInstr): 161 funct7 = Field(base=25, size=7) 162 163 164 class RTypeInstr(InstructionFunct3Type, InstructionFunct7Type): 165 rd = Field(base=7, size=5) 166 rs1 = Field(base=15, size=5) 167 rs2 = Field(base=20, size=5) 168 169 def __init__( 170 self, opcode: Opcode, funct3: ValueLike, funct7: ValueLike, rd: ValueLike, rs1: ValueLike, rs2: ValueLike 171 ): 172 super().__init__(opcode) 173 self.funct3 = funct3 174 self.funct7 = funct7 175 self.rd = rd 176 self.rs1 = rs1 177 self.rs2 = rs2 178 179 180 class ITypeInstr(InstructionFunct3Type): 181 rd = Field(base=7, size=5) 182 rs1 = Field(base=15, size=5) 183 imm = Field(base=20, size=12, signed=True) 184 185 def __init__(self, opcode: Opcode, funct3: ValueLike, rd: ValueLike, rs1: ValueLike, imm: ValueLike): 186 super().__init__(opcode) 187 self.funct3 = funct3 188 self.rd = rd 189 self.rs1 = rs1 190 self.imm = imm 191 192 193 class STypeInstr(InstructionFunct3Type): 194 rs1 = Field(base=15, size=5) 195 rs2 = Field(base=20, size=5) 196 imm = Field(base=[7, 25], size=[5, 7], signed=True) 197 198 def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike): 199 super().__init__(opcode) 200 self.funct3 = funct3 201 self.rs1 = rs1 202 self.rs2 = rs2 203 self.imm = imm 204 205 206 class BTypeInstr(InstructionFunct3Type): 207 rs1 = Field(base=15, size=5) 208 rs2 = Field(base=20, size=5) 209 imm = Field(base=[8, 25, 7, 31], size=[4, 6, 1, 1], offset=1, signed=True) 210 211 def __init__(self, opcode: Opcode, funct3: ValueLike, rs1: ValueLike, rs2: ValueLike, imm: ValueLike): 212 super().__init__(opcode) 213 self.funct3 = funct3 214 self.rs1 = rs1 215 self.rs2 = rs2 216 self.imm = imm 217 218 219 class UTypeInstr(RISCVInstr): 220 rd = Field(base=7, size=5) 221 imm = Field(base=12, size=20, offset=12, signed=True) 222 223 def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike): 224 super().__init__(opcode) 225 self.rd = rd 226 self.imm = imm 227 228 229 class JTypeInstr(RISCVInstr): 230 rd = Field(base=7, size=5) 231 imm = Field(base=[21, 20, 12, 31], size=[10, 1, 8, 1], offset=1, signed=True) 232 233 def __init__(self, opcode: Opcode, rd: ValueLike, imm: ValueLike): 234 super().__init__(opcode) 235 self.rd = rd 236 self.imm = imm 237 238 239 class IllegalInstr(RISCVInstr): 240 illegal = Field(base=7, size=25, static_value=Cat(1).replicate(25)) 241 242 def __init__(self): 243 super().__init__(opcode=Opcode.RESERVED) 244 245 246 class EBreakInstr(ITypeInstr): 247 def __init__(self): 248 super().__init__( 249 opcode=Opcode.SYSTEM, rd=Registers.ZERO, funct3=Funct3.PRIV, rs1=Registers.ZERO, imm=Funct12.EBREAK 250 )