exception.py
1 from dataclasses import dataclass, KW_ONLY 2 from typing import Sequence 3 from amaranth import * 4 from coreblocks.arch.isa_consts import PrivilegeLevel 5 from transactron.utils.dependencies import DependencyContext 6 7 from transactron import * 8 9 from coreblocks.params import GenParams, FunctionalComponentParams 10 from coreblocks.arch import OpType, Funct3, ExceptionCause 11 from coreblocks.interface.layouts import FetchLayouts 12 from transactron.utils import OneHotSwitch 13 from coreblocks.interface.keys import ExceptionReportKey, CSRInstancesKey 14 15 from coreblocks.func_blocks.fu.common import DecoderManager, FuncUnitBase 16 from enum import IntFlag, auto 17 18 from coreblocks.func_blocks.interface.func_protocols import FuncUnit 19 20 __all__ = ["ExceptionFuncUnit", "ExceptionUnitComponent"] 21 22 23 class ExceptionUnitFn(DecoderManager): 24 class Fn(IntFlag): 25 ECALL = auto() 26 EBREAK = auto() 27 # Fns representing exceptions caused by instructions before FUs 28 INSTR_ACCESS_FAULT = auto() 29 ILLEGAL_INSTRUCTION = auto() 30 BREAKPOINT = auto() 31 INSTR_PAGE_FAULT = auto() 32 33 def get_instructions(self) -> Sequence[tuple]: 34 return [ 35 (self.Fn.ECALL, OpType.ECALL), 36 (self.Fn.EBREAK, OpType.EBREAK), 37 (self.Fn.INSTR_ACCESS_FAULT, OpType.EXCEPTION, Funct3._EINSTRACCESSFAULT), 38 (self.Fn.ILLEGAL_INSTRUCTION, OpType.EXCEPTION, Funct3._EILLEGALINSTR), 39 (self.Fn.BREAKPOINT, OpType.EXCEPTION, Funct3._EBREAKPOINT), 40 (self.Fn.INSTR_PAGE_FAULT, OpType.EXCEPTION, Funct3._EINSTRPAGEFAULT), 41 ] 42 43 44 class ExceptionFuncUnit(FuncUnitBase[ExceptionUnitFn]): 45 def __init__(self, gen_params: GenParams, fn=ExceptionUnitFn()): 46 super().__init__(gen_params, fn) 47 48 self.dm = DependencyContext.get() 49 self.report = self.dm.get_dependency(ExceptionReportKey())() 50 51 def elaborate(self, platform): 52 m = super().elaborate(platform) 53 54 @def_method(m, self.issue_decoded) 55 def _(arg): 56 cause = Signal(ExceptionCause) 57 mtval = Signal(self.gen_params.isa.xlen) 58 59 priv_level = self.dm.get_dependency(CSRInstancesKey()).m_mode.priv_mode.read(m).data 60 61 instr_exc_on_second_half = Signal() 62 m.d.av_comb += instr_exc_on_second_half.eq( 63 (arg.imm & FetchLayouts.FaultFlag.EXCEPTION_ON_SECOND_HALF).any() 64 ) 65 66 with OneHotSwitch(m, arg.decode_fn) as OneHotCase: 67 with OneHotCase(ExceptionUnitFn.Fn.EBREAK): 68 m.d.av_comb += cause.eq(ExceptionCause.BREAKPOINT) 69 m.d.av_comb += mtval.eq(arg.pc) 70 with OneHotCase(ExceptionUnitFn.Fn.ECALL): 71 with m.Switch(priv_level): 72 with m.Case(PrivilegeLevel.MACHINE): 73 m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_M) 74 with m.Case(PrivilegeLevel.SUPERVISOR): 75 m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_S) 76 with m.Case(PrivilegeLevel.USER): 77 m.d.av_comb += cause.eq(ExceptionCause.ENVIRONMENT_CALL_FROM_U) 78 m.d.av_comb += mtval.eq(0) # by SPEC 79 with OneHotCase(ExceptionUnitFn.Fn.INSTR_ACCESS_FAULT): 80 m.d.av_comb += cause.eq(ExceptionCause.INSTRUCTION_ACCESS_FAULT) 81 # With C extension access fault can be only on the second half of instruction, and mepc != mtval. 82 # This information is passed in imm field 83 m.d.av_comb += mtval.eq(arg.pc + (instr_exc_on_second_half << 1)) 84 with OneHotCase(ExceptionUnitFn.Fn.ILLEGAL_INSTRUCTION): 85 m.d.av_comb += cause.eq(ExceptionCause.ILLEGAL_INSTRUCTION) 86 m.d.av_comb += mtval.eq(arg.imm) # passed instruction bytes 87 with OneHotCase(ExceptionUnitFn.Fn.BREAKPOINT): 88 m.d.av_comb += cause.eq(ExceptionCause.BREAKPOINT) 89 m.d.av_comb += mtval.eq(arg.pc) 90 with OneHotCase(ExceptionUnitFn.Fn.INSTR_PAGE_FAULT): 91 m.d.av_comb += cause.eq(ExceptionCause.INSTRUCTION_PAGE_FAULT) 92 m.d.av_comb += mtval.eq(arg.pc + (instr_exc_on_second_half << 1)) 93 94 self.report(m, rob_id=arg.rob_id, cause=cause, pc=arg.pc, mtval=mtval) 95 96 self.push_result(m, result=0, exception=1, rob_id=arg.rob_id, rp_dst=arg.rp_dst) 97 98 return m 99 100 101 @dataclass(frozen=True) 102 class ExceptionUnitComponent(FunctionalComponentParams): 103 _: KW_ONLY 104 result_fifo: bool = True 105 decoder_manager: ExceptionUnitFn = ExceptionUnitFn() 106 107 def get_module(self, gen_params: GenParams) -> FuncUnit: 108 return ExceptionFuncUnit(gen_params, self.decoder_manager)