/ test / regression / memory.py
memory.py
  1  from abc import ABC, abstractmethod
  2  from collections.abc import Callable
  3  from enum import Enum, IntFlag, auto
  4  from typing import Optional, TypeVar
  5  from dataclasses import dataclass, replace
  6  from elftools.elf.constants import P_FLAGS
  7  from elftools.elf.elffile import ELFFile, Segment
  8  from coreblocks.params.configurations import CoreConfiguration
  9  from transactron.utils import align_to_power_of_two, align_down_to_power_of_two
 10  
 11  all = [
 12      "ReplyStatus",
 13      "ReadRequest",
 14      "ReadReply",
 15      "WriteRequest",
 16      "WriteReply",
 17      "MemoryModel",
 18      "RAMSegment",
 19      "CoreMemoryModel",
 20  ]
 21  
 22  
 23  class ReplyStatus(Enum):
 24      OK = auto()
 25      ERROR = auto()
 26      RETRY = auto()
 27  
 28  
 29  class SegmentFlags(IntFlag):
 30      READ = auto()
 31      WRITE = auto()
 32      EXECUTABLE = auto()
 33  
 34  
 35  @dataclass
 36  class ReadRequest:
 37      addr: int
 38      byte_count: int
 39      byte_sel: int
 40      exec: bool
 41  
 42  
 43  @dataclass
 44  class ReadReply:
 45      data: int = 0
 46      status: ReplyStatus = ReplyStatus.OK
 47  
 48  
 49  @dataclass
 50  class WriteRequest:
 51      addr: int
 52      data: int
 53      byte_count: int
 54      byte_sel: int
 55  
 56  
 57  @dataclass
 58  class WriteReply:
 59      status: ReplyStatus = ReplyStatus.OK
 60  
 61  
 62  class MemorySegment(ABC):
 63      def __init__(self, address_range: range, flags: SegmentFlags):
 64          self.address_range = address_range
 65          self.flags = flags
 66  
 67      @abstractmethod
 68      def read(self, req: ReadRequest) -> ReadReply:
 69          raise NotImplementedError
 70  
 71      @abstractmethod
 72      def write(self, req: WriteRequest) -> WriteReply:
 73          raise NotImplementedError
 74  
 75  
 76  class RandomAccessMemory(MemorySegment):
 77      def __init__(self, address_range: range, flags: SegmentFlags, data: bytes):
 78          super().__init__(address_range, flags)
 79          self.data = bytearray(data)
 80  
 81          if len(self.data) != len(address_range):
 82              raise ValueError("Data length must be equal to the length of the address range")
 83  
 84      def read(self, req: ReadRequest) -> ReadReply:
 85          return ReadReply(data=int.from_bytes(self.data[req.addr : req.addr + req.byte_count], "little"))
 86  
 87      def write(self, req: WriteRequest) -> WriteReply:
 88          mask_bytes = [b"\x00", b"\xff"]
 89          mask = int.from_bytes(b"".join(mask_bytes[1 & (req.byte_sel >> i)] for i in range(4)), "little")
 90          old = int.from_bytes(self.data[req.addr : req.addr + req.byte_count], "little")
 91          self.data[req.addr : req.addr + req.byte_count] = (old & ~mask | req.data & mask).to_bytes(4, "little")
 92          return WriteReply()
 93  
 94  
 95  TReq = TypeVar("TReq", bound=ReadRequest | WriteRequest)
 96  TRep = TypeVar("TRep", bound=ReadReply | WriteReply)
 97  
 98  
 99  class CoreMemoryModel:
