/ src / python / txtai / database / client.py
client.py
  1  """
  2  Client module
  3  """
  4  
  5  import os
  6  import time
  7  
  8  # Conditional import
  9  try:
 10      from sqlalchemy import StaticPool, Text, cast, create_engine, insert, text as textsql
 11      from sqlalchemy.orm import Session, aliased
 12      from sqlalchemy.schema import CreateSchema
 13  
 14      from .schema import Base, Batch, Document, Object, Section, SectionBase, Score
 15  
 16      ORM = True
 17  except ImportError:
 18      ORM = False
 19  
 20  from .rdbms import RDBMS
 21  
 22  
 23  class Client(RDBMS):
 24      """
 25      Database client instance. This class connects to an external database using SQLAlchemy. It supports any database
 26      that is supported by SQLAlchemy (PostgreSQL, MariaDB, etc) and has JSON support.
 27      """
 28  
 29      def __init__(self, config):
 30          """
 31          Creates a new Database.
 32  
 33          Args:
 34              config: database configuration parameters
 35          """
 36  
 37          super().__init__(config)
 38  
 39          if not ORM:
 40              raise ImportError('SQLAlchemy is not available - install "database" extra to enable')
 41  
 42          # SQLAlchemy parameters
 43          self.engine, self.dbconnection = None, None
 44  
 45      def save(self, path):
 46          # Commit session and database connection
 47          super().save(path)
 48  
 49          if self.dbconnection:
 50              self.dbconnection.commit()
 51  
 52      def close(self):
 53          super().close()
 54  
 55          # Dispose of engine, which also closes dbconnection
 56          if self.engine:
 57              self.engine.dispose()
 58  
 59      def reindexstart(self):
 60          # Working table name
 61          name = f"rebuild{round(time.time() * 1000)}"
 62  
 63          # Create working table metadata
 64          type("Rebuild", (SectionBase,), {"__tablename__": name})
 65          Base.metadata.tables[name].create(self.dbconnection)
 66  
 67          return name
 68  
 69      def reindexend(self, name):
 70          # Remove table object from metadata
 71          Base.metadata.remove(Base.metadata.tables[name])
 72  
 73      def jsonprefix(self):
 74          # JSON column prefix
 75          return "cast("
 76  
 77      def jsoncolumn(self, name):
 78          # Alias documents table
 79          d = aliased(Document, name="d")
 80  
 81          # Build JSON column expression for column
 82          return str(cast(d.data[name].as_string(), Text).compile(dialect=self.engine.dialect, compile_kwargs={"literal_binds": True}))
 83  
 84      def createtables(self):
 85          # Create tables
 86          Base.metadata.create_all(self.dbconnection, checkfirst=True)
 87  
 88          # Clear existing data - table schema is created upon connecting to database
 89          for table in ["sections", "documents", "objects"]:
 90              self.cursor.execute(f"DELETE FROM {table}")
 91  
 92      def finalize(self):
 93          # Flush cached objects
 94          self.connection.flush()
 95  
 96      def insertdocument(self, uid, data, tags, entry):
 97          self.connection.add(Document(id=uid, data=data, tags=tags, entry=entry))
 98  
 99      def insertobject(self, uid, data, tags, entry):
100          self.connection.add(Object(id=uid, object=data, tags=tags, entry=entry))
101  
102      def insertsection(self, index, uid, text, tags, entry):
103          # Save text section
104          self.connection.add(Section(indexid=index, id=uid, text=text, tags=tags, entry=entry))
105  
106      def createbatch(self):
107          # Create temporary batch table, if necessary
108          Base.metadata.tables["batch"].create(self.dbconnection, checkfirst=True)
109  
110      def insertbatch(self, indexids, ids, batch):
111          if indexids:
112              self.connection.execute(insert(Batch), [{"indexid": i, "batch": batch} for i in indexids])
113          if ids:
114              self.connection.execute(insert(Batch), [{"id": str(uid), "batch": batch} for uid in ids])
115  
116      def createscores(self):
117          # Create temporary scores table, if necessary
118          Base.metadata.tables["scores"].create(self.dbconnection, checkfirst=True)
119  
120      def insertscores(self, scores):
121          # Average scores by id
122          if scores:
123              self.connection.execute(insert(Score), [{"indexid": i, "score": sum(s) / len(s)} for i, s in scores.items()])
124  
125      def connect(self, path=None):
126          # Connection URL
127          content = self.config.get("content")
128  
129          # Read ENV variable, if necessary
130          content = os.environ.get("CLIENT_URL") if content == "client" else content
131  
132          # Create engine using database URL
133          self.engine = create_engine(content, poolclass=StaticPool, echo=False, json_serializer=lambda x: x)
134          self.dbconnection = self.engine.connect()
135  
136          # Create database session
137          database = Session(self.dbconnection)
138  
139          # Set default schema, if necessary
140          schema = self.config.get("schema")
141          if schema:
142              with self.engine.begin():
143                  self.sqldialect(database, CreateSchema(schema, if_not_exists=True))
144  
145              self.sqldialect(database, textsql("SET search_path TO :schema"), {"schema": schema})
146  
147          return database
148  
149      def getcursor(self):
150          return Cursor(self.connection)
151  
152      def rows(self):
153          return self.cursor
154  
155      def addfunctions(self):
156          return
157  
158      def sqldialect(self, database, sql, parameters=None):
159          """
160          Executes a SQL statement based on the current SQL dialect.
161  
162          Args:
163              database: current database
164              sql: SQL to execute
165              parameters: optional bind parameters
166          """
167  
168          args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (textsql("SELECT 1"),)
169          database.execute(*args)
170  
171  
172  class Cursor:
173      """
174      Implements basic compatibility with the Python DB-API.
175      """
176  
177      def __init__(self, connection):
178          self.connection = connection
179          self.result = None
180  
181      def __iter__(self):
182          return self.result
183  
184      def execute(self, statement, parameters=None):
185          """
186          Executes statement.
187  
188          Args:
189              statement: statement to execute
190              parameters: optional dictionary with bind parameters
191          """
192  
193          if isinstance(statement, str):
194              statement = textsql(statement)
195  
196          self.result = self.connection.execute(statement, parameters)
197  
198      def fetchall(self):
199          """
200          Fetches all rows from the current result.
201  
202          Returns:
203              all rows from current result
204          """
205  
206          return self.result.all() if self.result else None
207  
208      def fetchone(self):
209          """
210          Fetches first row from current result.
211  
212          Returns:
213              first row from current result
214          """
215  
216          return self.result.first() if self.result else None
217  
218      @property
219      def description(self):
220          """
221          Returns columns for current result.
222  
223          Returns:
224              list of columns
225          """
226  
227          return [(key,) for key in self.result.keys()] if self.result else None