db.py
  1  """
  2  Database connection layer.
  3  
  4  Supports PostgreSQL (production) via asyncpg and SQLite (local dev) via aiosqlite.
  5  Switch via DATABASE_URL env var:
  6    PostgreSQL: DATABASE_URL=postgresql://user:pass@host/db
  7    SQLite:     DATABASE_URL=sqlite:///./db/sites.db  (or omit for auto-detect)
  8  """
  9  
 10  import json
 11  import os
 12  import sqlite3
 13  from contextlib import asynccontextmanager
 14  from pathlib import Path
 15  
 16  from dotenv import load_dotenv
 17  
 18  load_dotenv()
 19  
 20  DATABASE_URL = os.getenv("DATABASE_URL", "")
 21  _PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
 22  
 23  # Detect backend
 24  if DATABASE_URL.startswith("postgresql"):
 25      import asyncpg
 26      _pg_pool = None
 27  
 28      async def _get_pg_pool():
 29          global _pg_pool
 30          if _pg_pool is None:
 31              _pg_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
 32          return _pg_pool
 33  
 34      @asynccontextmanager
 35      async def get_conn():
 36          pool = await _get_pg_pool()
 37          async with pool.acquire() as conn:
 38              yield conn
 39  
 40      async def fetchall(query: str, *args) -> list[dict]:
 41          async with get_conn() as conn:
 42              rows = await conn.fetch(query, *args)
 43              return [dict(r) for r in rows]
 44  
 45      async def fetchone(query: str, *args) -> dict | None:
 46          async with get_conn() as conn:
 47              row = await conn.fetchrow(query, *args)
 48              return dict(row) if row else None
 49  
 50      async def fetchval(query: str, *args):
 51          async with get_conn() as conn:
 52              return await conn.fetchval(query, *args)
 53  
 54      # PostgreSQL uses $1, $2 placeholders — queries are written for PG natively
 55      PLACEHOLDER = "$"
 56  
 57  else:
 58      # SQLite fallback — translate $N → ? for local dev
 59      import re
 60      import aiosqlite
 61  
 62      _sqlite_path = DATABASE_URL.replace("sqlite:///", "") if DATABASE_URL.startswith("sqlite") else \
 63          str(_PROJECT_ROOT / os.getenv("DATABASE_PATH", "./db/sites.db").lstrip("./"))
 64  
 65      def _pg_to_sqlite(query: str) -> str:
 66          """Convert PostgreSQL $1,$2 placeholders to SQLite ?"""
 67          return re.sub(r"\$\d+", "?", query)
 68  
 69      def _pg_date_to_sqlite(query: str) -> str:
 70          """Convert PostgreSQL date functions to SQLite equivalents."""
 71          q = query
 72          q = q.replace("NOW()", "datetime('now')")
 73          q = q.replace("INTERVAL '", "'-")
 74          q = q.replace("'30 days'", "30 days'")
 75          q = q.replace("'48 hours'", "48 hours'")
 76          q = q.replace("'7 days'", "7 days'")
 77          q = q.replace("date_trunc('hour',", "strftime('%Y-%m-%d %H:00:00',")
 78          q = q.replace("date_trunc('day',", "DATE(")
 79          q = q.replace("EXTRACT(EPOCH FROM", "strftime('%s',")
 80          return q
 81  
 82      def _adapt(query: str) -> str:
 83          return _pg_date_to_sqlite(_pg_to_sqlite(query))
 84  
 85      @asynccontextmanager
 86      async def _get_sqlite():
 87          async with aiosqlite.connect(_sqlite_path) as db:
 88              db.row_factory = aiosqlite.Row
 89              await db.execute("PRAGMA query_only = ON")
 90              yield db
 91  
 92      @asynccontextmanager
 93      async def get_conn():
 94          async with _get_sqlite() as db:
 95              yield db
 96  
 97      async def fetchall(query: str, *args) -> list[dict]:
 98          async with _get_sqlite() as db:
 99              async with db.execute(_adapt(query), args) as cur:
100                  rows = await cur.fetchall()
101                  return [dict(r) for r in rows]
102  
103      async def fetchone(query: str, *args) -> dict | None:
104          async with _get_sqlite() as db:
105              async with db.execute(_adapt(query), args) as cur:
106                  row = await cur.fetchone()
107                  return dict(row) if row else None
108  
109      async def fetchval(query: str, *args):
110          async with _get_sqlite() as db:
111              async with db.execute(_adapt(query), args) as cur:
112                  row = await cur.fetchone()
113                  return row[0] if row else None
114  
115      PLACEHOLDER = "?"