/ test / func_blocks / fu / test_fu_decoder.py
test_fu_decoder.py
  1  import random
  2  from collections.abc import Sequence
  3  
  4  from transactron.testing import SimpleTestCircuit, TestCaseWithSimulator, TestbenchContext
  5  
  6  from coreblocks.func_blocks.fu.common.fu_decoder import DecoderManager, Decoder
  7  from coreblocks.arch import OpType, Funct3, Funct7
  8  from coreblocks.params import GenParams
  9  from coreblocks.params.configurations import test_core_config
 10  
 11  from enum import IntFlag, auto
 12  
 13  
 14  class TestFuDecoder(TestCaseWithSimulator):
 15      def setup_method(self) -> None:
 16          self.gen_params = GenParams(test_core_config)
 17  
 18      # calculates expected decoder output
 19      def expected_results(self, instructions: Sequence[tuple], op_type_dependent: bool, inp: dict[str, int]) -> int:
 20          acc = 0
 21  
 22          for inst in instructions:
 23              op_type_match = inp["op_type"] == inst[1] if op_type_dependent else True
 24              funct3_match = inp["funct3"] == inst[2] if len(inst) >= 3 else True
 25              funct7_match = inp["funct7"] == inst[3] if len(inst) >= 4 else True
 26  
 27              if op_type_match and funct3_match and funct7_match:
 28                  assert inst  # TODO: better typing
 29                  acc |= inst[0]
 30  
 31          return acc
 32  
 33      async def handle_signals(self, sim: TestbenchContext, decoder: Decoder, exec_fn: dict[str, int]):
 34          sim.set(decoder.exec_fn.op_type, exec_fn["op_type"])
 35          sim.set(decoder.exec_fn.funct3, exec_fn["funct3"])
 36          sim.set(decoder.exec_fn.funct7, exec_fn["funct7"])
 37  
 38          return sim.get(decoder.decode_fn)
 39  
 40      def run_test_case(self, decoder_manager: DecoderManager, test_inputs: Sequence[tuple]) -> None:
 41          instructions = decoder_manager.get_instructions()
 42          decoder = decoder_manager.get_decoder(self.gen_params)
 43          op_type_dependent = len(decoder_manager.get_op_types()) != 1
 44  
 45          async def process(sim: TestbenchContext):
 46              for test_input in test_inputs:
 47                  exec_fn = {
 48                      "op_type": test_input[1],
 49                      "funct3": test_input[2] if len(test_input) >= 3 else 0,
 50                      "funct7": test_input[3] if len(test_input) >= 4 else 0,
 51                  }
 52  
 53                  returned = await self.handle_signals(sim, decoder, exec_fn)
 54                  expected = self.expected_results(instructions, op_type_dependent, exec_fn)
 55  
 56                  assert returned == expected
 57  
 58          test_circuit = SimpleTestCircuit(decoder)
 59  
 60          with self.run_simulation(test_circuit) as sim:
 61              sim.add_testbench(process)
 62  
 63      def generate_random_instructions(self) -> Sequence[tuple]:
 64          random.seed(42)
 65  
 66          return [(0, random.randint(0, 10), random.randint(0, 10), random.randint(0, 10)) for i in range(50)]
 67  
 68      def test_1(self) -> None:
 69          # same op type
 70          class DM(DecoderManager):
 71              class Fn(IntFlag):
 72                  INST1 = auto()
 73                  INST2 = auto()
 74                  INST3 = auto()
 75                  INST4 = auto()
 76                  INST5 = auto()
 77  
 78              def get_instructions(self) -> Sequence[tuple]:
 79                  return [
 80                      (self.Fn.INST1, OpType.ARITHMETIC, Funct3.ADD, Funct7.ADD),
 81                      (self.Fn.INST2, OpType.ARITHMETIC, Funct3.AND, Funct7.SUB),
 82                      (self.Fn.INST3, OpType.ARITHMETIC, Funct3.OR, Funct7.ADD),
 83                      (self.Fn.INST4, OpType.ARITHMETIC, Funct3.XOR, Funct7.ADD),
 84                      (self.Fn.INST5, OpType.ARITHMETIC, Funct3.BGEU, Funct7.ADD),
 85                  ]
 86  
 87          decoder_manager = DM()
 88  
 89          test_inputs = list(decoder_manager.get_instructions()) + list(self.generate_random_instructions())
 90  
 91          self.run_test_case(decoder_manager, test_inputs)
 92  
 93      def test_2(self) -> None:
 94          # same op type, different instruction length
 95          class DM(DecoderManager):
 96              class Fn(IntFlag):
 97                  INST1 = auto()
 98                  INST2 = auto()
 99                  INST3 = auto()
100                  INST4 = auto()
101                  INST5 = auto()
102  
103              def get_instructions(self) -> Sequence[tuple]:
104                  return [
105                      (self.Fn.INST1, OpType.ARITHMETIC, Funct3.ADD, Funct7.ADD),
106                      (self.Fn.INST2, OpType.ARITHMETIC, Funct3.AND),
107                      (self.Fn.INST3, OpType.ARITHMETIC, Funct3.OR, Funct7.BEXT),
108                      (self.Fn.INST4, OpType.ARITHMETIC, Funct3.XOR),
109                      (self.Fn.INST5, OpType.ARITHMETIC, Funct3.BGEU, Funct7.BSET),
110                  ]
111  
112          decoder_manager = DM()
113  
114          test_inputs = list(decoder_manager.get_instructions()) + list(self.generate_random_instructions())
115  
116          self.run_test_case(decoder_manager, test_inputs)
117  
118      def test_3(self) -> None:
119          # different op types, different instruction length
120          class DM(DecoderManager):
121              class Fn(IntFlag):
122                  INST1 = auto()
123                  INST2 = auto()
124                  INST3 = auto()
125                  INST4 = auto()
126                  INST5 = auto()
127  
128              def get_instructions(self) -> Sequence[tuple]:
129                  return [
130                      (self.Fn.INST1, OpType.AUIPC, Funct3.ADD, Funct7.ADD),
131                      (self.Fn.INST2, OpType.MUL, Funct3.AND),
132                      (self.Fn.INST3, OpType.ARITHMETIC, Funct3.OR, Funct7.BEXT),
133                      (self.Fn.INST4, OpType.COMPARE),
134                      (self.Fn.INST5, OpType.ARITHMETIC, Funct3.BGEU, Funct7.BSET),
135                  ]
136  
137          decoder_manager = DM()
138  
139          test_inputs = list(decoder_manager.get_instructions()) + list(self.generate_random_instructions())
140  
141          self.run_test_case(decoder_manager, test_inputs)