/ fastapi-template-users / tests / test_database.py
test_database.py
  1  from contextvars import ContextVar
  2  from typing import Any, AsyncIterator
  3  from unittest.mock import AsyncMock, MagicMock, call
  4  
  5  import pytest
  6  from _pytest.monkeypatch import MonkeyPatch
  7  from pytest_mock import MockerFixture
  8  from sqlalchemy.orm import DeclarativeMeta, registry
  9  
 10  from ._utils import import_module, mock_asynccontextmanager, mock_dict, mock_list
 11  from api import database
 12  from api.settings import settings
 13  
 14  
 15  @pytest.mark.parametrize(
 16      "entity,args,expected",
 17      [
 18          ("my entity", [], call.sa_select("my entity")),
 19          (
 20              "asdf",
 21              ["col1", "col2"],
 22              call.sa_select("asdf").options(call.selectinload("col1"), call.selectinload("col2")),
 23          ),
 24          (
 25              "FooBar42",
 26              [(1, "foo"), ["bar", 2]],
 27              call.sa_select("FooBar42").options(
 28                  call.selectinload(1).selectinload("foo"), call.selectinload("bar").selectinload(2)
 29              ),
 30          ),
 31          (
 32              "qwertz",
 33              ["xyz", ["A", "B", "C", 42], "foo", (42, [1337])],
 34              call.sa_select("qwertz").options(
 35                  call.selectinload("xyz"),
 36                  call.selectinload("A").selectinload("B").selectinload("C").selectinload(42),
 37                  call.selectinload("foo"),
 38                  call.selectinload(42).selectinload([1337]),
 39              ),
 40          ),
 41      ],
 42  )
 43  async def test__select(mocker: MockerFixture, entity: Any, args: list[Any], expected: Any) -> None:
 44      sa_select_patch = mocker.patch("api.database.database.sa_select")
 45      selectinload_patch = mocker.patch("api.database.database.selectinload")
 46      sa_select_patch.side_effect = call.sa_select
 47      selectinload_patch.side_effect = call.selectinload
 48  
 49      assert database.select(entity, *args) == expected
 50  
 51  
 52  def test__filter_by(mocker: MockerFixture) -> None:
 53      select_patch = mocker.patch("api.database.database.select")
 54  
 55      cls = MagicMock()
 56      args = mock_list(5)
 57      kwargs = mock_dict(5, string_keys=True)
 58  
 59      result = database.filter_by(cls, *args, **kwargs)
 60  
 61      select_patch.assert_called_once_with(cls, *args)
 62      select_patch().filter_by.assert_called_once_with(**kwargs)
 63      assert result == select_patch().filter_by()
 64  
 65  
 66  async def test__exists(mocker: MockerFixture) -> None:
 67      sa_exists_patch = mocker.patch("api.database.database.sa_exists")
 68  
 69      args = mock_list(5)
 70      kwargs = mock_dict(5, True)
 71  
 72      result = database.database.exists(*args, **kwargs)
 73  
 74      sa_exists_patch.assert_called_once_with(*args, **kwargs)
 75      assert result == sa_exists_patch()
 76  
 77  
 78  async def test__delete(mocker: MockerFixture) -> None:
 79      sa_delete_patch = mocker.patch("api.database.database.sa_delete")
 80  
 81      table = MagicMock()
 82  
 83      result = database.database.delete(table)
 84  
 85      sa_delete_patch.assert_called_once_with(table)
 86      assert result == sa_delete_patch()
 87  
 88  
 89  async def test__base() -> None:
 90      assert isinstance(database.Base, DeclarativeMeta)
 91      assert database.Base.__abstract__ is True
 92      assert isinstance(database.Base.registry, registry)
 93      assert database.Base.registry.metadata == database.Base.metadata
 94  
 95  
 96  async def test__base_constructor() -> None:
 97      base = MagicMock()
 98      kwargs = mock_dict(5, string_keys=True)
 99  
