auth.py
1 from typing import Any 2 3 from fastapi import Depends, Request 4 from fastapi.openapi.models import HTTPBearer 5 from fastapi.security.base import SecurityBase 6 7 from .exceptions.auth import InvalidTokenError 8 from .utils.jwt import decode_jwt 9 10 11 def get_token(request: Request) -> str: 12 authorization: str = request.headers.get("Authorization", "") 13 return authorization.removeprefix("Bearer ") 14 15 16 class HTTPAuth(SecurityBase): 17 def __init__(self) -> None: 18 self.model = HTTPBearer() 19 self.scheme_name = self.__class__.__name__ 20 21 async def __call__(self, request: Request) -> Any: 22 raise NotImplementedError 23 24 25 class StaticTokenAuth(HTTPAuth): 26 def __init__(self, token: str) -> None: 27 super().__init__() 28 29 self._token = token 30 31 async def _check_token(self, token: str) -> bool: 32 return token == self._token 33 34 async def __call__(self, request: Request) -> bool: 35 if not await self._check_token(get_token(request)): 36 raise InvalidTokenError 37 return True 38 39 40 class JWTAuth(HTTPAuth): 41 def __init__(self, *, audience: list[str] | None = None, force_valid: bool = True): 42 super().__init__() 43 self.audience: list[str] | None = audience 44 self.force_valid: bool = force_valid 45 46 async def __call__(self, request: Request) -> dict[Any, Any] | None: 47 if (data := decode_jwt(get_token(request), audience=self.audience)) is None and self.force_valid: 48 raise InvalidTokenError 49 return data 50 51 52 class InternalAuth(JWTAuth): 53 def __init__(self, audience: list[str] | None = None): 54 super().__init__(audience=audience, force_valid=True) 55 56 57 static_token_auth = Depends(StaticTokenAuth("secret token")) 58 jwt_auth = Depends(JWTAuth()) 59 internal_auth = Depends(InternalAuth(audience=["service_xyz"]))