_utils.py
1 import importlib 2 import inspect 3 import runpy 4 import sys 5 from contextlib import asynccontextmanager 6 from functools import partial 7 from types import ModuleType 8 from typing import AsyncContextManager, AsyncIterator, Callable, TypeVar, cast 9 from unittest.mock import MagicMock 10 11 12 T = TypeVar("T") 13 14 15 def mock_list(size: int) -> list[MagicMock]: 16 return [MagicMock() for _ in range(size)] 17 18 19 def mock_dict(size: int, string_keys: bool = False) -> dict[MagicMock | str, MagicMock]: 20 return {(str(MagicMock()) if string_keys else MagicMock()): MagicMock() for _ in range(size)} 21 22 23 def reload_module(module: ModuleType) -> ModuleType: 24 return importlib.reload(module) 25 26 27 def import_module(name: str | ModuleType) -> ModuleType: 28 if isinstance(name, ModuleType): 29 return import_module(name.__name__) 30 31 old_module = sys.modules.pop(name, None) 32 new_module = importlib.import_module(name) 33 if old_module: 34 sys.modules[name] = old_module 35 return new_module 36 37 38 def run_module(module: ModuleType) -> None: 39 runpy.run_path(inspect.getfile(module), {}, "__main__") 40 41 42 def mock_call_assertions(n: int) -> tuple[list[Callable[[], None]], Callable[[], None]]: 43 events: list[int] = [] 44 45 def assert_calls() -> None: 46 assert events == [*range(n)] 47 48 callbacks = [cast(Callable[[], None], partial(events.append, i)) for i in range(n)] 49 50 return callbacks, assert_calls 51 52 53 def mock_asynccontextmanager( 54 n: int, value: T 55 ) -> tuple[Callable[[], AsyncContextManager[T]], list[Callable[[], None]], Callable[[], None]]: 56 [enter_callback, *callbacks, exit_callback], assert_calls = mock_call_assertions(n + 2) 57 58 async def context_manager() -> AsyncIterator[T]: 59 enter_callback() 60 yield value 61 exit_callback() 62 63 return asynccontextmanager(context_manager), callbacks, assert_calls