/ test / func_blocks / lsu / test_lsu_atomic_wrapper.py
test_lsu_atomic_wrapper.py
  1  from collections import deque
  2  import random
  3  import pytest
  4  from amaranth import *
  5  from amaranth.utils import ceil_log2
  6  from transactron import Method, TModule
  7  from transactron.lib import Adapter, AdapterTrans
  8  from transactron.testing import MethodMock, SimpleTestCircuit, TestCaseWithSimulator, TestbenchIO, def_method_mock
  9  
 10  from coreblocks.arch.isa_consts import Funct3, Funct7
 11  from coreblocks.arch.optypes import OpType
 12  from coreblocks.func_blocks.fu.lsu.lsu_atomic_wrapper import LSUAtomicWrapper
 13  from coreblocks.func_blocks.interface.func_protocols import FuncUnit
 14  from coreblocks.interface.layouts import FuncUnitLayouts
 15  from coreblocks.params.configurations import test_core_config
 16  from coreblocks.params.genparams import GenParams
 17  
 18  
 19  class FuncUnitMock(FuncUnit, Elaboratable):
 20      def __init__(self, gen_params: GenParams):
 21          layouts = gen_params.get(FuncUnitLayouts)
 22  
 23          self.issue = Method(i=layouts.issue)
 24          self.push_result = Method(i=layouts.push_result)
 25  
 26          self.issue_tb = TestbenchIO(Adapter.create(self.issue))
 27          self.push_result_tb = TestbenchIO(AdapterTrans.create(self.push_result))
 28  
 29      def elaborate(self, platform):
 30          m = TModule()
 31  
 32          m.submodules.issue_tb = self.issue_tb
 33          m.submodules.push_result_tb = self.push_result_tb
 34  
 35          return m
 36  
 37  
 38  class TestLSUAtomicWrapper(TestCaseWithSimulator):
 39      @pytest.fixture(autouse=True)
 40      def setup(self, fixture_initialize_testing_env):
 41          random.seed(1258)
 42          self.inst_cnt = 255
 43          self.gen_params = GenParams(test_core_config.replace(rob_entries_bits=ceil_log2(self.inst_cnt)))
 44          self.lsu = FuncUnitMock(self.gen_params)
 45          self.dut = SimpleTestCircuit(LSUAtomicWrapper(self.gen_params, self.lsu))
 46  
 47          self.mem_cell = 0
 48          self.instr_q = deque()
 49          self.results = {}
 50          self.lsu_res_q = deque()
 51          self.lsu_except_q = deque()
 52  
 53          self.generate_instrs(self.inst_cnt)
 54  
 55      @def_method_mock(lambda self: self.lsu.issue_tb, enable=lambda _: random.random() < 0.9)
 56      def lsu_issue_mock(self, arg):
 57          @MethodMock.effect
 58          def _():
 59              res = 0
 60              addr = arg["s1_val"] + arg["imm"]
 61  
 62              exc = self.lsu_except_q[0]
 63              self.lsu_except_q.popleft()
 64              assert addr == exc["addr"]
 65  
 66              if not exc["exception"]:
 67                  match arg["exec_fn"]["op_type"]:
 68                      case OpType.STORE:
 69                          self.mem_cell = arg["s2_val"]
 70                      case OpType.LOAD:
 71                          res = self.mem_cell
 72                      case _:
 73                          assert False
 74  
 75              self.lsu_res_q.append(
 76                  {"rob_id": arg["rob_id"], "result": res, "rp_dst": arg["rp_dst"], "exception": exc["exception"]}
 77              )
 78  
 79      @def_method_mock(
 80          lambda self: self.lsu.push_result_tb, enable=lambda self: random.random() < 0.9 and len(self.lsu_res_q) > 0
 81      )
 82      def lsu_accept_mock(self, arg):
 83          res = self.lsu_res_q[0]
 84  
 85          @MethodMock.effect
 86          def _():
 87              self.lsu_res_q.popleft()
 88  
 89          return res
 90  
 91      def generate_instrs(self, cnt):
 92          generation_mem_cell = 0
 93          generation_reservation_valid = 0
 94          for i in range(cnt):
 95              optype = random.choice([OpType.LOAD, OpType.STORE, OpType.ATOMIC_MEMORY_OP, OpType.ATOMIC_LR_SC])
 96              funct7 = 0
 97  
 98              imm = random.randint(0, 1)
 99              s1_val = random.randrange(0, 2**self.gen_params.isa.xlen - 1)
