/ coreblocks / func_blocks / fu / priv.py
priv.py
  1  from dataclasses import dataclass, KW_ONLY, field
  2  from amaranth import *
  3  from amaranth.lib import data
  4  
  5  from enum import IntFlag, auto, unique
  6  from typing import Sequence
  7  from coreblocks.arch.isa_consts import Funct12, Funct3, Funct7, Opcode, PrivilegeLevel, SatpMode
  8  
  9  
 10  from transactron import *
 11  from transactron.lib import logging
 12  from transactron.lib.metrics import TaggedCounter
 13  from transactron.lib.simultaneous import condition
 14  from transactron.utils import DependencyContext, OneHotSwitch
 15  
 16  from coreblocks.params import *
 17  from coreblocks.params import GenParams, FunctionalComponentParams
 18  from coreblocks.arch import OpType, ExceptionCause
 19  from coreblocks.interface.layouts import PrivUnitLayouts
 20  from coreblocks.interface.keys import (
 21      MretKey,
 22      SretKey,
 23      AsyncInterruptInsertSignalKey,
 24      ExceptionReportKey,
 25      CSRInstancesKey,
 26      InstructionPrecommitKey,
 27      UnsafeInstructionResolvedKey,
 28      FlushICacheKey,
 29      WaitForInterruptResumeKey,
 30  )
 31  from coreblocks.func_blocks.interface.func_protocols import FuncUnit
 32  
 33  from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase
 34  
 35  
 36  log = logging.HardwareLogger("backend.fu.priv")
 37  
 38  
 39  class PrivilegedFn(DecoderManager):
 40      def __init__(self, supervisor_enable=False) -> None:
 41          self.supervisor_enable = supervisor_enable
 42  
 43      @unique
 44      class Fn(IntFlag):
 45          MRET = auto()
 46          FENCEI = auto()
 47          WFI = auto()
 48          SRET = auto()
 49          SFENCEVMA = auto()
 50  
 51      def get_instructions(self) -> Sequence[tuple]:
 52          return [
 53              (self.Fn.MRET, OpType.MRET),
 54              (self.Fn.FENCEI, OpType.FENCEI),
 55              (self.Fn.WFI, OpType.WFI),
 56          ] + [
 57              (self.Fn.SRET, OpType.SRET),
 58              (self.Fn.SFENCEVMA, OpType.SFENCEVMA),
 59          ] * self.supervisor_enable
 60  
 61  
 62  class PrivilegedFuncUnit(FuncUnitBase[PrivilegedFn]):
 63      def __init__(self, gen_params: GenParams, fn=PrivilegedFn()):
 64          super().__init__(gen_params, fn)
 65  
 66          self.dm = DependencyContext.get()
 67  
 68          self.perf_instr = TaggedCounter(
 69              "backend.fu.priv.instr",
 70              "Number of instructions precommited with side effects by the privilege unit",
 71              tags=PrivilegedFn.Fn,
 72          )
 73  
 74          self.exception_report = self.dm.get_dependency(ExceptionReportKey())()
 75  
 76      def elaborate(self, platform):
 77          m = super().elaborate(platform)
 78  
 79          m.submodules += [self.perf_instr]
 80  
 81          instr_valid = Signal()
 82          finished = Signal()
 83          illegal_instruction = Signal()
 84  
 85          instr_rob = Signal(self.gen_params.rob_entries_bits)
 86          instr_pc = Signal(self.gen_params.isa.xlen)
 87          instr_fn = self.fn.get_function()
 88  
 89          instr_imm = Signal(self.gen_params.isa.xlen)
 90          instr_s1_val = Signal(self.gen_params.isa.xlen)
 91          instr_s2_val = Signal(self.gen_params.isa.xlen)
 92  
 93          mret = self.dm.get_dependency(MretKey())
 94          sret = self.dm.get_optional_dependency(SretKey())
 95          async_interrupt_active = self.dm.get_dependency(AsyncInterruptInsertSignalKey())
 96          wfi_resume = self.dm.get_dependency(WaitForInterruptResumeKey())
 97          csr = self.dm.get_dependency(CSRInstancesKey())
 98          priv_mode = csr.m_mode.priv_mode
 99          flush_icache = self.dm.get_dependency(FlushICacheKey())
