/ test / regression / cocotb.py
cocotb.py
  1  from decimal import Decimal
  2  import inspect
  3  import re
  4  import os
  5  from typing import Any
  6  from collections.abc import Coroutine
  7  from dataclasses import dataclass
  8  
  9  import cocotb
 10  from cocotb.clock import Clock, Timer
 11  from cocotb.handle import ModifiableObject
 12  from cocotb.triggers import FallingEdge, Event, RisingEdge, with_timeout
 13  from cocotb_bus.bus import Bus
 14  from cocotb.result import SimTimeoutError
 15  
 16  from .memory import *
 17  from .common import SimulationBackend, SimulationExecutionResult
 18  
 19  from transactron.profiler import CycleProfile, MethodSamples, Profile, ProfileSamples, TransactionSamples
 20  from transactron.utils.gen import GenerationInfo
 21  
 22  
 23  @dataclass
 24  class WishboneMasterSignals:
 25      adr: Any = 0
 26      we: Any = 0
 27      sel: Any = 0
 28      dat_w: Any = 0
 29  
 30  
 31  @dataclass
 32  class WishboneSlaveSignals:
 33      dat_r: Any = 0
 34      ack: Any = 0
 35      err: Any = 0
 36      rty: Any = 0
 37  
 38  
 39  class WishboneBus(Bus):
 40      _signals = ["cyc", "stb", "we", "adr", "dat_r", "dat_w", "ack"]
 41      _optional_signals = ["sel", "err", "rty"]
 42  
 43      cyc: ModifiableObject
 44      stb: ModifiableObject
 45      we: ModifiableObject
 46      adr: ModifiableObject
 47      dat_r: ModifiableObject
 48      dat_w: ModifiableObject
 49      ack: ModifiableObject
 50      sel: ModifiableObject
 51      err: ModifiableObject
 52      rty: ModifiableObject
 53  
 54      def __init__(self, entity, name):
 55          # case_insensitive is a workaround for cocotb_bus/verilator problem
 56          # see https://github.com/cocotb/cocotb/issues/3259
 57          super().__init__(
 58              entity, name, self._signals, self._optional_signals, bus_separator="__", case_insensitive=False
 59          )
 60  
 61  
 62  class WishboneSlave:
 63      def __init__(
 64          self, entity, name: str, clock, model: CoreMemoryModel, is_instr_bus: bool, word_bits: int = 2, delay: int = 0
 65      ):
 66          self.entity = entity
 67          self.name = name
 68          self.clock = clock
 69          self.model = model
 70          self.is_instr_bus = is_instr_bus
 71          self.word_size = 2**word_bits
 72          self.word_bits = word_bits
 73          self.delay = delay
 74          self.bus = WishboneBus(entity, name)
 75          self.bus.drive(WishboneSlaveSignals())
 76  
 77      async def start(self):
 78          clock_edge_event = FallingEdge(self.clock)
 79  
 80          while True:
 81              while not (self.bus.stb.value and self.bus.cyc.value):
 82                  await clock_edge_event  # type: ignore
 83  
 84              sig_m = WishboneMasterSignals()
 85              self.bus.sample(sig_m)
 86  
 87              addr = sig_m.adr << self.word_bits
 88  
 89              sig_s = WishboneSlaveSignals()
 90              if sig_m.we:
 91                  resp = self.model.write(
 92                      WriteRequest(
 93                          addr=addr,
 94                          data=sig_m.dat_w,
 95                          byte_count=self.word_size,
 96                          byte_sel=sig_m.sel,
 97                      )
 98                  )
 99              else:
