rs.py
  1  from abc import abstractmethod
  2  from collections.abc import Iterable
  3  from typing import Optional
  4  from amaranth import *
  5  from amaranth.lib.data import ArrayLayout
  6  from amaranth.utils import ceil_log2
  7  from amaranth_types import ValueLike
  8  from transactron import Method, Methods, Transaction, def_method, TModule, def_methods
  9  from transactron.lib import logging
 10  from transactron.lib.allocators import PreservedOrderAllocator
 11  from transactron.utils.amaranth_ext.elaboratables import OneHotMux
 12  from coreblocks.params import GenParams
 13  from coreblocks.arch import OpType
 14  from coreblocks.interface.layouts import RSLayouts
 15  from transactron.lib.metrics import HwExpHistogram, TaggedLatencyMeasurer
 16  from transactron.utils import RecordDict
 17  from transactron.utils.assign import assign, AssignType
 18  from transactron.utils.amaranth_ext.functions import popcount
 19  from transactron.utils.transactron_helpers import make_layout
 20  
 21  __all__ = ["RSBase", "RS"]
 22  
 23  
 24  class RSBase(Elaboratable):
 25      def __init__(
 26          self,
 27          gen_params: GenParams,
 28          rs_entries: int,
 29          rs_number: int,
 30          rs_ways: int = 1,
 31          ready_for: Optional[Iterable[Iterable[OpType]]] = None,
 32      ) -> None:
 33          ready_for = ready_for or ((op for op in OpType),)
 34          self.gen_params = gen_params
 35          self.rs_entries = rs_entries
 36          self.rs_ways = rs_ways
 37          self.layouts = gen_params.get(RSLayouts, rs_entries=self.rs_entries)
 38          self.internal_layout = make_layout(
 39              ("rs_data", self.layouts.rs.data_layout),
 40              ("rec_full", 1),
 41          )
 42  
 43          self.insert = Method(i=self.layouts.rs.insert_in)
 44          self.select = Method(o=self.layouts.rs.select_out)
 45          self.update = Methods(rs_ways, i=self.layouts.rs.update_in)
 46          self.take = Method(i=self.layouts.take_in, o=self.layouts.take_out)
 47  
 48          self.ready_for = [list(op_list) for op_list in ready_for]
 49          self.get_ready_list = [Method(o=self.layouts.get_ready_list_out) for _ in self.ready_for]
 50  
 51          self.data = Signal(ArrayLayout(self.internal_layout, self.rs_entries))
 52          self.data_ready = Signal(self.rs_entries)
 53  
 54          self.perf_rs_wait_time = TaggedLatencyMeasurer(
 55              f"fu.block_{rs_number}.rs.valid_time",
 56              description=f"Distribution of time instructions wait in RS {rs_number}",
 57              slots_number=self.rs_entries,
 58              max_latency=1000,
 59          )
 60          self.perf_num_full = HwExpHistogram(
 61              f"fu.block_{rs_number}.rs.num_full",
 62              description=f"Number of full entries in RS {rs_number}",
 63              bucket_count=ceil_log2(self.rs_entries + 1),
 64              sample_width=ceil_log2(self.rs_entries + 1),
 65          )
 66          self.log = logging.HardwareLogger(f"backend.rs.{rs_number}")
 67  
 68      @abstractmethod
 69      def elaborate(self, platform) -> TModule:
 70          raise NotImplementedError
 71  
 72      def _elaborate(self, m: TModule, takeable_mask: ValueLike, alloc: Method, free_idx: Method, order: Method):
 73          # The role of _elaborate is to accomodate FifoRS, which is currently
 74          # used in the LSU. This is a stop-gap - if one day a real LSU
 75          # is implemented, FifoRS should be removed.
 76  
 77          # The alloc, free_idx, order parameters follow the interface of
 78          # Transactron's PreservedOrderAllocator.
 79          # The takeable_mask parameter is a bitmask which marks which rows can
 80          # be taken. For a normal RS, it should contain all ones. For FifoRS,
 81          # only one row is takeable at a given moment.
 82  
 83          m.submodules += [self.perf_rs_wait_time, self.perf_num_full]
 84  
 85          for i, record in enumerate(iter(self.data)):
 86              m.d.comb += self.data_ready[i].eq(
 87                  ~record.rs_data.rp_s1.bool() & ~record.rs_data.rp_s2.bool() & record.rec_full
 88              )
 89  
 90          ready_lists: list[Value] = []
 91          for op_list in self.ready_for:
 92              op_vector = Cat(Cat(record.rs_data.exec_fn.op_type == op for op in op_list).any() for record in self.data)
 93              ready_lists.append(self.data_ready & op_vector)
 94  
 95          @def_method(m, self.select)
 96          def _() -> RecordDict:
 97              selected_id = alloc(m).ident
 98              self.log.debug(m, True, "selected entry {}", selected_id)
 99              return {"rs_entry_id": selected_id}
