/ spoolman / database / database.py
database.py
  1  """SQLAlchemy database setup."""
  2  
  3  import datetime
  4  import logging
  5  import shutil
  6  import sqlite3
  7  from collections.abc import AsyncGenerator
  8  from os import PathLike
  9  from pathlib import Path
 10  
 11  from scheduler.asyncio.scheduler import Scheduler
 12  from sqlalchemy import URL
 13  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
 14  
 15  from spoolman import env
 16  from spoolman.prometheus.metrics import filament_metrics, spool_metrics
 17  
 18  logger = logging.getLogger(__name__)
 19  
 20  
 21  def get_connection_url() -> URL:
 22      """Construct the connection URL for the database based on environment variables."""
 23      db_type = env.get_database_type()
 24      host = env.get_host()
 25      port = env.get_port()
 26      database = env.get_database()
 27      query = env.get_query()
 28      username = env.get_username()
 29      password = env.get_password()
 30  
 31      if db_type is None:
 32          db_type = env.DatabaseType.SQLITE
 33  
 34          database = str(env.get_data_dir().joinpath("spoolman.db"))
 35          logger.info('No database type specified, using a default SQLite database located at "%s"', database)
 36      elif db_type is env.DatabaseType.SQLITE:
 37          if database is not None:
 38              raise ValueError("Cannot specify a database name when using SQLite.")
 39  
 40          database = str(env.get_data_dir().joinpath("spoolman.db"))
 41          logger.info('Using SQLite database located at "%s"', database)
 42      else:
 43          logger.info('Connecting to database of type "%s" at "%s:%s"', db_type, host, port)
 44  
 45      return URL.create(
 46          drivername=db_type.to_drivername(),
 47          host=host,
 48          port=port,
 49          database=database,
 50          query=query or {},
 51          username=username,
 52          password=password,
 53      )
 54  
 55  
 56  class Database:
 57      connection_url: URL
 58      engine: AsyncEngine | None
 59      session_maker: async_sessionmaker[AsyncSession] | None
 60  
 61      def __init__(self, connection_url: URL) -> None:
 62          """Construct the Database wrapper and set config parameters."""
 63          self.connection_url = connection_url
 64  
 65      def is_file_based_sqlite(self) -> bool:
 66          """Return True if the database is file based."""
 67          return (
 68              self.connection_url.drivername[:6] == "sqlite"
 69              and self.connection_url.database is not None
 70              and self.connection_url.database != ":memory:"
 71          )
 72  
 73      def connect(self) -> None:
 74          """Connect to the database."""
 75          if env.get_logging_level() == logging.DEBUG:
 76              logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)
 77  
 78          connect_args = {}
 79          if self.connection_url.drivername == "sqlite+aiosqlite":
 80              connect_args["timeout"] = 60
 81          connection_options = {}
 82          if self.connection_url.drivername == "mysql+aiomysql":
 83              connection_options["pool_recycle"] = 3600
 84          self.engine = create_async_engine(
 85              self.connection_url,
 86              connect_args=connect_args,
 87              pool_pre_ping=True,
 88              **connection_options,
 89          )
 90          self.session_maker = async_sessionmaker(self.engine, autocommit=False, autoflush=True, expire_on_commit=False)
 91  
 92      def backup(self, target_path: str | PathLike[str]) -> None:
 93          """Backup the database."""
 94          if not self.is_file_based_sqlite() or self.connection_url.database is None:
 95              return
 96  
 97          logger.info("Backing up SQLite database to %s", target_path)
 98  
 99          def progress(_: int, remaining: int, total: int) -> None:
