oauth.py
  1  """OAuth related endpoints"""
  2  
  3  from typing import Any
  4  from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
  5  
  6  from aiohttp import BasicAuth, ClientSession
  7  from fastapi import APIRouter
  8  
  9  from .. import models
 10  from ..auth import get_user
 11  from ..database import db, filter_by
 12  from ..exceptions.auth import admin_responses
 13  from ..exceptions.oauth import (
 14      ConnectionNotFoundError,
 15      InvalidOAuthCodeError,
 16      ProviderNotFoundError,
 17      RemoteAlreadyLinkedError,
 18  )
 19  from ..exceptions.user import CannotDeleteLastLoginMethodError, UserNotFoundError
 20  from ..schemas.oauth import OAuthConnection, OAuthLogin, OAuthProvider
 21  from ..settings import settings
 22  from ..utils.docs import responses
 23  
 24  
 25  router = APIRouter()
 26  
 27  
 28  def add_qs(url: str, q: dict[str, str]) -> str:
 29      scheme, netloc, path, params, query, fragment = urlparse(url)
 30      return urlunparse((scheme, netloc, path, params, urlencode(dict(parse_qsl(query)) | q), fragment))
 31  
 32  
 33  async def resolve_code(login: OAuthLogin) -> tuple[str, str | None]:
 34      if login.provider_id not in settings.oauth_providers:
 35          raise ProviderNotFoundError
 36  
 37      provider = settings.oauth_providers[login.provider_id]
 38      async with ClientSession() as session, session.post(
 39          provider.token_url,
 40          data={
 41              "grant_type": "authorization_code",
 42              "code": login.code,
 43              "redirect_uri": login.redirect_uri,
 44              "client_id": provider.client_id,
 45          },
 46          headers={"Accept": "application/json"},
 47          auth=BasicAuth(provider.client_id, provider.client_secret),
 48      ) as response:
 49          if response.status != 200:
 50              raise InvalidOAuthCodeError
 51  
 52          data = await response.json()
 53  
 54      access_token: str | None = data.get("access_token")
 55      if not access_token:
 56          raise InvalidOAuthCodeError
 57  
 58      def fmt(x: str) -> str:
 59          return x.format(access_token=access_token)
 60  
 61      async with ClientSession() as session, session.get(
 62          fmt(provider.userinfo_url), headers={k: fmt(v) for k, v in provider.userinfo_headers.items()}
 63      ) as response:
 64          if response.status != 200:
 65              raise InvalidOAuthCodeError
 66  
 67          data = await response.json()
 68  
 69      remote_user_id = provider.userinfo_id_path.input(data).first()
 70      display_name = provider.userinfo_name_path.input(data).first()
 71  
 72      return str(remote_user_id), str(display_name) if display_name else None
 73  
 74  
 75  @router.get("/oauth/providers", responses=responses(list[OAuthProvider]))
 76  async def get_oauth_providers() -> Any:
 77      """Return a list of all available OAuth providers."""
 78  
 79      return [
 80          {
 81              "id": k,
 82              "name": v.name,
 83              "authorize_url": add_qs(v.authorize_url, {"response_type": "code", "client_id": v.client_id}),
 84          }
 85          for k, v in settings.oauth_providers.items()
 86      ]
 87  
 88  
 89  @router.get("/oauth/links/{user_id}", responses=admin_responses(list[OAuthConnection], UserNotFoundError))
 90  async def get_oauth_connections(
 91      user: models.User = get_user(models.User.oauth_connections, require_self_or_admin=True)
 92  ) -> Any:
 93      """
 94      Return a list of all OAuth connections for the given user.
 95  
 96      *Requirements:* **SELF** or **ADMIN**
 97      """
 98  
 99      return [connection.serialize for connection in user.oauth_connections]
100  
101  
102  @router.post(
103      "/oauth/links/{user_id}",
104      responses=admin_responses(
105          OAuthConnection, UserNotFoundError, RemoteAlreadyLinkedError, ProviderNotFoundError, InvalidOAuthCodeError
106      ),
107  )
108  async def create_oauth_connection(login: OAuthLogin, user: models.User = get_user(require_self_or_admin=True)) -> Any:
109      """
110      Create a new OAuth connection for the given user.
111  
112      The client can use almost the same procedure as described in the `POST /sessions/oauth` endpoint documentation.
113  
114      *Requirements:* **SELF** or **ADMIN**
115      """
116  
117      user_id, display_name = await resolve_code(login)
118  
119      if await db.exists(filter_by(models.OAuthUserConnection, provider_id=login.provider_id, remote_user_id=user_id)):
120          raise RemoteAlreadyLinkedError
121  
122      connection = await models.OAuthUserConnection.create(user.id, login.provider_id, user_id, display_name)
123  
124      return connection.serialize
125  
126  
127  @router.delete(
128      "/oauth/links/{user_id}/{connection_id}",
129      responses=admin_responses(bool, UserNotFoundError, CannotDeleteLastLoginMethodError, ConnectionNotFoundError),
130  )
131  async def delete_oauth_connection(
132      connection_id: str, user: models.User = get_user(models.User.oauth_connections, require_self_or_admin=True)
133  ) -> Any:
134      """
135      Delete an existing OAuth connection.
136  
137      *Requirements:* **SELF** or **ADMIN**
138      """
139  
140      if not user.password and len(user.oauth_connections) <= 1:
141          raise CannotDeleteLastLoginMethodError
142  
143      if not (connection := await db.get(models.OAuthUserConnection, id=connection_id, user_id=user.id)):
144          raise ConnectionNotFoundError
145  
146      await db.delete(connection)
147      return True