auth.py
  1  from enum import Enum
  2  from typing import Any, Awaitable, Callable
  3  
  4  from fastapi import Depends, Request
  5  from fastapi.openapi.models import HTTPBearer
  6  from fastapi.security.base import SecurityBase
  7  from sqlalchemy import Column
  8  
  9  from .database import db
 10  from .exceptions.auth import InvalidTokenError, PermissionDeniedError
 11  from .exceptions.user import UserNotFoundError
 12  from .models import Session, User
 13  from .utils.jwt import decode_jwt
 14  
 15  
 16  def get_token(request: Request) -> str:
 17      authorization: str = request.headers.get("Authorization", "")
 18      return authorization.removeprefix("Bearer ")
 19  
 20  
 21  class PermissionLevel(Enum):
 22      PUBLIC = 0
 23      USER = 1
 24      ADMIN = 2
 25  
 26  
 27  class HTTPAuth(SecurityBase):
 28      def __init__(self) -> None:
 29          self.model = HTTPBearer()
 30          self.scheme_name = self.__class__.__name__
 31  
 32      async def __call__(self, request: Request) -> Any:
 33          raise NotImplementedError
 34  
 35  
 36  class StaticTokenAuth(HTTPAuth):
 37      def __init__(self, token: str) -> None:
 38          super().__init__()
 39  
 40          self._token = token
 41  
 42      async def _check_token(self, token: str) -> bool:
 43          return token == self._token
 44  
 45      async def __call__(self, request: Request) -> bool:
 46          if not await self._check_token(get_token(request)):
 47              raise InvalidTokenError
 48          return True
 49  
 50  
 51  class JWTAuth(HTTPAuth):
 52      def __init__(self, *, audience: list[str] | None = None, force_valid: bool = True):
 53          super().__init__()
 54          self.audience: list[str] | None = audience
 55          self.force_valid: bool = force_valid
 56  
 57      async def __call__(self, request: Request) -> dict[Any, Any] | None:
 58          if (data := decode_jwt(get_token(request), audience=self.audience)) is None and self.force_valid:
 59              raise InvalidTokenError
 60          return data
 61  
 62  
 63  class InternalAuth(JWTAuth):
 64      def __init__(self, audience: list[str] | None = None):
 65          super().__init__(audience=audience, force_valid=True)
 66  
 67  
 68  class UserAuth(HTTPAuth):
 69      def __init__(self, min_level: PermissionLevel) -> None:
 70          super().__init__()
 71  
 72          self.min_level: PermissionLevel = min_level
 73  
 74      async def __call__(self, request: Request) -> Session | None:
 75          session: Session | None = await Session.from_access_token(get_token(request))
 76  
 77          if self.min_level == PermissionLevel.PUBLIC:
 78              return session
 79  
 80          if not session:
 81              raise InvalidTokenError
 82  
 83          if self.min_level == PermissionLevel.ADMIN and not session.user.admin:
 84              raise PermissionDeniedError
 85  
 86          return session
 87  
 88  
 89  static_token_auth = Depends(StaticTokenAuth("secret token"))
 90  jwt_auth = Depends(JWTAuth())
 91  internal_auth = Depends(InternalAuth(audience=["service_xyz"]))
 92  
 93  public_auth = Depends(UserAuth(PermissionLevel.PUBLIC))
 94  user_auth = Depends(UserAuth(PermissionLevel.USER))
 95  admin_auth = Depends(UserAuth(PermissionLevel.ADMIN))
 96  
 97  
 98  @Depends
 99  async def is_admin(session: Session | None = public_auth) -> bool:
100      return session is not None and session.user.admin
101  
102  
103  def _get_user_dependency(*args: Column[Any]) -> Callable[[str, Session | None], Awaitable[User]]:
104      async def default_dependency(user_id: str, session: Session | None = public_auth) -> User:
105          if user_id.lower() in ["me", "self"] and session:
106              user_id = session.user_id
107          if not (user := await db.get(User, *args, id=user_id)):
108              raise UserNotFoundError
109  
110          return user
111  
112      return default_dependency
113  
114  
115  def _get_user_privileged_dependency(*args: Column[Any]) -> Callable[[str, Session], Awaitable[User]]:
116      async def self_or_admin_dependency(user_id: str, session: Session = user_auth) -> User:
117          if user_id.lower() in ["me", "self"]:
118              user_id = session.user_id
119          if session.user_id != user_id and not session.user.admin:
120              raise PermissionDeniedError
121  
122          return await _get_user_dependency(*args)(user_id, None)
123  
124      return self_or_admin_dependency
125  
126  
127  def get_user(*args: Column[Any], require_self_or_admin: bool = False) -> Any:
128      return Depends(_get_user_privileged_dependency(*args) if require_self_or_admin else _get_user_dependency(*args))