100                  resp = self.model.read(
101                      ReadRequest(
102                          addr=addr,
103                          byte_count=self.word_size,
104                          byte_sel=sig_m.sel,
105                          exec=self.is_instr_bus,
106                      )
107                  )
108                  sig_s.dat_r = resp.data
109  
110              match resp.status:
111                  case ReplyStatus.OK:
112                      sig_s.ack = 1
113                  case ReplyStatus.ERROR:
114                      if not self.bus.err:
115                          raise ValueError("Bus doesn't support err")
116                      sig_s.err = 1
117                  case ReplyStatus.RETRY:
118                      if not self.bus.rty:
119                          raise ValueError("Bus doesn't support rty")
120                      sig_s.rty = 1
121  
122              for _ in range(self.delay):
123                  await clock_edge_event  # type: ignore
124  
125              self.bus.drive(sig_s)
126              await clock_edge_event  # type: ignore
127              self.bus.drive(WishboneSlaveSignals())
128  
129  
130  class CocotbSimulation(SimulationBackend):
131      def __init__(self, dut):
132          self.dut = dut
133          self.finish_event = Event()
134  
135          try:
136              gen_info_path = os.environ["_COREBLOCKS_GEN_INFO"]
137          except KeyError:
138              raise RuntimeError("No core generation info provided")
139  
140          self.gen_info = GenerationInfo.decode(gen_info_path)
141  
142          self.log_level = os.environ["__TRANSACTRON_LOG_LEVEL"]
143          self.log_filter = os.environ["__TRANSACTRON_LOG_FILTER"]
144  
145          cocotb.logging.getLogger().setLevel(self.log_level)
146  
147      def get_cocotb_handle(self, path_components: list[str]) -> ModifiableObject:
148          obj = self.dut
149          # Skip the first component, as it is already referenced in "self.dut"
150          for component in path_components[1:]:
151              try:
152                  # As the component may start with '_' character, we need to use '_id'
153                  # function instead of 'getattr' - this is required by cocotb.
154                  obj = obj._id(component, extended=False)
155              except AttributeError:
156                  # Try with escaped or unescaped name
157                  if component[0] != "\\" and component[-1] != " ":
158                      obj = obj._id("\\" + component + " ", extended=False)
159                  elif component[0] == "\\":
160                      obj = obj._id(component[1:], extended=False)
161                  else:
162                      raise
163  
164          return obj
165  
166      async def profile_handler(self, clock, profile: Profile):
167          clock_edge_event = RisingEdge(clock)
168  
169          while True:
170              samples = ProfileSamples()
171  
172              for transaction_id, location in self.gen_info.transaction_signals_location.items():
173                  request_val = self.get_cocotb_handle(location.ready)
174                  runnable_val = self.get_cocotb_handle(location.runnable)
175                  grant_val = self.get_cocotb_handle(location.run)
176                  samples.transactions[transaction_id] = TransactionSamples(
177                      bool(request_val.value), bool(runnable_val.value), bool(grant_val.value)
178                  )
179  
180              for method_id, location in self.gen_info.method_signals_location.items():
181                  run_val = self.get_cocotb_handle(location.run)
182                  samples.methods[method_id] = MethodSamples(bool(run_val.value))
183  
184              cprof = CycleProfile.make(samples, self.gen_info.profile_data)
185              profile.cycles.append(cprof)
186  
187              await clock_edge_event  # type: ignore
188  
189      async def logging_handler(self, clock):
190          clock_edge_event = FallingEdge(clock)
191  
192          log_level = cocotb.logging.getLogger().level
193  
194          logs = [
195              (rec, self.get_cocotb_handle(rec.trigger_location))
196              for rec in self.gen_info.logs
197              if rec.level >= log_level and re.search(self.log_filter, rec.logger_name)
198          ]
199  
200          while True:
201              for rec, trigger_handle in logs:
202                  if not trigger_handle.value:
203                      continue
204  
205                  values: list[int] = []
206                  for field in rec.fields_location:
207                      values.append(int(self.get_cocotb_handle(field).value))
208  
209                  formatted_msg = rec.format(*values)
210  
211                  cocotb_log = cocotb.logging.getLogger(rec.logger_name)
212  
213                  cocotb_log.log(
214                      rec.level,
215                      "%s:%d] %s",
216                      rec.location[0],
217                      rec.location[1],
218                      formatted_msg,
219                  )
220  
221                  if rec.level >= cocotb.logging.ERROR:
222                      assert False, f"Assertion failed at {rec.location[0], rec.location[1]}: {formatted_msg}"
223  
224              await clock_edge_event  # type: ignore
225  
226      async def run(self, mem_model: CoreMemoryModel, timeout_cycles: int = 5000) -> SimulationExecutionResult:
227          clk = Clock(self.dut.clk, 1, "ns")
228          cocotb.start_soon(clk.start())
229  
230          self.dut.rst.value = 1
231          await Timer(Decimal(1), "ns")
232          self.dut.rst.value = 0
233  
234          instr_wb = WishboneSlave(self.dut, "wb_instr", self.dut.clk, mem_model, is_instr_bus=True)
235          cocotb.start_soon(instr_wb.start())
236  
237          data_wb = WishboneSlave(self.dut, "wb_data", self.dut.clk, mem_model, is_instr_bus=False)
238          cocotb.start_soon(data_wb.start())
239  
240          profile = None
241          if "__TRANSACTRON_PROFILE" in os.environ:
242              profile = Profile()
243              profile.transactions_and_methods = self.gen_info.profile_data.transactions_and_methods
244              cocotb.start_soon(self.profile_handler(self.dut.clk, profile))
245  
246          cocotb.start_soon(self.logging_handler(self.dut.clk))
247  
248          success = True
249          try:
250              await with_timeout(self.finish_event.wait(), timeout_cycles, "ns")
251          except SimTimeoutError:
252              success = False
253  
254          result = SimulationExecutionResult(success)
255  
256          result.profile = profile
257  
258          for metric_name, metric_loc in self.gen_info.metrics_location.items():
259              result.metric_values[metric_name] = {}
260              for reg_name, reg_loc in metric_loc.regs.items():
261                  value = int(self.get_cocotb_handle(reg_loc))
262                  result.metric_values[metric_name][reg_name] = value
263                  cocotb.logging.info(f"Metric {metric_name}/{reg_name}={value}")
264  
265          return result
266  
267      def stop(self):
268          self.finish_event.set()
269  
270  
271  def _create_test(function, name, mod, *args, **kwargs):
272      async def _my_test(dut):
273          await function(dut, *args, **kwargs)
274  
275      _my_test.__name__ = name
276      _my_test.__qualname__ = name
277      _my_test.__module__ = mod.__name__
278  
279      return cocotb.test()(_my_test)  # type: ignore
280  
281  
282  def generate_tests(test_function: Callable[[Any, Any], Coroutine[Any, Any, None]], test_names: list[str]):
283      frm = inspect.stack()[1]
284      mod = inspect.getmodule(frm[0])
285  
286      for test_name in test_names:
287          setattr(mod, test_name, _create_test(test_function, test_name, mod, test_name))