database.py
1 """SQLAlchemy database setup.""" 2 3 import datetime 4 import logging 5 import shutil 6 import sqlite3 7 from collections.abc import AsyncGenerator 8 from os import PathLike 9 from pathlib import Path 10 11 from scheduler.asyncio.scheduler import Scheduler 12 from sqlalchemy import URL 13 from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine 14 15 from spoolman import env 16 from spoolman.prometheus.metrics import filament_metrics, spool_metrics 17 18 logger = logging.getLogger(__name__) 19 20 21 def get_connection_url() -> URL: 22 """Construct the connection URL for the database based on environment variables.""" 23 db_type = env.get_database_type() 24 host = env.get_host() 25 port = env.get_port() 26 database = env.get_database() 27 query = env.get_query() 28 username = env.get_username() 29 password = env.get_password() 30 31 if db_type is None: 32 db_type = env.DatabaseType.SQLITE 33 34 database = str(env.get_data_dir().joinpath("spoolman.db")) 35 logger.info('No database type specified, using a default SQLite database located at "%s"', database) 36 elif db_type is env.DatabaseType.SQLITE: 37 if database is not None: 38 raise ValueError("Cannot specify a database name when using SQLite.") 39 40 database = str(env.get_data_dir().joinpath("spoolman.db")) 41 logger.info('Using SQLite database located at "%s"', database) 42 else: 43 logger.info('Connecting to database of type "%s" at "%s:%s"', db_type, host, port) 44 45 return URL.create( 46 drivername=db_type.to_drivername(), 47 host=host, 48 port=port, 49 database=database, 50 query=query or {}, 51 username=username, 52 password=password, 53 ) 54 55 56 class Database: 57 connection_url: URL 58 engine: AsyncEngine | None 59 session_maker: async_sessionmaker[AsyncSession] | None 60 61 def __init__(self, connection_url: URL) -> None: 62 """Construct the Database wrapper and set config parameters.""" 63 self.connection_url = connection_url 64 65 def is_file_based_sqlite(self) -> bool: 66 """Return True if the database is file based.""" 67 return ( 68 self.connection_url.drivername[:6] == "sqlite" 69 and self.connection_url.database is not None 70 and self.connection_url.database != ":memory:" 71 ) 72 73 def connect(self) -> None: 74 """Connect to the database.""" 75 if env.get_logging_level() == logging.DEBUG: 76 logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) 77 78 connect_args = {} 79 if self.connection_url.drivername == "sqlite+aiosqlite": 80 connect_args["timeout"] = 60 81 connection_options = {} 82 if self.connection_url.drivername == "mysql+aiomysql": 83 connection_options["pool_recycle"] = 3600 84 self.engine = create_async_engine( 85 self.connection_url, 86 connect_args=connect_args, 87 pool_pre_ping=True, 88 **connection_options, 89 ) 90 self.session_maker = async_sessionmaker(self.engine, autocommit=False, autoflush=True, expire_on_commit=False) 91 92 def backup(self, target_path: str | PathLike[str]) -> None: 93 """Backup the database.""" 94 if not self.is_file_based_sqlite() or self.connection_url.database is None: 95 return 96 97 logger.info("Backing up SQLite database to %s", target_path) 98 99 def progress(_: int, remaining: int, total: int) -> None: 100 logger.info("Copied %d of %d pages.", total - remaining, total) 101 102 if self.connection_url.database == target_path: 103 raise ValueError("Cannot backup database to itself.") 104 if Path(target_path).exists(): 105 raise ValueError("Backup target file already exists.") 106 107 with sqlite3.connect(self.connection_url.database) as src, sqlite3.connect(target_path) as dst: 108 src.backup(dst, pages=1, progress=progress) 109 110 logger.info("Backup complete.") 111 112 def backup_and_rotate( 113 self, 114 backup_folder: str | PathLike[str], 115 num_backups: int = 5, 116 ) -> Path | None: 117 """Backup the database and rotate existing backups. 118 119 Args: 120 backup_folder: The folder to store the backups in. 121 num_backups: The number of backups to keep. 122 123 Returns: 124 The path to the created backup or None if no backup was created. 125 126 """ 127 if not self.is_file_based_sqlite() or self.connection_url.database is None: 128 logger.info("Skipping backup as the database is not SQLite.") 129 return None 130 131 backup_folder = Path(backup_folder) 132 backup_folder.mkdir(parents=True, exist_ok=True) 133 134 # Delete oldest backup 135 backup_path = backup_folder.joinpath(f"spoolman.db.{num_backups}") 136 if backup_path.exists(): 137 logger.info("Deleting oldest backup %s", backup_path) 138 backup_path.unlink() 139 140 # Rotate existing backups 141 for i in range(num_backups - 1, -1, -1): 142 if i == 0: 143 backup_path = backup_folder.joinpath("spoolman.db") 144 else: 145 backup_path = backup_folder.joinpath(f"spoolman.db.{i}") 146 if backup_path.exists(): 147 logger.debug("Rotating backup %s to %s", backup_path, backup_folder.joinpath(f"spoolman.db.{i + 1}")) 148 shutil.move(backup_path, backup_folder.joinpath(f"spoolman.db.{i + 1}")) 149 150 # Create new backup 151 backup_path = backup_folder.joinpath("spoolman.db") 152 self.backup(backup_path) 153 154 return backup_path 155 156 157 __db: Database | None = None 158 159 160 def setup_db(connection_url: URL) -> None: 161 """Connect to the database. 162 163 Args: 164 connection_url: The URL to connect to the database. 165 166 """ 167 global __db # noqa: PLW0603 168 __db = Database(connection_url) 169 __db.connect() 170 171 172 async def backup_global_db(num_backups: int = 5) -> Path | None: 173 """Backup the database and rotate existing backups. 174 175 Returns: 176 The path to the created backup or None if no backup was created. 177 178 """ 179 if __db is None: 180 raise RuntimeError("DB is not setup.") 181 return __db.backup_and_rotate(env.get_backups_dir(), num_backups=num_backups) 182 183 184 async def _backup_task() -> Path | None: 185 """Perform scheduled backup of the database.""" 186 logger.info("Performing scheduled database backup.") 187 if __db is None: 188 raise RuntimeError("DB is not setup.") 189 return __db.backup_and_rotate(env.get_backups_dir(), num_backups=5) 190 191 192 async def _metrics() -> None: 193 """Create some useful prometheus metrics.""" 194 logger.debug("Start metrics collection") 195 async for session in get_db_session(): 196 await filament_metrics(session) 197 await spool_metrics(session) 198 logger.debug("End metrics collection") 199 200 201 def schedule_tasks(scheduler: Scheduler) -> None: 202 """Schedule tasks to be executed by the provided scheduler. 203 204 Args: 205 scheduler: The scheduler to use for scheduling tasks. 206 207 """ 208 if __db is None: 209 raise RuntimeError("DB is not setup.") 210 if env.is_metrics_enabled(): 211 logger.info("Scheduling automatic metric collection.") 212 # Run every minute, may be needs specify timer 213 scheduler.minutely(datetime.time(second=0), _metrics) # type: ignore[arg-type] 214 if not env.is_automatic_backup_enabled(): 215 return 216 if "sqlite" in __db.connection_url.drivername: 217 logger.info("Scheduling automatic database backup for midnight.") 218 # Schedule for midnight 219 scheduler.daily(datetime.time(hour=0, minute=0, second=0), _backup_task) # type: ignore[arg-type] 220 221 222 async def get_db_session() -> AsyncGenerator[AsyncSession, None]: 223 """Get a DB session to be used with FastAPI's dependency system. 224 225 Yields: 226 The database session. 227 228 """ 229 if __db is None or __db.session_maker is None: 230 raise RuntimeError("DB is not setup.") 231 async with __db.session_maker() as session: 232 try: 233 yield session 234 await session.commit() 235 except Exception as exc: 236 await session.rollback() 237 raise exc 238 finally: 239 await session.close()