oauth.py
1 """OAuth related endpoints""" 2 3 from typing import Any 4 from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse 5 6 from aiohttp import BasicAuth, ClientSession 7 from fastapi import APIRouter 8 9 from .. import models 10 from ..auth import get_user 11 from ..database import db, filter_by 12 from ..exceptions.auth import admin_responses 13 from ..exceptions.oauth import ( 14 ConnectionNotFoundError, 15 InvalidOAuthCodeError, 16 ProviderNotFoundError, 17 RemoteAlreadyLinkedError, 18 ) 19 from ..exceptions.user import CannotDeleteLastLoginMethodError, UserNotFoundError 20 from ..schemas.oauth import OAuthConnection, OAuthLogin, OAuthProvider 21 from ..settings import settings 22 from ..utils.docs import responses 23 24 25 router = APIRouter() 26 27 28 def add_qs(url: str, q: dict[str, str]) -> str: 29 scheme, netloc, path, params, query, fragment = urlparse(url) 30 return urlunparse((scheme, netloc, path, params, urlencode(dict(parse_qsl(query)) | q), fragment)) 31 32 33 async def resolve_code(login: OAuthLogin) -> tuple[str, str | None]: 34 if login.provider_id not in settings.oauth_providers: 35 raise ProviderNotFoundError 36 37 provider = settings.oauth_providers[login.provider_id] 38 async with ClientSession() as session, session.post( 39 provider.token_url, 40 data={ 41 "grant_type": "authorization_code", 42 "code": login.code, 43 "redirect_uri": login.redirect_uri, 44 "client_id": provider.client_id, 45 }, 46 headers={"Accept": "application/json"}, 47 auth=BasicAuth(provider.client_id, provider.client_secret), 48 ) as response: 49 if response.status != 200: 50 raise InvalidOAuthCodeError 51 52 data = await response.json() 53 54 access_token: str | None = data.get("access_token") 55 if not access_token: 56 raise InvalidOAuthCodeError 57 58 def fmt(x: str) -> str: 59 return x.format(access_token=access_token) 60 61 async with ClientSession() as session, session.get( 62 fmt(provider.userinfo_url), headers={k: fmt(v) for k, v in provider.userinfo_headers.items()} 63 ) as response: 64 if response.status != 200: 65 raise InvalidOAuthCodeError 66 67 data = await response.json() 68 69 remote_user_id = provider.userinfo_id_path.input(data).first() 70 display_name = provider.userinfo_name_path.input(data).first() 71 72 return str(remote_user_id), str(display_name) if display_name else None 73 74 75 @router.get("/oauth/providers", responses=responses(list[OAuthProvider])) 76 async def get_oauth_providers() -> Any: 77 """Return a list of all available OAuth providers.""" 78 79 return [ 80 { 81 "id": k, 82 "name": v.name, 83 "authorize_url": add_qs(v.authorize_url, {"response_type": "code", "client_id": v.client_id}), 84 } 85 for k, v in settings.oauth_providers.items() 86 ] 87 88 89 @router.get("/oauth/links/{user_id}", responses=admin_responses(list[OAuthConnection], UserNotFoundError)) 90 async def get_oauth_connections( 91 user: models.User = get_user(models.User.oauth_connections, require_self_or_admin=True) 92 ) -> Any: 93 """ 94 Return a list of all OAuth connections for the given user. 95 96 *Requirements:* **SELF** or **ADMIN** 97 """ 98 99 return [connection.serialize for connection in user.oauth_connections] 100 101 102 @router.post( 103 "/oauth/links/{user_id}", 104 responses=admin_responses( 105 OAuthConnection, UserNotFoundError, RemoteAlreadyLinkedError, ProviderNotFoundError, InvalidOAuthCodeError 106 ), 107 ) 108 async def create_oauth_connection(login: OAuthLogin, user: models.User = get_user(require_self_or_admin=True)) -> Any: 109 """ 110 Create a new OAuth connection for the given user. 111 112 The client can use almost the same procedure as described in the `POST /sessions/oauth` endpoint documentation. 113 114 *Requirements:* **SELF** or **ADMIN** 115 """ 116 117 user_id, display_name = await resolve_code(login) 118 119 if await db.exists(filter_by(models.OAuthUserConnection, provider_id=login.provider_id, remote_user_id=user_id)): 120 raise RemoteAlreadyLinkedError 121 122 connection = await models.OAuthUserConnection.create(user.id, login.provider_id, user_id, display_name) 123 124 return connection.serialize 125 126 127 @router.delete( 128 "/oauth/links/{user_id}/{connection_id}", 129 responses=admin_responses(bool, UserNotFoundError, CannotDeleteLastLoginMethodError, ConnectionNotFoundError), 130 ) 131 async def delete_oauth_connection( 132 connection_id: str, user: models.User = get_user(models.User.oauth_connections, require_self_or_admin=True) 133 ) -> Any: 134 """ 135 Delete an existing OAuth connection. 136 137 *Requirements:* **SELF** or **ADMIN** 138 """ 139 140 if not user.password and len(user.oauth_connections) <= 1: 141 raise CannotDeleteLastLoginMethodError 142 143 if not (connection := await db.get(models.OAuthUserConnection, id=connection_id, user_id=user.id)): 144 raise ConnectionNotFoundError 145 146 await db.delete(connection) 147 return True