database.py
  1  from asyncio import Event
  2  from contextvars import ContextVar
  3  from typing import Any, AsyncIterator, Type, TypeVar, cast
  4  
  5  from sqlalchemy import Column
  6  from sqlalchemy.engine import Result
  7  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
  8  from sqlalchemy.future import select as sa_select
  9  from sqlalchemy.orm import DeclarativeMeta, registry, selectinload
 10  from sqlalchemy.sql import Executable
 11  from sqlalchemy.sql.expression import Delete
 12  from sqlalchemy.sql.expression import delete as sa_delete
 13  from sqlalchemy.sql.expression import exists as sa_exists
 14  from sqlalchemy.sql.functions import count
 15  from sqlalchemy.sql.selectable import Exists, Select
 16  
 17  from ..logger import get_logger
 18  from ..settings import settings
 19  
 20  
 21  T = TypeVar("T")
 22  
 23  logger = get_logger(__name__)
 24  
 25  
 26  def select(entity: Any, *args: Column[Any]) -> Select:
 27      """Shortcut for :meth:`sqlalchemy.future.select`"""
 28  
 29      if not args:
 30          return sa_select(entity)
 31  
 32      options = []
 33      for arg in args:
 34          if isinstance(arg, (tuple, list)):
 35              head, *tail = arg
 36              opt = selectinload(head)
 37              for x in tail:
 38                  opt = opt.selectinload(x)
 39              options.append(opt)
 40          else:
 41              options.append(selectinload(arg))
 42  
 43      return sa_select(entity).options(*options)
 44  
 45  
 46  def filter_by(cls: Any, *args: Column[Any], **kwargs: Any) -> Select:
 47      """Shortcut for :meth:`sqlalchemy.future.Select.filter_by`"""
 48  
 49      return select(cls, *args).filter_by(**kwargs)
 50  
 51  
 52  def exists(statement: Executable, *entities: Column[Any], **kwargs: Any) -> Exists:
 53      """Shortcut for :meth:`sqlalchemy.future.select`"""
 54  
 55      return sa_exists(statement, *entities, **kwargs)
 56  
 57  
 58  def delete(table: Any) -> Delete:
 59      """Shortcut for :meth:`sqlalchemy.sql.expression.delete`"""
 60  
 61      return sa_delete(table)
 62  
 63  
 64  class Base(metaclass=DeclarativeMeta):
 65      __abstract__ = True
 66      registry = registry()
 67      metadata = registry.metadata
 68  
 69      __table_args__ = {"mysql_collate": "utf8mb4_bin"}
 70  
 71      def __init__(self, **kwargs: Any) -> None:
 72          self.registry.constructor(self, **kwargs)
 73  
 74  
 75  class DB:
 76      def __init__(self, url: str, **kwargs: Any):
 77          self.engine: AsyncEngine = create_async_engine(url, **kwargs)
 78          self._session: ContextVar[AsyncSession | None] = ContextVar("session", default=None)
 79          self._close_event: ContextVar[Event | None] = ContextVar("close_event", default=None)
 80  
 81      async def create_tables(self) -> None:
 82          """Create all tables defined in enabled cog packages."""
 83  
 84          logger.debug("creating tables")
 85          async with self.engine.begin() as conn:
 86              await conn.run_sync(Base.metadata.create_all)
 87  
 88      async def add(self, obj: T) -> T:
 89          """
 90          Add a new row to the database
 91  
 92          :param obj: the row to insert
 93          :return: the same row
 94          """
 95  
 96          self.session.add(obj)
 97          return obj
 98  
 99      async def delete(self, obj: T) -> T:
100          """
101          Remove a row from the database
102  
103          :param obj: the row to remove
104          :return: the same row
105          """
106  
107          await self.session.delete(obj)
108          return obj
109  
110      async def exec(self, statement: Executable | str) -> Result:
111          """Execute an sql statement and return the result."""
112  
113          return await self.session.execute(cast(Executable, statement))
114  
115      async def stream(self, statement: Executable | str) -> AsyncIterator[Any]:
116          """Execute an sql statement and stream the result."""
117  
118          return cast(AsyncIterator[Any], (await self.session.stream(statement)).scalars())
119  
120      async def all(self, statement: Executable | str) -> list[Any]:
121          """Execute an sql statement and return all results as a list."""
122  
123          return [x async for x in await self.stream(statement)]
124  
125      async def first(self, statement: Executable | str) -> Any | None:
126          """Execute an sql statement and return the first result."""
127  
128          return (await self.exec(statement)).scalar()
129  
130      async def exists(self, statement: Executable | str, *args: Column[Any], **kwargs: Any) -> bool:
131          """Execute an sql statement and return whether it returned at least one row."""
132  
133          return cast(bool, await self.first(exists(cast(Executable, statement), *args, **kwargs).select()))
134  
135      async def count(self, statement: Select) -> int:
136          """Execute an sql statement and return the number of returned rows."""
137  
138          return cast(int, await self.first(select(count()).select_from(statement.subquery())))
139  
140      async def get(self, cls: Type[T], *args: Column[Any], **kwargs: Any) -> T | None:
141          """Shortcut for first(filter_by(...))"""
142  
143          return await self.first(filter_by(cls, *args, **kwargs))
144  
145      async def commit(self) -> None:
146          """Shortcut for :meth:`sqlalchemy.ext.asyncio.AsyncSession.commit`"""
147  
148          if self._session.get():
149              await self.session.commit()
150  
151      async def close(self) -> None:
152          """Close the current session"""
153  
154          if self._session.get():
155              await self.session.close()
156              if close_event := self._close_event.get():
157                  close_event.set()
158  
159      def create_session(self) -> AsyncSession:
160          """Create a new async session and store it in the context variable."""
161  
162          self._session.set(session := AsyncSession(self.engine))
163          self._close_event.set(Event())
164          return session
165  
166      @property
167      def session(self) -> AsyncSession:
168          """Get the session object for the current task"""
169  
170          return cast(AsyncSession, self._session.get())
171  
172      async def wait_for_close_event(self) -> None:
173          if close_event := self._close_event.get():
174              await close_event.wait()
175  
176  
177  def get_database() -> DB:
178      """
179      Create a database connection object using the environment variables
180  
181      :return: The DB object
182      """
183  
184      return DB(
185          url=settings.database_url,
186          pool_pre_ping=True,
187          pool_recycle=settings.pool_recycle,
188          pool_size=settings.pool_size,
189          max_overflow=settings.max_overflow,
190          echo=settings.sql_show_statements,
191      )