/ src / python / txtai / scoring / pgtext.py
pgtext.py
  1  """
  2  PGText module
  3  """
  4  
  5  import os
  6  import re
  7  
  8  # Conditional import
  9  try:
 10      from sqlalchemy import create_engine, desc, delete, func, text
 11      from sqlalchemy import Column, Computed, Index, Integer, MetaData, StaticPool, Table, Text
 12      from sqlalchemy.dialects.postgresql import TSVECTOR
 13      from sqlalchemy.orm import Session
 14      from sqlalchemy.schema import CreateSchema
 15  
 16      PGTEXT = True
 17  except ImportError:
 18      PGTEXT = False
 19  
 20  from .base import Scoring
 21  
 22  
 23  class PGText(Scoring):
 24      """
 25      Postgres full text search (FTS) based scoring.
 26      """
 27  
 28      def __init__(self, config=None):
 29          super().__init__(config)
 30  
 31          if not PGTEXT:
 32              raise ImportError('PGText is not available - install "scoring" extra to enable')
 33  
 34          # Database connection
 35          self.engine, self.database, self.connection, self.table = None, None, None, None
 36  
 37          # Language
 38          self.language = self.config.get("language", "english")
 39  
 40      def insert(self, documents, index=None, checkpoint=None):
 41          # Initialize tables
 42          self.initialize(recreate=True)
 43  
 44          # Collection of rows to insert
 45          rows = []
 46  
 47          # Collect rows
 48          for uid, document, _ in documents:
 49              # Extract text, if necessary
 50              if isinstance(document, dict):
 51                  document = document.get(self.text, document.get(self.object))
 52  
 53              if document is not None:
 54                  # If index is passed, use indexid, otherwise use id
 55                  uid = index if index is not None else uid
 56  
 57                  # Add row if the data type is accepted
 58                  if isinstance(document, (str, list)):
 59                      rows.append((uid, " ".join(document) if isinstance(document, list) else document))
 60  
 61                  # Increment index
 62                  index = index + 1 if index is not None else None
 63  
 64          # Insert rows
 65          self.database.execute(self.table.insert(), [{"indexid": x, "text": text} for x, text in rows])
 66  
 67      def delete(self, ids):
 68          self.database.execute(delete(self.table).where(self.table.c["indexid"].in_(ids)))
 69  
 70      def weights(self, tokens):
 71          # Not supported
 72          return None
 73  
 74      def search(self, query, limit=3):
 75          # Replace wildcards with prefix wildcard term
 76          query = re.sub(r"(?<!\:)\*", ":*", query)
 77  
 78          # Run query
 79          query = (
 80              self.database.query(self.table.c["indexid"], text("ts_rank(vector, plainto_tsquery(:language, :query)) rank"))
 81              .order_by(desc(text("rank")))
 82              .limit(limit)
 83              .params({"language": self.language, "query": query})
 84          )
 85  
 86          return [(uid, score) for uid, score in query if score > 1e-5]
 87  
 88      def batchsearch(self, queries, limit=3, threads=True):
 89          return [self.search(query, limit) for query in queries]
 90  
 91      def count(self):
 92          # pylint: disable=E1102
 93          return self.database.query(func.count(self.table.c["indexid"])).scalar()
 94  
 95      def load(self, path):
 96          # Reset database to original checkpoint
 97          if self.database:
 98              self.database.rollback()
 99              self.connection.rollback()
100  
101          # Initialize tables
102          self.initialize()
103  
104      def save(self, path):
105          # Commit session and connection
106          if self.database:
107              self.database.commit()
108              self.connection.commit()
109  
110      def close(self):
111          if self.database:
112              self.database.close()
113              self.engine.dispose()
114  
115      def issparse(self):
116          return True
117  
118      def isnormalized(self):
119          return True
120  
121      def isbayes(self):
122          return False
123  
124      def initialize(self, recreate=False):
125          """
126          Initializes a new database session.
127  
128          Args:
129              recreate: Recreates the database tables if True
130          """
131  
132          if not self.database:
133              # Create engine, connection and session
134              self.engine = create_engine(self.config.get("url", os.environ.get("SCORING_URL")), poolclass=StaticPool, echo=False)
135              self.connection = self.engine.connect()
136              self.database = Session(self.connection)
137  
138              # Set default schema, if necessary
139              schema = self.config.get("schema")
140              if schema:
141                  with self.engine.begin():
142                      self.sqldialect(CreateSchema(schema, if_not_exists=True))
143  
144                  self.sqldialect(text("SET search_path TO :schema"), {"schema": schema})
145  
146              # Table name
147              table = self.config.get("table", "scoring")
148  
149              # Create vectors table
150              self.table = Table(
151                  table,
152                  MetaData(),
153                  Column("indexid", Integer, primary_key=True, autoincrement=False),
154                  Column("text", Text),
155                  (
156                      Column("vector", TSVECTOR, Computed(f"to_tsvector('{self.language}', text)", persisted=True))
157                      if self.engine.dialect.name == "postgresql"
158                      else Column("vector", Integer)
159                  ),
160              )
161  
162              # Create text index
163              index = Index(
164                  f"{table}-index",
165                  self.table.c["vector"],
166                  postgresql_using="gin",
167              )
168  
169              # Drop and recreate table
170              if recreate:
171                  self.table.drop(self.connection, checkfirst=True)
172                  index.drop(self.connection, checkfirst=True)
173  
174              # Create table and index
175              self.table.create(self.connection, checkfirst=True)
176              index.create(self.connection, checkfirst=True)
177  
178      def sqldialect(self, sql, parameters=None):
179          """
180          Executes a SQL statement based on the current SQL dialect.
181  
182          Args:
183              sql: SQL to execute
184              parameters: optional bind parameters
185          """
186  
187          args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (text("SELECT 1"),)
188          self.database.execute(*args)