100  
101          matches_s1 = Signal(ArrayLayout(len(self.update), self.rs_entries))
102          matches_s2 = Signal(ArrayLayout(len(self.update), self.rs_entries))
103  
104          @def_methods(m, self.update)
105          def _(k: int, reg_id: Value, reg_val: Value) -> None:
106              for i, record in enumerate(iter(self.data)):
107                  m.d.comb += matches_s1[i][k].eq(record.rs_data.rp_s1 == reg_id)
108                  m.d.comb += matches_s2[i][k].eq(record.rs_data.rp_s2 == reg_id)
109  
110          # It is assumed that two simultaneous update calls never update the same physical register.
111          for k1, u1 in enumerate(self.update):
112              for k2, u2 in enumerate(self.update[k1 + 1 :]):
113                  self.log.error(
114                      m,
115                      u1.run & u2.run & (u1.data_in.reg_id == u2.data_in.reg_id),
116                      "Update methods {} and {} both called with reg_id {}",
117                      k1,
118                      k2,
119                      u1.data_in.reg_id,
120                  )
121  
122          for i, record in enumerate(iter(self.data)):
123              with m.If(matches_s1[i].any()):
124                  m.d.sync += record.rs_data.rp_s1.eq(0)
125                  m.d.sync += record.rs_data.s1_val.eq(
126                      OneHotMux.create(
127                          m,
128                          [(matches_s1[i][k], self.update[k].data_in.reg_val) for k in range(self.rs_ways)],
129                          C(0, self.gen_params.isa.xlen),
130                      )
131                  )
132  
133              with m.If(matches_s2[i].any()):
134                  m.d.sync += record.rs_data.rp_s2.eq(0)
135                  m.d.sync += record.rs_data.s2_val.eq(
136                      OneHotMux.create(
137                          m,
138                          [(matches_s2[i][k], self.update[k].data_in.reg_val) for k in range(self.rs_ways)],
139                          C(0, self.gen_params.isa.xlen),
140                      )
141                  )
142  
143          @def_method(m, self.insert)
144          def _(rs_entry_id: Value, rs_data: Value) -> None:
145              m.d.sync += self.data[rs_entry_id].rs_data.eq(rs_data)
146              m.d.sync += self.data[rs_entry_id].rec_full.eq(1)
147              self.perf_rs_wait_time.start(m, slot=rs_entry_id)
148              self.log.debug(m, True, "inserted entry {}", rs_entry_id)
149  
150          with Transaction().body(m):
151              self.order = order(m).order  # always ready!
152  
153          @def_method(m, self.take)
154          def _(rs_entry_id: Value) -> RecordDict:
155              actual_rs_entry_id = Signal.like(rs_entry_id)
156              m.d.av_comb += actual_rs_entry_id.eq(self.order[rs_entry_id])
157              record = self.data[actual_rs_entry_id]
158              free_idx(m, idx=rs_entry_id)
159              m.d.sync += record.rec_full.eq(0)
160              self.perf_rs_wait_time.stop(m, slot=actual_rs_entry_id)
161              out = Signal(self.layouts.take_out)
162              m.d.av_comb += assign(out, record.rs_data, fields=AssignType.COMMON)
163              self.log.debug(m, True, "taken entry {} at idx {}", actual_rs_entry_id, rs_entry_id)
164              return out
165  
166          for get_ready_list, ready_list in zip(self.get_ready_list, ready_lists):
167              tk_ready_list = ready_list & takeable_mask
168              reordered_list = Cat(tk_ready_list.bit_select(self.order[i], 1) for i in range(self.rs_entries))
169  
170              @def_method(m, get_ready_list, ready=tk_ready_list.any(), nonexclusive=True)
171              def _() -> RecordDict:
172                  return {"ready_list": reordered_list}
173  
174          if self.perf_num_full.metrics_enabled():
175              num_full = Signal(range(self.rs_entries + 1))
176              m.d.comb += num_full.eq(popcount(Cat(self.data[entry_id].rec_full for entry_id in range(self.rs_entries))))
177              with Transaction(name="perf").body(m):
178                  self.perf_num_full.add(m, num_full)
179  
180  
181  class RS(RSBase):
182      def elaborate(self, platform):
183          m = TModule()
184  
185          m.submodules.allocator = allocator = PreservedOrderAllocator(self.rs_entries)
186  
187          self._elaborate(m, -1, allocator.alloc, allocator.free_idx, allocator.order)
188  
189          return m