/ sussro_services / services / token_service.py
token_service.py
  1  """Token service for handling authentication tokens."""
  2  from datetime import datetime, timedelta
  3  from typing import Any, Dict, Optional, Union
  4  
  5  from fastapi import Depends, HTTPException, status
  6  from fastapi.security import OAuth2PasswordBearer
  7  from jose import JWTError, jwt
  8  from pydantic import ValidationError
  9  from sqlalchemy.orm import Session
 10  
 11  from .. import models, schemas
 12  from ..core.config import settings
 13  from ..core.security import create_access_token, verify_password
 14  from .user_service import get_user_by_email
 15  
 16  # OAuth2 scheme for token authentication
 17  oauth2_scheme = OAuth2PasswordBearer(
 18      tokenUrl=f"{settings.API_V1_STR}/auth/login"
 19  )
 20  
 21  
 22  def authenticate_user(
 23      db: Session, email: str, password: str
 24  ) -> Optional[models.User]:
 25      """Authenticate a user with email and password.
 26      
 27      Args:
 28          db: Database session
 29          email: User's email
 30          password: Plain text password
 31          
 32      Returns:
 33          Optional[models.User]: The authenticated user if successful, None otherwise
 34      """
 35      user = get_user_by_email(db, email=email)
 36      if not user:
 37          return None
 38      if not verify_password(password, user.hashed_password):
 39          return None
 40      return user
 41  
 42  
 43  def create_user_token(user: models.User) -> Dict[str, str]:
 44      """Create access and refresh tokens for a user.
 45      
 46      Args:
 47          user: The user to create tokens for
 48          
 49      Returns:
 50          Dict[str, str]: Dictionary containing access and refresh tokens
 51      """
 52      access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
 53      refresh_token_expires = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
 54      
 55      access_token = create_access_token(
 56          data={"sub": str(user.id)},
 57          expires_delta=access_token_expires
 58      )
 59      
 60      refresh_token = create_access_token(
 61          data={"sub": str(user.id), "type": "refresh"},
 62          expires_delta=refresh_token_expires
 63      )
 64      
 65      return {
 66          "access_token": access_token,
 67          "refresh_token": refresh_token,
 68          "token_type": "bearer",
 69      }
 70  
 71  
 72  def refresh_access_token(
 73      db: Session, refresh_token: str
 74  ) -> Dict[str, str]:
 75      """Refresh an access token using a refresh token.
 76      
 77      Args:
 78          db: Database session
 79          refresh_token: The refresh token
 80          
 81      Returns:
 82          Dict[str, str]: New access and refresh tokens
 83          
 84      Raises:
 85          HTTPException: If the refresh token is invalid
 86      """
 87      try:
 88          payload = jwt.decode(
 89              refresh_token,
 90              settings.SECRET_KEY,
 91              algorithms=[settings.ALGORITHM]
 92          )
 93          token_data = schemas.TokenPayload(**payload)
 94          
 95          if token_data.type != "refresh":
 96              raise HTTPException(
 97                  status_code=status.HTTP_403_FORBIDDEN,
 98                  detail="Invalid token type",
 99              )
100              
101          user = get_user_by_email(db, email=token_data.sub)
102          if not user:
103              raise HTTPException(
104                  status_code=status.HTTP_404_NOT_FOUND,
105                  detail="User not found",
106              )
107              
108          return create_user_token(user)
109          
110      except (JWTError, ValidationError) as e:
111          raise HTTPException(
112              status_code=status.HTTP_403_FORBIDDEN,
113              detail="Could not validate credentials",
114          ) from e
115  
116  
117  def get_current_user(
118      db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)
119  ) -> models.User:
120      """Get the current user from the token.
121      
122      Args:
123          db: Database session
124          token: JWT token
125          
126      Returns:
127          models.User: The current user
128          
129      Raises:
130          HTTPException: If the token is invalid or the user doesn't exist
131      """
132      credentials_exception = HTTPException(
133          status_code=status.HTTP_401_UNAUTHORIZED,
134          detail="Could not validate credentials",
135          headers={"WWW-Authenticate": "Bearer"},
136      )
137      
138      try:
139          payload = jwt.decode(
140              token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
141          )
142          token_data = schemas.TokenPayload(**payload)
143      except (JWTError, ValidationError) as e:
144          raise credentials_exception from e
145      
146      user = get_user_by_email(db, email=token_data.sub)
147      if user is None:
148          raise credentials_exception
149      return user
150  
151  
152  def get_current_active_user(
153      current_user: models.User = Depends(get_current_user),
154  ) -> models.User:
155      """Get the current active user.
156      
157      Args:
158          current_user: The current user
159          
160      Returns:
161          models.User: The current active user
162          
163      Raises:
164          HTTPException: If the user is inactive
165      """
166      if not current_user.is_active:
167          raise HTTPException(
168              status_code=status.HTTP_400_BAD_REQUEST,
169              detail="Inactive user",
170          )
171      return current_user
172  
173  
174  def get_current_active_superuser(
175      current_user: models.User = Depends(get_current_user),
176  ) -> models.User:
177      """Get the current active superuser.
178      
179      Args:
180          current_user: The current user
181          
182      Returns:
183          models.User: The current active superuser
184          
185      Raises:
186          HTTPException: If the user is not a superuser
187      """
188      if not current_user.is_superuser:
189          raise HTTPException(
190              status_code=status.HTTP_403_FORBIDDEN,
191              detail="The user doesn't have enough privileges",
192          )
193      return current_user
194  
195  
196  def verify_token(token: str) -> bool:
197      """Verify if a token is valid.
198      
199      Args:
200          token: The token to verify
201          
202      Returns:
203          bool: True if the token is valid, False otherwise
204      """
205      try:
206          jwt.decode(
207              token,
208              settings.SECRET_KEY,
209              algorithms=[settings.ALGORITHM]
210          )
211          return True
212      except JWTError:
213          return False