/ coreblocks / priv / pmp.py
pmp.py
  1  from amaranth import *
  2  from amaranth.lib import data
  3  from amaranth.lib.enum import Enum, auto, unique
  4  from amaranth_types import HasElaborate
  5  from transactron.core import TModule
  6  from transactron.utils import DependencyContext, assign
  7  
  8  from coreblocks.arch.isa_consts import PMPAFlagEncoding, PMPCfgLayout, PrivilegeLevel
  9  from coreblocks.interface.keys import CSRInstancesKey
 10  from coreblocks.params import *
 11  
 12  
 13  class PMPLayout(data.StructLayout):
 14      def __init__(self):
 15          super().__init__({"r": 1, "w": 1, "x": 1})
 16  
 17  
 18  @unique
 19  class PMPOperationMode(Enum):
 20      LSU = auto()
 21      INSTRUCTION_FETCH = auto()
 22      MMU = auto()
 23  
 24  
 25  class PMPChecker(Elaboratable):
 26      """
 27      Implementation of physical memory protection checker.
 28      This is a combinational circuit with return value read from `result` output.
 29  
 30      Effective mode depends on `mode`:
 31      - LSU: MPRV-aware (using MPP when MPRV=1 and current mode is M)
 32      - INSTRUCTION_FETCH: uses only current privilege mode
 33      - MMU: always behaves as supervisor mode
 34  
 35      In machine mode, accesses bypass PMP by default (result = 1/1/1) unless a
 36      matching locked entry (L=1) is found. S/U-mode accesses default to no
 37      access (0/0/0) if no entry matches.
 38  
 39      Attributes
 40      ----------
 41      paddr : Signal
 42          Memory address, for which PMP checks are requested.
 43      result : PMPLayout
 44          RWX permission bits for the given address based on current PMP configuration
 45          and privilege mode. Bits are set to 0 if access is denied.
 46      """
 47  
 48      def __init__(self, gen_params: GenParams, *, mode: PMPOperationMode) -> None:
 49          self.gen_params = gen_params
 50          self.mode = mode
 51          self.csr = DependencyContext.get().get_dependency(CSRInstancesKey()).m_mode
 52          self.paddr = Signal(gen_params.phys_addr_bits)
 53          self.result = Signal(PMPLayout())
 54  
 55      def elaborate(self, platform) -> HasElaborate:
 56          m = TModule()
 57  
 58          grain = self.gen_params.pmp_grain
 59          n = self.gen_params.pmp_register_count
 60  
 61          if n == 0:
 62              m.d.comb += self.result.eq(PMPLayout().const({"r": 1, "w": 1, "x": 1}))
 63              return m
 64  
 65          priv_mode = self.csr.priv_mode.value
 66          mprv = self.csr.mstatus_mprv.value
 67          mpp = self.csr.mstatus_mpp.value
 68  
 69          effective_priv_mode = Signal(PrivilegeLevel)
 70          match self.mode:
 71              case PMPOperationMode.LSU:
 72                  m.d.comb += effective_priv_mode.eq(Mux(mprv, mpp, priv_mode))
 73              case PMPOperationMode.INSTRUCTION_FETCH:
 74                  m.d.comb += effective_priv_mode.eq(priv_mode)
 75              case PMPOperationMode.MMU:
 76                  m.d.comb += effective_priv_mode.eq(PrivilegeLevel.SUPERVISOR)
 77  
 78          with m.If(effective_priv_mode == PrivilegeLevel.MACHINE):
 79              m.d.comb += self.result.eq(PMPLayout().const({"r": 1, "w": 1, "x": 1}))
 80  
 81          entry_matches = []
 82          cfgs = []
 83          addr_vals = []
 84  
 85          for i in range(n):
 86              cfg_val = data.View(PMPCfgLayout(), self.csr.pmpxcfg[i].value)
 87              addr_val = self.csr.pmpaddrx[i].value
 88              cfgs.append(cfg_val)
 89              addr_vals.append(addr_val)
 90  
 91              entry_match = Signal(name=f"match_{i}")
 92  
 93              with m.Switch(cfg_val.A):
 94                  with m.Case(PMPAFlagEncoding.OFF):
 95                      m.d.comb += entry_match.eq(0)
 96                  with m.Case(PMPAFlagEncoding.TOR):
 97                      lower = addr_vals[i - 1][grain:] if i > 0 else 0
 98                      m.d.comb += entry_match.eq(
 99                          (self.paddr[2 + grain :] >= lower) & (self.paddr[2 + grain :] < addr_val[grain:])
100                      )
101                  with m.Case(PMPAFlagEncoding.NA4):
102                      if grain == 0:
103                          m.d.comb += entry_match.eq(self.paddr[2:] == addr_val)
104                      else:
105                          m.d.comb += entry_match.eq(0)
106                  with m.Case(PMPAFlagEncoding.NAPOT):
107                      # NAPOT region size is encoded by trailing ones in pmpaddr.
108                      # XOR with (pmpaddr + 1) extracts those trailing ones as a mask.
109                      # Bits below the mask define the region; bits above must match.
110                      # With grain > 0, lower bits are forced to 1 so we skip them.
111                      start_bit = max(0, grain - 1)
112                      napot_mask = addr_val[start_bit:] ^ (addr_val[start_bit:] + 1)
113                      m.d.comb += entry_match.eq(
114                          (self.paddr[2 + start_bit :] & ~napot_mask) == (addr_val[start_bit:] & ~napot_mask)
115                      )
116  
117              entry_matches.append(entry_match)
118  
119          matches = Cat(entry_matches)
120          one_hot = matches & (~matches + 1)
121  
122          r_bits = Cat(cfg.R for cfg in cfgs)
123          w_bits = Cat(cfg.W for cfg in cfgs)
124          x_bits = Cat(cfg.X for cfg in cfgs)
125          l_bits = Cat(cfg.L for cfg in cfgs)
126  
127          selected_r = (one_hot & r_bits).any()
128          selected_w = (one_hot & w_bits).any()
129          selected_x = (one_hot & x_bits).any()
130          selected_l = (one_hot & l_bits).any()
131  
132          with m.If(matches.any() & ((effective_priv_mode != PrivilegeLevel.MACHINE) | selected_l)):
133              m.d.comb += assign(self.result, {"r": selected_r, "w": selected_w, "x": selected_x})
134  
135          return m