100          resume_core = self.dm.get_dependency(UnsafeInstructionResolvedKey())
101  
102          @def_method(m, self.issue_decoded, ready=~instr_valid)
103          def _(arg):
104              m.d.sync += [
105                  instr_valid.eq(1),
106                  instr_rob.eq(arg.rob_id),
107                  instr_pc.eq(arg.pc),
108                  instr_fn.eq(arg.decode_fn),
109                  instr_s1_val.eq(arg.s1_val),
110                  instr_s2_val.eq(arg.s2_val),
111                  instr_imm.eq(arg.imm),
112              ]
113  
114          with Transaction().body(m, ready=instr_valid & ~finished):
115              precommit = self.dm.get_dependency(InstructionPrecommitKey())
116              info = precommit(m, instr_rob)
117              m.d.sync += finished.eq(1)
118              self.perf_instr.incr(m, instr_fn, enable_call=info.side_fx)
119  
120              priv_data = priv_mode.read(m).data
121  
122              illegal_mret = (instr_fn == PrivilegedFn.Fn.MRET) & (priv_data != PrivilegeLevel.MACHINE)
123  
124              if self.fn.supervisor_enable:
125                  illegal_sret = (instr_fn == PrivilegedFn.Fn.SRET) & (
126                      (priv_data == PrivilegeLevel.USER)
127                      | ((priv_data == PrivilegeLevel.SUPERVISOR) & csr.m_mode.mstatus_tsr.read(m).data)
128                  )
129              else:
130                  illegal_sret = 0
131  
132              if self.fn.supervisor_enable:
133                  illegal_sfencevma = (instr_fn == PrivilegedFn.Fn.SFENCEVMA) & (
134                      (priv_data == PrivilegeLevel.USER)
135                      | ((priv_data == PrivilegeLevel.SUPERVISOR) & csr.m_mode.mstatus_tvm.read(m).data)
136                  )
137              else:
138                  illegal_sfencevma = 0
139  
140              illegal_wfi = (instr_fn == PrivilegedFn.Fn.WFI) & (
141                  ((priv_data == PrivilegeLevel.USER) if self.gen_params.supervisor_mode else 0)
142                  | ((priv_data < PrivilegeLevel.MACHINE) & (csr.m_mode.mstatus_tw.read(m).data))
143              )
144  
145              with condition(m, nonblocking=True) as branch:
146                  with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.MRET) & ~illegal_mret):
147                      mret(m)
148                  if self.fn.supervisor_enable:
149                      assert sret is not None
150                      with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.SRET) & ~illegal_sret):
151                          sret(m)
152  
153                      # TODO: implement proper SFENCE.VMA, for BARE only - NO-OP is ok
154                      assert self.gen_params.vmem_params.supported_schemes == {SatpMode.BARE}
155  
156                  with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.FENCEI)):
157                      flush_icache(m)
158                  with branch(info.side_fx & (instr_fn == PrivilegedFn.Fn.WFI) & ~illegal_wfi):
159                      # async_interrupt_active implies wfi_resume. WFI should continue normal execution
160                      # when interrupt is enabled in xie, but disabled via global mstatus.xIE
161                      m.d.sync += finished.eq(wfi_resume)
162  
163              m.d.sync += illegal_instruction.eq(illegal_wfi | illegal_mret | illegal_sret | illegal_sfencevma)
164  
165          with Transaction().body(m, ready=instr_valid & finished):
166              m.d.sync += instr_valid.eq(0)
167              m.d.sync += finished.eq(0)
168  
169              ret_pc = Signal(self.gen_params.isa.xlen)
170  
171              with OneHotSwitch(m, instr_fn) as OneHotCase:
172                  with OneHotCase(PrivilegedFn.Fn.MRET):
173                      m.d.av_comb += ret_pc.eq(csr.m_mode.mepc.read(m).data)
174                  if self.fn.supervisor_enable:
175                      with OneHotCase(PrivilegedFn.Fn.SRET):
176                          m.d.av_comb += ret_pc.eq(csr.s_mode.sepc.read(m).data)
177                  # SFENCE.VMA, FENCE.I and WFI can't be compressed, so next PC is always pc+4
178                  if self.fn.supervisor_enable:
179                      with OneHotCase(PrivilegedFn.Fn.SFENCEVMA):
180                          m.d.av_comb += ret_pc.eq(instr_pc + 4)
181                  with OneHotCase(PrivilegedFn.Fn.FENCEI):
182                      m.d.av_comb += ret_pc.eq(instr_pc + 4)
183                  with OneHotCase(PrivilegedFn.Fn.WFI):
184                      m.d.av_comb += ret_pc.eq(instr_pc + 4)
185  
186              with m.If(illegal_instruction):
187                  m.d.av_comb += ret_pc.eq(instr_pc)
188  
189              exception = Signal()
190              with m.If(illegal_instruction):
191                  m.d.av_comb += exception.eq(1)
192  
193                  # Replace with const zero if turns out not worth to re-encode instruction
194                  instr = Signal(self.gen_params.isa.xlen)
195                  m.d.av_comb += instr[0:2].eq(0b11)
196                  m.d.av_comb += instr[2:7].eq(Opcode.SYSTEM)
197                  m.d.av_comb += instr[7:12].eq(0)
198                  m.d.av_comb += instr[12:15].eq(Funct3.PRIV)
199                  m.d.av_comb += instr[15:20].eq(0)
200                  with m.Switch(instr_fn):
201                      with m.Case(PrivilegedFn.Fn.MRET):
202                          m.d.av_comb += instr[20:32].eq(Funct12.MRET)
203                      with m.Case(PrivilegedFn.Fn.WFI):
204                          m.d.av_comb += instr[20:32].eq(Funct12.WFI)
205                      if self.fn.supervisor_enable:
206                          with m.Case(PrivilegedFn.Fn.SRET):
207                              m.d.av_comb += instr[20:32].eq(Funct12.SRET)
208                          with m.Case(PrivilegedFn.Fn.SFENCEVMA):
209                              imm_view = data.View(self.gen_params.get(PrivUnitLayouts).sfencevma_imm_layout, instr_imm)
210                              m.d.av_comb += instr[15:20].eq(imm_view.rs1)
211                              m.d.av_comb += instr[20:25].eq(imm_view.rs2)
212                              m.d.av_comb += instr[25:32].eq(Funct7.SFENCEVMA)
213                      with m.Default():
214                          log.error(m, True, "missing Funct12 case")
215  
216                  self.exception_report(
217                      m, cause=ExceptionCause.ILLEGAL_INSTRUCTION, pc=ret_pc, rob_id=instr_rob, mtval=instr
218                  )
219              with m.Elif(async_interrupt_active):
220                  # SPEC: "These conditions for an interrupt trap to occur [..] must also be evaluated immediately
221                  # following the execution of an xRET instruction."
222                  # mret() method is called from precommit() that was executed at least one cycle earlier (because
223                  # of finished condition). If calling mret() caused interrupt to be active, it is already represented
224                  # by updated async_interrupt_active signal.
225                  # Interrupt is reported on this xRET instruction with return address set to instruction that we
226                  # would normally return to (mepc value is preserved)
227                  m.d.av_comb += exception.eq(1)
228                  self.exception_report(
229                      m, cause=ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT, pc=ret_pc, rob_id=instr_rob, mtval=0
230                  )
231              with m.Else():
232                  log.info(m, True, "Unstalling fetch from the priv unit new_pc=0x{:x}", ret_pc)
233                  # Unstall the fetch
234                  resume_core(m, pc=ret_pc)
235  
236              self.push_result(
237                  m,
238                  rob_id=instr_rob,
239                  exception=exception,
240                  rp_dst=0,
241                  result=0,
242              )
243  
244          return m
245  
246  
247  @dataclass(frozen=True)
248  class PrivilegedUnitComponent(FunctionalComponentParams):
249      _: KW_ONLY
250      supervisor_enable: bool = False
251      decoder_manager: PrivilegedFn = field(init=False)
252  
253      def get_decoder_manager(self):
254          return PrivilegedFn(supervisor_enable=self.supervisor_enable)
255  
256      def get_module(self, gen_params: GenParams) -> FuncUnit:
257          assert self.supervisor_enable == gen_params.supervisor_mode
258  
259          return PrivilegedFuncUnit(gen_params, self.decoder_manager)