retirement.py
1 from amaranth import * 2 from coreblocks.interface.layouts import ( 3 CoreInstructionCounterLayouts, 4 ExceptionRegisterLayouts, 5 FetchLayouts, 6 InternalInterruptControllerLayouts, 7 RATLayouts, 8 RFLayouts, 9 ROBLayouts, 10 RetirementLayouts, 11 ) 12 13 from transactron.core import Method, Transaction, TModule, def_method 14 from transactron.lib.simultaneous import condition 15 from transactron.utils.dependencies import DependencyContext 16 from transactron.lib.metrics import * 17 18 from coreblocks.params.genparams import GenParams 19 from coreblocks.arch import ExceptionCause, PrivilegeLevel 20 from coreblocks.arch.csr_address import CounterEnableFieldOffsets 21 from coreblocks.interface.keys import ( 22 CoreStateKey, 23 CSRInstancesKey, 24 InstructionPrecommitKey, 25 ) 26 from coreblocks.priv.csr.csr_instances import CSRAddress, DoubleCounterCSR, counteren_access_filter 27 from coreblocks.arch.isa_consts import TrapVectorMode 28 29 30 class Retirement(Elaboratable): 31 def __init__( 32 self, 33 gen_params: GenParams, 34 ): 35 self.gen_params = gen_params 36 self.rob_peek = Method(o=gen_params.get(ROBLayouts).peek_layout) 37 self.rob_retire = Method(i=gen_params.get(ROBLayouts).retire_layout) 38 self.r_rat_commit = Method( 39 i=gen_params.get(RATLayouts).rrat_commit_in, o=gen_params.get(RATLayouts).rrat_commit_out 40 ) 41 self.r_rat_peek = Method(i=gen_params.get(RATLayouts).rrat_peek_in, o=gen_params.get(RATLayouts).rrat_peek_out) 42 self.free_rf_put = Method(i=[("ident", range(gen_params.phys_regs))]) 43 self.rf_free = Method(i=gen_params.get(RFLayouts).rf_free) 44 self.exception_cause_get = Method(o=gen_params.get(ExceptionRegisterLayouts).get) 45 self.exception_cause_clear = Method() 46 self.c_rat_restore = Method(i=gen_params.get(RATLayouts).crat_flush_restore) 47 self.fetch_continue = Method(i=self.gen_params.get(FetchLayouts).resume) 48 self.instr_decrement = Method( 49 i=gen_params.get(CoreInstructionCounterLayouts).decrement_in, 50 o=gen_params.get(CoreInstructionCounterLayouts).decrement_out, 51 ) 52 self.trap_entry = Method(i=[("cause", gen_params.isa.xlen)], o=[("target_priv", PrivilegeLevel)]) 53 interrupt_controller_layouts = gen_params.get(InternalInterruptControllerLayouts) 54 self.async_interrupt_cause = Method(o=interrupt_controller_layouts.interrupt_cause) 55 self.checkpoint_tag_free = Method() 56 self.checkpoint_get_active_tags = Method(o=gen_params.get(RATLayouts).get_active_tags_out) 57 58 self.instret_csr = DoubleCounterCSR( 59 gen_params, 60 CSRAddress.MINSTRET, 61 CSRAddress.MINSTRETH if gen_params.isa.xlen == 32 else None, 62 CSRAddress.INSTRET, 63 CSRAddress.INSTRETH if gen_params.isa.xlen == 32 else None, 64 shadow_access_filter=counteren_access_filter(gen_params, CounterEnableFieldOffsets.IR), 65 ) 66 self.perf_instr_ret = HwCounter("backend.retirement.retired_instr", "Number of retired instructions") 67 self.perf_trap_latency = FIFOLatencyMeasurer( 68 "backend.retirement.trap_latency", 69 "Cycles spent flushing the core after a trap", 70 slots_number=1, 71 max_latency=2 * 2**gen_params.rob_entries_bits, 72 ) 73 74 layouts = self.gen_params.get(RetirementLayouts) 75 self.dependency_manager = DependencyContext.get() 76 self.core_state = Method(o=self.gen_params.get(RetirementLayouts).core_state) 77 self.dependency_manager.add_dependency(CoreStateKey(), self.core_state) 78 79 self.precommit = Method(i=layouts.precommit_in, o=layouts.precommit_out) 80 self.dependency_manager.add_dependency(InstructionPrecommitKey(), self.precommit) 81 82 def elaborate(self, platform): 83 m = TModule() 84 85 m.submodules += [self.perf_instr_ret, self.perf_trap_latency] 86 87 csr_instances = self.dependency_manager.get_dependency(CSRInstancesKey()) 88 m_csr = csr_instances.m_mode 89 s_csr = csr_instances.s_mode if self.gen_params.supervisor_mode else None 90 m.submodules.instret_csr = self.instret_csr 91 92 side_fx = Signal(init=1) 93 94 def free_phys_reg(rp_dst: Value): 95 # mark reg in Register File as free 96 self.rf_free(m, rp_dst) 97 # put to Free RF list 98 with m.If(rp_dst): # don't put rp0 to free list - reserved to no-return instructions 99 self.free_rf_put(m, rp_dst) 100 101 def retire_instr(rob_entry): 102 # set rl_dst -> rp_dst in R-RAT 103 rat_out = self.r_rat_commit(m, rl_dst=rob_entry.rob_data.rl_dst, rp_dst=rob_entry.rob_data.rp_dst) 104 105 # free old rp_dst from overwritten R-RAT mapping 106 free_phys_reg(rat_out.old_rp_dst) 107 108 self.instret_csr.increment(m) 109 self.perf_instr_ret.incr(m) 110 111 def flush_instr(rob_entry): 112 # get original rp_dst mapped to instruction rl_dst in R-RAT 113 rat_out = self.r_rat_peek(m, rl_dst=rob_entry.rob_data.rl_dst) 114 115 # free the "new" instruction rp_dst - result is flushed 116 free_phys_reg(rob_entry.rob_data.rp_dst) 117 118 # restore original rl_dst->rp_dst mapping in F-RAT 119 self.c_rat_restore(m, rl_dst=rob_entry.rob_data.rl_dst, rp_dst=rat_out.old_rp_dst) 120 121 retire_valid = Signal() 122 with Transaction().body(m) as validate_transaction: 123 # Ensure that when exception is processed, correct entry is alredy in ExceptionCauseRegister 124 rob_entries = self.rob_peek(m) 125 rob_entry = rob_entries.entries[0] 126 ecr_entry = self.exception_cause_get(m) 127 m.d.comb += retire_valid.eq( 128 ~rob_entry.exception | (rob_entry.exception & ecr_entry.valid & (ecr_entry.rob_id == rob_entry.rob_id)) 129 ) 130 131 continue_pc_override = Signal() 132 continue_pc = Signal(self.gen_params.isa.xlen) 133 core_flushing = Signal() 134 trap_target_priv = Signal(PrivilegeLevel, init=PrivilegeLevel.MACHINE) 135 136 with m.FSM("NORMAL") as fsm: 137 with m.State("NORMAL"): 138 with Transaction().body(m, ready=retire_valid) as retire_transaction: 139 rob_entries = self.rob_peek(m) 140 rob_entry = rob_entries.entries[0] 141 self.rob_retire(m, count=1) 142 143 with m.If(rob_entry.rob_data.tag_increment): 144 self.checkpoint_tag_free(m) 145 146 core_empty = self.instr_decrement(m, count=1) 147 148 commit = Signal() 149 150 with m.If(rob_entry.exception): 151 self.perf_trap_latency.start(m) 152 153 cause_register = self.exception_cause_get(m) 154 155 cause_entry = Signal(self.gen_params.isa.xlen) 156 157 arch_trap = Signal(init=1) 158 159 with m.If(cause_register.cause == ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT): 160 # Async interrupts are inserted only by JumpBranchUnit and conditionally by MRET and CSR 161 # The PC field is set to address of instruction to resume from interrupt (e.g. for jumps 162 # it is a jump result). 163 # Instruction that reported interrupt is the last one that is commited. 164 m.d.av_comb += commit.eq(1) 165 166 # Set MSB - the Interrupt bit 167 m.d.av_comb += cause_entry.eq( 168 (1 << (self.gen_params.isa.xlen - 1)) | self.async_interrupt_cause(m).cause 169 ) 170 with m.Elif(cause_register.cause == ExceptionCause._COREBLOCKS_MISPREDICTION): 171 # Branch misprediction - commit jump, flush core and continue from correct pc. 172 m.d.av_comb += commit.eq(1) 173 # Do not modify trap related CSRs 174 m.d.av_comb += arch_trap.eq(0) 175 176 m.d.sync += continue_pc_override.eq(1) 177 m.d.sync += continue_pc.eq(cause_register.pc) 178 with m.Else(): 179 # RISC-V synchronous exceptions - don't retire instruction that caused exception, 180 # and later resume from it. 181 # Value of ExceptionCauseRegister pc field is the instruction address. 182 m.d.av_comb += commit.eq(0) 183 184 m.d.av_comb += cause_entry.eq(cause_register.cause) 185 186 with m.If(arch_trap): 187 # Register RISC-V architectural trap in CSRs. 188 target_priv = self.trap_entry(m, cause=cause_entry).target_priv 189 190 def set_trap_csrs(cause_reg, epc_reg, tval_reg): 191 cause_reg.write(m, cause_entry) 192 epc_reg.write(m, cause_register.pc) 193 tval_reg.write(m, cause_register.mtval) 194 195 with m.Switch(target_priv): 196 if self.gen_params.supervisor_mode: 197 with m.Case(PrivilegeLevel.SUPERVISOR): 198 assert s_csr is not None 199 set_trap_csrs(s_csr.scause, s_csr.sepc, s_csr.stval) 200 with m.Case(PrivilegeLevel.MACHINE): 201 set_trap_csrs(m_csr.mcause, m_csr.mepc, m_csr.mtval) 202 203 m.d.sync += trap_target_priv.eq(target_priv) 204 205 # Fetch is already stalled by ExceptionCauseRegister 206 with m.If(core_empty): 207 m.next = "TRAP_RESUME" 208 with m.Else(): 209 m.next = "TRAP_FLUSH" 210 211 with m.Else(): 212 # Normally retire all non-trap instructions 213 m.d.av_comb += commit.eq(1) 214 215 # Condition is used to avoid FRAT locking during normal operation 216 with condition(m) as cond: 217 with cond(commit): 218 retire_instr(rob_entry) 219 with cond(): 220 # Not using default condition, because we want to block if branch is not ready 221 flush_instr(rob_entry) 222 223 m.d.comb += core_flushing.eq(1) 224 225 validate_transaction.schedule_before(retire_transaction) 226 227 with m.State("TRAP_FLUSH"): 228 with Transaction().body(m): 229 # Flush entire core 230 rob_entries = self.rob_peek(m) 231 rob_entry = rob_entries.entries[0] 232 self.rob_retire(m, count=1) 233 234 with m.If(rob_entry.rob_data.tag_increment): 235 self.checkpoint_tag_free(m) 236 237 core_empty = self.instr_decrement(m, count=1) 238 239 flush_instr(rob_entry) 240 241 with m.If(core_empty): 242 m.next = "TRAP_RESUME" 243 244 m.d.comb += core_flushing.eq(1) 245 246 with m.State("TRAP_RESUME"): 247 with Transaction().body(m): 248 # Resume core operation 249 self.perf_trap_latency.stop(m) 250 251 handler_pc = Signal(self.gen_params.isa.xlen) 252 tvec_offset = Signal(self.gen_params.isa.xlen) 253 tvec_base = Signal(self.gen_params.isa.xlen) 254 tvec_mode = Signal(TrapVectorMode) 255 tcause = Signal(self.gen_params.isa.xlen) 256 257 def set_vals(reg_base, reg_mode, reg_cause): 258 m.d.av_comb += [ 259 tvec_base.eq(reg_base.read(m).data), 260 tvec_mode.eq(reg_mode.read(m).data), 261 tcause.eq(reg_cause.read(m).data), 262 ] 263 264 with m.Switch(trap_target_priv): 265 if self.gen_params.supervisor_mode: 266 with m.Case(PrivilegeLevel.SUPERVISOR): 267 assert s_csr is not None 268 set_vals(s_csr.stvec_base, s_csr.stvec_mode, s_csr.scause) 269 with m.Case(PrivilegeLevel.MACHINE): 270 set_vals(m_csr.mtvec_base, m_csr.mtvec_mode, m_csr.mcause) 271 272 # When mode is Vectored, interrupts set pc to base + 4 * cause_number 273 with m.If(tcause[-1] & (tvec_mode == TrapVectorMode.VECTORED)): 274 m.d.av_comb += tvec_offset.eq(tcause << 2) 275 276 # (xtvec_base stores base[MXLEN-1:2]) 277 m.d.av_comb += handler_pc.eq((tvec_base << 2) + tvec_offset) 278 279 resume_pc = Mux(continue_pc_override, continue_pc, handler_pc) 280 m.d.sync += continue_pc_override.eq(0) 281 282 self.fetch_continue(m, pc=resume_pc) 283 284 # Release pending trap state - allow accepting new reports 285 self.exception_cause_clear(m) 286 287 m.next = "NORMAL" 288 289 # Disable executing any side effects from instructions in core when it is flushed 290 m.d.comb += side_fx.eq(~fsm.ongoing("TRAP_FLUSH")) 291 292 @def_method(m, self.core_state, nonexclusive=True) 293 def _(): 294 return {"flushing": core_flushing} 295 296 rob_id_val = Signal(self.gen_params.rob_entries_bits) 297 298 # The argument is only used in argument validation, it is not needed in the method body. 299 # A dummy combiner is provided. 300 @def_method( 301 m, 302 self.precommit, 303 validate_arguments=lambda rob_id: rob_id == rob_id_val, 304 nonexclusive=True, 305 combiner=lambda m, args, runs: 0, 306 ) 307 def _(rob_id): 308 m.d.top_comb += rob_id_val.eq(self.rob_peek(m).entries[0].rob_id) 309 return {"side_fx": side_fx} 310 311 return m