/ coreblocks / cache / icache.py
icache.py
  1  from functools import reduce
  2  import operator
  3  
  4  from amaranth import *
  5  from amaranth.lib.data import View
  6  import amaranth.lib.memory as memory
  7  from amaranth.utils import exact_log2
  8  
  9  from transactron.core import def_method, Priority, TModule
 10  from transactron import Method, Transaction
 11  from coreblocks.params import ICacheParameters
 12  from coreblocks.interface.layouts import ICacheLayouts
 13  from transactron.utils import assign, OneHotSwitchDynamic
 14  from transactron.lib import *
 15  from transactron.lib import logging
 16  from coreblocks.peripherals.bus_adapter import BusMasterInterface
 17  
 18  from coreblocks.cache.iface import CacheInterface, CacheRefillerInterface
 19  from transactron.utils.transactron_helpers import make_layout
 20  
 21  __all__ = [
 22      "ICache",
 23      "ICacheBypass",
 24  ]
 25  
 26  log = logging.HardwareLogger("frontend.icache")
 27  
 28  
 29  class ICacheBypass(Elaboratable, CacheInterface):
 30      def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, bus_master: BusMasterInterface) -> None:
 31          self.params = params
 32          self.bus_master = bus_master
 33  
 34          self.issue_req = Method(i=layouts.issue_req)
 35          self.accept_res = Method(o=layouts.accept_res)
 36          self.flush = Method()
 37  
 38          if params.words_in_fetch_block != 1:
 39              raise ValueError("ICacheBypass only supports fetch block size equal to the word size.")
 40  
 41      def elaborate(self, platform):
 42          m = TModule()
 43  
 44          req_addr = Signal(self.params.addr_width)
 45  
 46          @def_method(m, self.issue_req)
 47          def _(paddr: Value) -> None:
 48              m.d.sync += req_addr.eq(paddr)
 49              self.bus_master.request_read(
 50                  m,
 51                  addr=paddr >> exact_log2(self.params.word_width_bytes),
 52                  sel=C(1).replicate(self.bus_master.params.data_width // self.bus_master.params.granularity),
 53              )
 54  
 55          @def_method(m, self.accept_res)
 56          def _():
 57              res = self.bus_master.get_read_response(m)
 58              return {
 59                  "fetch_block": res.data,
 60                  "error": res.err,
 61              }
 62  
 63          @def_method(m, self.flush)
 64          def _() -> None:
 65              pass
 66  
 67          return m
 68  
 69  
 70  class ICache(Elaboratable, CacheInterface):
 71      """A simple set-associative instruction cache.
 72  
 73      The replacement policy is a pseudo random scheme. Every time a line is trashed,
 74      we select the next way we write to (we keep one global counter for selecting the next way).
 75  
 76      Refilling a cache line is abstracted away from this module. ICache module needs two methods
 77      from the refiller `refiller_start`, which is called whenever we need to refill a cache line.
 78      `refiller_accept` should be ready to be called whenever the refiller has another fetch block
 79      ready to be written to cache. `refiller_accept` should set `last` bit when either an error
 80      occurs or the transfer is over. After issuing `last` bit, `refiller_accept` shouldn't be ready
 81      until the next transfer is started.
 82      """
 83  
 84      def __init__(self, layouts: ICacheLayouts, params: ICacheParameters, refiller: CacheRefillerInterface) -> None:
 85          """
 86          Parameters
 87          ----------
 88          layouts : ICacheLayouts
 89              Instance of ICacheLayouts used to create cache methods.
 90          params : ICacheParameters
 91              Instance of ICacheParameters with parameters which should be used to generate
 92              the cache.
 93          refiller_start : Method
 94              A method with input layout ICacheLayouts.start_refill
 95          refiller_accept : Method
 96              A method with output layout ICacheLayouts.accept_refill
 97          """
 98          self.layouts = layouts
 99          self.params = params
100  
101          self.refiller = refiller
102  
103          self.issue_req = Method(i=layouts.issue_req)
104          self.accept_res = Method(o=layouts.accept_res)
105          self.flush = Method()
106          self.flush.add_conflict(self.issue_req, Priority.LEFT)
107  
108          self.addr_layout = make_layout(
109              ("offset", self.params.offset_bits),
110              ("index", self.params.index_bits),
111              ("tag", self.params.tag_bits),
112          )
113  
114          self.perf_loads = HwCounter("frontend.icache.loads", "Number of requests to the L1 Instruction Cache")
115          self.perf_hits = HwCounter("frontend.icache.hits")
116          self.perf_misses = HwCounter("frontend.icache.misses")
117          self.perf_errors = HwCounter("frontend.icache.fetch_errors")
118          self.perf_flushes = HwCounter("frontend.icache.flushes")
119          self.req_latency = FIFOLatencyMeasurer(
120              "frontend.icache.req_latency", "Latencies of cache requests", slots_number=2, max_latency=500
121          )
122  
123      def deserialize_addr(self, raw_addr: Value) -> dict[str, Value]:
124          return {
125              "offset": raw_addr[: self.params.offset_bits],
126              "index": raw_addr[self.params.index_start_bit : self.params.index_end_bit + 1],
127              "tag": raw_addr[-self.params.tag_bits :],
128          }
129  
130      def serialize_addr(self, addr: View) -> Value:
131          return Cat(addr.offset, addr.index, addr.tag)
132  
133      def elaborate(self, platform):
134          m = TModule()
135  
136          m.submodules += [
137              self.perf_loads,
138              self.perf_hits,
139              self.perf_misses,
140              self.perf_errors,
141              self.perf_flushes,
142              self.req_latency,
143          ]
144  
145          m.submodules.mem = self.mem = ICacheMemory(self.params)
146          m.submodules.req_zipper = req_zipper = ArgumentsToResultsZipper(self.addr_layout, self.layouts.accept_res)
147  
148          # State machine logic
149          needs_refill = Signal()
150          refill_finish = Signal()
151          refill_error = Signal()
152          refill_error_saved = Signal()
153  
154          flush_start = Signal()
155          flush_finish = Signal()
156  
157          with Transaction().body(m):
158              self.perf_flushes.incr(m, enable_call=flush_finish)
159  
160          with m.FSM(init="FLUSH") as fsm:
161              with m.State("FLUSH"):
162                  with m.If(flush_finish):
163                      m.next = "LOOKUP"
164  
165              with m.State("LOOKUP"):
166                  with m.If(needs_refill):
167                      m.next = "REFILL"
168                  with m.Elif(flush_start):
169                      m.next = "FLUSH"
170  
171              with m.State("REFILL"):
172                  with m.If(refill_finish):
173                      m.next = "LOOKUP"
174  
175          # Replacement policy
176          way_selector = Signal(self.params.num_of_ways, init=1)
177          with m.If(refill_finish):
178              m.d.sync += way_selector.eq(way_selector.rotate_left(1))
179  
180          # Fast path - read requests
181          mem_read_addr = Signal(self.addr_layout)
182          prev_mem_read_addr = Signal(self.addr_layout)
183          m.d.comb += assign(mem_read_addr, prev_mem_read_addr)
184  
185          mem_read_output_valid = Signal()
186          forwarding_response_now = Signal()
187          accepting_requests = ~mem_read_output_valid | forwarding_response_now
188  
189          with Transaction(name="MemRead").body(
190              m, ready=fsm.ongoing("LOOKUP") & (mem_read_output_valid | refill_error_saved)
191          ):
192              req_addr = req_zipper.peek_arg(m)
193  
194              tag_hit = [tag_data.valid & (tag_data.tag == req_addr.tag) for tag_data in self.mem.tag_rd_data]
195              tag_hit_any = reduce(operator.or_, tag_hit)
196  
197              with m.If(tag_hit_any | refill_error_saved):
198                  m.d.comb += forwarding_response_now.eq(1)
199                  self.perf_hits.incr(m, enable_call=tag_hit_any)
200                  mem_out = Signal(self.params.fetch_block_bytes * 8)
201                  for i in OneHotSwitchDynamic(m, Cat(tag_hit)):
202                      m.d.av_comb += mem_out.eq(self.mem.data_rd_data[i])
203  
204                  req_zipper.write_results(m, fetch_block=mem_out, error=refill_error_saved)
205                  m.d.sync += refill_error_saved.eq(0)
206                  m.d.sync += mem_read_output_valid.eq(0)
207              with m.Else():
208                  self.perf_misses.incr(m)
209  
210                  m.d.comb += needs_refill.eq(1)
211  
212                  # Align to the beginning of the cache line
213                  aligned_addr = self.serialize_addr(req_addr) & ~((1 << self.params.offset_bits) - 1)
214                  log.debug(m, True, "Refilling line 0x{:x}", aligned_addr)
215                  self.refiller.start_refill(m, paddr=aligned_addr)
216  
217          @def_method(m, self.accept_res)
218          def _():
219              self.req_latency.stop(m)
220  
221              output = req_zipper.read(m)
222              return output.results
223  
224          @def_method(m, self.issue_req, ready=accepting_requests)
225          def _(paddr: Value) -> None:
226              self.perf_loads.incr(m)
227              self.req_latency.start(m)
228  
229              deserialized = self.deserialize_addr(paddr)
230              m.d.comb += assign(mem_read_addr, deserialized)
231              m.d.sync += assign(prev_mem_read_addr, deserialized)
232              req_zipper.write_args(m, deserialized)
233  
234              m.d.sync += mem_read_output_valid.eq(1)
235  
236          m.d.comb += [
237              self.mem.tag_rd_index.eq(mem_read_addr.index),
238              self.mem.data_rd_addr.index.eq(mem_read_addr.index),
239              self.mem.data_rd_addr.offset.eq(mem_read_addr.offset),
240          ]
241  
242          # Flush logic
243          flush_index = Signal(self.params.index_bits)
244          with m.If(fsm.ongoing("FLUSH")):
245              m.d.sync += flush_index.eq(flush_index + 1)
246  
247          @def_method(m, self.flush, ready=accepting_requests)
248          def _() -> None:
249              log.info(m, True, "Flushing the cache...")
250              m.d.sync += flush_index.eq(0)
251              m.d.comb += flush_start.eq(1)
252  
253          m.d.comb += flush_finish.eq(flush_index == self.params.num_of_sets - 1)
254  
255          # Slow path - data refilling
256          with Transaction().body(m):
257              ret = self.refiller.accept_refill(m)
258              deserialized = self.deserialize_addr(ret.paddr)
259  
260              self.perf_errors.incr(m, enable_call=ret.error)
261  
262              m.d.top_comb += [
263                  self.mem.data_wr_addr.index.eq(deserialized["index"]),
264                  self.mem.data_wr_addr.offset.eq(deserialized["offset"]),
265                  self.mem.data_wr_data.eq(ret.fetch_block),
266              ]
267  
268              m.d.comb += self.mem.data_wr_en.eq(1)
269              m.d.comb += refill_finish.eq(ret.last)
270              m.d.comb += refill_error.eq(ret.error)
271              with m.If(ret.error):
272                  m.d.sync += refill_error_saved.eq(1)
273  
274          with m.If(fsm.ongoing("FLUSH")):
275              m.d.comb += [
276                  self.mem.way_wr_en.eq(C(1).replicate(self.params.num_of_ways)),
277                  self.mem.tag_wr_index.eq(flush_index),
278                  self.mem.tag_wr_data.valid.eq(0),
279                  self.mem.tag_wr_data.tag.eq(0),
280                  self.mem.tag_wr_en.eq(1),
281              ]
282          with m.Else():
283              m.d.comb += [
284                  self.mem.way_wr_en.eq(way_selector),
285                  self.mem.tag_wr_index.eq(mem_read_addr.index),
286                  self.mem.tag_wr_data.valid.eq(~refill_error),
287                  self.mem.tag_wr_data.tag.eq(mem_read_addr.tag),
288                  self.mem.tag_wr_en.eq(refill_finish),
289              ]
290  
291          return m
292  
293  
294  class ICacheMemory(Elaboratable):
295      """A helper module for managing memories used in the instruction cache.
296  
297      In case of an associative cache, all address and write data lines are shared.
298      Writes are multiplexed using one-hot `way_wr_en` signal. Read data lines from all
299      ways are separately exposed (as an array).
300  
301      The data memory is addressed using fetch blocks.
302      """
303  
304      def __init__(self, params: ICacheParameters) -> None:
305          self.params = params
306  
307          self.tag_data_layout = make_layout(("valid", 1), ("tag", self.params.tag_bits))
308  
309          self.way_wr_en = Signal(self.params.num_of_ways)
310  
311          self.tag_rd_index = Signal(self.params.index_bits)
312          self.tag_rd_data = Array([Signal(self.tag_data_layout) for _ in range(self.params.num_of_ways)])
313          self.tag_wr_index = Signal(self.params.index_bits)
314          self.tag_wr_en = Signal()
315          self.tag_wr_data = Signal(self.tag_data_layout)
316  
317          self.data_addr_layout = make_layout(("index", self.params.index_bits), ("offset", self.params.offset_bits))
318  
319          self.fetch_block_bits = params.fetch_block_bytes * 8
320  
321          self.data_rd_addr = Signal(self.data_addr_layout)
322          self.data_rd_data = Array([Signal(self.fetch_block_bits) for _ in range(self.params.num_of_ways)])
323          self.data_wr_addr = Signal(self.data_addr_layout)
324          self.data_wr_en = Signal()
325          self.data_wr_data = Signal(self.fetch_block_bits)
326  
327      def elaborate(self, platform):
328          m = TModule()
329  
330          for i in range(self.params.num_of_ways):
331              way_wr = self.way_wr_en[i]
332  
333              tag_mem = memory.Memory(shape=self.tag_data_layout, depth=self.params.num_of_sets, init=[])
334              tag_mem_wp = tag_mem.write_port()
335              tag_mem_rp = tag_mem.read_port(transparent_for=[tag_mem_wp])
336              m.submodules[f"tag_mem_{i}"] = tag_mem
337  
338              m.d.comb += [
339                  assign(self.tag_rd_data[i], tag_mem_rp.data),
340                  tag_mem_rp.addr.eq(self.tag_rd_index),
341                  tag_mem_wp.addr.eq(self.tag_wr_index),
342                  assign(tag_mem_wp.data, self.tag_wr_data),
343                  tag_mem_wp.en.eq(self.tag_wr_en & way_wr),
344              ]
345  
346              data_mem = memory.Memory(
347                  shape=self.fetch_block_bits, depth=self.params.num_of_sets * self.params.fetch_blocks_in_line, init=[]
348              )
349              data_mem_wp = data_mem.write_port()
350              data_mem_rp = data_mem.read_port(transparent_for=[data_mem_wp])
351              m.submodules[f"data_mem_{i}"] = data_mem
352  
353              # We address the data RAM using fetch blocks, so we have to
354              # discard a few least significant bits from the address.
355              rd_addr = Cat(self.data_rd_addr.offset, self.data_rd_addr.index)[self.params.fetch_block_bytes_log :]
356              wr_addr = Cat(self.data_wr_addr.offset, self.data_wr_addr.index)[self.params.fetch_block_bytes_log :]
357  
358              m.d.comb += [
359                  self.data_rd_data[i].eq(data_mem_rp.data),
360                  data_mem_rp.addr.eq(rd_addr),
361                  data_mem_wp.addr.eq(wr_addr),
362                  data_mem_wp.data.eq(self.data_wr_data),
363                  data_mem_wp.en.eq(self.data_wr_en & way_wr),
364              ]
365  
366          return m