100      def __init__(self, segments: list[MemorySegment], fail_on_undefined_read=False, fail_on_undefined_write=True):
101          self.segments = segments
102          self.fail_on_undefined_read = fail_on_undefined_read  # Core may do undefined reads speculatively
103          self.fail_on_undefined_write = fail_on_undefined_write
104  
105      def _run_on_range(self, f: Callable[[MemorySegment, TReq], TRep], req: TReq) -> Optional[TRep]:
106          for seg in self.segments:
107              if req.addr in seg.address_range:
108                  return f(seg, req)
109  
110      def _do_read(self, seg: MemorySegment, req: ReadRequest) -> ReadReply:
111          if SegmentFlags.READ not in seg.flags:
112              raise RuntimeError("Tried to read from non-read memory: %x" % req.addr)
113          if req.exec and SegmentFlags.EXECUTABLE not in seg.flags:
114              raise RuntimeError("Memory is not executable: %x" % req.addr)
115  
116          return seg.read(replace(req, addr=req.addr - seg.address_range.start))
117  
118      def _do_write(self, seg: MemorySegment, req: WriteRequest) -> WriteReply:
119          if SegmentFlags.WRITE not in seg.flags:
120              raise RuntimeError("Tried to write to non-writable memory: %x" % req.addr)
121  
122          return seg.write(replace(req, addr=req.addr - seg.address_range.start))
123  
124      def read(self, req: ReadRequest) -> ReadReply:
125          rep = self._run_on_range(self._do_read, req)
126          if rep is not None:
127              return rep
128          if self.fail_on_undefined_read:
129              raise RuntimeError("Undefined read: %x" % req.addr)
130          else:
131              return ReadReply(status=ReplyStatus.ERROR)
132  
133      def write(self, req: WriteRequest) -> WriteReply:
134          rep = self._run_on_range(self._do_write, req)
135          if rep is not None:
136              return rep
137          if self.fail_on_undefined_write:
138              raise RuntimeError("Undefined write: %x <= %x" % (req.addr, req.data))
139          else:
140              return WriteReply(status=ReplyStatus.ERROR)
141  
142  
143  def load_segment(
144      segment: Segment, *, disable_write_protection: bool = False, force_executable: bool = False
145  ) -> RandomAccessMemory:
146      paddr = segment.header["p_paddr"]
147      memsz = segment.header["p_memsz"]
148      flags_raw = segment.header["p_flags"]
149  
150      seg_start = paddr
151      seg_end = paddr + memsz
152  
153      data = segment.data()
154  
155      # fill the rest of the segment with zeroes
156      data = data + b"\x00" * (seg_end - seg_start - len(data))
157  
158      flags = SegmentFlags(0)
159      if flags_raw & P_FLAGS.PF_R:
160          flags |= SegmentFlags.READ
161      if flags_raw & P_FLAGS.PF_W or disable_write_protection:
162          flags |= SegmentFlags.WRITE
163      if flags_raw & P_FLAGS.PF_X or force_executable:
164          flags |= SegmentFlags.EXECUTABLE
165  
166      config = CoreConfiguration()
167      if flags & SegmentFlags.EXECUTABLE:
168          # align instruction section to full icache lines
169          align_bits = config.icache_line_bytes_log
170          # workaround for fetching/stalling issue
171          extend_end = 2**config.icache_line_bytes_log
172      else:
173          align_bits = 0
174          extend_end = 0
175  
176      align_data_front = seg_start - align_down_to_power_of_two(seg_start, align_bits)
177      align_data_back = align_to_power_of_two(seg_end, align_bits) - seg_end + extend_end
178  
179      data = b"\x00" * align_data_front + data + b"\x00" * align_data_back
180  
181      seg_start = align_down_to_power_of_two(seg_start, align_bits)
182      seg_end = align_to_power_of_two(seg_end, align_bits) + extend_end
183  
184      return RandomAccessMemory(range(seg_start, seg_end), flags, data)
185  
186  
187  def load_segments_from_elf(
188      file_path: str, *, disable_write_protection: bool = False, force_executable: bool = False
189  ) -> list[RandomAccessMemory]:
190      segments: list[RandomAccessMemory] = []
191  
192      with open(file_path, "rb") as f:
193          elffile = ELFFile(f)
194          for segment in elffile.iter_segments():
195              if segment.header["p_type"] != "PT_LOAD" and segment.header["p_type"] != "PT_NULL":
196                  continue
197              segments.append(
198                  load_segment(
199                      segment, disable_write_protection=disable_write_protection, force_executable=force_executable
200                  )
201              )
202  
203      return segments