_utils.py
 1  import importlib
 2  import inspect
 3  import runpy
 4  import sys
 5  from contextlib import asynccontextmanager
 6  from functools import partial
 7  from types import ModuleType
 8  from typing import AsyncContextManager, AsyncIterator, Callable, TypeVar, cast
 9  from unittest.mock import MagicMock
10  
11  
12  T = TypeVar("T")
13  
14  
15  def mock_list(size: int) -> list[MagicMock]:
16      return [MagicMock() for _ in range(size)]
17  
18  
19  def mock_dict(size: int, string_keys: bool = False) -> dict[MagicMock | str, MagicMock]:
20      return {(str(MagicMock()) if string_keys else MagicMock()): MagicMock() for _ in range(size)}
21  
22  
23  def reload_module(module: ModuleType) -> ModuleType:
24      return importlib.reload(module)
25  
26  
27  def import_module(name: str | ModuleType) -> ModuleType:
28      if isinstance(name, ModuleType):
29          return import_module(name.__name__)
30  
31      old_module = sys.modules.pop(name, None)
32      new_module = importlib.import_module(name)
33      if old_module:
34          sys.modules[name] = old_module
35      return new_module
36  
37  
38  def run_module(module: ModuleType) -> None:
39      runpy.run_path(inspect.getfile(module), {}, "__main__")
40  
41  
42  def mock_call_assertions(n: int) -> tuple[list[Callable[[], None]], Callable[[], None]]:
43      events: list[int] = []
44  
45      def assert_calls() -> None:
46          assert events == [*range(n)]
47  
48      callbacks = [cast(Callable[[], None], partial(events.append, i)) for i in range(n)]
49  
50      return callbacks, assert_calls
51  
52  
53  def mock_asynccontextmanager(
54      n: int, value: T
55  ) -> tuple[Callable[[], AsyncContextManager[T]], list[Callable[[], None]], Callable[[], None]]:
56      [enter_callback, *callbacks, exit_callback], assert_calls = mock_call_assertions(n + 2)
57  
58      async def context_manager() -> AsyncIterator[T]:
59          enter_callback()
60          yield value
61          exit_callback()
62  
63      return asynccontextmanager(context_manager), callbacks, assert_calls