/ coreblocks / func_blocks / fu / exception.py
exception.py
  1  from dataclasses import dataclass, KW_ONLY
  2  from typing import Sequence
  3  from amaranth import *
  4  from coreblocks.arch.isa_consts import PrivilegeLevel
  5  from transactron.utils.dependencies import DependencyContext
  6  
  7  from transactron import *
  8  
  9  from coreblocks.params import GenParams, FunctionalComponentParams
 10  from coreblocks.arch import OpType, Funct3, ExceptionCause
 11  from coreblocks.interface.layouts import FetchLayouts
 12  from transactron.utils import OneHotSwitch
 13  from coreblocks.interface.keys import ExceptionReportKey, CSRInstancesKey
 14  
 15  from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase
 16  from enum import IntFlag, auto
 17  
 18  from coreblocks.func_blocks.interface.func_protocols import FuncUnit
 19  
 20  __all__ = ["ExceptionFuncUnit", "ExceptionUnitComponent"]
 21  
 22  
 23  class ExceptionUnitFn(DecoderManager):
 24      class Fn(IntFlag):
 25          ECALL = auto()
 26          EBREAK = auto()
 27          # Fns representing exceptions caused by instructions before FUs
 28          INSTR_ACCESS_FAULT = auto()
 29          ILLEGAL_INSTRUCTION = auto()
 30          BREAKPOINT = auto()
 31          INSTR_PAGE_FAULT = auto()
 32  
 33      def get_instructions(self) -> Sequence[tuple]:
 34          return [
 35              (self.Fn.ECALL, OpType.ECALL),
 36              (self.Fn.EBREAK, OpType.EBREAK),
 37              (self.Fn.INSTR_ACCESS_FAULT, OpType.EXCEPTION, Funct3._EINSTRACCESSFAULT),
 38              (self.Fn.ILLEGAL_INSTRUCTION, OpType.EXCEPTION, Funct3._EILLEGALINSTR),
 39              (self.Fn.BREAKPOINT, OpType.EXCEPTION, Funct3._EBREAKPOINT),
 40              (self.Fn.INSTR_PAGE_FAULT, OpType.EXCEPTION, Funct3._EINSTRPAGEFAULT),
 41          ]
 42  
 43  
 44  class ExceptionFuncUnit(FuncUnitBase[ExceptionUnitFn]):
 45      def __init__(self, gen_params: GenParams, fn=ExceptionUnitFn()):
 46          super().__init__(gen_params, fn)
 47  
 48          self.dm = DependencyContext.get()
 49          self.report = self.dm.get_dependency(ExceptionReportKey())()
 50  
 51      def elaborate(self, platform):
 52          m = super().elaborate(platform)
 53  
 54          @def_method(m, self.issue_decoded)
 55          def _(arg):
 56              cause = Signal(ExceptionCause)
 57              mtval = Signal(self.gen_params.isa.xlen)
 58  
 59              priv_level = self.dm.get_dependency(CSRInstancesKey()).m_mode.priv_mode.read(m).data
 60  
 61              instr_exc_on_second_half = Signal()
 62              m.d.av_comb += instr_exc_on_second_half.eq(
 63                  (arg.imm & FetchLayouts.FaultFlag.EXCEPTION_ON_SECOND_HALF).any()
 64              )
 65  
 66              with OneHotSwitch(m, arg.decode_fn) as OneHotCase:
 67                  with OneHotCase(ExceptionUnitFn.Fn.EBREAK):
 68                      m.d.av_comb += cause.eq(ExceptionCause.BREAKPOINT)
 69                      m.d.av_comb += mtval.eq(arg.pc)
 70                  with OneHotCase(ExceptionUnitFn.Fn.ECALL):
 71                      with m.Switch(priv_level):
 72                          with m.Case(PrivilegeLevel.MACHINE):
 73                              m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_M)
 74                          with m.Case(PrivilegeLevel.SUPERVISOR):
 75                              m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_S)
 76                          with m.Case(PrivilegeLevel.USER):
 77                              m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_U)
 78                      m.d.av_comb += mtval.eq(0)  # by SPEC
 79                  with OneHotCase(ExceptionUnitFn.Fn.INSTR_ACCESS_FAULT):
 80                      m.d.av_comb += cause.eq(ExceptionCause.INSTRUCTION_ACCESS_FAULT)
 81                      # With C extension access fault can be only on the second half of instruction, and mepc != mtval.
 82                      # This information is passed in imm field
 83                      m.d.av_comb += mtval.eq(arg.pc + (instr_exc_on_second_half << 1))
 84                  with OneHotCase(ExceptionUnitFn.Fn.ILLEGAL_INSTRUCTION):
 85                      m.d.av_comb += cause.eq(ExceptionCause.ILLEGAL_INSTRUCTION)
 86                      m.d.av_comb += mtval.eq(arg.imm)  # passed instruction bytes
 87                  with OneHotCase(ExceptionUnitFn.Fn.BREAKPOINT):
 88                      m.d.av_comb += cause.eq(ExceptionCause.BREAKPOINT)
 89                      m.d.av_comb += mtval.eq(arg.pc)
 90                  with OneHotCase(ExceptionUnitFn.Fn.INSTR_PAGE_FAULT):
 91                      m.d.av_comb += cause.eq(ExceptionCause.INSTRUCTION_PAGE_FAULT)
 92                      m.d.av_comb += mtval.eq(arg.pc + (instr_exc_on_second_half << 1))
 93  
 94              self.report(m, rob_id=arg.rob_id, cause=cause, pc=arg.pc, mtval=mtval)
 95  
 96              self.push_result(m, result=0, exception=1, rob_id=arg.rob_id, rp_dst=arg.rp_dst)
 97  
 98          return m
 99  
100  
101  @dataclass(frozen=True)
102  class ExceptionUnitComponent(FunctionalComponentParams):
103      _: KW_ONLY
104      result_fifo: bool = True
105      decoder_manager: ExceptionUnitFn = ExceptionUnitFn()
106  
107      def get_module(self, gen_params: GenParams) -> FuncUnit:
108          return ExceptionFuncUnit(gen_params, self.decoder_manager)