database.py
1 from asyncio import Event 2 from contextvars import ContextVar 3 from typing import Any, AsyncIterator, Type, TypeVar, cast 4 5 from sqlalchemy import Column 6 from sqlalchemy.engine import Result 7 from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine 8 from sqlalchemy.future import select as sa_select 9 from sqlalchemy.orm import DeclarativeMeta, registry, selectinload 10 from sqlalchemy.sql import Executable 11 from sqlalchemy.sql.expression import Delete 12 from sqlalchemy.sql.expression import delete as sa_delete 13 from sqlalchemy.sql.expression import exists as sa_exists 14 from sqlalchemy.sql.functions import count 15 from sqlalchemy.sql.selectable import Exists, Select 16 17 from ..logger import get_logger 18 from ..settings import settings 19 20 21 T = TypeVar("T") 22 23 logger = get_logger(__name__) 24 25 26 def select(entity: Any, *args: Column[Any]) -> Select: 27 """Shortcut for :meth:`sqlalchemy.future.select`""" 28 29 if not args: 30 return sa_select(entity) 31 32 options = [] 33 for arg in args: 34 if isinstance(arg, (tuple, list)): 35 head, *tail = arg 36 opt = selectinload(head) 37 for x in tail: 38 opt = opt.selectinload(x) 39 options.append(opt) 40 else: 41 options.append(selectinload(arg)) 42 43 return sa_select(entity).options(*options) 44 45 46 def filter_by(cls: Any, *args: Column[Any], **kwargs: Any) -> Select: 47 """Shortcut for :meth:`sqlalchemy.future.Select.filter_by`""" 48 49 return select(cls, *args).filter_by(**kwargs) 50 51 52 def exists(statement: Executable, *entities: Column[Any], **kwargs: Any) -> Exists: 53 """Shortcut for :meth:`sqlalchemy.future.select`""" 54 55 return sa_exists(statement, *entities, **kwargs) 56 57 58 def delete(table: Any) -> Delete: 59 """Shortcut for :meth:`sqlalchemy.sql.expression.delete`""" 60 61 return sa_delete(table) 62 63 64 class Base(metaclass=DeclarativeMeta): 65 __abstract__ = True 66 registry = registry() 67 metadata = registry.metadata 68 69 __table_args__ = {"mysql_collate": "utf8mb4_bin"} 70 71 def __init__(self, **kwargs: Any) -> None: 72 self.registry.constructor(self, **kwargs) 73 74 75 class DB: 76 def __init__(self, url: str, **kwargs: Any): 77 self.engine: AsyncEngine = create_async_engine(url, **kwargs) 78 self._session: ContextVar[AsyncSession | None] = ContextVar("session", default=None) 79 self._close_event: ContextVar[Event | None] = ContextVar("close_event", default=None) 80 81 async def create_tables(self) -> None: 82 """Create all tables defined in enabled cog packages.""" 83 84 logger.debug("creating tables") 85 async with self.engine.begin() as conn: 86 await conn.run_sync(Base.metadata.create_all) 87 88 async def add(self, obj: T) -> T: 89 """ 90 Add a new row to the database 91 92 :param obj: the row to insert 93 :return: the same row 94 """ 95 96 self.session.add(obj) 97 return obj 98 99 async def delete(self, obj: T) -> T: 100 """ 101 Remove a row from the database 102 103 :param obj: the row to remove 104 :return: the same row 105 """ 106 107 await self.session.delete(obj) 108 return obj 109 110 async def exec(self, statement: Executable | str) -> Result: 111 """Execute an sql statement and return the result.""" 112 113 return await self.session.execute(cast(Executable, statement)) 114 115 async def stream(self, statement: Executable | str) -> AsyncIterator[Any]: 116 """Execute an sql statement and stream the result.""" 117 118 return cast(AsyncIterator[Any], (await self.session.stream(statement)).scalars()) 119 120 async def all(self, statement: Executable | str) -> list[Any]: 121 """Execute an sql statement and return all results as a list.""" 122 123 return [x async for x in await self.stream(statement)] 124 125 async def first(self, statement: Executable | str) -> Any | None: 126 """Execute an sql statement and return the first result.""" 127 128 return (await self.exec(statement)).scalar() 129 130 async def exists(self, statement: Executable | str, *args: Column[Any], **kwargs: Any) -> bool: 131 """Execute an sql statement and return whether it returned at least one row.""" 132 133 return cast(bool, await self.first(exists(cast(Executable, statement), *args, **kwargs).select())) 134 135 async def count(self, statement: Select) -> int: 136 """Execute an sql statement and return the number of returned rows.""" 137 138 return cast(int, await self.first(select(count()).select_from(statement.subquery()))) 139 140 async def get(self, cls: Type[T], *args: Column[Any], **kwargs: Any) -> T | None: 141 """Shortcut for first(filter_by(...))""" 142 143 return await self.first(filter_by(cls, *args, **kwargs)) 144 145 async def commit(self) -> None: 146 """Shortcut for :meth:`sqlalchemy.ext.asyncio.AsyncSession.commit`""" 147 148 if self._session.get(): 149 await self.session.commit() 150 151 async def close(self) -> None: 152 """Close the current session""" 153 154 if self._session.get(): 155 await self.session.close() 156 if close_event := self._close_event.get(): 157 close_event.set() 158 159 def create_session(self) -> AsyncSession: 160 """Create a new async session and store it in the context variable.""" 161 162 self._session.set(session := AsyncSession(self.engine)) 163 self._close_event.set(Event()) 164 return session 165 166 @property 167 def session(self) -> AsyncSession: 168 """Get the session object for the current task""" 169 170 return cast(AsyncSession, self._session.get()) 171 172 async def wait_for_close_event(self) -> None: 173 if close_event := self._close_event.get(): 174 await close_event.wait() 175 176 177 def get_database() -> DB: 178 """ 179 Create a database connection object using the environment variables 180 181 :return: The DB object 182 """ 183 184 return DB( 185 url=settings.database_url, 186 pool_pre_ping=True, 187 pool_recycle=settings.pool_recycle, 188 pool_size=settings.pool_size, 189 max_overflow=settings.max_overflow, 190 echo=settings.sql_show_statements, 191 )