/ src / db / database.py
database.py
 1  """
 2  Async SQLite database connection for Ag3ntum API.
 3  
 4  Uses SQLAlchemy async with aiosqlite for non-blocking database operations.
 5  """
 6  import logging
 7  from pathlib import Path
 8  from typing import AsyncGenerator
 9  
10  from sqlalchemy.ext.asyncio import (
11      AsyncSession,
12      async_sessionmaker,
13      create_async_engine,
14  )
15  from sqlalchemy.orm import DeclarativeBase
16  
17  from ..config import AGENT_DIR
18  
19  logger = logging.getLogger(__name__)
20  
21  # Database directory and file
22  DATA_DIR: Path = AGENT_DIR / "data"
23  DATABASE_PATH: Path = DATA_DIR / "ag3ntum.db"
24  
25  # SQLAlchemy async engine (SQLite with aiosqlite driver)
26  DATABASE_URL = f"sqlite+aiosqlite:///{DATABASE_PATH}"
27  
28  engine = create_async_engine(
29      DATABASE_URL,
30      echo=False,
31      connect_args={"check_same_thread": False},
32      pool_pre_ping=True,
33      pool_recycle=3600,
34  )
35  
36  # Session factory for dependency injection
37  AsyncSessionLocal = async_sessionmaker(
38      engine,
39      class_=AsyncSession,
40      expire_on_commit=False,
41  )
42  
43  
44  class Base(DeclarativeBase):
45      """Base class for SQLAlchemy ORM models."""
46      pass
47  
48  
49  async def init_db() -> None:
50      """
51      Initialize the database by creating all tables via create_all().
52  
53      NOTE: For production API startup, schema is managed by Alembic migrations
54      in entrypoint-api.sh (see src/db/migrations.py). This function is retained
55      for CLI tools (create_user.py, delete_user.py) and test fixtures.
56      """
57      DATA_DIR.mkdir(parents=True, exist_ok=True)
58  
59      async with engine.begin() as conn:
60          await conn.run_sync(Base.metadata.create_all)
61  
62      logger.info(f"Database initialized at {DATABASE_PATH}")
63  
64  
65  async def get_db() -> AsyncGenerator[AsyncSession, None]:
66      """
67      Dependency that yields database sessions.
68  
69      Usage in FastAPI:
70          @router.get("/example")
71          async def example(db: AsyncSession = Depends(get_db)):
72              ...
73      """
74      async with AsyncSessionLocal() as session:
75          try:
76              yield session
77          finally:
78              await session.close()