session.py
1 from __future__ import annotations 2 3 import hashlib 4 import secrets 5 from datetime import datetime, timedelta 6 from typing import Any 7 from uuid import uuid4 8 9 from sqlalchemy import Column, DateTime, ForeignKey, String, Text 10 from sqlalchemy.orm import Mapped, relationship 11 12 from .user import User 13 from ..database import Base, db, db_wrapper, delete 14 from ..logger import get_logger 15 from ..redis import redis 16 from ..settings import settings 17 from ..utils.jwt import decode_jwt, encode_jwt 18 19 20 logger = get_logger(__name__) 21 22 23 class SessionExpiredError(Exception): 24 pass 25 26 27 def _hash_token(token: str) -> str: 28 return hashlib.sha256(token.encode()).hexdigest() 29 30 31 class Session(Base): 32 __tablename__ = "session" 33 34 id: Mapped[str] = Column(String(36), primary_key=True, unique=True) 35 user_id: Mapped[str] = Column(String(36), ForeignKey("user.id")) 36 user: User = relationship("User", back_populates="sessions") 37 device_name: Mapped[str] = Column(Text) 38 last_update: Mapped[datetime] = Column(DateTime) 39 refresh_token: Mapped[str] = Column(String(64), unique=True) 40 41 @property 42 def serialize(self) -> dict[str, Any]: 43 return { 44 "id": self.id, 45 "user_id": self.user_id, 46 "device_name": self.device_name, 47 "last_update": self.last_update.timestamp(), 48 } 49 50 @staticmethod 51 async def create(user_id: str, device_name: str) -> tuple[Session, str, str]: 52 refresh_token = secrets.token_urlsafe(64) 53 session = Session( 54 id=str(uuid4()), 55 user_id=user_id, 56 device_name=device_name, 57 last_update=datetime.utcnow(), 58 refresh_token=_hash_token(refresh_token), 59 ) 60 await db.add(session) 61 return session, session._generate_access_token(), refresh_token 62 63 def _generate_access_token(self) -> str: 64 return encode_jwt( 65 {"uid": self.user_id, "sid": self.id, "rt": self.refresh_token}, 66 timedelta(seconds=settings.access_token_ttl), 67 ) 68 69 @staticmethod 70 async def from_access_token(access_token: str) -> Session | None: 71 if (data := decode_jwt(access_token, require=["uid", "sid", "rt"])) is None: 72 return None 73 if await redis.exists(f"session_logout:{data['rt']}"): 74 return None 75 76 return await db.get(Session, Session.user, id=data["sid"]) 77 78 @staticmethod 79 async def refresh(refresh_token: str) -> tuple[Session, str, str]: 80 token_hash = _hash_token(refresh_token) 81 session: Session | None = await db.get(Session, Session.user, refresh_token=token_hash) 82 if not session: 83 raise ValueError("Invalid refresh token") 84 if datetime.utcnow() > session.last_update + timedelta(seconds=settings.refresh_token_ttl): 85 await session.logout() 86 raise SessionExpiredError 87 88 await redis.setex(f"session_logout:{session.refresh_token}", settings.access_token_ttl, 1) 89 refresh_token = secrets.token_urlsafe(64) 90 session.refresh_token = _hash_token(refresh_token) 91 session.last_update = datetime.utcnow() 92 return session, session._generate_access_token(), refresh_token 93 94 async def logout(self) -> None: 95 await redis.setex(f"session_logout:{self.refresh_token}", settings.access_token_ttl, 1) 96 await db.delete(self) 97 98 99 @db_wrapper 100 async def clean_expired_sessions() -> None: 101 await db.exec( 102 delete(Session).where(Session.last_update < datetime.utcnow() - timedelta(seconds=settings.refresh_token_ttl)) 103 )