/ src / python / txtai / database / rdbms.py
rdbms.py
  1  """
  2  RDBMS module
  3  """
  4  
  5  import datetime
  6  import json
  7  
  8  from .base import Database
  9  from .schema import Statement
 10  
 11  
 12  # pylint: disable=R0904
 13  class RDBMS(Database):
 14      """
 15      Base relational database class. A relational database uses SQL to insert, update, delete and select from a
 16      database instance.
 17      """
 18  
 19      def __init__(self, config):
 20          """
 21          Creates a new Database.
 22  
 23          Args:
 24              config: database configuration parameters
 25          """
 26  
 27          super().__init__(config)
 28  
 29          # Database connection
 30          self.connection = None
 31          self.cursor = None
 32  
 33      def load(self, path):
 34          # Load an existing database. Thread locking must be handled externally.
 35          self.session(path)
 36  
 37      def insert(self, documents, index=0):
 38          # Initialize connection if not open
 39          self.initialize()
 40  
 41          # Get entry date
 42          entry = datetime.datetime.now(datetime.timezone.utc)
 43  
 44          # Insert documents
 45          for uid, document, tags in documents:
 46              if isinstance(document, dict):
 47                  # Insert document and use return value for sections table
 48                  document = self.loaddocument(uid, document, tags, entry)
 49  
 50              if document is not None:
 51                  if isinstance(document, list):
 52                      # Join tokens to text
 53                      document = " ".join(document)
 54                  elif not isinstance(document, str):
 55                      # If object support is enabled, save object
 56                      self.loadobject(uid, document, tags, entry)
 57  
 58                      # Clear section text for objects, even when objects aren't inserted
 59                      document = None
 60  
 61                  # Save text section
 62                  self.loadsection(index, uid, document, tags, entry)
 63                  index += 1
 64  
 65          # Post processing logic
 66          self.finalize()
 67  
 68      def delete(self, ids):
 69          if self.connection:
 70              # Batch ids
 71              self.batch(ids=ids)
 72  
 73              # Delete all documents, objects and sections by id
 74              self.cursor.execute(Statement.DELETE_DOCUMENTS)
 75              self.cursor.execute(Statement.DELETE_OBJECTS)
 76              self.cursor.execute(Statement.DELETE_SECTIONS)
 77  
 78      def reindex(self, config):
 79          if self.connection:
 80              # Set new configuration
 81              self.configure(config)
 82  
 83              # Resolve text column
 84              select = self.resolve(self.text)
 85  
 86              # Initialize reindex operation
 87              name = self.reindexstart()
 88  
 89              # Copy data over
 90              self.cursor.execute(Statement.COPY_SECTIONS % (name, select))
 91  
 92              # Stream new results
 93              self.cursor.execute(Statement.STREAM_SECTIONS % name)
 94              for uid, text, data, obj, tags in self.rows():
 95                  if not text and self.encoder and obj:
 96                      yield (uid, self.encoder.decode(obj), tags)
 97                  else:
 98                      # Read JSON data, if provided
 99                      data = json.loads(data) if data and isinstance(data, str) else data
