rf.py
1 from amaranth import * 2 from amaranth.lib.data import ArrayLayout 3 from amaranth.lib import memory 4 from transactron import Methods, Transaction, def_methods, TModule 5 from transactron.utils.amaranth_ext.elaboratables import OneHotMux 6 from coreblocks.interface.layouts import RFLayouts 7 from coreblocks.params import GenParams 8 from transactron.lib import logging 9 from transactron.lib.metrics import HwExpHistogram, TaggedLatencyMeasurer 10 from transactron.lib.storage import MemoryBank 11 from transactron.utils.amaranth_ext.functions import popcount 12 13 __all__ = ["RegisterFile"] 14 15 log = logging.HardwareLogger("core_structs.rf") 16 17 18 class RegisterFile(Elaboratable): 19 def __init__(self, *, gen_params: GenParams, read_ports: int, write_ports: int, free_ports: int): 20 self.gen_params = gen_params 21 22 layouts = gen_params.get(RFLayouts) 23 self.read_layout = layouts.rf_read_out 24 self.entries = MemoryBank( 25 memory_type=gen_params.multiport_memory_type if write_ports > 1 else memory.Memory, 26 shape=gen_params.isa.xlen, 27 depth=2**gen_params.phys_regs_bits, 28 read_ports=read_ports, 29 write_ports=write_ports, 30 read_on_resp=True, 31 ) 32 self.valids = Array(Signal(init=k == 0) for k in range(2**gen_params.phys_regs_bits)) 33 34 self.read_req = Methods(read_ports, i=layouts.rf_read_in) 35 self.read_resp = Methods(read_ports, i=layouts.rf_read_in, o=layouts.rf_read_out) 36 self.write = Methods(write_ports, i=layouts.rf_write) 37 self.free = Methods(free_ports, i=layouts.rf_free) 38 39 self.perf_rf_valid_time = TaggedLatencyMeasurer( 40 "struct.rf.valid_time", 41 description="Distribution of time registers are valid in RF", 42 slots_number=2**gen_params.phys_regs_bits, 43 max_latency=1000, 44 ways=max(write_ports, free_ports), 45 ) 46 self.perf_num_valid = HwExpHistogram( 47 "struct.rf.num_valid", 48 description="Number of valid registers in RF", 49 bucket_count=gen_params.phys_regs_bits + 1, 50 sample_width=gen_params.phys_regs_bits + 1, 51 ) 52 53 def elaborate(self, platform): 54 m = TModule() 55 56 m.submodules += [self.entries, self.perf_rf_valid_time, self.perf_num_valid] 57 58 being_written = Signal(ArrayLayout(self.gen_params.phys_regs_bits, len(self.write))) 59 written_value = Signal(ArrayLayout(self.gen_params.isa.xlen, len(self.write))) 60 61 @def_methods(m, self.read_req) 62 def _(k: int, reg_id: Value): 63 self.entries.read_req[k](m, addr=reg_id) 64 65 @def_methods(m, self.read_resp) 66 def _(k: int, reg_id: Value): 67 forward = Signal() 68 reg_written = Signal(len(self.write)) 69 m.d.av_comb += reg_written.eq( 70 Cat((being_written[i] == reg_id) & (reg_id != 0) for i in range(len(self.write))) 71 ) 72 m.d.av_comb += forward.eq(reg_written.any()) 73 reg_val = OneHotMux.create( 74 m, 75 [(reg_written[i], written_value[i]) for i in range(len(self.write))], 76 self.entries.read_resp[k](m).data, 77 ) 78 return { 79 "reg_val": reg_val, 80 "valid": forward | self.valids[reg_id], 81 } 82 83 @def_methods(m, self.write) 84 def _(k: int, reg_id: Value, reg_val: Value): 85 m.d.comb += being_written[k].eq(reg_id) 86 m.d.av_comb += written_value[k].eq(reg_val) 87 with m.If(reg_id != 0): 88 log.assertion(m, ~self.valids[reg_id], "Valid register {} written", reg_id) 89 self.entries.write[k](m, addr=reg_id, data=reg_val) 90 m.d.sync += self.valids[reg_id].eq(1) 91 self.perf_rf_valid_time.start[k](m, slot=reg_id) 92 93 @def_methods(m, self.free) 94 def _(k: int, reg_id: Value): 95 with m.If(reg_id != 0): 96 log.assertion(m, self.valids[reg_id], "Invalid register {} freed", reg_id) 97 m.d.sync += self.valids[reg_id].eq(0) 98 self.perf_rf_valid_time.stop[k](m, slot=reg_id) 99 100 # It is assumed that two simultaneous write calls never write the same physical register. 101 for k1, m1 in enumerate(self.write): 102 for k2, m2 in enumerate(self.write[k1 + 1 :]): 103 log.error( 104 m, 105 m1.run & m2.run & (m1.data_in.reg_id == m2.data_in.reg_id) & (m1.data_in.reg_id != 0), 106 "Write methods {} and {} both called with reg_id {}", 107 k1, 108 k2, 109 m1.data_in.reg_id, 110 ) 111 112 # It is assumed that two simultaneous free calls never free the same physical register. 113 for k1, m1 in enumerate(self.free): 114 for k2, m2 in enumerate(self.free[k1 + 1 :]): 115 log.error( 116 m, 117 m1.run & m2.run & (m1.data_in.reg_id == m2.data_in.reg_id) & (m1.data_in.reg_id != 0), 118 "Free methods {} and {} both called with reg_id {}", 119 k1, 120 k2, 121 m1.data_in.reg_id, 122 ) 123 124 if self.perf_num_valid.metrics_enabled(): 125 num_valid = Signal(self.gen_params.phys_regs_bits + 1) 126 m.d.comb += num_valid.eq( 127 popcount(Cat(self.valids[reg_id] for reg_id in range(2**self.gen_params.phys_regs_bits))) 128 ) 129 with Transaction(name="perf").body(m): 130 self.perf_num_valid.add(m, num_valid) 131 132 return m