/ fastapi-template-users / tests / test_auth.py
test_auth.py
  1  from typing import Type
  2  from unittest.mock import AsyncMock, MagicMock
  3  
  4  import pytest
  5  from fastapi.security.base import SecurityBase
  6  from pytest_mock import MockerFixture
  7  
  8  from ._utils import mock_list
  9  from api import auth
 10  from api.auth import PermissionLevel
 11  from api.exceptions.auth import InvalidTokenError, PermissionDeniedError
 12  from api.exceptions.user import UserNotFoundError
 13  from api.models import User
 14  
 15  
 16  @pytest.mark.parametrize("auth_header,token", [("test", "test"), (None, ""), ("Bearer asDF1234", "asDF1234")])
 17  def test__get_token(auth_header: str | None, token: str) -> None:
 18      request = MagicMock()
 19      request.headers = {"Authorization": auth_header} if auth_header is not None else {}
 20  
 21      assert auth.get_token(request) == token
 22  
 23  
 24  async def test__httpauth_constructor(mocker: MockerFixture) -> None:
 25      httpbearer_patch = mocker.patch("api.auth.HTTPBearer")
 26  
 27      http_auth = auth.HTTPAuth()
 28  
 29      httpbearer_patch.assert_called_once_with()
 30      assert http_auth.model == httpbearer_patch()
 31      assert http_auth.scheme_name == http_auth.__class__.__name__
 32      assert issubclass(auth.HTTPAuth, SecurityBase)
 33  
 34  
 35  async def test__httpauth_call() -> None:
 36      request = MagicMock()
 37      http_auth = MagicMock()
 38      with pytest.raises(NotImplementedError):
 39          await auth.HTTPAuth.__call__(http_auth, request)
 40  
 41  
 42  @pytest.mark.parametrize("token,ok", [("S3cr3t Token!", True), ("asdf1234", False)])
 43  async def test__statictokenauth_check_token(token: str, ok: bool) -> None:
 44      http_auth = MagicMock()
 45      http_auth._token = "S3cr3t Token!"
 46      assert await auth.StaticTokenAuth._check_token(http_auth, token) == ok
 47  
 48  
 49  async def test__statictokenauth_call__invalid_token(mocker: MockerFixture) -> None:
 50      get_token = mocker.patch("api.auth.get_token")
 51  
 52      request = MagicMock()
 53      http_auth = MagicMock()
 54      http_auth._check_token = AsyncMock(return_value=False)
 55  
 56      with pytest.raises(InvalidTokenError):
 57          await auth.StaticTokenAuth.__call__(http_auth, request)
 58  
 59      get_token.assert_called_once_with(request)
 60      http_auth._check_token.assert_called_once_with(get_token())
 61  
 62  
 63  async def test__statictokenauth_call__valid_token(mocker: MockerFixture) -> None:
 64      get_token = mocker.patch("api.auth.get_token")
 65  
 66      request = MagicMock()
 67      http_auth = MagicMock()
 68      http_auth._check_token = AsyncMock(return_value=True)
 69  
 70      assert await auth.StaticTokenAuth.__call__(http_auth, request) is True
 71  
 72      get_token.assert_called_once_with(request)
 73      http_auth._check_token.assert_called_once_with(get_token())
 74  
 75  
 76  async def test__jwtauth_call__invalid_token(mocker: MockerFixture) -> None:
 77      get_token = mocker.patch("api.auth.get_token")
 78      mocker.patch("api.auth.decode_jwt", MagicMock(return_value=None))
 79  
 80      request = MagicMock()
 81      http_auth = MagicMock(force_valid=False)
 82  
 83      assert await auth.JWTAuth.__call__(http_auth, request) is None
 84  
 85      get_token.assert_called_once_with(request)
 86  
 87  
 88  async def test__jwtauth_call__invalid_token__force_valid(mocker: MockerFixture) -> None:
 89      get_token = mocker.patch("api.auth.get_token")
 90      mocker.patch("api.auth.decode_jwt", MagicMock(return_value=None))
 91  
 92      request = MagicMock()
 93      http_auth = MagicMock(force_valid=True)
 94  
 95      with pytest.raises(InvalidTokenError):
 96          await auth.JWTAuth.__call__(http_auth, request)
 97  
 98      get_token.assert_called_once_with(request)
 99  