100      database.Base.__init__(base, **kwargs)
101  
102      base.registry.constructor.assert_called_once_with(base, **kwargs)
103  
104  
105  async def test__constructor(mocker: MockerFixture) -> None:
106      create_async_engine_patch = mocker.patch("api.database.database.create_async_engine")
107  
108      url = MagicMock()
109      kwargs = mock_dict(5, string_keys=True)
110  
111      result = database.database.DB(url, **kwargs)
112  
113      create_async_engine_patch.assert_called_once_with(url, **kwargs)
114      assert result.engine == create_async_engine_patch()
115  
116      assert isinstance(result._session, ContextVar)
117      assert result._session.name == "session"
118      assert result._session.get() is None
119  
120      assert isinstance(result._close_event, ContextVar)
121      assert result._close_event.name == "close_event"
122      assert result._close_event.get() is None
123  
124  
125  async def test__create_tables(mocker: MockerFixture) -> None:
126      base_patch = mocker.patch("api.database.database.Base")
127  
128      db = MagicMock()
129  
130      async def run_sync(coro: Any) -> None:
131          assert coro == base_patch.metadata.create_all
132          func_callback()
133  
134      conn = MagicMock()
135      conn.run_sync = run_sync
136      db.engine.begin, [func_callback], assert_calls = mock_asynccontextmanager(1, conn)
137  
138      await database.database.DB.create_tables(db)
139  
140      assert_calls()
141  
142  
143  async def test__add() -> None:
144      db = MagicMock()
145      obj = MagicMock()
146  
147      result = await database.database.DB.add(db, obj)
148  
149      db.session.add.assert_called_once_with(obj)
150      assert obj == result
151  
152  
153  async def test__db__delete() -> None:
154      db = AsyncMock()
155      obj = MagicMock()
156  
157      result = await database.database.DB.delete(db, obj)
158  
159      db.session.delete.assert_called_once_with(obj)
160      assert obj == result
161  
162  
163  async def test__exec() -> None:
164      db = AsyncMock()
165      statement = MagicMock()
166  
167      result = await database.database.DB.exec(db, statement)
168  
169      db.session.execute.assert_called_once_with(statement)
170      assert result == await db.session.execute()
171  
172  
173  async def test__stream() -> None:
174      db = AsyncMock()
175      statement = MagicMock()
176      db.session.stream.return_value = MagicMock()
177  
178      result = await database.database.DB.stream(db, statement)
179  
180      db.session.stream.assert_called_once_with(statement)
181      (await db.session.stream()).scalars.assert_called_once_with()
182      assert result == (await db.session.stream()).scalars()
183  
184  
185  async def test__all() -> None:
186      db = AsyncMock()
187      statement = MagicMock()
188      expected = mock_list(5)
189  
190      async def async_iterator() -> AsyncIterator[Any]:
191          for x in expected:
192              yield x
193  
194      db.stream.return_value = async_iterator()
195  
196      result = await database.database.DB.all(db, statement)
197  
198      db.stream.assert_called_once_with(statement)
199      assert result == expected
200  
201  
202  async def test__first() -> None:
203      db = AsyncMock()
204      statement = MagicMock()
205      db.exec.return_value = MagicMock()
206  
207      result = await database.database.DB.first(db, statement)
208  
209      db.exec.assert_called_once_with(statement)
210      (await db.exec()).scalar.assert_called_once_with()
211      assert result == (await db.exec()).scalar()
212  
213  
214  async def test__db__exists(mocker: MockerFixture) -> None:
215      exists_patch = mocker.patch("api.database.database.exists")
216  
217      db = AsyncMock()
218      args = mock_list(5)
219      kwargs = mock_dict(5, True)
220  
221      result = await database.database.DB.exists(db, *args, **kwargs)
222  
223      exists_patch.assert_called_once_with(*args, **kwargs)
224      exists_patch().select.assert_called_once_with()
225      db.first.assert_called_once_with(exists_patch().select())
226      assert result == await db.first(exists_patch().select())
227  
228  
229  async def test__db__count(mocker: MockerFixture) -> None:
230      count_patch = mocker.patch("api.database.database.count")
231      select_patch = mocker.patch("api.database.database.select")
232  
233      db = AsyncMock()
234      arg = MagicMock()
235  
236      result = await database.database.DB.count(db, arg)
237  
238      count_patch.assert_called_once_with()
239      select_patch.assert_called_once_with(count_patch())
240      arg.subquery.assert_called_once_with()
241      select_patch().select_from.assert_called_once_with(arg.subquery())
242      db.first.assert_called_once_with(select_patch().select_from())
243      assert result == await db.first()
244  
245  
246  async def test__get(mocker: MockerFixture) -> None:
247      filter_by_patch = mocker.patch("api.database.database.filter_by")
248  
249      db = AsyncMock()
250      args = mock_list(5)
251      kwargs = mock_dict(5, True)
252  
253      result = await database.database.DB.get(db, *args, **kwargs)  # type: ignore
254  
255      filter_by_patch.assert_called_once_with(*args, **kwargs)
256      db.first.assert_called_once_with(filter_by_patch())
257      assert result == await db.first()
258  
259  
260  async def test__commit__no_session() -> None:
261      db = MagicMock()
262      db._session.get.return_value = None
263      db.session = AsyncMock()
264  
265      await database.database.DB.commit(db)
266  
267      db._session.get.assert_called_once_with()
268      db.session.commit.assert_not_called()
269  
270  
271  async def test__commit__with_session() -> None:
272      db = MagicMock()
273      session = db._session.get.return_value = db.session = MagicMock()
274      session.commit = AsyncMock()
275  
276      await database.database.DB.commit(db)
277  
278      db._session.get.assert_called_once_with()
279      session.commit.assert_called_once_with()
280  
281  
282  async def test__close__no_session() -> None:
283      db = MagicMock()
284      db._session.get.return_value = None
285      db.session = AsyncMock()
286  
287      await database.database.DB.close(db)
288  
289      db._session.get.assert_called_once_with()
290      db.session.close.assert_not_called()
291      db._close_event.get().set.assert_not_called()
292  
293  
294  async def test__close__with_session_no_close_event() -> None:
295      db = MagicMock()
296      session = db._session.get.return_value = db.session = MagicMock()
297      session.close = AsyncMock()
298      db._close_event.get.return_value = None
299  
300      await database.database.DB.close(db)
301  
302      db._session.get.assert_called_once_with()
303      session.close.assert_called_once_with()
304      db._close_event.get.assert_called_once_with()
305  
306  
307  async def test__close__with_session() -> None:
308      db = MagicMock()
309      session = db._session.get.return_value = db.session = MagicMock()
310      session.close = AsyncMock()
311  
312      await database.database.DB.close(db)
313  
314      db._session.get.assert_called_once_with()
315      session.close.assert_called_once_with()
316      db._close_event.get.assert_called_once_with()
317      db._close_event.get().set.assert_called_once_with()
318  
319  
320  async def test__create_session(mocker: MockerFixture) -> None:
321      async_session_patch = mocker.patch("api.database.database.AsyncSession")
322      event_patch = mocker.patch("api.database.database.Event")
323  
324      db = MagicMock()
325  
326      result = database.database.DB.create_session(db)
327  
328      async_session_patch.assert_called_once_with(db.engine)
329      db._session.set.assert_called_with(async_session_patch())
330      event_patch.assert_called_once_with()
331      db._close_event.set.assert_called_with(event_patch())
332      assert result == async_session_patch()
333  
334  
335  async def test__session() -> None:
336      db = MagicMock()
337  
338      result = database.database.DB.session.fget(db)  # type: ignore
339  
340      db._session.get.assert_called_once_with()
341      assert result == db._session.get()
342  
343  
344  async def test__wait_for_close_event__not_set() -> None:
345      db = MagicMock()
346      db._close_event.get.return_value = None
347  
348      await database.database.DB.wait_for_close_event(db)
349  
350      db._close_event.get.assert_called_once_with()
351  
352  
353  async def test__wait_for_close_event() -> None:
354      db = MagicMock()
355      close_event = db._close_event.get.return_value = MagicMock()
356      close_event.wait = AsyncMock()
357  
358      await database.database.DB.wait_for_close_event(db)
359  
360      db._close_event.get.assert_called_once_with()
361      close_event.wait.assert_called_once_with()
362  
363  
364  async def test__get_database(mocker: MockerFixture, monkeypatch: MonkeyPatch) -> None:
365      db_patch = mocker.patch("api.database.database.DB")
366  
367      monkeypatch.setattr(settings, "database_url", url_patch := MagicMock())
368      monkeypatch.setattr(settings, "pool_recycle", pool_recycle_patch := MagicMock())
369      monkeypatch.setattr(settings, "pool_size", pool_size_patch := MagicMock())
370      monkeypatch.setattr(settings, "max_overflow", max_overflow_patch := MagicMock())
371      monkeypatch.setattr(settings, "sql_show_statements", sql_show_statements_patch := MagicMock())
372  
373      result = database.database.get_database()
374  
375      db_patch.assert_called_once_with(
376          url=url_patch,
377          pool_pre_ping=True,
378          pool_recycle=pool_recycle_patch,
379          pool_size=pool_size_patch,
380          max_overflow=max_overflow_patch,
381          echo=sql_show_statements_patch,
382      )
383      assert result == db_patch()
384  
385  
386  async def test__db_context(mocker: MockerFixture) -> None:
387      db_patch = mocker.patch("api.database.db")
388  
389      db_patch.commit = AsyncMock()
390      db_patch.close = AsyncMock()
391      db_patch.close.side_effect = lambda: db_patch.commit.assert_called_once_with()
392  
393      async with database.db_context():
394          db_patch.create_session.assert_called_once_with()
395  
396      db_patch.close.assert_called_once_with()
397  
398  
399  async def test__db_wrapper(mocker: MockerFixture) -> None:
400      db_context_patch = mocker.patch("api.database.db_context")
401      db_context_patch.side_effect, [func_callback], assert_calls = mock_asynccontextmanager(1, None)
402  
403      args = mock_list(5)
404      kwargs = mock_dict(5, True)
405      expected = MagicMock()
406  
407      @database.db_wrapper
408      async def test(*_args: Any, **_kwargs: Any) -> Any:
409          assert args == list(_args)
410          assert kwargs == _kwargs
411          func_callback()
412          return expected
413  
414      result = await test(*args, **kwargs)
415  
416      assert result == expected
417      db_context_patch.assert_called_once_with()
418      assert_calls()
419      assert test.__name__ == "test"
420  
421  
422  async def test__db(mocker: MockerFixture) -> None:
423      get_database_mock = mocker.patch("api.database.database.get_database")
424  
425      db = import_module("api.database")
426  
427      get_database_mock.assert_called_once_with()
428      assert get_database_mock() == db.db