session.py
  1  """Endpoints for session management"""
  2  
  3  import hashlib
  4  from typing import Any
  5  from uuid import uuid4
  6  
  7  from fastapi import APIRouter, Body, Request
  8  
  9  from .oauth import resolve_code
 10  from .. import models
 11  from ..auth import admin_auth, get_user, user_auth
 12  from ..database import db, filter_by
 13  from ..exceptions.auth import admin_responses, user_responses
 14  from ..exceptions.oauth import InvalidOAuthCodeError, ProviderNotFoundError
 15  from ..exceptions.session import (
 16      InvalidCredentialsError,
 17      InvalidRefreshTokenError,
 18      SessionNotFoundError,
 19      UserDisabledError,
 20  )
 21  from ..exceptions.user import InvalidCodeError, RecaptchaError, UserNotFoundError
 22  from ..models.session import SessionExpiredError
 23  from ..redis import redis
 24  from ..schemas.oauth import OAuthLogin
 25  from ..schemas.session import Login, LoginResponse, OAuthLoginResponse, Session
 26  from ..settings import settings
 27  from ..utils.docs import responses
 28  from ..utils.mfa import check_mfa_code
 29  from ..utils.recaptcha import check_recaptcha, recaptcha_enabled
 30  
 31  
 32  router = APIRouter()
 33  
 34  
 35  async def _check_mfa(user: models.User, mfa_code: str | None, recovery_code: str | None) -> bool:
 36      if not user.mfa_enabled or not user.mfa_secret:
 37          return True
 38  
 39      if recovery_code:
 40          if hashlib.sha256(recovery_code.encode()).hexdigest() != user.mfa_recovery_code:
 41              return False
 42  
 43          user.mfa_secret = None
 44          user.mfa_enabled = False
 45          user.mfa_recovery_code = None
 46          return True
 47  
 48      return mfa_code is not None and await check_mfa_code(mfa_code, user.mfa_secret)
 49  
 50  
 51  @router.get("/session", responses=user_responses(Session))
 52  async def get_current_session(session: models.Session = user_auth) -> Any:
 53      """
 54      Return the current session.
 55  
 56      *Requirements:* **USER**
 57      """
 58  
 59      return session.serialize
 60  
 61  
 62  @router.get("/sessions/{user_id}", responses=admin_responses(list[Session], UserNotFoundError))
 63  async def get_sessions(user: models.User = get_user(require_self_or_admin=True)) -> Any:
 64      """
 65      Return all sessions of a user.
 66  
 67      *Requirements:* **SELF** or **ADMIN**
 68      """
 69  
 70      return [session.serialize async for session in await db.stream(filter_by(models.Session, user_id=user.id))]
 71  
 72  
 73  @router.post(
 74      "/sessions",
 75      responses=responses(LoginResponse, RecaptchaError, InvalidCredentialsError, UserDisabledError, InvalidCodeError),
 76  )
 77  async def login(data: Login, request: Request) -> Any:
 78      """
 79      Create a new session via username/password authentication.
 80  
 81      The client should use the following procedure to login:
 82      1. Try to login with name and password only.
 83      2. If a `RecaptchaError` is raised, ask the user to solve the captcha (get the recaptcha sitekey from
 84         `GET /recaptcha`) and repeat the request with the obtained recaptcha response. Go back to step 2.
 85      3. If a `InvalidCredentialsError` is raised, try again with a different username or password. Go back to step 2.
 86      4. If a `InvalidCodeError` is raised, MFA is enabled. Try again with the current MFA code or a recovery code.
 87         Go back to step 2.
 88      5. If a `UserDisabledError` is raised, the user is disabled and a session cannot be created.
 89      6. If the request was successful, the response contains an access token and a refresh token for authentication.
 90  
 91      The value of the `User-agent` header is used as the device name of the created session.
 92      """
 93  
 94      name_hash: str = hashlib.sha256(data.name.lower().encode()).hexdigest()
 95      failed_attempts = int(await redis.get(key := f"failed_login_attempts:{name_hash}") or "0")
 96      if (
 97          recaptcha_enabled()
 98          and (0 <= settings.login_fails_before_captcha <= failed_attempts)
 99          and not (data.recaptcha_response and await check_recaptcha(data.recaptcha_response))
100      ):
101          raise RecaptchaError
102  
103      user: models.User | None = await db.first(models.User.filter_by_name(data.name))
104      if not user or not await user.check_password(data.password):
105          await redis.incr(key)
106          raise InvalidCredentialsError
107  
108      if user.mfa_enabled and not await _check_mfa(user, data.mfa_code, data.recovery_code):
109          await redis.incr(key)
110          raise InvalidCodeError
111  
112      await redis.delete(key)
113  
114      if not user.enabled:
115          raise UserDisabledError
116  
117      session, access_token, refresh_token = await user.create_session(request.headers.get("User-agent", "")[:256])
118      return {
119          "user": user.serialize,
120          "session": session.serialize,
121          "access_token": access_token,
122          "refresh_token": refresh_token,
123      }
124  
125  
126  @router.post(
127      "/sessions/oauth",
128      responses=responses(OAuthLoginResponse, ProviderNotFoundError, InvalidOAuthCodeError, UserDisabledError),
129  )
130  async def oauth_login(data: OAuthLogin, request: Request) -> Any:
131      """
132      Create a new session via OAuth.
133  
134      The client should use the following procedure to login:
135      1. Get the list of available OAuth providers from `GET /oauth/providers`.
136      2. Redirect to the `authorize_url` of the provider after adding the following parameters to the query string:
137          - `redirect_uri`: The URL to redirect to after the authorization process is complete.
138          - `state`: The `id` of the OAuth provider.
139      3. After the authorization process is complete, the client will be redirected to the `redirect_uri` with the
140      following query parameters:
141          - `code`: The authorization code.
142          - `state`: The `id` of the OAuth provider.
143      4. Send the authorization code to this endpoint.
144      5. If a `UserDisabledError` is raised, the user is disabled and a session cannot be created.
145      6. If the request was successful, the response contains either
146          - an access token and a refresh token for authentication, or
147          - a registration token for the `POST /users` endpoint to create a new user that is linked to the OAuth provider.
148  
149      The value of the `User-agent` header is used as the device name of the created session.
150      """
151  
152      remote_user_id, display_name = await resolve_code(data)
153      connection: models.OAuthUserConnection | None = await db.get(
154          models.OAuthUserConnection,
155          models.OAuthUserConnection.user,
156          provider_id=data.provider_id,
157          remote_user_id=remote_user_id,
158      )
159      if not connection:
160          token = str(uuid4())
161          async with redis.pipeline() as pipe:
162              ttl = settings.oauth_register_token_ttl
163              await pipe.setex(f"oauth_register_token:{token}:provider", ttl, data.provider_id)
164              await pipe.setex(f"oauth_register_token:{token}:user_id", ttl, remote_user_id)
165              await pipe.setex(f"oauth_register_token:{token}:display_name", ttl, display_name or "")
166              await pipe.execute()
167  
168          return {"register_token": token}
169  
170      user = connection.user
171      if not user.enabled:
172          raise UserDisabledError
173  
174      session, access_token, refresh_token = await user.create_session(request.headers.get("User-agent", "")[:256])
175      return {
176          "login": {
177              "user": user.serialize,
178              "session": session.serialize,
179              "access_token": access_token,
180              "refresh_token": refresh_token,
181          }
182      }
183  
184  
185  @router.post(
186      "/sessions/{user_id}", dependencies=[admin_auth], responses=admin_responses(LoginResponse, UserNotFoundError)
187  )
188  async def impersonate(request: Request, user: models.User = get_user()) -> Any:
189      """
190      Impersonate a specific user by creating a new session for them.
191  
192      *Requirements:* **ADMIN**
193      """
194  
195      session, access_token, refresh_token = await user.create_session(request.headers.get("User-agent", "")[:256])
196      return {
197          "user": user.serialize,
198          "session": session.serialize,
199          "access_token": access_token,
200          "refresh_token": refresh_token,
201      }
202  
203  
204  @router.put("/session", responses=responses(LoginResponse, InvalidRefreshTokenError))
205  async def refresh(refresh_token: str = Body(embed=True, description="The refresh token of an existing session")) -> Any:
206      """
207      Refresh access token and refresh token of an existing session.
208  
209      *Note:* The old refresh token is invalidated. To refresh the session again later, use the new refresh token that is
210      returned by this endpoint.
211      """
212  
213      try:
214          session, access_token, refresh_token = await models.Session.refresh(refresh_token)
215      except (ValueError, SessionExpiredError):
216          raise InvalidRefreshTokenError
217  
218      return {
219          "user": session.user.serialize,
220          "session": session.serialize,
221          "access_token": access_token,
222          "refresh_token": refresh_token,
223      }
224  
225  
226  @router.delete("/session", responses=user_responses(bool))
227  async def logout_current_session(session: models.Session = user_auth) -> Any:
228      """
229      Delete the current session.
230  
231      *Requirements:* **USER**
232      """
233  
234      await session.logout()
235      return True
236  
237  
238  @router.delete("/sessions/{user_id}", responses=admin_responses(bool, UserNotFoundError))
239  async def logout(user: models.User = get_user(models.User.sessions, require_self_or_admin=True)) -> Any:
240      """
241      Delete all sessions of a given user.
242  
243      *Requirements:* **SELF** or **ADMIN**
244      """
245  
246      await user.logout()
247      return True
248  
249  
250  @router.delete(
251      "/sessions/{user_id}/{session_id}", responses=admin_responses(bool, UserNotFoundError, SessionNotFoundError)
252  )
253  async def logout_session(session_id: str, user: models.User = get_user(require_self_or_admin=True)) -> Any:
254      """
255      Delete a specific session of a given user.
256  
257      *Requirements:* **SELF** or **ADMIN**
258      """
259  
260      session: models.Session | None = await db.get(models.Session, id=session_id, user_id=user.id)
261      if not session:
262          raise SessionNotFoundError
263  
264      await session.logout()
265      return True