100  
101  async def test__jwtauth_call__valid_token(mocker: MockerFixture) -> None:
102      get_token = mocker.patch("api.auth.get_token")
103      mocker.patch("api.auth.decode_jwt", MagicMock(return_value={"foo": "bar"}))
104  
105      request = MagicMock()
106      http_auth = MagicMock()
107  
108      assert await auth.JWTAuth.__call__(http_auth, request) == {"foo": "bar"}
109  
110      get_token.assert_called_once_with(request)
111  
112  
113  async def test__userauth_constructor() -> None:
114      min_level = MagicMock()
115  
116      user_auth = auth.UserAuth(min_level)
117  
118      assert user_auth.min_level == min_level
119      assert issubclass(auth.UserAuth, auth.HTTPAuth)
120  
121  
122  @pytest.mark.parametrize("valid", [True, False])
123  async def test__userauth_call__public(valid: bool, mocker: MockerFixture) -> None:
124      get_token = mocker.patch("api.auth.get_token")
125      from_access_token = mocker.patch("api.auth.Session.from_access_token", AsyncMock())
126      request = MagicMock()
127  
128      result = await auth.UserAuth(PermissionLevel.PUBLIC)(request)
129  
130      get_token.assert_called_once_with(request)
131      from_access_token.assert_called_once_with(get_token())
132      assert result == await from_access_token()
133  
134  
135  @pytest.mark.parametrize(
136      "permission_level,valid,admin,exc",
137      [
138          (PermissionLevel.USER, False, False, InvalidTokenError),
139          (PermissionLevel.ADMIN, False, False, InvalidTokenError),
140          (PermissionLevel.USER, True, False, None),
141          (PermissionLevel.ADMIN, True, False, PermissionDeniedError),
142          (PermissionLevel.USER, True, True, None),
143          (PermissionLevel.ADMIN, True, True, None),
144      ],
145  )
146  async def test__userauth_call(
147      permission_level: PermissionLevel, valid: bool, admin: bool, exc: Type[Exception] | None, mocker: MockerFixture
148  ) -> None:
149      get_token = mocker.patch("api.auth.get_token")
150      from_access_token = mocker.patch("api.auth.Session.from_access_token")
151  
152      request = MagicMock()
153      session = MagicMock()
154      session.user.admin = admin
155  
156      from_access_token.return_value = session if valid else None
157  
158      user_auth = auth.UserAuth(permission_level)
159  
160      if exc is None:
161          assert await user_auth(request) == session
162      else:
163          with pytest.raises(exc):
164              await user_auth(request)
165  
166      get_token.assert_called_once_with(request)
167      from_access_token.assert_called_once_with(get_token())
168  
169  
170  @pytest.mark.parametrize("valid,admin", [(False, False), (True, False), (True, True)])
171  async def test__is_admin(valid: bool, admin: bool) -> None:
172      session = MagicMock(user=MagicMock(admin=admin)) if valid else None
173  
174      assert await auth.is_admin.dependency(session) == (valid and admin)
175      assert auth.is_admin.dependency.__defaults__ == (auth.public_auth,)
176  
177  
178  async def test__get_user_dependency__not_found(mocker: MockerFixture) -> None:
179      db = mocker.patch("api.auth.db")
180      db.get = AsyncMock(return_value=None)
181  
182      args = mock_list(5)
183  
184      with pytest.raises(UserNotFoundError):
185          await auth._get_user_dependency(*args)("some_user_id", None)
186  
187      db.get.assert_called_once_with(User, *args, id="some_user_id")
188  
189  
190  async def test__get_user_dependency__by_id(mocker: MockerFixture) -> None:
191      user = MagicMock()
192      db = mocker.patch("api.auth.db")
193      db.get = AsyncMock(return_value=user)
194  
195      args = mock_list(5)
196  
197      assert await auth._get_user_dependency(*args)("some_user_id", None) == user
198  
199      db.get.assert_called_once_with(User, *args, id="some_user_id")
200  
201  
202  @pytest.mark.parametrize("alias", ["self", "me"])
203  async def test__get_user_dependency__self(alias: str, mocker: MockerFixture) -> None:
204      user = MagicMock()
205      session = MagicMock(user_id="some_user_id")
206      db = mocker.patch("api.auth.db")
207      db.get = AsyncMock(return_value=user)
208  
209      args = mock_list(5)
210  
211      assert await auth._get_user_dependency(*args)(alias, session) == user
212  
213      db.get.assert_called_once_with(User, *args, id="some_user_id")
214  
215  
216  @pytest.mark.parametrize(
217      "user_id,session_user_id,admin,ok",
218      [
219          ("me", "some_user_id", False, True),
220          ("self", "some_user_id", False, True),
221          ("some_user_id", "some_user_id", False, True),
222          ("some_user_id", "other_user_id", True, True),
223          ("some_user_id", "other_user_id", False, False),
224      ],
225  )
226  async def test__get_user_privileged(
227      session_user_id: str, user_id: str, admin: bool, ok: bool, mocker: MockerFixture
228  ) -> None:
229      user = MagicMock(id=user_id, admin=admin)
230      session = MagicMock(user_id=session_user_id, user=user)
231  
232      get_user_dependency = mocker.patch("api.auth._get_user_dependency")
233      get_user_dependency.return_value = AsyncMock(return_value=user)
234  
235      args = mock_list(5)
236  
237      if ok:
238          assert await auth._get_user_privileged_dependency(*args)(user_id, session) == user
239          get_user_dependency.assert_called_once_with(*args)
240          get_user_dependency().assert_called_once_with(session_user_id if user_id in ["self", "me"] else user_id, None)
241      else:
242          with pytest.raises(PermissionDeniedError):
243              await auth._get_user_privileged_dependency(*args)(user_id, session)
244  
245  
246  @pytest.mark.parametrize("require_self_or_admin", [True, False])
247  async def test__get_user(require_self_or_admin: bool, mocker: MockerFixture) -> None:
248      get_user_dependency = mocker.patch("api.auth._get_user_dependency")
249      get_user_privileged_dependency = mocker.patch("api.auth._get_user_privileged_dependency")
250      depends = mocker.patch("api.auth.Depends")
251  
252      args = mock_list(5)
253  
254      result = auth.get_user(*args, require_self_or_admin=require_self_or_admin)
255  
256      if require_self_or_admin:
257          get_user_privileged_dependency.assert_called_once_with(*args)
258          depends.assert_called_once_with(get_user_privileged_dependency())
259      else:
260          get_user_dependency.assert_called_once_with(*args)
261          depends.assert_called_once_with(get_user_dependency())
262  
263      assert result == depends()