pmp.py
1 from amaranth import * 2 from amaranth.lib import data 3 from amaranth.lib.enum import Enum, auto, unique 4 from amaranth_types import HasElaborate 5 from transactron.core import TModule 6 from transactron.utils import DependencyContext, assign 7 8 from coreblocks.arch.isa_consts import PMPAFlagEncoding, PMPCfgLayout, PrivilegeLevel 9 from coreblocks.interface.keys import CSRInstancesKey 10 from coreblocks.params import * 11 12 13 class PMPLayout(data.StructLayout): 14 def __init__(self): 15 super().__init__({"r": 1, "w": 1, "x": 1}) 16 17 18 @unique 19 class PMPOperationMode(Enum): 20 LSU = auto() 21 INSTRUCTION_FETCH = auto() 22 MMU = auto() 23 24 25 class PMPChecker(Elaboratable): 26 """ 27 Implementation of physical memory protection checker. 28 This is a combinational circuit with return value read from `result` output. 29 30 Effective mode depends on `mode`: 31 - LSU: MPRV-aware (using MPP when MPRV=1 and current mode is M) 32 - INSTRUCTION_FETCH: uses only current privilege mode 33 - MMU: always behaves as supervisor mode 34 35 In machine mode, accesses bypass PMP by default (result = 1/1/1) unless a 36 matching locked entry (L=1) is found. S/U-mode accesses default to no 37 access (0/0/0) if no entry matches. 38 39 Attributes 40 ---------- 41 paddr : Signal 42 Memory address, for which PMP checks are requested. 43 result : PMPLayout 44 RWX permission bits for the given address based on current PMP configuration 45 and privilege mode. Bits are set to 0 if access is denied. 46 """ 47 48 def __init__(self, gen_params: GenParams, *, mode: PMPOperationMode) -> None: 49 self.gen_params = gen_params 50 self.mode = mode 51 self.csr = DependencyContext.get().get_dependency(CSRInstancesKey()).m_mode 52 self.paddr = Signal(gen_params.phys_addr_bits) 53 self.result = Signal(PMPLayout()) 54 55 def elaborate(self, platform) -> HasElaborate: 56 m = TModule() 57 58 grain = self.gen_params.pmp_grain 59 n = self.gen_params.pmp_register_count 60 61 if n == 0: 62 m.d.comb += self.result.eq(PMPLayout().const({"r": 1, "w": 1, "x": 1})) 63 return m 64 65 priv_mode = self.csr.priv_mode.value 66 mprv = self.csr.mstatus_mprv.value 67 mpp = self.csr.mstatus_mpp.value 68 69 effective_priv_mode = Signal(PrivilegeLevel) 70 match self.mode: 71 case PMPOperationMode.LSU: 72 m.d.comb += effective_priv_mode.eq(Mux(mprv, mpp, priv_mode)) 73 case PMPOperationMode.INSTRUCTION_FETCH: 74 m.d.comb += effective_priv_mode.eq(priv_mode) 75 case PMPOperationMode.MMU: 76 m.d.comb += effective_priv_mode.eq(PrivilegeLevel.SUPERVISOR) 77 78 with m.If(effective_priv_mode == PrivilegeLevel.MACHINE): 79 m.d.comb += self.result.eq(PMPLayout().const({"r": 1, "w": 1, "x": 1})) 80 81 entry_matches = [] 82 cfgs = [] 83 addr_vals = [] 84 85 for i in range(n): 86 cfg_val = data.View(PMPCfgLayout(), self.csr.pmpxcfg[i].value) 87 addr_val = self.csr.pmpaddrx[i].value 88 cfgs.append(cfg_val) 89 addr_vals.append(addr_val) 90 91 entry_match = Signal(name=f"match_{i}") 92 93 with m.Switch(cfg_val.A): 94 with m.Case(PMPAFlagEncoding.OFF): 95 m.d.comb += entry_match.eq(0) 96 with m.Case(PMPAFlagEncoding.TOR): 97 lower = addr_vals[i - 1][grain:] if i > 0 else 0 98 m.d.comb += entry_match.eq( 99 (self.paddr[2 + grain :] >= lower) & (self.paddr[2 + grain :] < addr_val[grain:]) 100 ) 101 with m.Case(PMPAFlagEncoding.NA4): 102 if grain == 0: 103 m.d.comb += entry_match.eq(self.paddr[2:] == addr_val) 104 else: 105 m.d.comb += entry_match.eq(0) 106 with m.Case(PMPAFlagEncoding.NAPOT): 107 # NAPOT region size is encoded by trailing ones in pmpaddr. 108 # XOR with (pmpaddr + 1) extracts those trailing ones as a mask. 109 # Bits below the mask define the region; bits above must match. 110 # With grain > 0, lower bits are forced to 1 so we skip them. 111 start_bit = max(0, grain - 1) 112 napot_mask = addr_val[start_bit:] ^ (addr_val[start_bit:] + 1) 113 m.d.comb += entry_match.eq( 114 (self.paddr[2 + start_bit :] & ~napot_mask) == (addr_val[start_bit:] & ~napot_mask) 115 ) 116 117 entry_matches.append(entry_match) 118 119 matches = Cat(entry_matches) 120 one_hot = matches & (~matches + 1) 121 122 r_bits = Cat(cfg.R for cfg in cfgs) 123 w_bits = Cat(cfg.W for cfg in cfgs) 124 x_bits = Cat(cfg.X for cfg in cfgs) 125 l_bits = Cat(cfg.L for cfg in cfgs) 126 127 selected_r = (one_hot & r_bits).any() 128 selected_w = (one_hot & w_bits).any() 129 selected_x = (one_hot & x_bits).any() 130 selected_l = (one_hot & l_bits).any() 131 132 with m.If(matches.any() & ((effective_priv_mode != PrivilegeLevel.MACHINE) | selected_l)): 133 m.d.comb += assign(self.result, {"r": selected_r, "w": selected_w, "x": selected_x}) 134 135 return m