rf.py
  1  from amaranth import *
  2  from amaranth.lib.data import ArrayLayout
  3  from amaranth.lib import memory
  4  from transactron import Methods, Transaction, def_methods, TModule
  5  from transactron.utils.amaranth_ext.elaboratables import OneHotMux
  6  from coreblocks.interface.layouts import RFLayouts
  7  from coreblocks.params import GenParams
  8  from transactron.lib import logging
  9  from transactron.lib.metrics import HwExpHistogram, TaggedLatencyMeasurer
 10  from transactron.lib.storage import MemoryBank
 11  from transactron.utils.amaranth_ext.functions import popcount
 12  
 13  __all__ = ["RegisterFile"]
 14  
 15  log = logging.HardwareLogger("core_structs.rf")
 16  
 17  
 18  class RegisterFile(Elaboratable):
 19      def __init__(self, *, gen_params: GenParams, read_ports: int, write_ports: int, free_ports: int):
 20          self.gen_params = gen_params
 21  
 22          layouts = gen_params.get(RFLayouts)
 23          self.read_layout = layouts.rf_read_out
 24          self.entries = MemoryBank(
 25              memory_type=gen_params.multiport_memory_type if write_ports > 1 else memory.Memory,
 26              shape=gen_params.isa.xlen,
 27              depth=2**gen_params.phys_regs_bits,
 28              read_ports=read_ports,
 29              write_ports=write_ports,
 30              read_on_resp=True,
 31          )
 32          self.valids = Array(Signal(init=k == 0) for k in range(2**gen_params.phys_regs_bits))
 33  
 34          self.read_req = Methods(read_ports, i=layouts.rf_read_in)
 35          self.read_resp = Methods(read_ports, i=layouts.rf_read_in, o=layouts.rf_read_out)
 36          self.write = Methods(write_ports, i=layouts.rf_write)
 37          self.free = Methods(free_ports, i=layouts.rf_free)
 38  
 39          self.perf_rf_valid_time = TaggedLatencyMeasurer(
 40              "struct.rf.valid_time",
 41              description="Distribution of time registers are valid in RF",
 42              slots_number=2**gen_params.phys_regs_bits,
 43              max_latency=1000,
 44              ways=max(write_ports, free_ports),
 45          )
 46          self.perf_num_valid = HwExpHistogram(
 47              "struct.rf.num_valid",
 48              description="Number of valid registers in RF",
 49              bucket_count=gen_params.phys_regs_bits + 1,
 50              sample_width=gen_params.phys_regs_bits + 1,
 51          )
 52  
 53      def elaborate(self, platform):
 54          m = TModule()
 55  
 56          m.submodules += [self.entries, self.perf_rf_valid_time, self.perf_num_valid]
 57  
 58          being_written = Signal(ArrayLayout(self.gen_params.phys_regs_bits, len(self.write)))
 59          written_value = Signal(ArrayLayout(self.gen_params.isa.xlen, len(self.write)))
 60  
 61          @def_methods(m, self.read_req)
 62          def _(k: int, reg_id: Value):
 63              self.entries.read_req[k](m, addr=reg_id)
 64  
 65          @def_methods(m, self.read_resp)
 66          def _(k: int, reg_id: Value):
 67              forward = Signal()
 68              reg_written = Signal(len(self.write))
 69              m.d.av_comb += reg_written.eq(
 70                  Cat((being_written[i] == reg_id) & (reg_id != 0) for i in range(len(self.write)))
 71              )
 72              m.d.av_comb += forward.eq(reg_written.any())
 73              reg_val = OneHotMux.create(
 74                  m,
 75                  [(reg_written[i], written_value[i]) for i in range(len(self.write))],
 76                  self.entries.read_resp[k](m).data,
 77              )
 78              return {
 79                  "reg_val": reg_val,
 80                  "valid": forward | self.valids[reg_id],
 81              }
 82  
 83          @def_methods(m, self.write)
 84          def _(k: int, reg_id: Value, reg_val: Value):
 85              m.d.comb += being_written[k].eq(reg_id)
 86              m.d.av_comb += written_value[k].eq(reg_val)
 87              with m.If(reg_id != 0):
 88                  log.assertion(m, ~self.valids[reg_id], "Valid register {} written", reg_id)
 89                  self.entries.write[k](m, addr=reg_id, data=reg_val)
 90                  m.d.sync += self.valids[reg_id].eq(1)
 91                  self.perf_rf_valid_time.start[k](m, slot=reg_id)
 92  
 93          @def_methods(m, self.free)
 94          def _(k: int, reg_id: Value):
 95              with m.If(reg_id != 0):
 96                  log.assertion(m, self.valids[reg_id], "Invalid register {} freed", reg_id)
 97                  m.d.sync += self.valids[reg_id].eq(0)
 98                  self.perf_rf_valid_time.stop[k](m, slot=reg_id)
 99  
100          # It is assumed that two simultaneous write calls never write the same physical register.
101          for k1, m1 in enumerate(self.write):
102              for k2, m2 in enumerate(self.write[k1 + 1 :]):
103                  log.error(
104                      m,
105                      m1.run & m2.run & (m1.data_in.reg_id == m2.data_in.reg_id) & (m1.data_in.reg_id != 0),
106                      "Write methods {} and {} both called with reg_id {}",
107                      k1,
108                      k2,
109                      m1.data_in.reg_id,
110                  )
111  
112          # It is assumed that two simultaneous free calls never free the same physical register.
113          for k1, m1 in enumerate(self.free):
114              for k2, m2 in enumerate(self.free[k1 + 1 :]):
115                  log.error(
116                      m,
117                      m1.run & m2.run & (m1.data_in.reg_id == m2.data_in.reg_id) & (m1.data_in.reg_id != 0),
118                      "Free methods {} and {} both called with reg_id {}",
119                      k1,
120                      k2,
121                      m1.data_in.reg_id,
122                  )
123  
124          if self.perf_num_valid.metrics_enabled():
125              num_valid = Signal(self.gen_params.phys_regs_bits + 1)
126              m.d.comb += num_valid.eq(
127                  popcount(Cat(self.valids[reg_id] for reg_id in range(2**self.gen_params.phys_regs_bits)))
128              )
129              with Transaction(name="perf").body(m):
130                  self.perf_num_valid.add(m, num_valid)
131  
132          return m