cocotb.py
1 from decimal import Decimal 2 import inspect 3 import re 4 import os 5 from typing import Any 6 from collections.abc import Coroutine 7 from dataclasses import dataclass 8 9 import cocotb 10 from cocotb.clock import Clock, Timer 11 from cocotb.handle import ModifiableObject 12 from cocotb.triggers import FallingEdge, Event, RisingEdge, with_timeout 13 from cocotb_bus.bus import Bus 14 from cocotb.result import SimTimeoutError 15 16 from .memory import * 17 from .common import SimulationBackend, SimulationExecutionResult 18 19 from transactron.profiler import CycleProfile, MethodSamples, Profile, ProfileSamples, TransactionSamples 20 from transactron.utils.gen import GenerationInfo 21 22 23 @dataclass 24 class WishboneMasterSignals: 25 adr: Any = 0 26 we: Any = 0 27 sel: Any = 0 28 dat_w: Any = 0 29 30 31 @dataclass 32 class WishboneSlaveSignals: 33 dat_r: Any = 0 34 ack: Any = 0 35 err: Any = 0 36 rty: Any = 0 37 38 39 class WishboneBus(Bus): 40 _signals = ["cyc", "stb", "we", "adr", "dat_r", "dat_w", "ack"] 41 _optional_signals = ["sel", "err", "rty"] 42 43 cyc: ModifiableObject 44 stb: ModifiableObject 45 we: ModifiableObject 46 adr: ModifiableObject 47 dat_r: ModifiableObject 48 dat_w: ModifiableObject 49 ack: ModifiableObject 50 sel: ModifiableObject 51 err: ModifiableObject 52 rty: ModifiableObject 53 54 def __init__(self, entity, name): 55 # case_insensitive is a workaround for cocotb_bus/verilator problem 56 # see https://github.com/cocotb/cocotb/issues/3259 57 super().__init__( 58 entity, name, self._signals, self._optional_signals, bus_separator="__", case_insensitive=False 59 ) 60 61 62 class WishboneSlave: 63 def __init__( 64 self, entity, name: str, clock, model: CoreMemoryModel, is_instr_bus: bool, word_bits: int = 2, delay: int = 0 65 ): 66 self.entity = entity 67 self.name = name 68 self.clock = clock 69 self.model = model 70 self.is_instr_bus = is_instr_bus 71 self.word_size = 2**word_bits 72 self.word_bits = word_bits 73 self.delay = delay 74 self.bus = WishboneBus(entity, name) 75 self.bus.drive(WishboneSlaveSignals()) 76 77 async def start(self): 78 clock_edge_event = FallingEdge(self.clock) 79 80 while True: 81 while not (self.bus.stb.value and self.bus.cyc.value): 82 await clock_edge_event # type: ignore 83 84 sig_m = WishboneMasterSignals() 85 self.bus.sample(sig_m) 86 87 addr = sig_m.adr << self.word_bits 88 89 sig_s = WishboneSlaveSignals() 90 if sig_m.we: 91 resp = self.model.write( 92 WriteRequest( 93 addr=addr, 94 data=sig_m.dat_w, 95 byte_count=self.word_size, 96 byte_sel=sig_m.sel, 97 ) 98 ) 99 else: 100 resp = self.model.read( 101 ReadRequest( 102 addr=addr, 103 byte_count=self.word_size, 104 byte_sel=sig_m.sel, 105 exec=self.is_instr_bus, 106 ) 107 ) 108 sig_s.dat_r = resp.data 109 110 match resp.status: 111 case ReplyStatus.OK: 112 sig_s.ack = 1 113 case ReplyStatus.ERROR: 114 if not self.bus.err: 115 raise ValueError("Bus doesn't support err") 116 sig_s.err = 1 117 case ReplyStatus.RETRY: 118 if not self.bus.rty: 119 raise ValueError("Bus doesn't support rty") 120 sig_s.rty = 1 121 122 for _ in range(self.delay): 123 await clock_edge_event # type: ignore 124 125 self.bus.drive(sig_s) 126 await clock_edge_event # type: ignore 127 self.bus.drive(WishboneSlaveSignals()) 128 129 130 class CocotbSimulation(SimulationBackend): 131 def __init__(self, dut): 132 self.dut = dut 133 self.finish_event = Event() 134 135 try: 136 gen_info_path = os.environ["_COREBLOCKS_GEN_INFO"] 137 except KeyError: 138 raise RuntimeError("No core generation info provided") 139 140 self.gen_info = GenerationInfo.decode(gen_info_path) 141 142 self.log_level = os.environ["__TRANSACTRON_LOG_LEVEL"] 143 self.log_filter = os.environ["__TRANSACTRON_LOG_FILTER"] 144 145 cocotb.logging.getLogger().setLevel(self.log_level) 146 147 def get_cocotb_handle(self, path_components: list[str]) -> ModifiableObject: 148 obj = self.dut 149 # Skip the first component, as it is already referenced in "self.dut" 150 for component in path_components[1:]: 151 try: 152 # As the component may start with '_' character, we need to use '_id' 153 # function instead of 'getattr' - this is required by cocotb. 154 obj = obj._id(component, extended=False) 155 except AttributeError: 156 # Try with escaped or unescaped name 157 if component[0] != "\\" and component[-1] != " ": 158 obj = obj._id("\\" + component + " ", extended=False) 159 elif component[0] == "\\": 160 obj = obj._id(component[1:], extended=False) 161 else: 162 raise 163 164 return obj 165 166 async def profile_handler(self, clock, profile: Profile): 167 clock_edge_event = RisingEdge(clock) 168 169 while True: 170 samples = ProfileSamples() 171 172 for transaction_id, location in self.gen_info.transaction_signals_location.items(): 173 request_val = self.get_cocotb_handle(location.ready) 174 runnable_val = self.get_cocotb_handle(location.runnable) 175 grant_val = self.get_cocotb_handle(location.run) 176 samples.transactions[transaction_id] = TransactionSamples( 177 bool(request_val.value), bool(runnable_val.value), bool(grant_val.value) 178 ) 179 180 for method_id, location in self.gen_info.method_signals_location.items(): 181 run_val = self.get_cocotb_handle(location.run) 182 samples.methods[method_id] = MethodSamples(bool(run_val.value)) 183 184 cprof = CycleProfile.make(samples, self.gen_info.profile_data) 185 profile.cycles.append(cprof) 186 187 await clock_edge_event # type: ignore 188 189 async def logging_handler(self, clock): 190 clock_edge_event = FallingEdge(clock) 191 192 log_level = cocotb.logging.getLogger().level 193 194 logs = [ 195 (rec, self.get_cocotb_handle(rec.trigger_location)) 196 for rec in self.gen_info.logs 197 if rec.level >= log_level and re.search(self.log_filter, rec.logger_name) 198 ] 199 200 while True: 201 for rec, trigger_handle in logs: 202 if not trigger_handle.value: 203 continue 204 205 values: list[int] = [] 206 for field in rec.fields_location: 207 values.append(int(self.get_cocotb_handle(field).value)) 208 209 formatted_msg = rec.format(*values) 210 211 cocotb_log = cocotb.logging.getLogger(rec.logger_name) 212 213 cocotb_log.log( 214 rec.level, 215 "%s:%d] %s", 216 rec.location[0], 217 rec.location[1], 218 formatted_msg, 219 ) 220 221 if rec.level >= cocotb.logging.ERROR: 222 assert False, f"Assertion failed at {rec.location[0], rec.location[1]}: {formatted_msg}" 223 224 await clock_edge_event # type: ignore 225 226 async def run(self, mem_model: CoreMemoryModel, timeout_cycles: int = 5000) -> SimulationExecutionResult: 227 clk = Clock(self.dut.clk, 1, "ns") 228 cocotb.start_soon(clk.start()) 229 230 self.dut.rst.value = 1 231 await Timer(Decimal(1), "ns") 232 self.dut.rst.value = 0 233 234 instr_wb = WishboneSlave(self.dut, "wb_instr", self.dut.clk, mem_model, is_instr_bus=True) 235 cocotb.start_soon(instr_wb.start()) 236 237 data_wb = WishboneSlave(self.dut, "wb_data", self.dut.clk, mem_model, is_instr_bus=False) 238 cocotb.start_soon(data_wb.start()) 239 240 profile = None 241 if "__TRANSACTRON_PROFILE" in os.environ: 242 profile = Profile() 243 profile.transactions_and_methods = self.gen_info.profile_data.transactions_and_methods 244 cocotb.start_soon(self.profile_handler(self.dut.clk, profile)) 245 246 cocotb.start_soon(self.logging_handler(self.dut.clk)) 247 248 success = True 249 try: 250 await with_timeout(self.finish_event.wait(), timeout_cycles, "ns") 251 except SimTimeoutError: 252 success = False 253 254 result = SimulationExecutionResult(success) 255 256 result.profile = profile 257 258 for metric_name, metric_loc in self.gen_info.metrics_location.items(): 259 result.metric_values[metric_name] = {} 260 for reg_name, reg_loc in metric_loc.regs.items(): 261 value = int(self.get_cocotb_handle(reg_loc)) 262 result.metric_values[metric_name][reg_name] = value 263 cocotb.logging.info(f"Metric {metric_name}/{reg_name}={value}") 264 265 return result 266 267 def stop(self): 268 self.finish_event.set() 269 270 271 def _create_test(function, name, mod, *args, **kwargs): 272 async def _my_test(dut): 273 await function(dut, *args, **kwargs) 274 275 _my_test.__name__ = name 276 _my_test.__qualname__ = name 277 _my_test.__module__ = mod.__name__ 278 279 return cocotb.test()(_my_test) # type: ignore 280 281 282 def generate_tests(test_function: Callable[[Any, Any], Coroutine[Any, Any, None]], test_names: list[str]): 283 frm = inspect.stack()[1] 284 mod = inspect.getmodule(frm[0]) 285 286 for test_name in test_names: 287 setattr(mod, test_name, _create_test(test_function, test_name, mod, test_name))