auth.py
1 from enum import Enum 2 from typing import Any, Awaitable, Callable 3 4 from fastapi import Depends, Request 5 from fastapi.openapi.models import HTTPBearer 6 from fastapi.security.base import SecurityBase 7 from sqlalchemy import Column 8 9 from .database import db 10 from .exceptions.auth import InvalidTokenError, PermissionDeniedError 11 from .exceptions.user import UserNotFoundError 12 from .models import Session, User 13 from .utils.jwt import decode_jwt 14 15 16 def get_token(request: Request) -> str: 17 authorization: str = request.headers.get("Authorization", "") 18 return authorization.removeprefix("Bearer ") 19 20 21 class PermissionLevel(Enum): 22 PUBLIC = 0 23 USER = 1 24 ADMIN = 2 25 26 27 class HTTPAuth(SecurityBase): 28 def __init__(self) -> None: 29 self.model = HTTPBearer() 30 self.scheme_name = self.__class__.__name__ 31 32 async def __call__(self, request: Request) -> Any: 33 raise NotImplementedError 34 35 36 class StaticTokenAuth(HTTPAuth): 37 def __init__(self, token: str) -> None: 38 super().__init__() 39 40 self._token = token 41 42 async def _check_token(self, token: str) -> bool: 43 return token == self._token 44 45 async def __call__(self, request: Request) -> bool: 46 if not await self._check_token(get_token(request)): 47 raise InvalidTokenError 48 return True 49 50 51 class JWTAuth(HTTPAuth): 52 def __init__(self, *, audience: list[str] | None = None, force_valid: bool = True): 53 super().__init__() 54 self.audience: list[str] | None = audience 55 self.force_valid: bool = force_valid 56 57 async def __call__(self, request: Request) -> dict[Any, Any] | None: 58 if (data := decode_jwt(get_token(request), audience=self.audience)) is None and self.force_valid: 59 raise InvalidTokenError 60 return data 61 62 63 class InternalAuth(JWTAuth): 64 def __init__(self, audience: list[str] | None = None): 65 super().__init__(audience=audience, force_valid=True) 66 67 68 class UserAuth(HTTPAuth): 69 def __init__(self, min_level: PermissionLevel) -> None: 70 super().__init__() 71 72 self.min_level: PermissionLevel = min_level 73 74 async def __call__(self, request: Request) -> Session | None: 75 session: Session | None = await Session.from_access_token(get_token(request)) 76 77 if self.min_level == PermissionLevel.PUBLIC: 78 return session 79 80 if not session: 81 raise InvalidTokenError 82 83 if self.min_level == PermissionLevel.ADMIN and not session.user.admin: 84 raise PermissionDeniedError 85 86 return session 87 88 89 static_token_auth = Depends(StaticTokenAuth("secret token")) 90 jwt_auth = Depends(JWTAuth()) 91 internal_auth = Depends(InternalAuth(audience=["service_xyz"])) 92 93 public_auth = Depends(UserAuth(PermissionLevel.PUBLIC)) 94 user_auth = Depends(UserAuth(PermissionLevel.USER)) 95 admin_auth = Depends(UserAuth(PermissionLevel.ADMIN)) 96 97 98 @Depends 99 async def is_admin(session: Session | None = public_auth) -> bool: 100 return session is not None and session.user.admin 101 102 103 def _get_user_dependency(*args: Column[Any]) -> Callable[[str, Session | None], Awaitable[User]]: 104 async def default_dependency(user_id: str, session: Session | None = public_auth) -> User: 105 if user_id.lower() in ["me", "self"] and session: 106 user_id = session.user_id 107 if not (user := await db.get(User, *args, id=user_id)): 108 raise UserNotFoundError 109 110 return user 111 112 return default_dependency 113 114 115 def _get_user_privileged_dependency(*args: Column[Any]) -> Callable[[str, Session], Awaitable[User]]: 116 async def self_or_admin_dependency(user_id: str, session: Session = user_auth) -> User: 117 if user_id.lower() in ["me", "self"]: 118 user_id = session.user_id 119 if session.user_id != user_id and not session.user.admin: 120 raise PermissionDeniedError 121 122 return await _get_user_dependency(*args)(user_id, None) 123 124 return self_or_admin_dependency 125 126 127 def get_user(*args: Column[Any], require_self_or_admin: bool = False) -> Any: 128 return Depends(_get_user_privileged_dependency(*args) if require_self_or_admin else _get_user_dependency(*args))