100              s2_val = random.randrange(0, 2**self.gen_params.isa.xlen)
101              rp_dst = random.randrange(0, 2**self.gen_params.phys_regs_bits)
102  
103              exception = 0
104              result = 0
105  
106              if optype == OpType.ATOMIC_MEMORY_OP:
107                  funct7 = random.choice(
108                      [
109                          Funct7.AMOSWAP,
110                          Funct7.AMOADD,
111                          Funct7.AMOMAXU,
112                          Funct7.AMOMIN,
113                          Funct7.AMOXOR,
114                          Funct7.AMOOR,
115                          Funct7.AMOAND,
116                          Funct7.AMOMAX,
117                          Funct7.AMOMINU,
118                      ]
119                  )
120  
121                  exception = random.random() < 0.3
122                  exception_on_load = exception and random.random() < 0.5
123                  self.lsu_except_q.append({"addr": s1_val, "exception": exception_on_load})
124  
125                  if not exception:
126                      result = generation_mem_cell
127  
128                      def twos(x):
129                          if x & (1 << (self.gen_params.isa.xlen - 1)):
130                              x ^= (1 << self.gen_params.isa.xlen) - 1
131                              x += 1
132                              x *= -1
133                          return x
134  
135                      match funct7:
136                          case Funct7.AMOSWAP:
137                              generation_mem_cell = s2_val
138                          case Funct7.AMOADD:
139                              generation_mem_cell += s2_val
140                              generation_mem_cell %= 2**self.gen_params.isa.xlen
141                          case Funct7.AMOAND:
142                              generation_mem_cell &= s2_val
143                          case Funct7.AMOOR:
144                              generation_mem_cell |= s2_val
145                          case Funct7.AMOXOR:
146                              generation_mem_cell ^= s2_val
147                          case Funct7.AMOMIN:
148                              generation_mem_cell = (
149                                  generation_mem_cell if twos(generation_mem_cell) < twos(s2_val) else s2_val
150                              )
151                          case Funct7.AMOMAX:
152                              generation_mem_cell = (
153                                  generation_mem_cell if twos(generation_mem_cell) > twos(s2_val) else s2_val
154                              )
155                          case Funct7.AMOMINU:
156                              generation_mem_cell = min(generation_mem_cell, s2_val)
157                          case Funct7.AMOMAXU:
158                              generation_mem_cell = max(generation_mem_cell, s2_val)
159  
160                  if not exception_on_load:
161                      self.lsu_except_q.append({"addr": s1_val, "exception": exception})
162              elif optype == OpType.ATOMIC_LR_SC:
163                  is_load = random.random() < 0.5
164                  exception = random.random() < 0.3
165                  sc_fail = False
166                  if is_load:
167                      funct7 = Funct7.LR
168                      generation_reservation_valid = not exception
169                      result = generation_mem_cell if not exception else 0
170                  else:
171                      funct7 = Funct7.SC
172                      if generation_reservation_valid:
173                          if not exception:
174                              generation_mem_cell = s2_val
175                      else:
176                          sc_fail = True
177                          exception = 0
178  
179                      result = 0 if generation_reservation_valid or exception else 1
180                      generation_reservation_valid = 0
181  
182                  if not sc_fail:
183                      self.lsu_except_q.append({"addr": s1_val, "exception": exception})
184  
185              elif optype == OpType.LOAD:
186                  result = generation_mem_cell
187                  self.lsu_except_q.append({"addr": s1_val + imm, "exception": 0})
188              elif optype == OpType.STORE:
189                  generation_mem_cell = s2_val
190                  result = 0
191                  self.lsu_except_q.append({"addr": s1_val + imm, "exception": 0})
192  
193              exec_fn = {"op_type": optype, "funct3": Funct3.W, "funct7": funct7}
194              rob_id = i
195              instr = {
196                  "rp_dst": rp_dst,
197                  "rob_id": rob_id,
198                  "exec_fn": exec_fn,
199                  "s1_val": s1_val,
200                  "s2_val": s2_val,
201                  "imm": imm,
202                  "pc": 0,
203              }
204              self.instr_q.append(instr)
205              self.results[rob_id] = {"rob_id": rob_id, "rp_dst": rp_dst, "result": result, "exception": exception}
206  
207      async def issue_process(self, sim):
208          while self.instr_q:
209              await self.dut.issue.call(sim, self.instr_q[0])
210              self.instr_q.popleft()
211              await self.random_wait_geom(sim, 0.9)
212  
213      async def accept_process(self, sim):
214          for _ in range(self.inst_cnt):
215              res = await self.dut.push_result.call(sim)
216              expected = self.results[res["rob_id"]]
217              assert res["rp_dst"] == expected["rp_dst"]
218              assert res["exception"] == expected["exception"]
219              assert res["result"] == expected["result"]
220  
221              await self.random_wait_geom(sim, 0.9)
222  
223      def test_randomized(self):
224          with self.run_simulation(self.dut, max_cycles=700) as sim:
225              sim.add_testbench(self.issue_process)
226              sim.add_testbench(self.accept_process)