priv.py
1 from dataclasses import dataclass, KW_ONLY, field 2 from amaranth import * 3 from amaranth.lib import data 4 5 from enum import IntFlag, auto, unique 6 from typing import Sequence 7 from coreblocks.arch.isa_consts import Funct12, Funct3, Funct7, Opcode, PrivilegeLevel, SatpMode 8 9 10 from transactron import * 11 from transactron.lib import logging 12 from transactron.lib.metrics import TaggedCounter 13 from transactron.lib.simultaneous import condition 14 from transactron.utils import DependencyContext, OneHotSwitch 15 16 from coreblocks.params import * 17 from coreblocks.params import GenParams, FunctionalComponentParams 18 from coreblocks.arch import OpType, ExceptionCause 19 from coreblocks.interface.layouts import PrivUnitLayouts 20 from coreblocks.interface.keys import ( 21 MretKey, 22 SretKey, 23 AsyncInterruptInsertSignalKey, 24 ExceptionReportKey, 25 CSRInstancesKey, 26 InstructionPrecommitKey, 27 UnsafeInstructionResolvedKey, 28 FlushICacheKey, 29 WaitForInterruptResumeKey, 30 ) 31 from coreblocks.func_blocks.interface.func_protocols import FuncUnit 32 33 from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase 34 35 36 log = logging.HardwareLogger("backend.fu.priv") 37 38 39 class PrivilegedFn(DecoderManager): 40 def __init__(self, supervisor_enable=False) -> None: 41 self.supervisor_enable = supervisor_enable 42 43 @unique 44 class Fn(IntFlag): 45 MRET = auto() 46 FENCEI = auto() 47 WFI = auto() 48 SRET = auto() 49 SFENCEVMA = auto() 50 51 def get_instructions(self) -> Sequence[tuple]: 52 return [ 53 (self.Fn.MRET, OpType.MRET), 54 (self.Fn.FENCEI, OpType.FENCEI), 55 (self.Fn.WFI, OpType.WFI), 56 ] + [ 57 (self.Fn.SRET, OpType.SRET), 58 (self.Fn.SFENCEVMA, OpType.SFENCEVMA), 59 ] * self.supervisor_enable 60 61 62 class PrivilegedFuncUnit(FuncUnitBase[PrivilegedFn]): 63 def __init__(self, gen_params: GenParams, fn=PrivilegedFn()): 64 super().__init__(gen_params, fn) 65 66 self.dm = DependencyContext.get() 67 68 self.perf_instr = TaggedCounter( 69 "backend.fu.priv.instr", 70 "Number of instructions precommited with side effects by the privilege unit", 71 tags=PrivilegedFn.Fn, 72 ) 73 74 self.exception_report = self.dm.get_dependency(ExceptionReportKey())() 75 76 def elaborate(self, platform): 77 m = super().elaborate(platform) 78 79 m.submodules += [self.perf_instr] 80 81 instr_valid = Signal() 82 finished = Signal() 83 illegal_instruction = Signal() 84 85 instr_rob = Signal(self.gen_params.rob_entries_bits) 86 instr_pc = Signal(self.gen_params.isa.xlen) 87 instr_fn = self.fn.get_function() 88 89 instr_imm = Signal(self.gen_params.isa.xlen) 90 instr_s1_val = Signal(self.gen_params.isa.xlen) 91 instr_s2_val = Signal(self.gen_params.isa.xlen) 92 93 mret = self.dm.get_dependency(MretKey()) 94 sret = self.dm.get_optional_dependency(SretKey()) 95 async_interrupt_active = self.dm.get_dependency(AsyncInterruptInsertSignalKey()) 96 wfi_resume = self.dm.get_dependency(WaitForInterruptResumeKey()) 97 csr = self.dm.get_dependency(CSRInstancesKey()) 98 priv_mode = csr.m_mode.priv_mode 99 flush_icache = self.dm.get_dependency(FlushICacheKey()) 100 resume_core = self.dm.get_dependency(UnsafeInstructionResolvedKey()) 101 102 @def_method(m, self.issue_decoded, ready=~instr_valid) 103 def _(arg): 104 m.d.sync += [ 105 instr_valid.eq(1), 106 instr_rob.eq(arg.rob_id), 107 instr_pc.eq(arg.pc), 108 instr_fn.eq(arg.decode_fn), 109 instr_s1_val.eq(arg.s1_val), 110 instr_s2_val.eq(arg.s2_val), 111 instr_imm.eq(arg.imm), 112 ] 113 114 with Transaction().body(m, ready=instr_valid & ~finished): 115 precommit = self.dm.get_dependency(InstructionPrecommitKey()) 116 info = precommit(m, instr_rob) 117 m.d.sync += finished.eq(1) 118 self.perf_instr.incr(m, instr_fn, enable_call=info.side_fx) 119 120 priv_data = priv_mode.read(m).data 121 122 illegal_mret = (instr_fn == PrivilegedFn.Fn.MRET) & (priv_data != PrivilegeLevel.MACHINE) 123 124 if self.fn.supervisor_enable: 125 illegal_sret = (instr_fn == PrivilegedFn.Fn.SRET) & ( 126 (priv_data == PrivilegeLevel.USER) 127 | ((priv_data == PrivilegeLevel.SUPERVISOR) & csr.m_mode.mstatus_tsr.read(m).data) 128 ) 129 else: 130 illegal_sret = 0 131 132 if self.fn.supervisor_enable: 133 illegal_sfencevma = (instr_fn == PrivilegedFn.Fn.SFENCEVMA) & ( 134 (priv_data == PrivilegeLevel.USER) 135 | ((priv_data == PrivilegeLevel.SUPERVISOR) & csr.m_mode.mstatus_tvm.read(m).data) 136 ) 137 else: 138 illegal_sfencevma = 0 139 140 illegal_wfi = (instr_fn == PrivilegedFn.Fn.WFI) & ( 141 ((priv_data == PrivilegeLevel.USER) if self.gen_params.supervisor_mode else 0) 142 | ((priv_data < PrivilegeLevel.MACHINE) & (csr.m_mode.mstatus_tw.read(m).data)) 143 ) 144 145 with condition(m, nonblocking=True) as branch: 146 with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.MRET) & ~illegal_mret): 147 mret(m) 148 if self.fn.supervisor_enable: 149 assert sret is not None 150 with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.SRET) & ~illegal_sret): 151 sret(m) 152 153 # TODO: implement proper SFENCE.VMA, for BARE only - NO-OP is ok 154 assert self.gen_params.vmem_params.supported_schemes == {SatpMode.BARE} 155 156 with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.FENCEI)): 157 flush_icache(m) 158 with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.WFI) & ~illegal_wfi): 159 # async_interrupt_active implies wfi_resume. WFI should continue normal execution 160 # when interrupt is enabled in xie, but disabled via global mstatus.xIE 161 m.d.sync += finished.eq(wfi_resume) 162 163 m.d.sync += illegal_instruction.eq(illegal_wfi | illegal_mret | illegal_sret | illegal_sfencevma) 164 165 with Transaction().body(m, ready=instr_valid & finished): 166 m.d.sync += instr_valid.eq(0) 167 m.d.sync += finished.eq(0) 168 169 ret_pc = Signal(self.gen_params.isa.xlen) 170 171 with OneHotSwitch(m, instr_fn) as OneHotCase: 172 with OneHotCase(PrivilegedFn.Fn.MRET): 173 m.d.av_comb += ret_pc.eq(csr.m_mode.mepc.read(m).data) 174 if self.fn.supervisor_enable: 175 with OneHotCase(PrivilegedFn.Fn.SRET): 176 m.d.av_comb += ret_pc.eq(csr.s_mode.sepc.read(m).data) 177 # SFENCE.VMA, FENCE.I and WFI can't be compressed, so next PC is always pc+4 178 if self.fn.supervisor_enable: 179 with OneHotCase(PrivilegedFn.Fn.SFENCEVMA): 180 m.d.av_comb += ret_pc.eq(instr_pc + 4) 181 with OneHotCase(PrivilegedFn.Fn.FENCEI): 182 m.d.av_comb += ret_pc.eq(instr_pc + 4) 183 with OneHotCase(PrivilegedFn.Fn.WFI): 184 m.d.av_comb += ret_pc.eq(instr_pc + 4) 185 186 with m.If(illegal_instruction): 187 m.d.av_comb += ret_pc.eq(instr_pc) 188 189 exception = Signal() 190 with m.If(illegal_instruction): 191 m.d.av_comb += exception.eq(1) 192 193 # Replace with const zero if turns out not worth to re-encode instruction 194 instr = Signal(self.gen_params.isa.xlen) 195 m.d.av_comb += instr[0:2].eq(0b11) 196 m.d.av_comb += instr[2:7].eq(Opcode.SYSTEM) 197 m.d.av_comb += instr[7:12].eq(0) 198 m.d.av_comb += instr[12:15].eq(Funct3.PRIV) 199 m.d.av_comb += instr[15:20].eq(0) 200 with m.Switch(instr_fn): 201 with m.Case(PrivilegedFn.Fn.MRET): 202 m.d.av_comb += instr[20:32].eq(Funct12.MRET) 203 with m.Case(PrivilegedFn.Fn.WFI): 204 m.d.av_comb += instr[20:32].eq(Funct12.WFI) 205 if self.fn.supervisor_enable: 206 with m.Case(PrivilegedFn.Fn.SRET): 207 m.d.av_comb += instr[20:32].eq(Funct12.SRET) 208 with m.Case(PrivilegedFn.Fn.SFENCEVMA): 209 imm_view = data.View(self.gen_params.get(PrivUnitLayouts).sfencevma_imm_layout, instr_imm) 210 m.d.av_comb += instr[15:20].eq(imm_view.rs1) 211 m.d.av_comb += instr[20:25].eq(imm_view.rs2) 212 m.d.av_comb += instr[25:32].eq(Funct7.SFENCEVMA) 213 with m.Default(): 214 log.error(m, True, "missing Funct12 case") 215 216 self.exception_report( 217 m, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=ret_pc, rob_id=instr_rob, mtval=instr 218 ) 219 with m.Elif(async_interrupt_active): 220 # SPEC: "These conditions for an interrupt trap to occur [..] must also be evaluated immediately 221 # following the execution of an xRET instruction." 222 # mret() method is called from precommit() that was executed at least one cycle earlier (because 223 # of finished condition). If calling mret() caused interrupt to be active, it is already represented 224 # by updated async_interrupt_active signal. 225 # Interrupt is reported on this xRET instruction with return address set to instruction that we 226 # would normally return to (mepc value is preserved) 227 m.d.av_comb += exception.eq(1) 228 self.exception_report( 229 m, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=ret_pc, rob_id=instr_rob, mtval=0 230 ) 231 with m.Else(): 232 log.info(m, True, "Unstalling fetch from the priv unit new_pc=0x{:x}", ret_pc) 233 # Unstall the fetch 234 resume_core(m, pc=ret_pc) 235 236 self.push_result( 237 m, 238 rob_id=instr_rob, 239 exception=exception, 240 rp_dst=0, 241 result=0, 242 ) 243 244 return m 245 246 247 @dataclass(frozen=True) 248 class PrivilegedUnitComponent(FunctionalComponentParams): 249 _: KW_ONLY 250 supervisor_enable: bool = False 251 decoder_manager: PrivilegedFn = field(init=False) 252 253 def get_decoder_manager(self): 254 return PrivilegedFn(supervisor_enable=self.supervisor_enable) 255 256 def get_module(self, gen_params: GenParams) -> FuncUnit: 257 assert self.supervisor_enable == gen_params.supervisor_mode 258 259 return PrivilegedFuncUnit(gen_params, self.decoder_manager)