oauth.py
1 from datetime import timedelta 2 import logging 3 4 import aiohttp 5 from authlib.integrations.starlette_client import OAuth 6 from authlib.oidc.core import UserInfo 7 from fastapi import ( 8 HTTPException, 9 status, 10 ) 11 from starlette.responses import RedirectResponse 12 13 14 from restai import config 15 from restai.auth import create_access_token 16 from restai.config import ( 17 LOG_LEVEL, 18 OAUTH_PROVIDERS, 19 ) 20 from restai.constants import ERROR_MESSAGES 21 22 23 logging.basicConfig(level=LOG_LEVEL) 24 25 26 class OAuthManager: 27 def __init__(self, app, db_wrapper): 28 self.oauth = OAuth() 29 self.app = app 30 self.db_wrapper = db_wrapper 31 for _, provider_config in OAUTH_PROVIDERS.items(): 32 provider_config["register"](self.oauth) 33 34 def reinit(self): 35 """Re-register OAuth providers from current config.""" 36 self.oauth = OAuth() 37 for _, provider_config in OAUTH_PROVIDERS.items(): 38 provider_config["register"](self.oauth) 39 40 def get_client(self, provider_name): 41 return self.oauth.create_client(provider_name) 42 43 async def handle_login(self, request, provider): 44 if provider not in OAUTH_PROVIDERS: 45 raise HTTPException(404) 46 # If the provider has a custom redirect URL, use that, otherwise automatically generate one 47 redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( 48 "oauth_callback", provider=provider 49 ) 50 client = self.get_client(provider) 51 if client is None: 52 raise HTTPException(404) 53 return await client.authorize_redirect(request, redirect_uri) 54 55 async def handle_callback(self, request, provider, response): 56 if provider not in OAUTH_PROVIDERS: 57 raise HTTPException(404) 58 client = self.get_client(provider) 59 try: 60 token = await client.authorize_access_token(request) 61 except Exception as e: 62 logging.warning(f"OAuth callback error: {e}") 63 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) 64 user_data: UserInfo = token.get("userinfo") 65 if not user_data or config.OAUTH_EMAIL_CLAIM not in user_data: 66 user_data: UserInfo = await client.userinfo(token=token) 67 if not user_data: 68 logging.warning(f"OAuth callback failed, user data is missing: {token}") 69 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) 70 71 sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub")) 72 if not sub: 73 logging.warning(f"OAuth callback failed, sub is missing: {user_data}") 74 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) 75 provider_sub = f"{provider}@{sub}" 76 email_claim = config.OAUTH_EMAIL_CLAIM 77 email = user_data.get(email_claim, "") 78 if not email: 79 if provider == "github": 80 try: 81 access_token = token.get("access_token") 82 headers = {"Authorization": f"Bearer {access_token}"} 83 async with aiohttp.ClientSession() as session: 84 async with session.get( 85 "https://api.github.com/user/emails", headers=headers 86 ) as resp: 87 if resp.ok: 88 emails = await resp.json() 89 # use the primary email as the user's email 90 primary_email = next( 91 (e["email"] for e in emails if e.get("primary")), 92 None, 93 ) 94 if primary_email: 95 email = primary_email 96 else: 97 logging.warning( 98 "No primary email found in GitHub response" 99 ) 100 raise HTTPException( 101 400, detail=ERROR_MESSAGES.INVALID_CRED 102 ) 103 else: 104 logging.warning("Failed to fetch GitHub email") 105 raise HTTPException( 106 400, detail=ERROR_MESSAGES.INVALID_CRED 107 ) 108 except Exception as e: 109 logging.warning(f"Error fetching GitHub email: {e}") 110 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) 111 else: 112 logging.warning(f"OAuth callback failed, email is missing: {user_data}") 113 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) 114 email = email.lower() 115 if ( 116 "*" not in config.OAUTH_ALLOWED_DOMAINS 117 and email.split("@")[-1] not in config.OAUTH_ALLOWED_DOMAINS 118 ): 119 logging.warning( 120 f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}" 121 ) 122 raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) 123 124 user = self.db_wrapper.get_user_by_username(email) 125 if user is None and config.AUTO_CREATE_USER: 126 sso_restricted = self.db_wrapper.get_setting_value("sso_auto_restricted", "true").lower() in ("true", "1") 127 user = self.db_wrapper.create_user(email, None, False, False, restricted=sso_restricted) 128 self.db_wrapper.db.commit() 129 sso_team_id = self.db_wrapper.get_setting_value("sso_auto_team_id", "") 130 if sso_team_id: 131 try: 132 team = self.db_wrapper.get_team_by_id(int(sso_team_id)) 133 if team: 134 self.db_wrapper.add_user_to_team(team, user) 135 except (ValueError, TypeError): 136 pass 137 elif user is None: 138 raise HTTPException( 139 status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED 140 ) 141 142 jwt_token = create_access_token( 143 data={"username": user.username}, expires_delta=timedelta(minutes=1440) 144 ) 145 146 response.set_cookie( 147 key="restai_token", 148 value=jwt_token, 149 samesite="strict", 150 expires=86400, 151 httponly=True, 152 ) 153 154 base_url = config.RESTAI_URL or "" 155 if base_url and not base_url.startswith("http"): 156 base_url = "https://" + base_url 157 return RedirectResponse(base_url + "/admin", headers=response.headers)