user.py
1 from __future__ import annotations 2 3 from datetime import datetime 4 from typing import TYPE_CHECKING, Any 5 from uuid import uuid4 6 7 from sqlalchemy import Boolean, Column, DateTime, String, func 8 from sqlalchemy.orm import Mapped, relationship 9 from sqlalchemy.sql import Select 10 11 from ..database import Base, db, select 12 from ..logger import get_logger 13 from ..redis import redis 14 from ..settings import settings 15 from ..utils.jwt import decode_jwt 16 from ..utils.passwords import hash_password, verify_password 17 18 19 if TYPE_CHECKING: 20 from .oauth_user_connection import OAuthUserConnection 21 from .session import Session 22 23 logger = get_logger(__name__) 24 25 26 class User(Base): 27 __tablename__ = "user" 28 29 id: Mapped[str] = Column(String(36), primary_key=True, unique=True) 30 name: Mapped[str] = Column(String(32), unique=True) 31 password: Mapped[str | None] = Column(String(128), nullable=True) 32 registration: Mapped[datetime] = Column(DateTime) 33 last_login: Mapped[datetime | None] = Column(DateTime, nullable=True) 34 enabled: Mapped[bool] = Column(Boolean, default=True) 35 admin: Mapped[bool] = Column(Boolean, default=False) 36 mfa_secret: Mapped[str | None] = Column(String(32), nullable=True) 37 mfa_enabled: Mapped[bool] = Column(Boolean, default=False) 38 mfa_recovery_code: Mapped[str | None] = Column(String(64), nullable=True) 39 sessions: list[Session] = relationship("Session", back_populates="user", cascade="all, delete") 40 oauth_connections: list[OAuthUserConnection] = relationship( 41 "OAuthUserConnection", back_populates="user", cascade="all, delete" 42 ) 43 44 @property 45 def serialize(self) -> dict[str, Any]: 46 return { 47 "id": self.id, 48 "name": self.name, 49 "registration": self.registration.timestamp(), 50 "last_login": self.last_login.timestamp() if self.last_login else None, 51 "enabled": self.enabled, 52 "admin": self.admin, 53 "password": bool(self.password), 54 "mfa_enabled": self.mfa_enabled, 55 } 56 57 @staticmethod 58 async def create(name: str, password: str | None, enabled: bool, admin: bool) -> User: 59 user = User( 60 id=str(uuid4()), 61 name=name, 62 password=await hash_password(password) if password else None, 63 registration=datetime.utcnow(), 64 last_login=None, 65 enabled=enabled, 66 admin=admin, 67 mfa_secret=None, 68 mfa_enabled=False, 69 mfa_recovery_code=None, 70 ) 71 await db.add(user) 72 return user 73 74 @staticmethod 75 def filter_by_name(name: str) -> Select: 76 return select(User).where(func.lower(User.name) == name.lower()) 77 78 @staticmethod 79 async def initialize() -> None: 80 if await db.exists(select(User)): 81 return 82 83 await User.create(settings.admin_username, settings.admin_password, True, True) 84 logger.info(f"Admin user '{settings.admin_username}' has been created!") 85 86 async def check_password(self, password: str) -> bool: 87 if not self.password: 88 return False 89 90 return await verify_password(password, self.password) 91 92 async def change_password(self, password: str | None) -> None: 93 self.password = await hash_password(password) if password else None 94 95 async def create_session(self, device_name: str) -> tuple[Session, str, str]: 96 from .session import Session 97 98 self.last_login = datetime.utcnow() 99 return await Session.create(self.id, device_name) 100 101 @staticmethod 102 async def from_access_token(access_token: str) -> User | None: 103 if (data := decode_jwt(access_token, require=["uid", "sid", "rt"])) is None: 104 return None 105 if await redis.exists(f"session_logout:{data['rt']}"): 106 return None 107 108 return await db.get(User, id=data["uid"], enabled=True) 109 110 async def logout(self) -> None: 111 for session in self.sessions: 112 await session.logout()