session.py
  1  from __future__ import annotations
  2  
  3  import hashlib
  4  import secrets
  5  from datetime import datetime, timedelta
  6  from typing import Any
  7  from uuid import uuid4
  8  
  9  from sqlalchemy import Column, DateTime, ForeignKey, String, Text
 10  from sqlalchemy.orm import Mapped, relationship
 11  
 12  from .user import User
 13  from ..database import Base, db, db_wrapper, delete
 14  from ..logger import get_logger
 15  from ..redis import redis
 16  from ..settings import settings
 17  from ..utils.jwt import decode_jwt, encode_jwt
 18  
 19  
 20  logger = get_logger(__name__)
 21  
 22  
 23  class SessionExpiredError(Exception):
 24      pass
 25  
 26  
 27  def _hash_token(token: str) -> str:
 28      return hashlib.sha256(token.encode()).hexdigest()
 29  
 30  
 31  class Session(Base):
 32      __tablename__ = "session"
 33  
 34      id: Mapped[str] = Column(String(36), primary_key=True, unique=True)
 35      user_id: Mapped[str] = Column(String(36), ForeignKey("user.id"))
 36      user: User = relationship("User", back_populates="sessions")
 37      device_name: Mapped[str] = Column(Text)
 38      last_update: Mapped[datetime] = Column(DateTime)
 39      refresh_token: Mapped[str] = Column(String(64), unique=True)
 40  
 41      @property
 42      def serialize(self) -> dict[str, Any]:
 43          return {
 44              "id": self.id,
 45              "user_id": self.user_id,
 46              "device_name": self.device_name,
 47              "last_update": self.last_update.timestamp(),
 48          }
 49  
 50      @staticmethod
 51      async def create(user_id: str, device_name: str) -> tuple[Session, str, str]:
 52          refresh_token = secrets.token_urlsafe(64)
 53          session = Session(
 54              id=str(uuid4()),
 55              user_id=user_id,
 56              device_name=device_name,
 57              last_update=datetime.utcnow(),
 58              refresh_token=_hash_token(refresh_token),
 59          )
 60          await db.add(session)
 61          return session, session._generate_access_token(), refresh_token
 62  
 63      def _generate_access_token(self) -> str:
 64          return encode_jwt(
 65              {"uid": self.user_id, "sid": self.id, "rt": self.refresh_token},
 66              timedelta(seconds=settings.access_token_ttl),
 67          )
 68  
 69      @staticmethod
 70      async def from_access_token(access_token: str) -> Session | None:
 71          if (data := decode_jwt(access_token, require=["uid", "sid", "rt"])) is None:
 72              return None
 73          if await redis.exists(f"session_logout:{data['rt']}"):
 74              return None
 75  
 76          return await db.get(Session, Session.user, id=data["sid"])
 77  
 78      @staticmethod
 79      async def refresh(refresh_token: str) -> tuple[Session, str, str]:
 80          token_hash = _hash_token(refresh_token)
 81          session: Session | None = await db.get(Session, Session.user, refresh_token=token_hash)
 82          if not session:
 83              raise ValueError("Invalid refresh token")
 84          if datetime.utcnow() > session.last_update + timedelta(seconds=settings.refresh_token_ttl):
 85              await session.logout()
 86              raise SessionExpiredError
 87  
 88          await redis.setex(f"session_logout:{session.refresh_token}", settings.access_token_ttl, 1)
 89          refresh_token = secrets.token_urlsafe(64)
 90          session.refresh_token = _hash_token(refresh_token)
 91          session.last_update = datetime.utcnow()
 92          return session, session._generate_access_token(), refresh_token
 93  
 94      async def logout(self) -> None:
 95          await redis.setex(f"session_logout:{self.refresh_token}", settings.access_token_ttl, 1)
 96          await db.delete(self)
 97  
 98  
 99  @db_wrapper
100  async def clean_expired_sessions() -> None:
101      await db.exec(
102          delete(Session).where(Session.last_update < datetime.utcnow() - timedelta(seconds=settings.refresh_token_ttl))
103      )