/ coreblocks / priv / vmem / translation.py
translation.py
  1  from amaranth import *
  2  from amaranth.lib.data import StructLayout
  3  from amaranth.lib.enum import Enum, unique, auto
  4  
  5  from transactron import Method, TModule, def_method
  6  from transactron.lib import Forwarder
  7  from transactron.utils import DependencyContext
  8  
  9  from coreblocks.arch.isa_consts import PrivilegeLevel, SatpMode, PAGE_SIZE, PAGE_SIZE_LOG
 10  from coreblocks.interface.keys import CSRInstancesKey
 11  from coreblocks.interface.layouts import AddressTranslationLayouts
 12  from coreblocks.params import GenParams
 13  
 14  __all__ = ["AddressTranslator", "AddressTranslatorMode"]
 15  
 16  
 17  @unique
 18  class AddressTranslatorMode(Enum):
 19      INSTRUCTION = auto()
 20      LSU = auto()
 21  
 22  
 23  def level_count(mode: SatpMode) -> int:
 24      match mode:
 25          case SatpMode.BARE:
 26              return 0
 27          case SatpMode.SV32:
 28              return 2
 29          case SatpMode.SV39:
 30              return 3
 31          case SatpMode.SV48:
 32              return 4
 33          case SatpMode.SV57:
 34              return 5
 35          case _:
 36              raise ValueError(f"Unsupported SATP mode: {mode}")
 37  
 38  
 39  def page_table_entry_format(mode: SatpMode) -> StructLayout:
 40      match mode:
 41          case SatpMode.BARE:
 42              return StructLayout({})
 43          case SatpMode.SV32:
 44              return StructLayout(
 45                  {
 46                      "V": 1,
 47                      "R": 1,
 48                      "W": 1,
 49                      "X": 1,
 50                      "U": 1,
 51                      "G": 1,
 52                      "A": 1,
 53                      "D": 1,
 54                      "RSW": 2,
 55                      "ppn": 22,
 56                  }
 57              )
 58          case SatpMode.SV39 | SatpMode.SV48 | SatpMode.SV57:
 59              return StructLayout(
 60                  {
 61                      "V": 1,
 62                      "R": 1,
 63                      "W": 1,
 64                      "X": 1,
 65                      "U": 1,
 66                      "G": 1,
 67                      "A": 1,
 68                      "D": 1,
 69                      "RSW": 2,
 70                      "ppn": 44,
 71                      "reserved": 7,
 72                      "PBMT": 2,
 73                      "N": 1,
 74                  }
 75              )
 76          case _:
 77              raise ValueError(f"Unsupported SATP mode: {mode}")
 78  
 79  
 80  def bits_per_level(mode: SatpMode) -> int:
 81      """Number of virtual address bits translated at each page table level."""
 82      num_entries = PAGE_SIZE // page_table_entry_format(mode).as_shape().width
 83      return num_entries.bit_length() - 1
 84  
 85  
 86  class AddressTranslator(Elaboratable):
 87      """Address translator from virtual to physical addresses."""
 88  
 89      def __init__(
 90          self,
 91          gen_params: GenParams,
 92          *,
 93          mode: AddressTranslatorMode,
 94      ) -> None:
 95          self.gen_params = gen_params
 96          self.mode = mode
 97          self.layouts = self.gen_params.get(AddressTranslationLayouts)
 98  
 99          self.request = Method(i=self.layouts.request)
100          self.accept = Method(o=self.layouts.accept)
101  
102      def elaborate(self, platform):
103          m = TModule()
104          m.submodules.resp_fwd = resp_fwd = Forwarder(self.layouts.accept)
105          csr = DependencyContext.get().get_dependency(CSRInstancesKey())
106  
107          @def_method(m, self.request)
108          def _(addr: Value):
109              access_fault = Signal()
110              page_fault = Signal()
111              effective_priv_mode = Signal(PrivilegeLevel)
112  
113              priv_mode = csr.m_mode.priv_mode.read(m).data
114  
115              match self.mode:
116                  case AddressTranslatorMode.LSU:
117                      mprv = csr.m_mode.mstatus_mprv.read(m).data
118                      mpp = csr.m_mode.mstatus_mpp.read(m).data
119  
120                      m.d.av_comb += effective_priv_mode.eq(Mux(mprv, mpp, priv_mode))
121  
122                  case AddressTranslatorMode.INSTRUCTION:
123                      m.d.av_comb += effective_priv_mode.eq(priv_mode)
124  
125              effective_satp_mode = Signal(SatpMode, init=SatpMode.BARE)
126  
127              if self.gen_params.supervisor_mode:
128                  with m.If(effective_priv_mode < PrivilegeLevel.MACHINE):
129                      m.d.av_comb += effective_satp_mode.eq(csr.s_mode.satp_mode)
130  
131              poffset = Signal(PAGE_SIZE_LOG)
132              ppn = Signal(self.gen_params.phys_addr_bits - PAGE_SIZE_LOG)
133              vpn = Signal(self.gen_params.isa.xlen - PAGE_SIZE_LOG)
134  
135              m.d.av_comb += Cat(poffset, vpn).eq(addr)
136  
137              max_ppn = 1 << (self.gen_params.phys_addr_bits - PAGE_SIZE_LOG) - 1
138  
139              vpn_invalid = Signal()
140              with m.Switch(effective_satp_mode):
141                  for mode in self.gen_params.vmem_params.supported_schemes - {SatpMode.BARE}:
142                      vpn_len = bits_per_level(mode) * level_count(mode)
143  
144                      with m.Case(mode):
145                          # virtual modes require sign-extended vaddr
146                          m.d.av_comb += vpn_invalid.eq(vpn[vpn_len:].any() & ~vpn[vpn_len:].all())
147  
148              ppn_invalid = Signal()
149              with m.Switch(effective_satp_mode):
150                  with m.Case(SatpMode.BARE):
151                      m.d.av_comb += ppn.eq(vpn)
152                      m.d.av_comb += ppn_invalid.eq(vpn > max_ppn)
153  
154                  # TODO: implement actual page table walking
155  
156              with m.If(vpn_invalid):
157                  m.d.av_comb += page_fault.eq(1)
158  
159              with m.If(ppn_invalid):
160                  m.d.av_comb += access_fault.eq(1)
161  
162              resp_fwd.write(
163                  m,
164                  vaddr=addr,
165                  paddr=Cat(poffset, ppn),
166                  page_fault=page_fault,
167                  access_fault=access_fault,
168              )
169  
170          @def_method(m, self.accept)
171          def _():
172              return resp_fwd.read(m)
173  
174          return m