test_auth.py
1 from typing import Type 2 from unittest.mock import AsyncMock, MagicMock 3 4 import pytest 5 from fastapi.security.base import SecurityBase 6 from pytest_mock import MockerFixture 7 8 from ._utils import mock_list 9 from api import auth 10 from api.auth import PermissionLevel 11 from api.exceptions.auth import InvalidTokenError, PermissionDeniedError 12 from api.exceptions.user import UserNotFoundError 13 from api.models import User 14 15 16 @pytest.mark.parametrize("auth_header,token", [("test", "test"), (None, ""), ("Bearer asDF1234", "asDF1234")]) 17 def test__get_token(auth_header: str | None, token: str) -> None: 18 request = MagicMock() 19 request.headers = {"Authorization": auth_header} if auth_header is not None else {} 20 21 assert auth.get_token(request) == token 22 23 24 async def test__httpauth_constructor(mocker: MockerFixture) -> None: 25 httpbearer_patch = mocker.patch("api.auth.HTTPBearer") 26 27 http_auth = auth.HTTPAuth() 28 29 httpbearer_patch.assert_called_once_with() 30 assert http_auth.model == httpbearer_patch() 31 assert http_auth.scheme_name == http_auth.__class__.__name__ 32 assert issubclass(auth.HTTPAuth, SecurityBase) 33 34 35 async def test__httpauth_call() -> None: 36 request = MagicMock() 37 http_auth = MagicMock() 38 with pytest.raises(NotImplementedError): 39 await auth.HTTPAuth.__call__(http_auth, request) 40 41 42 @pytest.mark.parametrize("token,ok", [("S3cr3t Token!", True), ("asdf1234", False)]) 43 async def test__statictokenauth_check_token(token: str, ok: bool) -> None: 44 http_auth = MagicMock() 45 http_auth._token = "S3cr3t Token!" 46 assert await auth.StaticTokenAuth._check_token(http_auth, token) == ok 47 48 49 async def test__statictokenauth_call__invalid_token(mocker: MockerFixture) -> None: 50 get_token = mocker.patch("api.auth.get_token") 51 52 request = MagicMock() 53 http_auth = MagicMock() 54 http_auth._check_token = AsyncMock(return_value=False) 55 56 with pytest.raises(InvalidTokenError): 57 await auth.StaticTokenAuth.__call__(http_auth, request) 58 59 get_token.assert_called_once_with(request) 60 http_auth._check_token.assert_called_once_with(get_token()) 61 62 63 async def test__statictokenauth_call__valid_token(mocker: MockerFixture) -> None: 64 get_token = mocker.patch("api.auth.get_token") 65 66 request = MagicMock() 67 http_auth = MagicMock() 68 http_auth._check_token = AsyncMock(return_value=True) 69 70 assert await auth.StaticTokenAuth.__call__(http_auth, request) is True 71 72 get_token.assert_called_once_with(request) 73 http_auth._check_token.assert_called_once_with(get_token()) 74 75 76 async def test__jwtauth_call__invalid_token(mocker: MockerFixture) -> None: 77 get_token = mocker.patch("api.auth.get_token") 78 mocker.patch("api.auth.decode_jwt", MagicMock(return_value=None)) 79 80 request = MagicMock() 81 http_auth = MagicMock(force_valid=False) 82 83 assert await auth.JWTAuth.__call__(http_auth, request) is None 84 85 get_token.assert_called_once_with(request) 86 87 88 async def test__jwtauth_call__invalid_token__force_valid(mocker: MockerFixture) -> None: 89 get_token = mocker.patch("api.auth.get_token") 90 mocker.patch("api.auth.decode_jwt", MagicMock(return_value=None)) 91 92 request = MagicMock() 93 http_auth = MagicMock(force_valid=True) 94 95 with pytest.raises(InvalidTokenError): 96 await auth.JWTAuth.__call__(http_auth, request) 97 98 get_token.assert_called_once_with(request) 99 100 101 async def test__jwtauth_call__valid_token(mocker: MockerFixture) -> None: 102 get_token = mocker.patch("api.auth.get_token") 103 mocker.patch("api.auth.decode_jwt", MagicMock(return_value={"foo": "bar"})) 104 105 request = MagicMock() 106 http_auth = MagicMock() 107 108 assert await auth.JWTAuth.__call__(http_auth, request) == {"foo": "bar"} 109 110 get_token.assert_called_once_with(request) 111 112 113 async def test__userauth_constructor() -> None: 114 min_level = MagicMock() 115 116 user_auth = auth.UserAuth(min_level) 117 118 assert user_auth.min_level == min_level 119 assert issubclass(auth.UserAuth, auth.HTTPAuth) 120 121 122 @pytest.mark.parametrize("valid", [True, False]) 123 async def test__userauth_call__public(valid: bool, mocker: MockerFixture) -> None: 124 get_token = mocker.patch("api.auth.get_token") 125 from_access_token = mocker.patch("api.auth.Session.from_access_token", AsyncMock()) 126 request = MagicMock() 127 128 result = await auth.UserAuth(PermissionLevel.PUBLIC)(request) 129 130 get_token.assert_called_once_with(request) 131 from_access_token.assert_called_once_with(get_token()) 132 assert result == await from_access_token() 133 134 135 @pytest.mark.parametrize( 136 "permission_level,valid,admin,exc", 137 [ 138 (PermissionLevel.USER, False, False, InvalidTokenError), 139 (PermissionLevel.ADMIN, False, False, InvalidTokenError), 140 (PermissionLevel.USER, True, False, None), 141 (PermissionLevel.ADMIN, True, False, PermissionDeniedError), 142 (PermissionLevel.USER, True, True, None), 143 (PermissionLevel.ADMIN, True, True, None), 144 ], 145 ) 146 async def test__userauth_call( 147 permission_level: PermissionLevel, valid: bool, admin: bool, exc: Type[Exception] | None, mocker: MockerFixture 148 ) -> None: 149 get_token = mocker.patch("api.auth.get_token") 150 from_access_token = mocker.patch("api.auth.Session.from_access_token") 151 152 request = MagicMock() 153 session = MagicMock() 154 session.user.admin = admin 155 156 from_access_token.return_value = session if valid else None 157 158 user_auth = auth.UserAuth(permission_level) 159 160 if exc is None: 161 assert await user_auth(request) == session 162 else: 163 with pytest.raises(exc): 164 await user_auth(request) 165 166 get_token.assert_called_once_with(request) 167 from_access_token.assert_called_once_with(get_token()) 168 169 170 @pytest.mark.parametrize("valid,admin", [(False, False), (True, False), (True, True)]) 171 async def test__is_admin(valid: bool, admin: bool) -> None: 172 session = MagicMock(user=MagicMock(admin=admin)) if valid else None 173 174 assert await auth.is_admin.dependency(session) == (valid and admin) 175 assert auth.is_admin.dependency.__defaults__ == (auth.public_auth,) 176 177 178 async def test__get_user_dependency__not_found(mocker: MockerFixture) -> None: 179 db = mocker.patch("api.auth.db") 180 db.get = AsyncMock(return_value=None) 181 182 args = mock_list(5) 183 184 with pytest.raises(UserNotFoundError): 185 await auth._get_user_dependency(*args)("some_user_id", None) 186 187 db.get.assert_called_once_with(User, *args, id="some_user_id") 188 189 190 async def test__get_user_dependency__by_id(mocker: MockerFixture) -> None: 191 user = MagicMock() 192 db = mocker.patch("api.auth.db") 193 db.get = AsyncMock(return_value=user) 194 195 args = mock_list(5) 196 197 assert await auth._get_user_dependency(*args)("some_user_id", None) == user 198 199 db.get.assert_called_once_with(User, *args, id="some_user_id") 200 201 202 @pytest.mark.parametrize("alias", ["self", "me"]) 203 async def test__get_user_dependency__self(alias: str, mocker: MockerFixture) -> None: 204 user = MagicMock() 205 session = MagicMock(user_id="some_user_id") 206 db = mocker.patch("api.auth.db") 207 db.get = AsyncMock(return_value=user) 208 209 args = mock_list(5) 210 211 assert await auth._get_user_dependency(*args)(alias, session) == user 212 213 db.get.assert_called_once_with(User, *args, id="some_user_id") 214 215 216 @pytest.mark.parametrize( 217 "user_id,session_user_id,admin,ok", 218 [ 219 ("me", "some_user_id", False, True), 220 ("self", "some_user_id", False, True), 221 ("some_user_id", "some_user_id", False, True), 222 ("some_user_id", "other_user_id", True, True), 223 ("some_user_id", "other_user_id", False, False), 224 ], 225 ) 226 async def test__get_user_privileged( 227 session_user_id: str, user_id: str, admin: bool, ok: bool, mocker: MockerFixture 228 ) -> None: 229 user = MagicMock(id=user_id, admin=admin) 230 session = MagicMock(user_id=session_user_id, user=user) 231 232 get_user_dependency = mocker.patch("api.auth._get_user_dependency") 233 get_user_dependency.return_value = AsyncMock(return_value=user) 234 235 args = mock_list(5) 236 237 if ok: 238 assert await auth._get_user_privileged_dependency(*args)(user_id, session) == user 239 get_user_dependency.assert_called_once_with(*args) 240 get_user_dependency().assert_called_once_with(session_user_id if user_id in ["self", "me"] else user_id, None) 241 else: 242 with pytest.raises(PermissionDeniedError): 243 await auth._get_user_privileged_dependency(*args)(user_id, session) 244 245 246 @pytest.mark.parametrize("require_self_or_admin", [True, False]) 247 async def test__get_user(require_self_or_admin: bool, mocker: MockerFixture) -> None: 248 get_user_dependency = mocker.patch("api.auth._get_user_dependency") 249 get_user_privileged_dependency = mocker.patch("api.auth._get_user_privileged_dependency") 250 depends = mocker.patch("api.auth.Depends") 251 252 args = mock_list(5) 253 254 result = auth.get_user(*args, require_self_or_admin=require_self_or_admin) 255 256 if require_self_or_admin: 257 get_user_privileged_dependency.assert_called_once_with(*args) 258 depends.assert_called_once_with(get_user_privileged_dependency()) 259 else: 260 get_user_dependency.assert_called_once_with(*args) 261 depends.assert_called_once_with(get_user_dependency()) 262 263 assert result == depends()