/ fastapi-template / api / auth.py
auth.py
 1  from typing import Any
 2  
 3  from fastapi import Depends, Request
 4  from fastapi.openapi.models import HTTPBearer
 5  from fastapi.security.base import SecurityBase
 6  
 7  from .exceptions.auth import InvalidTokenError
 8  from .utils.jwt import decode_jwt
 9  
10  
11  def get_token(request: Request) -> str:
12      authorization: str = request.headers.get("Authorization", "")
13      return authorization.removeprefix("Bearer ")
14  
15  
16  class HTTPAuth(SecurityBase):
17      def __init__(self) -> None:
18          self.model = HTTPBearer()
19          self.scheme_name = self.__class__.__name__
20  
21      async def __call__(self, request: Request) -> Any:
22          raise NotImplementedError
23  
24  
25  class StaticTokenAuth(HTTPAuth):
26      def __init__(self, token: str) -> None:
27          super().__init__()
28  
29          self._token = token
30  
31      async def _check_token(self, token: str) -> bool:
32          return token == self._token
33  
34      async def __call__(self, request: Request) -> bool:
35          if not await self._check_token(get_token(request)):
36              raise InvalidTokenError
37          return True
38  
39  
40  class JWTAuth(HTTPAuth):
41      def __init__(self, *, audience: list[str] | None = None, force_valid: bool = True):
42          super().__init__()
43          self.audience: list[str] | None = audience
44          self.force_valid: bool = force_valid
45  
46      async def __call__(self, request: Request) -> dict[Any, Any] | None:
47          if (data := decode_jwt(get_token(request), audience=self.audience)) is None and self.force_valid:
48              raise InvalidTokenError
49          return data
50  
51  
52  class InternalAuth(JWTAuth):
53      def __init__(self, audience: list[str] | None = None):
54          super().__init__(audience=audience, force_valid=True)
55  
56  
57  static_token_auth = Depends(StaticTokenAuth("secret token"))
58  jwt_auth = Depends(JWTAuth())
59  internal_auth = Depends(InternalAuth(audience=["service_xyz"]))