100              logger.info("Copied %d of %d pages.", total - remaining, total)
101  
102          if self.connection_url.database == target_path:
103              raise ValueError("Cannot backup database to itself.")
104          if Path(target_path).exists():
105              raise ValueError("Backup target file already exists.")
106  
107          with sqlite3.connect(self.connection_url.database) as src, sqlite3.connect(target_path) as dst:
108              src.backup(dst, pages=1, progress=progress)
109  
110          logger.info("Backup complete.")
111  
112      def backup_and_rotate(
113          self,
114          backup_folder: str | PathLike[str],
115          num_backups: int = 5,
116      ) -> Path | None:
117          """Backup the database and rotate existing backups.
118  
119          Args:
120              backup_folder: The folder to store the backups in.
121              num_backups: The number of backups to keep.
122  
123          Returns:
124              The path to the created backup or None if no backup was created.
125  
126          """
127          if not self.is_file_based_sqlite() or self.connection_url.database is None:
128              logger.info("Skipping backup as the database is not SQLite.")
129              return None
130  
131          backup_folder = Path(backup_folder)
132          backup_folder.mkdir(parents=True, exist_ok=True)
133  
134          # Delete oldest backup
135          backup_path = backup_folder.joinpath(f"spoolman.db.{num_backups}")
136          if backup_path.exists():
137              logger.info("Deleting oldest backup %s", backup_path)
138              backup_path.unlink()
139  
140          # Rotate existing backups
141          for i in range(num_backups - 1, -1, -1):
142              if i == 0:
143                  backup_path = backup_folder.joinpath("spoolman.db")
144              else:
145                  backup_path = backup_folder.joinpath(f"spoolman.db.{i}")
146              if backup_path.exists():
147                  logger.debug("Rotating backup %s to %s", backup_path, backup_folder.joinpath(f"spoolman.db.{i + 1}"))
148                  shutil.move(backup_path, backup_folder.joinpath(f"spoolman.db.{i + 1}"))
149  
150          # Create new backup
151          backup_path = backup_folder.joinpath("spoolman.db")
152          self.backup(backup_path)
153  
154          return backup_path
155  
156  
157  __db: Database | None = None
158  
159  
160  def setup_db(connection_url: URL) -> None:
161      """Connect to the database.
162  
163      Args:
164          connection_url: The URL to connect to the database.
165  
166      """
167      global __db  # noqa: PLW0603
168      __db = Database(connection_url)
169      __db.connect()
170  
171  
172  async def backup_global_db(num_backups: int = 5) -> Path | None:
173      """Backup the database and rotate existing backups.
174  
175      Returns:
176          The path to the created backup or None if no backup was created.
177  
178      """
179      if __db is None:
180          raise RuntimeError("DB is not setup.")
181      return __db.backup_and_rotate(env.get_backups_dir(), num_backups=num_backups)
182  
183  
184  async def _backup_task() -> Path | None:
185      """Perform scheduled backup of the database."""
186      logger.info("Performing scheduled database backup.")
187      if __db is None:
188          raise RuntimeError("DB is not setup.")
189      return __db.backup_and_rotate(env.get_backups_dir(), num_backups=5)
190  
191  
192  async def _metrics() -> None:
193      """Create some useful prometheus metrics."""
194      logger.debug("Start metrics collection")
195      async for session in get_db_session():
196          await filament_metrics(session)
197          await spool_metrics(session)
198      logger.debug("End metrics collection")
199  
200  
201  def schedule_tasks(scheduler: Scheduler) -> None:
202      """Schedule tasks to be executed by the provided scheduler.
203  
204      Args:
205          scheduler: The scheduler to use for scheduling tasks.
206  
207      """
208      if __db is None:
209          raise RuntimeError("DB is not setup.")
210      if env.is_metrics_enabled():
211          logger.info("Scheduling automatic metric collection.")
212          # Run every minute, may be needs specify timer
213          scheduler.minutely(datetime.time(second=0), _metrics)  # type: ignore[arg-type]
214      if not env.is_automatic_backup_enabled():
215          return
216      if "sqlite" in __db.connection_url.drivername:
217          logger.info("Scheduling automatic database backup for midnight.")
218          # Schedule for midnight
219          scheduler.daily(datetime.time(hour=0, minute=0, second=0), _backup_task)  # type: ignore[arg-type]
220  
221  
222  async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
223      """Get a DB session to be used with FastAPI's dependency system.
224  
225      Yields:
226          The database session.
227  
228      """
229      if __db is None or __db.session_maker is None:
230          raise RuntimeError("DB is not setup.")
231      async with __db.session_maker() as session:
232          try:
233              yield session
234              await session.commit()
235          except Exception as exc:
236              await session.rollback()
237              raise exc
238          finally:
239              await session.close()