test_app.py
1 from typing import Any, Callable 2 from unittest.mock import AsyncMock, MagicMock 3 4 from _pytest.monkeypatch import MonkeyPatch 5 from httpx import AsyncClient 6 from pytest_mock import MockerFixture 7 8 from ._utils import import_module, mock_asynccontextmanager 9 from api import app 10 from api.settings import settings 11 12 13 def get_decorated_function( 14 fastapi_patch: MagicMock, decorator_name: str, *decorator_args: Any, **decorator_kwargs: Any 15 ) -> tuple[Any, Callable[..., Any]]: 16 functions: list[Callable[..., Any]] = [] 17 decorator = MagicMock(side_effect=functions.append) 18 getattr(fastapi_patch(), decorator_name).side_effect = ( 19 lambda *args, **kwargs: decorator if (args, kwargs) == (decorator_args, decorator_kwargs) else MagicMock() 20 ) 21 fastapi_patch.reset_mock() 22 23 module = import_module(app) 24 25 decorator.assert_called_once() 26 assert len(functions) == 1 27 return module, functions[0] 28 29 30 async def test__setup_app__sentry(mocker: MockerFixture, monkeypatch: MonkeyPatch) -> None: 31 get_version_mock = mocker.patch("api.app.get_version") 32 setup_sentry_mock = mocker.patch("api.app.setup_sentry") 33 app_mock = mocker.patch("api.app.app") 34 monkeypatch.setattr(settings, "sentry_dsn", sentry_dsn_mock := MagicMock()) 35 monkeypatch.setattr(settings, "debug", False) 36 37 app.setup_app() 38 39 get_version_mock.assert_called_once_with() 40 setup_sentry_mock.assert_called_once_with(app_mock, sentry_dsn_mock, "FastAPI", get_version_mock().description) 41 42 43 async def test__setup_app__debug(mocker: MockerFixture, monkeypatch: MonkeyPatch) -> None: 44 app_mock = mocker.patch("api.app.app") 45 cors_middleware_mock = mocker.patch("api.app.CORSMiddleware") 46 monkeypatch.setattr(settings, "sentry_dsn", None) 47 monkeypatch.setattr(settings, "debug", True) 48 49 app.setup_app() 50 51 app_mock.add_middleware.assert_called_once_with( 52 cors_middleware_mock, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] 53 ) 54 55 56 async def test__db_session(mocker: MockerFixture) -> None: 57 fastapi_patch = mocker.patch("fastapi.FastAPI") 58 expected = MagicMock() 59 request = MagicMock() 60 61 module, db_session = get_decorated_function(fastapi_patch, "middleware", "http") 62 63 module.db_context, [func_callback], assert_calls = mock_asynccontextmanager(1, None) 64 call_next = AsyncMock(side_effect=lambda _: func_callback() or expected) 65 66 result = await db_session(request, call_next) 67 68 assert_calls() 69 call_next.assert_called_once_with(request) 70 assert result == expected 71 72 73 async def test__rollback_on_exception(mocker: MockerFixture) -> None: 74 fastapi_patch = mocker.patch("fastapi.FastAPI") 75 db_patch = mocker.patch("api.database.db") 76 db_patch.session.rollback = AsyncMock() 77 http_exception_patch = mocker.patch("starlette.exceptions.HTTPException") 78 http_exception_handler_patch = mocker.patch("fastapi.exception_handlers.http_exception_handler", AsyncMock()) 79 80 _, rollback_on_exception = get_decorated_function(fastapi_patch, "exception_handler", http_exception_patch) 81 82 result = await rollback_on_exception(request := MagicMock(), exc := MagicMock()) 83 84 db_patch.session.rollback.assert_called_once_with() 85 http_exception_handler_patch.assert_called_once_with(request, exc) 86 assert result == await http_exception_handler_patch() 87 88 89 async def test__on_startup(mocker: MockerFixture, monkeypatch: MonkeyPatch) -> None: 90 fastapi_patch = mocker.patch("fastapi.FastAPI") 91 db_patch = mocker.patch("api.database.db") 92 93 module, on_startup = get_decorated_function(fastapi_patch, "on_event", "startup") 94 db_patch.create_tables = AsyncMock() 95 monkeypatch.setattr(module, "setup_app", MagicMock()) 96 97 await on_startup() 98 99 module.setup_app.assert_called_once_with() 100 db_patch.create_tables.assert_not_called() # use alembic migrations instead 101 102 103 async def test__on_shutdown(mocker: MockerFixture) -> None: 104 fastapi_patch = mocker.patch("fastapi.FastAPI") 105 106 _, on_shutdown = get_decorated_function(fastapi_patch, "on_event", "shutdown") 107 108 await on_shutdown() 109 110 111 async def test__status(client: AsyncClient) -> None: 112 response = await client.head("/status") 113 assert response.status_code == 200