/ restai / oauth.py
oauth.py
  1  from datetime import timedelta
  2  import logging
  3  
  4  import aiohttp
  5  from authlib.integrations.starlette_client import OAuth
  6  from authlib.oidc.core import UserInfo
  7  from fastapi import (
  8      HTTPException,
  9      status,
 10  )
 11  from starlette.responses import RedirectResponse
 12  
 13  
 14  from restai import config
 15  from restai.auth import create_access_token
 16  from restai.config import (
 17      LOG_LEVEL,
 18      OAUTH_PROVIDERS,
 19  )
 20  from restai.constants import ERROR_MESSAGES
 21  
 22  
 23  logging.basicConfig(level=LOG_LEVEL)
 24  
 25  
 26  class OAuthManager:
 27      def __init__(self, app, db_wrapper):
 28          self.oauth = OAuth()
 29          self.app = app
 30          self.db_wrapper = db_wrapper
 31          for _, provider_config in OAUTH_PROVIDERS.items():
 32              provider_config["register"](self.oauth)
 33  
 34      def reinit(self):
 35          """Re-register OAuth providers from current config."""
 36          self.oauth = OAuth()
 37          for _, provider_config in OAUTH_PROVIDERS.items():
 38              provider_config["register"](self.oauth)
 39  
 40      def get_client(self, provider_name):
 41          return self.oauth.create_client(provider_name)
 42  
 43      async def handle_login(self, request, provider):
 44          if provider not in OAUTH_PROVIDERS:
 45              raise HTTPException(404)
 46          # If the provider has a custom redirect URL, use that, otherwise automatically generate one
 47          redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
 48              "oauth_callback", provider=provider
 49          )
 50          client = self.get_client(provider)
 51          if client is None:
 52              raise HTTPException(404)
 53          return await client.authorize_redirect(request, redirect_uri)
 54  
 55      async def handle_callback(self, request, provider, response):
 56          if provider not in OAUTH_PROVIDERS:
 57              raise HTTPException(404)
 58          client = self.get_client(provider)
 59          try:
 60              token = await client.authorize_access_token(request)
 61          except Exception as e:
 62              logging.warning(f"OAuth callback error: {e}")
 63              raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 64          user_data: UserInfo = token.get("userinfo")
 65          if not user_data or config.OAUTH_EMAIL_CLAIM not in user_data:
 66              user_data: UserInfo = await client.userinfo(token=token)
 67          if not user_data:
 68              logging.warning(f"OAuth callback failed, user data is missing: {token}")
 69              raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 70  
 71          sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
 72          if not sub:
 73              logging.warning(f"OAuth callback failed, sub is missing: {user_data}")
 74              raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
 75          provider_sub = f"{provider}@{sub}"
 76          email_claim = config.OAUTH_EMAIL_CLAIM
 77          email = user_data.get(email_claim, "")
 78          if not email:
 79              if provider == "github":
 80                  try:
 81                      access_token = token.get("access_token")
 82                      headers = {"Authorization": f"Bearer {access_token}"}
 83                      async with aiohttp.ClientSession() as session:
 84                          async with session.get(
 85                              "https://api.github.com/user/emails", headers=headers
 86                          ) as resp:
 87                              if resp.ok:
 88                                  emails = await resp.json()
 89                                  # use the primary email as the user's email
 90                                  primary_email = next(
 91                                      (e["email"] for e in emails if e.get("primary")),
 92                                      None,
 93                                  )
 94                                  if primary_email:
 95                                      email = primary_email
 96                                  else:
 97                                      logging.warning(
 98                                          "No primary email found in GitHub response"
 99                                      )
100                                      raise HTTPException(
101                                          400, detail=ERROR_MESSAGES.INVALID_CRED
102                                      )
103                              else:
104                                  logging.warning("Failed to fetch GitHub email")
105                                  raise HTTPException(
106                                      400, detail=ERROR_MESSAGES.INVALID_CRED
107                                  )
108                  except Exception as e:
109                      logging.warning(f"Error fetching GitHub email: {e}")
110                      raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
111              else:
112                  logging.warning(f"OAuth callback failed, email is missing: {user_data}")
113                  raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
114          email = email.lower()
115          if (
116              "*" not in config.OAUTH_ALLOWED_DOMAINS
117              and email.split("@")[-1] not in config.OAUTH_ALLOWED_DOMAINS
118          ):
119              logging.warning(
120                  f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}"
121              )
122              raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
123  
124          user = self.db_wrapper.get_user_by_username(email)
125          if user is None and config.AUTO_CREATE_USER:
126              sso_restricted = self.db_wrapper.get_setting_value("sso_auto_restricted", "true").lower() in ("true", "1")
127              user = self.db_wrapper.create_user(email, None, False, False, restricted=sso_restricted)
128              self.db_wrapper.db.commit()
129              sso_team_id = self.db_wrapper.get_setting_value("sso_auto_team_id", "")
130              if sso_team_id:
131                  try:
132                      team = self.db_wrapper.get_team_by_id(int(sso_team_id))
133                      if team:
134                          self.db_wrapper.add_user_to_team(team, user)
135                  except (ValueError, TypeError):
136                      pass
137          elif user is None:
138              raise HTTPException(
139                  status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
140              )
141  
142          jwt_token = create_access_token(
143              data={"username": user.username}, expires_delta=timedelta(minutes=1440)
144          )
145  
146          response.set_cookie(
147              key="restai_token",
148              value=jwt_token,
149              samesite="strict",
150              expires=86400,
151              httponly=True,
152          )
153  
154          base_url = config.RESTAI_URL or ""
155          if base_url and not base_url.startswith("http"):
156              base_url = "https://" + base_url
157          return RedirectResponse(base_url + "/admin", headers=response.headers)