memory.py
1 from abc import ABC, abstractmethod 2 from collections.abc import Callable 3 from enum import Enum, IntFlag, auto 4 from typing import Optional, TypeVar 5 from dataclasses import dataclass, replace 6 from elftools.elf.constants import P_FLAGS 7 from elftools.elf.elffile import ELFFile, Segment 8 from coreblocks.params.configurations import CoreConfiguration 9 from transactron.utils import align_to_power_of_two, align_down_to_power_of_two 10 11 all = [ 12 "ReplyStatus", 13 "ReadRequest", 14 "ReadReply", 15 "WriteRequest", 16 "WriteReply", 17 "MemoryModel", 18 "RAMSegment", 19 "CoreMemoryModel", 20 ] 21 22 23 class ReplyStatus(Enum): 24 OK = auto() 25 ERROR = auto() 26 RETRY = auto() 27 28 29 class SegmentFlags(IntFlag): 30 READ = auto() 31 WRITE = auto() 32 EXECUTABLE = auto() 33 34 35 @dataclass 36 class ReadRequest: 37 addr: int 38 byte_count: int 39 byte_sel: int 40 exec: bool 41 42 43 @dataclass 44 class ReadReply: 45 data: int = 0 46 status: ReplyStatus = ReplyStatus.OK 47 48 49 @dataclass 50 class WriteRequest: 51 addr: int 52 data: int 53 byte_count: int 54 byte_sel: int 55 56 57 @dataclass 58 class WriteReply: 59 status: ReplyStatus = ReplyStatus.OK 60 61 62 class MemorySegment(ABC): 63 def __init__(self, address_range: range, flags: SegmentFlags): 64 self.address_range = address_range 65 self.flags = flags 66 67 @abstractmethod 68 def read(self, req: ReadRequest) -> ReadReply: 69 raise NotImplementedError 70 71 @abstractmethod 72 def write(self, req: WriteRequest) -> WriteReply: 73 raise NotImplementedError 74 75 76 class RandomAccessMemory(MemorySegment): 77 def __init__(self, address_range: range, flags: SegmentFlags, data: bytes): 78 super().__init__(address_range, flags) 79 self.data = bytearray(data) 80 81 if len(self.data) != len(address_range): 82 raise ValueError("Data length must be equal to the length of the address range") 83 84 def read(self, req: ReadRequest) -> ReadReply: 85 return ReadReply(data=int.from_bytes(self.data[req.addr : req.addr + req.byte_count], "little")) 86 87 def write(self, req: WriteRequest) -> WriteReply: 88 mask_bytes = [b"\x00", b"\xff"] 89 mask = int.from_bytes(b"".join(mask_bytes[1 & (req.byte_sel >> i)] for i in range(4)), "little") 90 old = int.from_bytes(self.data[req.addr : req.addr + req.byte_count], "little") 91 self.data[req.addr : req.addr + req.byte_count] = (old & ~mask | req.data & mask).to_bytes(4, "little") 92 return WriteReply() 93 94 95 TReq = TypeVar("TReq", bound=ReadRequest | WriteRequest) 96 TRep = TypeVar("TRep", bound=ReadReply | WriteReply) 97 98 99 class CoreMemoryModel: 100 def __init__(self, segments: list[MemorySegment], fail_on_undefined_read=False, fail_on_undefined_write=True): 101 self.segments = segments 102 self.fail_on_undefined_read = fail_on_undefined_read # Core may do undefined reads speculatively 103 self.fail_on_undefined_write = fail_on_undefined_write 104 105 def _run_on_range(self, f: Callable[[MemorySegment, TReq], TRep], req: TReq) -> Optional[TRep]: 106 for seg in self.segments: 107 if req.addr in seg.address_range: 108 return f(seg, req) 109 110 def _do_read(self, seg: MemorySegment, req: ReadRequest) -> ReadReply: 111 if SegmentFlags.READ not in seg.flags: 112 raise RuntimeError("Tried to read from non-read memory: %x" % req.addr) 113 if req.exec and SegmentFlags.EXECUTABLE not in seg.flags: 114 raise RuntimeError("Memory is not executable: %x" % req.addr) 115 116 return seg.read(replace(req, addr=req.addr - seg.address_range.start)) 117 118 def _do_write(self, seg: MemorySegment, req: WriteRequest) -> WriteReply: 119 if SegmentFlags.WRITE not in seg.flags: 120 raise RuntimeError("Tried to write to non-writable memory: %x" % req.addr) 121 122 return seg.write(replace(req, addr=req.addr - seg.address_range.start)) 123 124 def read(self, req: ReadRequest) -> ReadReply: 125 rep = self._run_on_range(self._do_read, req) 126 if rep is not None: 127 return rep 128 if self.fail_on_undefined_read: 129 raise RuntimeError("Undefined read: %x" % req.addr) 130 else: 131 return ReadReply(status=ReplyStatus.ERROR) 132 133 def write(self, req: WriteRequest) -> WriteReply: 134 rep = self._run_on_range(self._do_write, req) 135 if rep is not None: 136 return rep 137 if self.fail_on_undefined_write: 138 raise RuntimeError("Undefined write: %x <= %x" % (req.addr, req.data)) 139 else: 140 return WriteReply(status=ReplyStatus.ERROR) 141 142 143 def load_segment( 144 segment: Segment, *, disable_write_protection: bool = False, force_executable: bool = False 145 ) -> RandomAccessMemory: 146 paddr = segment.header["p_paddr"] 147 memsz = segment.header["p_memsz"] 148 flags_raw = segment.header["p_flags"] 149 150 seg_start = paddr 151 seg_end = paddr + memsz 152 153 data = segment.data() 154 155 # fill the rest of the segment with zeroes 156 data = data + b"\x00" * (seg_end - seg_start - len(data)) 157 158 flags = SegmentFlags(0) 159 if flags_raw & P_FLAGS.PF_R: 160 flags |= SegmentFlags.READ 161 if flags_raw & P_FLAGS.PF_W or disable_write_protection: 162 flags |= SegmentFlags.WRITE 163 if flags_raw & P_FLAGS.PF_X or force_executable: 164 flags |= SegmentFlags.EXECUTABLE 165 166 config = CoreConfiguration() 167 if flags & SegmentFlags.EXECUTABLE: 168 # align instruction section to full icache lines 169 align_bits = config.icache_line_bytes_log 170 # workaround for fetching/stalling issue 171 extend_end = 2**config.icache_line_bytes_log 172 else: 173 align_bits = 0 174 extend_end = 0 175 176 align_data_front = seg_start - align_down_to_power_of_two(seg_start, align_bits) 177 align_data_back = align_to_power_of_two(seg_end, align_bits) - seg_end + extend_end 178 179 data = b"\x00" * align_data_front + data + b"\x00" * align_data_back 180 181 seg_start = align_down_to_power_of_two(seg_start, align_bits) 182 seg_end = align_to_power_of_two(seg_end, align_bits) + extend_end 183 184 return RandomAccessMemory(range(seg_start, seg_end), flags, data) 185 186 187 def load_segments_from_elf( 188 file_path: str, *, disable_write_protection: bool = False, force_executable: bool = False 189 ) -> list[RandomAccessMemory]: 190 segments: list[RandomAccessMemory] = [] 191 192 with open(file_path, "rb") as f: 193 elffile = ELFFile(f) 194 for segment in elffile.iter_segments(): 195 if segment.header["p_type"] != "PT_LOAD" and segment.header["p_type"] != "PT_NULL": 196 continue 197 segments.append( 198 load_segment( 199 segment, disable_write_protection=disable_write_protection, force_executable=force_executable 200 ) 201 ) 202 203 return segments