icache.py
1 from functools import reduce 2 import operator 3 4 from amaranth import * 5 from amaranth.lib.data import View 6 import amaranth.lib.memory as memory 7 from amaranth.utils import exact_log2 8 9 from transactron.core import def_method, Priority, TModule 10 from transactron import Method, Transaction 11 from coreblocks.params import ICacheParameters 12 from coreblocks.interface.layouts import ICacheLayouts 13 from transactron.utils import assign, OneHotSwitchDynamic 14 from transactron.lib import * 15 from transactron.lib import logging 16 from coreblocks.peripherals.bus_adapter import BusMasterInterface 17 18 from coreblocks.cache.iface import CacheInterface, CacheRefillerInterface 19 from transactron.utils.transactron_helpers import make_layout 20 21 __all__ = [ 22 "ICache", 23 "ICacheBypass", 24 ] 25 26 log = logging.HardwareLogger("frontend.icache") 27 28 29 class ICacheBypass(Elaboratable, CacheInterface): 30 def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, bus_master: BusMasterInterface) -> None: 31 self.params = params 32 self.bus_master = bus_master 33 34 self.issue_req = Method(i=layouts.issue_req) 35 self.accept_res = Method(o=layouts.accept_res) 36 self.flush = Method() 37 38 if params.words_in_fetch_block != 1: 39 raise ValueError("ICacheBypass only supports fetch block size equal to the word size.") 40 41 def elaborate(self, platform): 42 m = TModule() 43 44 req_addr = Signal(self.params.addr_width) 45 46 @def_method(m, self.issue_req) 47 def _(paddr: Value) -> None: 48 m.d.sync += req_addr.eq(paddr) 49 self.bus_master.request_read( 50 m, 51 addr=paddr >> exact_log2(self.params.word_width_bytes), 52 sel=C(1).replicate(self.bus_master.params.data_width // self.bus_master.params.granularity), 53 ) 54 55 @def_method(m, self.accept_res) 56 def _(): 57 res = self.bus_master.get_read_response(m) 58 return { 59 "fetch_block": res.data, 60 "error": res.err, 61 } 62 63 @def_method(m, self.flush) 64 def _() -> None: 65 pass 66 67 return m 68 69 70 class ICache(Elaboratable, CacheInterface): 71 """A simple set-associative instruction cache. 72 73 The replacement policy is a pseudo random scheme. Every time a line is trashed, 74 we select the next way we write to (we keep one global counter for selecting the next way). 75 76 Refilling a cache line is abstracted away from this module. ICache module needs two methods 77 from the refiller `refiller_start`, which is called whenever we need to refill a cache line. 78 `refiller_accept` should be ready to be called whenever the refiller has another fetch block 79 ready to be written to cache. `refiller_accept` should set `last` bit when either an error 80 occurs or the transfer is over. After issuing `last` bit, `refiller_accept` shouldn't be ready 81 until the next transfer is started. 82 """ 83 84 def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, refiller: CacheRefillerInterface) -> None: 85 """ 86 Parameters 87 ---------- 88 layouts : ICacheLayouts 89 Instance of ICacheLayouts used to create cache methods. 90 params : ICacheParameters 91 Instance of ICacheParameters with parameters which should be used to generate 92 the cache. 93 refiller_start : Method 94 A method with input layout ICacheLayouts.start_refill 95 refiller_accept : Method 96 A method with output layout ICacheLayouts.accept_refill 97 """ 98 self.layouts = layouts 99 self.params = params 100 101 self.refiller = refiller 102 103 self.issue_req = Method(i=layouts.issue_req) 104 self.accept_res = Method(o=layouts.accept_res) 105 self.flush = Method() 106 self.flush.add_conflict(self.issue_req, Priority.LEFT) 107 108 self.addr_layout = make_layout( 109 ("offset", self.params.offset_bits), 110 ("index", self.params.index_bits), 111 ("tag", self.params.tag_bits), 112 ) 113 114 self.perf_loads = HwCounter("frontend.icache.loads", "Number of requests to the L1 Instruction Cache") 115 self.perf_hits = HwCounter("frontend.icache.hits") 116 self.perf_misses = HwCounter("frontend.icache.misses") 117 self.perf_errors = HwCounter("frontend.icache.fetch_errors") 118 self.perf_flushes = HwCounter("frontend.icache.flushes") 119 self.req_latency = FIFOLatencyMeasurer( 120 "frontend.icache.req_latency", "Latencies of cache requests", slots_number=2, max_latency=500 121 ) 122 123 def deserialize_addr(self, raw_addr: Value) -> dict[str, Value]: 124 return { 125 "offset": raw_addr[: self.params.offset_bits], 126 "index": raw_addr[self.params.index_start_bit : self.params.index_end_bit + 1], 127 "tag": raw_addr[-self.params.tag_bits :], 128 } 129 130 def serialize_addr(self, addr: View) -> Value: 131 return Cat(addr.offset, addr.index, addr.tag) 132 133 def elaborate(self, platform): 134 m = TModule() 135 136 m.submodules += [ 137 self.perf_loads, 138 self.perf_hits, 139 self.perf_misses, 140 self.perf_errors, 141 self.perf_flushes, 142 self.req_latency, 143 ] 144 145 m.submodules.mem = self.mem = ICacheMemory(self.params) 146 m.submodules.req_zipper = req_zipper = ArgumentsToResultsZipper(self.addr_layout, self.layouts.accept_res) 147 148 # State machine logic 149 needs_refill = Signal() 150 refill_finish = Signal() 151 refill_error = Signal() 152 refill_error_saved = Signal() 153 154 flush_start = Signal() 155 flush_finish = Signal() 156 157 with Transaction().body(m): 158 self.perf_flushes.incr(m, enable_call=flush_finish) 159 160 with m.FSM(init="FLUSH") as fsm: 161 with m.State("FLUSH"): 162 with m.If(flush_finish): 163 m.next = "LOOKUP" 164 165 with m.State("LOOKUP"): 166 with m.If(needs_refill): 167 m.next = "REFILL" 168 with m.Elif(flush_start): 169 m.next = "FLUSH" 170 171 with m.State("REFILL"): 172 with m.If(refill_finish): 173 m.next = "LOOKUP" 174 175 # Replacement policy 176 way_selector = Signal(self.params.num_of_ways, init=1) 177 with m.If(refill_finish): 178 m.d.sync += way_selector.eq(way_selector.rotate_left(1)) 179 180 # Fast path - read requests 181 mem_read_addr = Signal(self.addr_layout) 182 prev_mem_read_addr = Signal(self.addr_layout) 183 m.d.comb += assign(mem_read_addr, prev_mem_read_addr) 184 185 mem_read_output_valid = Signal() 186 forwarding_response_now = Signal() 187 accepting_requests = ~mem_read_output_valid | forwarding_response_now 188 189 with Transaction(name="MemRead").body( 190 m, ready=fsm.ongoing("LOOKUP") & (mem_read_output_valid | refill_error_saved) 191 ): 192 req_addr = req_zipper.peek_arg(m) 193 194 tag_hit = [tag_data.valid & (tag_data.tag == req_addr.tag) for tag_data in self.mem.tag_rd_data] 195 tag_hit_any = reduce(operator.or_, tag_hit) 196 197 with m.If(tag_hit_any | refill_error_saved): 198 m.d.comb += forwarding_response_now.eq(1) 199 self.perf_hits.incr(m, enable_call=tag_hit_any) 200 mem_out = Signal(self.params.fetch_block_bytes * 8) 201 for i in OneHotSwitchDynamic(m, Cat(tag_hit)): 202 m.d.av_comb += mem_out.eq(self.mem.data_rd_data[i]) 203 204 req_zipper.write_results(m, fetch_block=mem_out, error=refill_error_saved) 205 m.d.sync += refill_error_saved.eq(0) 206 m.d.sync += mem_read_output_valid.eq(0) 207 with m.Else(): 208 self.perf_misses.incr(m) 209 210 m.d.comb += needs_refill.eq(1) 211 212 # Align to the beginning of the cache line 213 aligned_addr = self.serialize_addr(req_addr) & ~((1 << self.params.offset_bits) - 1) 214 log.debug(m, True, "Refilling line 0x{:x}", aligned_addr) 215 self.refiller.start_refill(m, paddr=aligned_addr) 216 217 @def_method(m, self.accept_res) 218 def _(): 219 self.req_latency.stop(m) 220 221 output = req_zipper.read(m) 222 return output.results 223 224 @def_method(m, self.issue_req, ready=accepting_requests) 225 def _(paddr: Value) -> None: 226 self.perf_loads.incr(m) 227 self.req_latency.start(m) 228 229 deserialized = self.deserialize_addr(paddr) 230 m.d.comb += assign(mem_read_addr, deserialized) 231 m.d.sync += assign(prev_mem_read_addr, deserialized) 232 req_zipper.write_args(m, deserialized) 233 234 m.d.sync += mem_read_output_valid.eq(1) 235 236 m.d.comb += [ 237 self.mem.tag_rd_index.eq(mem_read_addr.index), 238 self.mem.data_rd_addr.index.eq(mem_read_addr.index), 239 self.mem.data_rd_addr.offset.eq(mem_read_addr.offset), 240 ] 241 242 # Flush logic 243 flush_index = Signal(self.params.index_bits) 244 with m.If(fsm.ongoing("FLUSH")): 245 m.d.sync += flush_index.eq(flush_index + 1) 246 247 @def_method(m, self.flush, ready=accepting_requests) 248 def _() -> None: 249 log.info(m, True, "Flushing the cache...") 250 m.d.sync += flush_index.eq(0) 251 m.d.comb += flush_start.eq(1) 252 253 m.d.comb += flush_finish.eq(flush_index == self.params.num_of_sets - 1) 254 255 # Slow path - data refilling 256 with Transaction().body(m): 257 ret = self.refiller.accept_refill(m) 258 deserialized = self.deserialize_addr(ret.paddr) 259 260 self.perf_errors.incr(m, enable_call=ret.error) 261 262 m.d.top_comb += [ 263 self.mem.data_wr_addr.index.eq(deserialized["index"]), 264 self.mem.data_wr_addr.offset.eq(deserialized["offset"]), 265 self.mem.data_wr_data.eq(ret.fetch_block), 266 ] 267 268 m.d.comb += self.mem.data_wr_en.eq(1) 269 m.d.comb += refill_finish.eq(ret.last) 270 m.d.comb += refill_error.eq(ret.error) 271 with m.If(ret.error): 272 m.d.sync += refill_error_saved.eq(1) 273 274 with m.If(fsm.ongoing("FLUSH")): 275 m.d.comb += [ 276 self.mem.way_wr_en.eq(C(1).replicate(self.params.num_of_ways)), 277 self.mem.tag_wr_index.eq(flush_index), 278 self.mem.tag_wr_data.valid.eq(0), 279 self.mem.tag_wr_data.tag.eq(0), 280 self.mem.tag_wr_en.eq(1), 281 ] 282 with m.Else(): 283 m.d.comb += [ 284 self.mem.way_wr_en.eq(way_selector), 285 self.mem.tag_wr_index.eq(mem_read_addr.index), 286 self.mem.tag_wr_data.valid.eq(~refill_error), 287 self.mem.tag_wr_data.tag.eq(mem_read_addr.tag), 288 self.mem.tag_wr_en.eq(refill_finish), 289 ] 290 291 return m 292 293 294 class ICacheMemory(Elaboratable): 295 """A helper module for managing memories used in the instruction cache. 296 297 In case of an associative cache, all address and write data lines are shared. 298 Writes are multiplexed using one-hot `way_wr_en` signal. Read data lines from all 299 ways are separately exposed (as an array). 300 301 The data memory is addressed using fetch blocks. 302 """ 303 304 def __init__(self, params: ICacheParameters) -> None: 305 self.params = params 306 307 self.tag_data_layout = make_layout(("valid", 1), ("tag", self.params.tag_bits)) 308 309 self.way_wr_en = Signal(self.params.num_of_ways) 310 311 self.tag_rd_index = Signal(self.params.index_bits) 312 self.tag_rd_data = Array([Signal(self.tag_data_layout) for _ in range(self.params.num_of_ways)]) 313 self.tag_wr_index = Signal(self.params.index_bits) 314 self.tag_wr_en = Signal() 315 self.tag_wr_data = Signal(self.tag_data_layout) 316 317 self.data_addr_layout = make_layout(("index", self.params.index_bits), ("offset", self.params.offset_bits)) 318 319 self.fetch_block_bits = params.fetch_block_bytes * 8 320 321 self.data_rd_addr = Signal(self.data_addr_layout) 322 self.data_rd_data = Array([Signal(self.fetch_block_bits) for _ in range(self.params.num_of_ways)]) 323 self.data_wr_addr = Signal(self.data_addr_layout) 324 self.data_wr_en = Signal() 325 self.data_wr_data = Signal(self.fetch_block_bits) 326 327 def elaborate(self, platform): 328 m = TModule() 329 330 for i in range(self.params.num_of_ways): 331 way_wr = self.way_wr_en[i] 332 333 tag_mem = memory.Memory(shape=self.tag_data_layout, depth=self.params.num_of_sets, init=[]) 334 tag_mem_wp = tag_mem.write_port() 335 tag_mem_rp = tag_mem.read_port(transparent_for=[tag_mem_wp]) 336 m.submodules[f"tag_mem_{i}"] = tag_mem 337 338 m.d.comb += [ 339 assign(self.tag_rd_data[i], tag_mem_rp.data), 340 tag_mem_rp.addr.eq(self.tag_rd_index), 341 tag_mem_wp.addr.eq(self.tag_wr_index), 342 assign(tag_mem_wp.data, self.tag_wr_data), 343 tag_mem_wp.en.eq(self.tag_wr_en & way_wr), 344 ] 345 346 data_mem = memory.Memory( 347 shape=self.fetch_block_bits, depth=self.params.num_of_sets * self.params.fetch_blocks_in_line, init=[] 348 ) 349 data_mem_wp = data_mem.write_port() 350 data_mem_rp = data_mem.read_port(transparent_for=[data_mem_wp]) 351 m.submodules[f"data_mem_{i}"] = data_mem 352 353 # We address the data RAM using fetch blocks, so we have to 354 # discard a few least significant bits from the address. 355 rd_addr = Cat(self.data_rd_addr.offset, self.data_rd_addr.index)[self.params.fetch_block_bytes_log :] 356 wr_addr = Cat(self.data_wr_addr.offset, self.data_wr_addr.index)[self.params.fetch_block_bytes_log :] 357 358 m.d.comb += [ 359 self.data_rd_data[i].eq(data_mem_rp.data), 360 data_mem_rp.addr.eq(rd_addr), 361 data_mem_wp.addr.eq(wr_addr), 362 data_mem_wp.data.eq(self.data_wr_data), 363 data_mem_wp.en.eq(self.data_wr_en & way_wr), 364 ] 365 366 return m