/ coreblocks / backend / retirement.py
retirement.py
  1  from amaranth import *
  2  from coreblocks.interface.layouts import (
  3      CoreInstructionCounterLayouts,
  4      ExceptionRegisterLayouts,
  5      FetchLayouts,
  6      InternalInterruptControllerLayouts,
  7      RATLayouts,
  8      RFLayouts,
  9      ROBLayouts,
 10      RetirementLayouts,
 11  )
 12  
 13  from transactron.core import Method, Transaction, TModule, def_method
 14  from transactron.lib.simultaneous import condition
 15  from transactron.utils.dependencies import DependencyContext
 16  from transactron.lib.metrics import *
 17  
 18  from coreblocks.params.genparams import GenParams
 19  from coreblocks.arch import ExceptionCause, PrivilegeLevel
 20  from coreblocks.arch.csr_address import CounterEnableFieldOffsets
 21  from coreblocks.interface.keys import (
 22      CoreStateKey,
 23      CSRInstancesKey,
 24      InstructionPrecommitKey,
 25  )
 26  from coreblocks.priv.csr.csr_instances import CSRAddress, DoubleCounterCSR, counteren_access_filter
 27  from coreblocks.arch.isa_consts import TrapVectorMode
 28  
 29  
 30  class Retirement(Elaboratable):
 31      def __init__(
 32          self,
 33          gen_params: GenParams,
 34      ):
 35          self.gen_params = gen_params
 36          self.rob_peek = Method(o=gen_params.get(ROBLayouts).peek_layout)
 37          self.rob_retire = Method(i=gen_params.get(ROBLayouts).retire_layout)
 38          self.r_rat_commit = Method(
 39              i=gen_params.get(RATLayouts).rrat_commit_in, o=gen_params.get(RATLayouts).rrat_commit_out
 40          )
 41          self.r_rat_peek = Method(i=gen_params.get(RATLayouts).rrat_peek_in, o=gen_params.get(RATLayouts).rrat_peek_out)
 42          self.free_rf_put = Method(i=[("ident", range(gen_params.phys_regs))])
 43          self.rf_free = Method(i=gen_params.get(RFLayouts).rf_free)
 44          self.exception_cause_get = Method(o=gen_params.get(ExceptionRegisterLayouts).get)
 45          self.exception_cause_clear = Method()
 46          self.c_rat_restore = Method(i=gen_params.get(RATLayouts).crat_flush_restore)
 47          self.fetch_continue = Method(i=self.gen_params.get(FetchLayouts).resume)
 48          self.instr_decrement = Method(
 49              i=gen_params.get(CoreInstructionCounterLayouts).decrement_in,
 50              o=gen_params.get(CoreInstructionCounterLayouts).decrement_out,
 51          )
 52          self.trap_entry = Method(i=[("cause", gen_params.isa.xlen)], o=[("target_priv", PrivilegeLevel)])
 53          interrupt_controller_layouts = gen_params.get(InternalInterruptControllerLayouts)
 54          self.async_interrupt_cause = Method(o=interrupt_controller_layouts.interrupt_cause)
 55          self.checkpoint_tag_free = Method()
 56          self.checkpoint_get_active_tags = Method(o=gen_params.get(RATLayouts).get_active_tags_out)
 57  
 58          self.instret_csr = DoubleCounterCSR(
 59              gen_params,
 60              CSRAddress.MINSTRET,
 61              CSRAddress.MINSTRETH if gen_params.isa.xlen == 32 else None,
 62              CSRAddress.INSTRET,
 63              CSRAddress.INSTRETH if gen_params.isa.xlen == 32 else None,
 64              shadow_access_filter=counteren_access_filter(gen_params, CounterEnableFieldOffsets.IR),
 65          )
 66          self.perf_instr_ret = HwCounter("backend.retirement.retired_instr", "Number of retired instructions")
 67          self.perf_trap_latency = FIFOLatencyMeasurer(
 68              "backend.retirement.trap_latency",
 69              "Cycles spent flushing the core after a trap",
 70              slots_number=1,
 71              max_latency=2 * 2**gen_params.rob_entries_bits,
 72          )
 73  
 74          layouts = self.gen_params.get(RetirementLayouts)
 75          self.dependency_manager = DependencyContext.get()
 76          self.core_state = Method(o=self.gen_params.get(RetirementLayouts).core_state)
 77          self.dependency_manager.add_dependency(CoreStateKey(), self.core_state)
 78  
 79          self.precommit = Method(i=layouts.precommit_in, o=layouts.precommit_out)
 80          self.dependency_manager.add_dependency(InstructionPrecommitKey(), self.precommit)
 81  
 82      def elaborate(self, platform):
 83          m = TModule()
 84  
 85          m.submodules += [self.perf_instr_ret, self.perf_trap_latency]
 86  
 87          csr_instances = self.dependency_manager.get_dependency(CSRInstancesKey())
 88          m_csr = csr_instances.m_mode
 89          s_csr = csr_instances.s_mode if self.gen_params.supervisor_mode else None
 90          m.submodules.instret_csr = self.instret_csr
 91  
 92          side_fx = Signal(init=1)
 93  
 94          def free_phys_reg(rp_dst: Value):
 95              # mark reg in Register File as free
 96              self.rf_free(m, rp_dst)
 97              # put to Free RF list
 98              with m.If(rp_dst):  # don't put rp0 to free list - reserved to no-return instructions
 99                  self.free_rf_put(m, rp_dst)