100  
101                      # Stream data if available, otherwise use section text
102                      yield (uid, data if data else text, tags)
103  
104              # Swap as new table
105              self.cursor.execute(Statement.DROP_SECTIONS)
106              self.cursor.execute(Statement.RENAME_SECTIONS % name)
107  
108              # Finish reindex operation
109              self.reindexend(name)
110  
111      def save(self, path):
112          if self.connection:
113              self.connection.commit()
114  
115      def close(self):
116          # Close connection
117          if self.connection:
118              self.connection.close()
119  
120      def ids(self, ids):
121          # Batch ids and run query
122          self.batch(ids=ids)
123          self.cursor.execute(Statement.SELECT_IDS)
124  
125          # Format and return results
126          return self.cursor.fetchall()
127  
128      def count(self):
129          self.cursor.execute(Statement.COUNT_IDS)
130          return self.cursor.fetchone()[0]
131  
132      def resolve(self, name, alias=None):
133          # Standard column names
134          sections = ["indexid", "id", "tags", "entry"]
135          noprefix = ["data", "object", "score", "text"]
136  
137          # Alias expression
138          if alias:
139              # Skip if name matches alias or alias is a standard column name
140              if name == alias or alias in sections:
141                  return name
142  
143              # Build alias clause
144              return f'{name} as "{alias}"'
145  
146          # Resolve expression
147          if self.expressions and name in self.expressions:
148              return self.expressions[name]["expression"]
149  
150          # Name is already resolved, skip
151          if name.startswith(self.jsonprefix()) or any(f"s.{s}" == name for s in sections):
152              return name
153  
154          # Standard columns - need prefixes
155          if name.lower() in sections:
156              return f"s.{name}"
157  
158          # Standard columns - no prefixes
159          if name.lower() in noprefix:
160              return name
161  
162          # Other columns come from documents.data JSON
163          return self.jsoncolumn(name)
164  
165      def embed(self, similarity, batch):
166          # Load similarity results id batch
167          self.batch(indexids=[i for i, _ in similarity[batch]], batch=batch)
168  
169          # Average and load all similarity scores with first batch
170          if not batch:
171              self.scores(similarity)
172  
173          # Return ids clause placeholder
174          return Statement.IDS_CLAUSE % batch
175  
176      # pylint: disable=R0912
177      def query(self, query, limit, parameters, indexids):
178          # Extract query components
179          select = query.get("select", self.defaults())
180          where = query.get("where")
181          groupby, having = query.get("groupby"), query.get("having")
182          orderby, qlimit, offset = query.get("orderby"), query.get("limit"), query.get("offset")
183          similarity = query.get("similar")
184  
185          # Select "indexid, score" when indexids is True
186          if indexids:
187              select = f"{self.resolve('indexid')}, {self.resolve('score')}"
188  
189          # Use JOIN when documents table is used to utilize indexes, default to LEFT JOIN
190          join = "JOIN" if any(x and self.jsonprefix() in x for x in [where, groupby, orderby]) else "LEFT JOIN"
191  
192          # Build query text
193          query = Statement.TABLE_CLAUSE % (select, join)
194          if where is not None:
195              query += f" WHERE {where}"
196          if groupby is not None:
197              query += f" GROUP BY {groupby}"
198          if having is not None:
199              query += f" HAVING {having}"
200          if orderby is not None:
201              query += f" ORDER BY {orderby}"
202  
203          # Default ORDER BY if not provided and similarity scores are available
204          if similarity and orderby is None:
205              query += " ORDER BY score DESC"
206  
207          # Apply query limit
208          if qlimit is not None or limit:
209              query += f" LIMIT {qlimit if qlimit else limit}"
210  
211              # Apply offset
212              if offset is not None:
213                  query += f" OFFSET {offset}"
214  
215          # Clear scores when no similar clauses present
216          if not similarity:
217              self.scores(None)
218  
219          # Runs a user query through execute method, which has common user query handling logic
220          args = (query, parameters) if parameters else (query,)
221          self.execute(self.cursor.execute, *args)
222  
223          # Retrieve column list from query
224          columns = [c[0] for c in self.cursor.description]
225  
226          # Map results and return
227          results = []
228          for row in self.rows():
229              result = {}
230  
231              # Copy columns to result. In cases with duplicate column names, find one with a value
232              for x, column in enumerate(columns):
233                  if column not in result or result[column] is None:
234                      # Decode object
235                      if self.encoder and column == self.object:
236                          result[column] = self.encoder.decode(row[x])
237                      else:
238                          result[column] = row[x]
239  
240              results.append(result)
241  
242          # Transform results, if necessary
243          return [(x["indexid"], x["score"]) for x in results] if indexids else results
244  
245      def initialize(self):
246          """
247          Creates connection and initial database schema if no connection exists.
248          """
249  
250          if not self.connection:
251              # Create database session. Thread locking must be handled externally.
252              self.session()
253  
254              # Create initial table schema
255              self.createtables()
256  
257              # Create indexes
258              self.createindexes()
259  
260      def session(self, path=None, connection=None):
261          """
262          Starts a new database session.
263  
264          Args:
265              path: path to database file
266              connection: existing connection to use
267          """
268  
269          # Create database connection and cursor
270          self.connection = connection if connection else self.connect(path) if path else self.connect()
271          self.cursor = self.getcursor()
272  
273          # Register custom functions - session scope
274          self.addfunctions()
275  
276          # Create temporary tables - session scope
277          self.createbatch()
278          self.createscores()
279  
280      def createtables(self):
281          """
282          Creates the initial table schema.
283          """
284  
285          self.cursor.execute(Statement.CREATE_DOCUMENTS)
286          self.cursor.execute(Statement.CREATE_OBJECTS)
287          self.cursor.execute(Statement.CREATE_SECTIONS % "sections")
288          self.cursor.execute(Statement.CREATE_SECTIONS_INDEX)
289  
290      def createindexes(self):
291          """
292          Creates expression indexes
293          """
294  
295          if self.expressions:
296              for key, values in self.expressions.items():
297                  # Create index for expression, if enabled
298                  if values["index"]:
299                      # Get parameters
300                      name = f"expression_{key}".lower()
301                      expression = values["expression"]
302                      table = "documents" if expression.startswith(self.jsonprefix()) else "sections"
303  
304                      # Execute statement
305                      self.cursor.execute(Statement.CREATE_EXPRESSION_INDEX % (name, table, expression))
306  
307      def finalize(self):
308          """
309          Post processing logic run after inserting a batch of documents. Default method is no-op.
310          """
311  
312      def loaddocument(self, uid, document, tags, entry):
313          """
314          Applies pre-processing logic and inserts a document.
315  
316          Args:
317              uid: unique id
318              document: input document dictionary
319              tags: document tags
320              entry: generated entry date
321  
322          Returns:
323              section value
324          """
325  
326          # Make a copy of document before changing
327          document = document.copy()
328  
329          # Get and remove object field from document
330          obj = document.pop(self.object) if self.object in document else None
331  
332          if document:
333              # Apply data filters, if necessary
334              data = {key: value for key, value in document.items() if key in self.store} if self.store is not None else document
335  
336              # Insert document as JSON
337              if data:
338                  self.insertdocument(uid, json.dumps(data, allow_nan=False), tags, entry)
339  
340          # If text and object are both available, load object as it won't otherwise be used
341          if self.text in document and obj:
342              self.loadobject(uid, obj, tags, entry)
343  
344          # Return value to use for section - use text if available otherwise use object
345          return document[self.text] if self.text in document else obj
346  
347      def insertdocument(self, uid, data, tags, entry):
348          """
349          Inserts a document.
350  
351          Args:
352              uid: unique id
353              data: document data
354              tags: document tags
355              entry: generated entry date
356          """
357  
358          self.cursor.execute(Statement.INSERT_DOCUMENT, [uid, data, tags, entry])
359  
360      def loadobject(self, uid, obj, tags, entry):
361          """
362          Applies pre-preprocessing logic and inserts an object.
363  
364          Args:
365              uid: unique id
366              obj: input object
367              tags: object tags
368              entry: generated entry date
369          """
370  
371          # If object support is enabled, save object
372          if self.encoder:
373              self.insertobject(uid, self.encoder.encode(obj), tags, entry)
374  
375      def insertobject(self, uid, data, tags, entry):
376          """
377          Inserts an object.
378  
379          Args:
380              uid: unique id
381              data: encoded data
382              tags: object tags
383              entry: generated entry date
384          """
385  
386          self.cursor.execute(Statement.INSERT_OBJECT, [uid, data, tags, entry])
387  
388      def loadsection(self, index, uid, text, tags, entry):
389          """
390          Applies pre-processing logic and inserts a section.
391  
392          Args:
393              index: index id
394              uid: unique id
395              text: section text
396              tags: section tags
397              entry: generated entry date
398          """
399  
400          self.insertsection(index, uid, text, tags, entry)
401  
402      def insertsection(self, index, uid, text, tags, entry):
403          """
404          Inserts a section.
405  
406          Args:
407              index: index id
408              uid: unique id
409              text: section text
410              tags: section tags
411              entry: generated entry date
412          """
413  
414          # Save text section
415          self.cursor.execute(Statement.INSERT_SECTION, [index, uid, text, tags, entry])
416  
417      def reindexstart(self):
418          """
419          Starts a reindex operation.
420  
421          Returns:
422              temporary working table name
423          """
424  
425          # Working table name
426          name = "rebuild"
427  
428          # Create new table to hold reordered sections
429          self.cursor.execute(Statement.CREATE_SECTIONS % name)
430  
431          return name
432  
433      # pylint: disable=W0613
434      def reindexend(self, name):
435          """
436          Ends a reindex operation.
437  
438          Args:
439              name: working table name
440          """
441  
442          self.cursor.execute(Statement.CREATE_SECTIONS_INDEX)
443  
444      def batch(self, indexids=None, ids=None, batch=None):
445          """
446          Loads ids to a temporary batch table for efficient query processing.
447  
448          Args:
449              indexids: list of indexids
450              ids: list of ids
451              batch: batch index, used when statement has multiple subselects
452          """
453  
454          # Delete batch when batch id is empty or for batch 0
455          if not batch:
456              self.cursor.execute(Statement.DELETE_BATCH)
457  
458          # Add batch
459          self.insertbatch(indexids, ids, batch)
460  
461      def createbatch(self):
462          """
463          Creates temporary batch table.
464          """
465  
466          # Create or Replace temporary batch table
467          self.cursor.execute(Statement.CREATE_BATCH)
468  
469      def insertbatch(self, indexids, ids, batch):
470          """
471          Inserts batch of ids.
472          """
473  
474          if indexids:
475              self.cursor.executemany(Statement.INSERT_BATCH_INDEXID, [(i, batch) for i in indexids])
476          if ids:
477              self.cursor.executemany(Statement.INSERT_BATCH_ID, [(str(uid), batch) for uid in ids])
478  
479      def scores(self, similarity):
480          """
481          Loads a batch of similarity scores to a temporary table for efficient query processing.
482  
483          Args:
484              similarity: similarity results as [(indexid, score)]
485          """
486  
487          # Delete scores
488          self.cursor.execute(Statement.DELETE_SCORES)
489  
490          if similarity:
491              # Average scores per id, needed for multiple similar() clauses
492              scores = {}
493              for s in similarity:
494                  for i, score in s:
495                      if i not in scores:
496                          scores[i] = []
497                      scores[i].append(score)
498  
499              # Add scores
500              self.insertscores(scores)
501  
502      def createscores(self):
503          """
504          Creates temporary scores table.
505          """
506  
507          # Create or Replace temporary scores table
508          self.cursor.execute(Statement.CREATE_SCORES)
509  
510      def insertscores(self, scores):
511          """
512          Inserts a batch of scores.
513  
514          Args:
515              scores: scores to add
516          """
517  
518          # Average scores by id
519          if scores:
520              self.cursor.executemany(Statement.INSERT_SCORE, [(i, sum(s) / len(s)) for i, s in scores.items()])
521  
522      def defaults(self):
523          """
524          Returns a list of default columns when there is no select clause.
525  
526          Returns:
527              list of default columns
528          """
529  
530          return "s.id, text, score"
531  
532      def connect(self, path=None):
533          """
534          Creates a new database connection.
535  
536          Args:
537              path: path to database file
538  
539          Returns:
540              connection
541          """
542  
543          raise NotImplementedError
544  
545      def getcursor(self):
546          """
547          Opens a cursor for current connection.
548  
549          Returns:
550              cursor
551          """
552  
553          raise NotImplementedError
554  
555      def jsonprefix(self):
556          """
557          Returns json column prefix to test for.
558  
559          Returns:
560              dynamic column prefix
561          """
562  
563          raise NotImplementedError
564  
565      def jsoncolumn(self, name):
566          """
567          Builds a json extract column expression for name.
568  
569          Args:
570              name: column name
571  
572          Returns:
573              dynamic column expression
574          """
575  
576          raise NotImplementedError
577  
578      def rows(self):
579          """
580          Returns current cursor row iterator for last executed query.
581  
582          Args:
583              cursor: cursor
584  
585          Returns:
586              iterable collection of rows
587          """
588  
589          raise NotImplementedError
590  
591      def addfunctions(self):
592          """
593          Adds custom functions in current connection.
594          """
595  
596          raise NotImplementedError