translation.py
1 from amaranth import * 2 from amaranth.lib.data import StructLayout 3 from amaranth.lib.enum import Enum, unique, auto 4 5 from transactron import Method, TModule, def_method 6 from transactron.lib import Forwarder 7 from transactron.utils import DependencyContext 8 9 from coreblocks.arch.isa_consts import PrivilegeLevel, SatpMode, PAGE_SIZE, PAGE_SIZE_LOG 10 from coreblocks.interface.keys import CSRInstancesKey 11 from coreblocks.interface.layouts import AddressTranslationLayouts 12 from coreblocks.params import GenParams 13 14 __all__ = ["AddressTranslator", "AddressTranslatorMode"] 15 16 17 @unique 18 class AddressTranslatorMode(Enum): 19 INSTRUCTION = auto() 20 LSU = auto() 21 22 23 def level_count(mode: SatpMode) -> int: 24 match mode: 25 case SatpMode.BARE: 26 return 0 27 case SatpMode.SV32: 28 return 2 29 case SatpMode.SV39: 30 return 3 31 case SatpMode.SV48: 32 return 4 33 case SatpMode.SV57: 34 return 5 35 case _: 36 raise ValueError(f"Unsupported SATP mode: {mode}") 37 38 39 def page_table_entry_format(mode: SatpMode) -> StructLayout: 40 match mode: 41 case SatpMode.BARE: 42 return StructLayout({}) 43 case SatpMode.SV32: 44 return StructLayout( 45 { 46 "V": 1, 47 "R": 1, 48 "W": 1, 49 "X": 1, 50 "U": 1, 51 "G": 1, 52 "A": 1, 53 "D": 1, 54 "RSW": 2, 55 "ppn": 22, 56 } 57 ) 58 case SatpMode.SV39 | SatpMode.SV48 | SatpMode.SV57: 59 return StructLayout( 60 { 61 "V": 1, 62 "R": 1, 63 "W": 1, 64 "X": 1, 65 "U": 1, 66 "G": 1, 67 "A": 1, 68 "D": 1, 69 "RSW": 2, 70 "ppn": 44, 71 "reserved": 7, 72 "PBMT": 2, 73 "N": 1, 74 } 75 ) 76 case _: 77 raise ValueError(f"Unsupported SATP mode: {mode}") 78 79 80 def bits_per_level(mode: SatpMode) -> int: 81 """Number of virtual address bits translated at each page table level.""" 82 num_entries = PAGE_SIZE // page_table_entry_format(mode).as_shape().width 83 return num_entries.bit_length() - 1 84 85 86 class AddressTranslator(Elaboratable): 87 """Address translator from virtual to physical addresses.""" 88 89 def __init__( 90 self, 91 gen_params: GenParams, 92 *, 93 mode: AddressTranslatorMode, 94 ) -> None: 95 self.gen_params = gen_params 96 self.mode = mode 97 self.layouts = self.gen_params.get(AddressTranslationLayouts) 98 99 self.request = Method(i=self.layouts.request) 100 self.accept = Method(o=self.layouts.accept) 101 102 def elaborate(self, platform): 103 m = TModule() 104 m.submodules.resp_fwd = resp_fwd = Forwarder(self.layouts.accept) 105 csr = DependencyContext.get().get_dependency(CSRInstancesKey()) 106 107 @def_method(m, self.request) 108 def _(addr: Value): 109 access_fault = Signal() 110 page_fault = Signal() 111 effective_priv_mode = Signal(PrivilegeLevel) 112 113 priv_mode = csr.m_mode.priv_mode.read(m).data 114 115 match self.mode: 116 case AddressTranslatorMode.LSU: 117 mprv = csr.m_mode.mstatus_mprv.read(m).data 118 mpp = csr.m_mode.mstatus_mpp.read(m).data 119 120 m.d.av_comb += effective_priv_mode.eq(Mux(mprv, mpp, priv_mode)) 121 122 case AddressTranslatorMode.INSTRUCTION: 123 m.d.av_comb += effective_priv_mode.eq(priv_mode) 124 125 effective_satp_mode = Signal(SatpMode, init=SatpMode.BARE) 126 127 if self.gen_params.supervisor_mode: 128 with m.If(effective_priv_mode < PrivilegeLevel.MACHINE): 129 m.d.av_comb += effective_satp_mode.eq(csr.s_mode.satp_mode) 130 131 poffset = Signal(PAGE_SIZE_LOG) 132 ppn = Signal(self.gen_params.phys_addr_bits - PAGE_SIZE_LOG) 133 vpn = Signal(self.gen_params.isa.xlen - PAGE_SIZE_LOG) 134 135 m.d.av_comb += Cat(poffset, vpn).eq(addr) 136 137 max_ppn = 1 << (self.gen_params.phys_addr_bits - PAGE_SIZE_LOG) - 1 138 139 vpn_invalid = Signal() 140 with m.Switch(effective_satp_mode): 141 for mode in self.gen_params.vmem_params.supported_schemes - {SatpMode.BARE}: 142 vpn_len = bits_per_level(mode) * level_count(mode) 143 144 with m.Case(mode): 145 # virtual modes require sign-extended vaddr 146 m.d.av_comb += vpn_invalid.eq(vpn[vpn_len:].any() & ~vpn[vpn_len:].all()) 147 148 ppn_invalid = Signal() 149 with m.Switch(effective_satp_mode): 150 with m.Case(SatpMode.BARE): 151 m.d.av_comb += ppn.eq(vpn) 152 m.d.av_comb += ppn_invalid.eq(vpn > max_ppn) 153 154 # TODO: implement actual page table walking 155 156 with m.If(vpn_invalid): 157 m.d.av_comb += page_fault.eq(1) 158 159 with m.If(ppn_invalid): 160 m.d.av_comb += access_fault.eq(1) 161 162 resp_fwd.write( 163 m, 164 vaddr=addr, 165 paddr=Cat(poffset, ppn), 166 page_fault=page_fault, 167 access_fault=access_fault, 168 ) 169 170 @def_method(m, self.accept) 171 def _(): 172 return resp_fwd.read(m) 173 174 return m