user.py
  1  from __future__ import annotations
  2  
  3  from datetime import datetime
  4  from typing import TYPE_CHECKING, Any
  5  from uuid import uuid4
  6  
  7  from sqlalchemy import Boolean, Column, DateTime, String, func
  8  from sqlalchemy.orm import Mapped, relationship
  9  from sqlalchemy.sql import Select
 10  
 11  from ..database import Base, db, select
 12  from ..logger import get_logger
 13  from ..redis import redis
 14  from ..settings import settings
 15  from ..utils.jwt import decode_jwt
 16  from ..utils.passwords import hash_password, verify_password
 17  
 18  
 19  if TYPE_CHECKING:
 20      from .oauth_user_connection import OAuthUserConnection
 21      from .session import Session
 22  
 23  logger = get_logger(__name__)
 24  
 25  
 26  class User(Base):
 27      __tablename__ = "user"
 28  
 29      id: Mapped[str] = Column(String(36), primary_key=True, unique=True)
 30      name: Mapped[str] = Column(String(32), unique=True)
 31      password: Mapped[str | None] = Column(String(128), nullable=True)
 32      registration: Mapped[datetime] = Column(DateTime)
 33      last_login: Mapped[datetime | None] = Column(DateTime, nullable=True)
 34      enabled: Mapped[bool] = Column(Boolean, default=True)
 35      admin: Mapped[bool] = Column(Boolean, default=False)
 36      mfa_secret: Mapped[str | None] = Column(String(32), nullable=True)
 37      mfa_enabled: Mapped[bool] = Column(Boolean, default=False)
 38      mfa_recovery_code: Mapped[str | None] = Column(String(64), nullable=True)
 39      sessions: list[Session] = relationship("Session", back_populates="user", cascade="all, delete")
 40      oauth_connections: list[OAuthUserConnection] = relationship(
 41          "OAuthUserConnection", back_populates="user", cascade="all, delete"
 42      )
 43  
 44      @property
 45      def serialize(self) -> dict[str, Any]:
 46          return {
 47              "id": self.id,
 48              "name": self.name,
 49              "registration": self.registration.timestamp(),
 50              "last_login": self.last_login.timestamp() if self.last_login else None,
 51              "enabled": self.enabled,
 52              "admin": self.admin,
 53              "password": bool(self.password),
 54              "mfa_enabled": self.mfa_enabled,
 55          }
 56  
 57      @staticmethod
 58      async def create(name: str, password: str | None, enabled: bool, admin: bool) -> User:
 59          user = User(
 60              id=str(uuid4()),
 61              name=name,
 62              password=await hash_password(password) if password else None,
 63              registration=datetime.utcnow(),
 64              last_login=None,
 65              enabled=enabled,
 66              admin=admin,
 67              mfa_secret=None,
 68              mfa_enabled=False,
 69              mfa_recovery_code=None,
 70          )
 71          await db.add(user)
 72          return user
 73  
 74      @staticmethod
 75      def filter_by_name(name: str) -> Select:
 76          return select(User).where(func.lower(User.name) == name.lower())
 77  
 78      @staticmethod
 79      async def initialize() -> None:
 80          if await db.exists(select(User)):
 81              return
 82  
 83          await User.create(settings.admin_username, settings.admin_password, True, True)
 84          logger.info(f"Admin user '{settings.admin_username}' has been created!")
 85  
 86      async def check_password(self, password: str) -> bool:
 87          if not self.password:
 88              return False
 89  
 90          return await verify_password(password, self.password)
 91  
 92      async def change_password(self, password: str | None) -> None:
 93          self.password = await hash_password(password) if password else None
 94  
 95      async def create_session(self, device_name: str) -> tuple[Session, str, str]:
 96          from .session import Session
 97  
 98          self.last_login = datetime.utcnow()
 99          return await Session.create(self.id, device_name)
100  
101      @staticmethod
102      async def from_access_token(access_token: str) -> User | None:
103          if (data := decode_jwt(access_token, require=["uid", "sid", "rt"])) is None:
104              return None
105          if await redis.exists(f"session_logout:{data['rt']}"):
106              return None
107  
108          return await db.get(User, id=data["uid"], enabled=True)
109  
110      async def logout(self) -> None:
111          for session in self.sessions:
112              await session.logout()