100  
101          def retire_instr(rob_entry):
102              # set rl_dst -> rp_dst in R-RAT
103              rat_out = self.r_rat_commit(m, rl_dst=rob_entry.rob_data.rl_dst, rp_dst=rob_entry.rob_data.rp_dst)
104  
105              # free old rp_dst from overwritten R-RAT mapping
106              free_phys_reg(rat_out.old_rp_dst)
107  
108              self.instret_csr.increment(m)
109              self.perf_instr_ret.incr(m)
110  
111          def flush_instr(rob_entry):
112              # get original rp_dst mapped to instruction rl_dst in R-RAT
113              rat_out = self.r_rat_peek(m, rl_dst=rob_entry.rob_data.rl_dst)
114  
115              # free the "new" instruction rp_dst - result is flushed
116              free_phys_reg(rob_entry.rob_data.rp_dst)
117  
118              # restore original rl_dst->rp_dst mapping in F-RAT
119              self.c_rat_restore(m, rl_dst=rob_entry.rob_data.rl_dst, rp_dst=rat_out.old_rp_dst)
120  
121          retire_valid = Signal()
122          with Transaction().body(m) as validate_transaction:
123              # Ensure that when exception is processed, correct entry is alredy in ExceptionCauseRegister
124              rob_entries = self.rob_peek(m)
125              rob_entry = rob_entries.entries[0]
126              ecr_entry = self.exception_cause_get(m)
127              m.d.comb += retire_valid.eq(
128                  ~rob_entry.exception | (rob_entry.exception & ecr_entry.valid & (ecr_entry.rob_id == rob_entry.rob_id))
129              )
130  
131          continue_pc_override = Signal()
132          continue_pc = Signal(self.gen_params.isa.xlen)
133          core_flushing = Signal()
134          trap_target_priv = Signal(PrivilegeLevel, init=PrivilegeLevel.MACHINE)
135  
136          with m.FSM("NORMAL") as fsm:
137              with m.State("NORMAL"):
138                  with Transaction().body(m, ready=retire_valid) as retire_transaction:
139                      rob_entries = self.rob_peek(m)
140                      rob_entry = rob_entries.entries[0]
141                      self.rob_retire(m, count=1)
142  
143                      with m.If(rob_entry.rob_data.tag_increment):
144                          self.checkpoint_tag_free(m)
145  
146                      core_empty = self.instr_decrement(m, count=1)
147  
148                      commit = Signal()
149  
150                      with m.If(rob_entry.exception):
151                          self.perf_trap_latency.start(m)
152  
153                          cause_register = self.exception_cause_get(m)
154  
155                          cause_entry = Signal(self.gen_params.isa.xlen)
156  
157                          arch_trap = Signal(init=1)
158  
159                          with m.If(cause_register.cause == ExceptionCause._COREBLOCKS_ASYNC_INTERRUPT):
160                              # Async interrupts are inserted only by JumpBranchUnit and conditionally by MRET and CSR
161                              # The PC field is set to address of instruction to resume from interrupt (e.g. for jumps
162                              # it is a jump result).
163                              # Instruction that reported interrupt is the last one that is commited.
164                              m.d.av_comb += commit.eq(1)
165  
166                              # Set MSB - the Interrupt bit
167                              m.d.av_comb += cause_entry.eq(
168                                  (1 << (self.gen_params.isa.xlen - 1)) | self.async_interrupt_cause(m).cause
169                              )
170                          with m.Elif(cause_register.cause == ExceptionCause._COREBLOCKS_MISPREDICTION):
171                              # Branch misprediction - commit jump, flush core and continue from correct pc.
172                              m.d.av_comb += commit.eq(1)
173                              # Do not modify trap related CSRs
174                              m.d.av_comb += arch_trap.eq(0)
175  
176                              m.d.sync += continue_pc_override.eq(1)
177                              m.d.sync += continue_pc.eq(cause_register.pc)
178                          with m.Else():
179                              # RISC-V synchronous exceptions - don't retire instruction that caused exception,
180                              # and later resume from it.
181                              # Value of ExceptionCauseRegister pc field is the instruction address.
182                              m.d.av_comb += commit.eq(0)
183  
184                              m.d.av_comb += cause_entry.eq(cause_register.cause)
185  
186                          with m.If(arch_trap):
187                              # Register RISC-V architectural trap in CSRs.
188                              target_priv = self.trap_entry(m, cause=cause_entry).target_priv
189  
190                              def set_trap_csrs(cause_reg, epc_reg, tval_reg):
191                                  cause_reg.write(m, cause_entry)
192                                  epc_reg.write(m, cause_register.pc)
193                                  tval_reg.write(m, cause_register.mtval)
194  
195                              with m.Switch(target_priv):
196                                  if self.gen_params.supervisor_mode:
197                                      with m.Case(PrivilegeLevel.SUPERVISOR):
198                                          assert s_csr is not None
199                                          set_trap_csrs(s_csr.scause, s_csr.sepc, s_csr.stval)
200                                  with m.Case(PrivilegeLevel.MACHINE):
201                                      set_trap_csrs(m_csr.mcause, m_csr.mepc, m_csr.mtval)
202  
203                              m.d.sync += trap_target_priv.eq(target_priv)
204  
205                          # Fetch is already stalled by ExceptionCauseRegister
206                          with m.If(core_empty):
207                              m.next = "TRAP_RESUME"
208                          with m.Else():
209                              m.next = "TRAP_FLUSH"
210  
211                      with m.Else():
212                          # Normally retire all non-trap instructions
213                          m.d.av_comb += commit.eq(1)
214  
215                      # Condition is used to avoid FRAT locking during normal operation
216                      with condition(m) as cond:
217                          with cond(commit):
218                              retire_instr(rob_entry)
219                          with cond():
220                              # Not using default condition, because we want to block if branch is not ready
221                              flush_instr(rob_entry)
222  
223                              m.d.comb += core_flushing.eq(1)
224  
225                      validate_transaction.schedule_before(retire_transaction)
226  
227              with m.State("TRAP_FLUSH"):
228                  with Transaction().body(m):
229                      # Flush entire core
230                      rob_entries = self.rob_peek(m)
231                      rob_entry = rob_entries.entries[0]
232                      self.rob_retire(m, count=1)
233  
234                      with m.If(rob_entry.rob_data.tag_increment):
235                          self.checkpoint_tag_free(m)
236  
237                      core_empty = self.instr_decrement(m, count=1)
238  
239                      flush_instr(rob_entry)
240  
241                      with m.If(core_empty):
242                          m.next = "TRAP_RESUME"
243  
244                  m.d.comb += core_flushing.eq(1)
245  
246              with m.State("TRAP_RESUME"):
247                  with Transaction().body(m):
248                      # Resume core operation
249                      self.perf_trap_latency.stop(m)
250  
251                      handler_pc = Signal(self.gen_params.isa.xlen)
252                      tvec_offset = Signal(self.gen_params.isa.xlen)
253                      tvec_base = Signal(self.gen_params.isa.xlen)
254                      tvec_mode = Signal(TrapVectorMode)
255                      tcause = Signal(self.gen_params.isa.xlen)
256  
257                      def set_vals(reg_base, reg_mode, reg_cause):
258                          m.d.av_comb += [
259                              tvec_base.eq(reg_base.read(m).data),
260                              tvec_mode.eq(reg_mode.read(m).data),
261                              tcause.eq(reg_cause.read(m).data),
262                          ]
263  
264                      with m.Switch(trap_target_priv):
265                          if self.gen_params.supervisor_mode:
266                              with m.Case(PrivilegeLevel.SUPERVISOR):
267                                  assert s_csr is not None
268                                  set_vals(s_csr.stvec_base, s_csr.stvec_mode, s_csr.scause)
269                          with m.Case(PrivilegeLevel.MACHINE):
270                              set_vals(m_csr.mtvec_base, m_csr.mtvec_mode, m_csr.mcause)
271  
272                      # When mode is Vectored, interrupts set pc to base + 4 * cause_number
273                      with m.If(tcause[-1] & (tvec_mode == TrapVectorMode.VECTORED)):
274                          m.d.av_comb += tvec_offset.eq(tcause << 2)
275  
276                      # (xtvec_base stores base[MXLEN-1:2])
277                      m.d.av_comb += handler_pc.eq((tvec_base << 2) + tvec_offset)
278  
279                      resume_pc = Mux(continue_pc_override, continue_pc, handler_pc)
280                      m.d.sync += continue_pc_override.eq(0)
281  
282                      self.fetch_continue(m, pc=resume_pc)
283  
284                      # Release pending trap state - allow accepting new reports
285                      self.exception_cause_clear(m)
286  
287                      m.next = "NORMAL"
288  
289          # Disable executing any side effects from instructions in core when it is flushed
290          m.d.comb += side_fx.eq(~fsm.ongoing("TRAP_FLUSH"))
291  
292          @def_method(m, self.core_state, nonexclusive=True)
293          def _():
294              return {"flushing": core_flushing}
295  
296          rob_id_val = Signal(self.gen_params.rob_entries_bits)
297  
298          # The argument is only used in argument validation, it is not needed in the method body.
299          # A dummy combiner is provided.
300          @def_method(
301              m,
302              self.precommit,
303              validate_arguments=lambda rob_id: rob_id == rob_id_val,
304              nonexclusive=True,
305              combiner=lambda m, args, runs: 0,
306          )
307          def _(rob_id):
308              m.d.top_comb += rob_id_val.eq(self.rob_peek(m).entries[0].rob_id)
309              return {"side_fx": side_fx}
310  
311          return m