test_database.py
1 from contextvars import ContextVar 2 from typing import Any, AsyncIterator 3 from unittest.mock import AsyncMock, MagicMock, call 4 5 import pytest 6 from _pytest.monkeypatch import MonkeyPatch 7 from pytest_mock import MockerFixture 8 from sqlalchemy.orm import DeclarativeMeta, registry 9 10 from ._utils import import_module, mock_asynccontextmanager, mock_dict, mock_list 11 from api import database 12 from api.settings import settings 13 14 15 @pytest.mark.parametrize( 16 "entity,args,expected", 17 [ 18 ("my entity", [], call.sa_select("my entity")), 19 ( 20 "asdf", 21 ["col1", "col2"], 22 call.sa_select("asdf").options(call.selectinload("col1"), call.selectinload("col2")), 23 ), 24 ( 25 "FooBar42", 26 [(1, "foo"), ["bar", 2]], 27 call.sa_select("FooBar42").options( 28 call.selectinload(1).selectinload("foo"), call.selectinload("bar").selectinload(2) 29 ), 30 ), 31 ( 32 "qwertz", 33 ["xyz", ["A", "B", "C", 42], "foo", (42, [1337])], 34 call.sa_select("qwertz").options( 35 call.selectinload("xyz"), 36 call.selectinload("A").selectinload("B").selectinload("C").selectinload(42), 37 call.selectinload("foo"), 38 call.selectinload(42).selectinload([1337]), 39 ), 40 ), 41 ], 42 ) 43 async def test__select(mocker: MockerFixture, entity: Any, args: list[Any], expected: Any) -> None: 44 sa_select_patch = mocker.patch("api.database.database.sa_select") 45 selectinload_patch = mocker.patch("api.database.database.selectinload") 46 sa_select_patch.side_effect = call.sa_select 47 selectinload_patch.side_effect = call.selectinload 48 49 assert database.select(entity, *args) == expected 50 51 52 def test__filter_by(mocker: MockerFixture) -> None: 53 select_patch = mocker.patch("api.database.database.select") 54 55 cls = MagicMock() 56 args = mock_list(5) 57 kwargs = mock_dict(5, string_keys=True) 58 59 result = database.filter_by(cls, *args, **kwargs) 60 61 select_patch.assert_called_once_with(cls, *args) 62 select_patch().filter_by.assert_called_once_with(**kwargs) 63 assert result == select_patch().filter_by() 64 65 66 async def test__exists(mocker: MockerFixture) -> None: 67 sa_exists_patch = mocker.patch("api.database.database.sa_exists") 68 69 args = mock_list(5) 70 kwargs = mock_dict(5, True) 71 72 result = database.database.exists(*args, **kwargs) 73 74 sa_exists_patch.assert_called_once_with(*args, **kwargs) 75 assert result == sa_exists_patch() 76 77 78 async def test__delete(mocker: MockerFixture) -> None: 79 sa_delete_patch = mocker.patch("api.database.database.sa_delete") 80 81 table = MagicMock() 82 83 result = database.database.delete(table) 84 85 sa_delete_patch.assert_called_once_with(table) 86 assert result == sa_delete_patch() 87 88 89 async def test__base() -> None: 90 assert isinstance(database.Base, DeclarativeMeta) 91 assert database.Base.__abstract__ is True 92 assert isinstance(database.Base.registry, registry) 93 assert database.Base.registry.metadata == database.Base.metadata 94 95 96 async def test__base_constructor() -> None: 97 base = MagicMock() 98 kwargs = mock_dict(5, string_keys=True) 99 100 database.Base.__init__(base, **kwargs) 101 102 base.registry.constructor.assert_called_once_with(base, **kwargs) 103 104 105 async def test__constructor(mocker: MockerFixture) -> None: 106 create_async_engine_patch = mocker.patch("api.database.database.create_async_engine") 107 108 url = MagicMock() 109 kwargs = mock_dict(5, string_keys=True) 110 111 result = database.database.DB(url, **kwargs) 112 113 create_async_engine_patch.assert_called_once_with(url, **kwargs) 114 assert result.engine == create_async_engine_patch() 115 116 assert isinstance(result._session, ContextVar) 117 assert result._session.name == "session" 118 assert result._session.get() is None 119 120 assert isinstance(result._close_event, ContextVar) 121 assert result._close_event.name == "close_event" 122 assert result._close_event.get() is None 123 124 125 async def test__create_tables(mocker: MockerFixture) -> None: 126 base_patch = mocker.patch("api.database.database.Base") 127 128 db = MagicMock() 129 130 async def run_sync(coro: Any) -> None: 131 assert coro == base_patch.metadata.create_all 132 func_callback() 133 134 conn = MagicMock() 135 conn.run_sync = run_sync 136 db.engine.begin, [func_callback], assert_calls = mock_asynccontextmanager(1, conn) 137 138 await database.database.DB.create_tables(db) 139 140 assert_calls() 141 142 143 async def test__add() -> None: 144 db = MagicMock() 145 obj = MagicMock() 146 147 result = await database.database.DB.add(db, obj) 148 149 db.session.add.assert_called_once_with(obj) 150 assert obj == result 151 152 153 async def test__db__delete() -> None: 154 db = AsyncMock() 155 obj = MagicMock() 156 157 result = await database.database.DB.delete(db, obj) 158 159 db.session.delete.assert_called_once_with(obj) 160 assert obj == result 161 162 163 async def test__exec() -> None: 164 db = AsyncMock() 165 statement = MagicMock() 166 167 result = await database.database.DB.exec(db, statement) 168 169 db.session.execute.assert_called_once_with(statement) 170 assert result == await db.session.execute() 171 172 173 async def test__stream() -> None: 174 db = AsyncMock() 175 statement = MagicMock() 176 db.session.stream.return_value = MagicMock() 177 178 result = await database.database.DB.stream(db, statement) 179 180 db.session.stream.assert_called_once_with(statement) 181 (await db.session.stream()).scalars.assert_called_once_with() 182 assert result == (await db.session.stream()).scalars() 183 184 185 async def test__all() -> None: 186 db = AsyncMock() 187 statement = MagicMock() 188 expected = mock_list(5) 189 190 async def async_iterator() -> AsyncIterator[Any]: 191 for x in expected: 192 yield x 193 194 db.stream.return_value = async_iterator() 195 196 result = await database.database.DB.all(db, statement) 197 198 db.stream.assert_called_once_with(statement) 199 assert result == expected 200 201 202 async def test__first() -> None: 203 db = AsyncMock() 204 statement = MagicMock() 205 db.exec.return_value = MagicMock() 206 207 result = await database.database.DB.first(db, statement) 208 209 db.exec.assert_called_once_with(statement) 210 (await db.exec()).scalar.assert_called_once_with() 211 assert result == (await db.exec()).scalar() 212 213 214 async def test__db__exists(mocker: MockerFixture) -> None: 215 exists_patch = mocker.patch("api.database.database.exists") 216 217 db = AsyncMock() 218 args = mock_list(5) 219 kwargs = mock_dict(5, True) 220 221 result = await database.database.DB.exists(db, *args, **kwargs) 222 223 exists_patch.assert_called_once_with(*args, **kwargs) 224 exists_patch().select.assert_called_once_with() 225 db.first.assert_called_once_with(exists_patch().select()) 226 assert result == await db.first(exists_patch().select()) 227 228 229 async def test__db__count(mocker: MockerFixture) -> None: 230 count_patch = mocker.patch("api.database.database.count") 231 select_patch = mocker.patch("api.database.database.select") 232 233 db = AsyncMock() 234 arg = MagicMock() 235 236 result = await database.database.DB.count(db, arg) 237 238 count_patch.assert_called_once_with() 239 select_patch.assert_called_once_with(count_patch()) 240 arg.subquery.assert_called_once_with() 241 select_patch().select_from.assert_called_once_with(arg.subquery()) 242 db.first.assert_called_once_with(select_patch().select_from()) 243 assert result == await db.first() 244 245 246 async def test__get(mocker: MockerFixture) -> None: 247 filter_by_patch = mocker.patch("api.database.database.filter_by") 248 249 db = AsyncMock() 250 args = mock_list(5) 251 kwargs = mock_dict(5, True) 252 253 result = await database.database.DB.get(db, *args, **kwargs) # type: ignore 254 255 filter_by_patch.assert_called_once_with(*args, **kwargs) 256 db.first.assert_called_once_with(filter_by_patch()) 257 assert result == await db.first() 258 259 260 async def test__commit__no_session() -> None: 261 db = MagicMock() 262 db._session.get.return_value = None 263 db.session = AsyncMock() 264 265 await database.database.DB.commit(db) 266 267 db._session.get.assert_called_once_with() 268 db.session.commit.assert_not_called() 269 270 271 async def test__commit__with_session() -> None: 272 db = MagicMock() 273 session = db._session.get.return_value = db.session = MagicMock() 274 session.commit = AsyncMock() 275 276 await database.database.DB.commit(db) 277 278 db._session.get.assert_called_once_with() 279 session.commit.assert_called_once_with() 280 281 282 async def test__close__no_session() -> None: 283 db = MagicMock() 284 db._session.get.return_value = None 285 db.session = AsyncMock() 286 287 await database.database.DB.close(db) 288 289 db._session.get.assert_called_once_with() 290 db.session.close.assert_not_called() 291 db._close_event.get().set.assert_not_called() 292 293 294 async def test__close__with_session_no_close_event() -> None: 295 db = MagicMock() 296 session = db._session.get.return_value = db.session = MagicMock() 297 session.close = AsyncMock() 298 db._close_event.get.return_value = None 299 300 await database.database.DB.close(db) 301 302 db._session.get.assert_called_once_with() 303 session.close.assert_called_once_with() 304 db._close_event.get.assert_called_once_with() 305 306 307 async def test__close__with_session() -> None: 308 db = MagicMock() 309 session = db._session.get.return_value = db.session = MagicMock() 310 session.close = AsyncMock() 311 312 await database.database.DB.close(db) 313 314 db._session.get.assert_called_once_with() 315 session.close.assert_called_once_with() 316 db._close_event.get.assert_called_once_with() 317 db._close_event.get().set.assert_called_once_with() 318 319 320 async def test__create_session(mocker: MockerFixture) -> None: 321 async_session_patch = mocker.patch("api.database.database.AsyncSession") 322 event_patch = mocker.patch("api.database.database.Event") 323 324 db = MagicMock() 325 326 result = database.database.DB.create_session(db) 327 328 async_session_patch.assert_called_once_with(db.engine) 329 db._session.set.assert_called_with(async_session_patch()) 330 event_patch.assert_called_once_with() 331 db._close_event.set.assert_called_with(event_patch()) 332 assert result == async_session_patch() 333 334 335 async def test__session() -> None: 336 db = MagicMock() 337 338 result = database.database.DB.session.fget(db) # type: ignore 339 340 db._session.get.assert_called_once_with() 341 assert result == db._session.get() 342 343 344 async def test__wait_for_close_event__not_set() -> None: 345 db = MagicMock() 346 db._close_event.get.return_value = None 347 348 await database.database.DB.wait_for_close_event(db) 349 350 db._close_event.get.assert_called_once_with() 351 352 353 async def test__wait_for_close_event() -> None: 354 db = MagicMock() 355 close_event = db._close_event.get.return_value = MagicMock() 356 close_event.wait = AsyncMock() 357 358 await database.database.DB.wait_for_close_event(db) 359 360 db._close_event.get.assert_called_once_with() 361 close_event.wait.assert_called_once_with() 362 363 364 async def test__get_database(mocker: MockerFixture, monkeypatch: MonkeyPatch) -> None: 365 db_patch = mocker.patch("api.database.database.DB") 366 367 monkeypatch.setattr(settings, "database_url", url_patch := MagicMock()) 368 monkeypatch.setattr(settings, "pool_recycle", pool_recycle_patch := MagicMock()) 369 monkeypatch.setattr(settings, "pool_size", pool_size_patch := MagicMock()) 370 monkeypatch.setattr(settings, "max_overflow", max_overflow_patch := MagicMock()) 371 monkeypatch.setattr(settings, "sql_show_statements", sql_show_statements_patch := MagicMock()) 372 373 result = database.database.get_database() 374 375 db_patch.assert_called_once_with( 376 url=url_patch, 377 pool_pre_ping=True, 378 pool_recycle=pool_recycle_patch, 379 pool_size=pool_size_patch, 380 max_overflow=max_overflow_patch, 381 echo=sql_show_statements_patch, 382 ) 383 assert result == db_patch() 384 385 386 async def test__db_context(mocker: MockerFixture) -> None: 387 db_patch = mocker.patch("api.database.db") 388 389 db_patch.commit = AsyncMock() 390 db_patch.close = AsyncMock() 391 db_patch.close.side_effect = lambda: db_patch.commit.assert_called_once_with() 392 393 async with database.db_context(): 394 db_patch.create_session.assert_called_once_with() 395 396 db_patch.close.assert_called_once_with() 397 398 399 async def test__db_wrapper(mocker: MockerFixture) -> None: 400 db_context_patch = mocker.patch("api.database.db_context") 401 db_context_patch.side_effect, [func_callback], assert_calls = mock_asynccontextmanager(1, None) 402 403 args = mock_list(5) 404 kwargs = mock_dict(5, True) 405 expected = MagicMock() 406 407 @database.db_wrapper 408 async def test(*_args: Any, **_kwargs: Any) -> Any: 409 assert args == list(_args) 410 assert kwargs == _kwargs 411 func_callback() 412 return expected 413 414 result = await test(*args, **kwargs) 415 416 assert result == expected 417 db_context_patch.assert_called_once_with() 418 assert_calls() 419 assert test.__name__ == "test" 420 421 422 async def test__db(mocker: MockerFixture) -> None: 423 get_database_mock = mocker.patch("api.database.database.get_database") 424 425 db = import_module("api.database") 426 427 get_database_mock.assert_called_once_with() 428 assert get_database_mock() == db.db