/ src / python / txtai / ann / dense / pgvector.py
pgvector.py
  1  """
  2  PGVector module
  3  """
  4  
  5  import os
  6  
  7  import numpy as np
  8  
  9  # Conditional import
 10  try:
 11      from pgvector.sqlalchemy import BIT, HALFVEC, VECTOR
 12  
 13      from sqlalchemy import create_engine, delete, func, text, Column, Index, Integer, MetaData, StaticPool, Table
 14      from sqlalchemy.orm import Session
 15      from sqlalchemy.schema import CreateSchema
 16  
 17      PGVECTOR = True
 18  except ImportError:
 19      PGVECTOR = False
 20  
 21  from ..base import ANN
 22  
 23  
 24  # pylint: disable=R0904
 25  class PGVector(ANN):
 26      """
 27      Builds an ANN index backed by a Postgres database.
 28      """
 29  
 30      def __init__(self, config):
 31          super().__init__(config)
 32  
 33          if not PGVECTOR:
 34              raise ImportError('PGVector is not available - install "ann" extra to enable')
 35  
 36          # Database connection
 37          self.engine, self.database, self.connection, self.table = None, None, None, None
 38  
 39          # Scalar quantization
 40          quantize = self.config.get("quantize")
 41          self.qbits = quantize if quantize and isinstance(quantize, int) and not isinstance(quantize, bool) else None
 42  
 43      def load(self, path):
 44          # Initialize tables
 45          self.initialize()
 46  
 47      def index(self, embeddings):
 48          # Initialize tables
 49          self.initialize(recreate=True)
 50  
 51          # Prepare embeddings and insert rows
 52          self.database.execute(self.table.insert(), [{"indexid": x, "embedding": self.prepare(row)} for x, row in enumerate(embeddings)])
 53  
 54          # Create index
 55          self.createindex()
 56  
 57          # Add id offset and index build metadata
 58          self.config["offset"] = embeddings.shape[0]
 59          self.metadata(self.settings())
 60  
 61      def append(self, embeddings):
 62          # Prepare embeddings and insert rows
 63          self.database.execute(
 64              self.table.insert(), [{"indexid": x + self.config["offset"], "embedding": self.prepare(row)} for x, row in enumerate(embeddings)]
 65          )
 66  
 67          # Update id offset and index metadata
 68          self.config["offset"] += embeddings.shape[0]
 69          self.metadata()
 70  
 71      def delete(self, ids):
 72          self.database.execute(delete(self.table).where(self.table.c["indexid"].in_(ids)))
 73  
 74      def search(self, queries, limit):
 75          results = []
 76          for query in queries:
 77              # Run query
 78              query = self.database.query(self.table.c["indexid"], self.query(query)).order_by("score").limit(limit)
 79  
 80              # Calculate and collect scores
 81              results.append([(indexid, self.score(score)) for indexid, score in query])
 82  
 83          return results
 84  
 85      def count(self):
 86          # pylint: disable=E1102
 87          return self.database.query(func.count(self.table.c["indexid"])).scalar()
 88  
 89      def save(self, path):
 90          # Commit session and connection
 91          self.database.commit()
 92          self.connection.commit()
 93  
 94      def close(self):
 95          # Parent logic
 96          super().close()
 97  
 98          # Close database connection
 99          if self.database:
100              self.database.close()
101              self.engine.dispose()
102  
103      def initialize(self, recreate=False):
104          """
105          Initializes a new database session.
106  
107          Args:
108              recreate: Recreates the database tables if True
109          """
110  
111          # Connect to database
112          self.connect()
113  
114          # Set the database schema
115          self.schema()
116  
117          # Table name
118          table = self.setting("table", self.defaulttable())
119  
120          # Create vectors table object
121          self.table = Table(table, MetaData(), Column("indexid", Integer, primary_key=True, autoincrement=False), Column("embedding", self.column()))
122  
123          # Drop table, if necessary
124          if recreate:
125              self.table.drop(self.connection, checkfirst=True)
126  
127          # Create table, if necessary
128          self.table.create(self.connection, checkfirst=True)
129  
130      def createindex(self):
131          """
132          Creates a index with the current settings.
133          """
134  
135          # Table name
136          table = self.setting("table", self.defaulttable())
137  
138          # Create ANN index - inner product is equal to cosine similarity on normalized vectors
139          index = Index(
140              f"{table}-index",
141              self.table.c["embedding"],
142              postgresql_using="hnsw",
143              postgresql_with=self.settings(),
144              postgresql_ops={"embedding": self.operation()},
145          )
146  
147          # Create or recreate index
148          index.drop(self.connection, checkfirst=True)
149          index.create(self.connection, checkfirst=True)
150  
151      def connect(self):
152          """
153          Establishes a database connection. Cleans up any existing database connection first.
154          """
155  
156          # Close existing connection
157          if self.database:
158              self.close()
159  
160          # Create engine
161          self.engine = create_engine(self.url(), poolclass=StaticPool, echo=False)
162          self.connection = self.engine.connect()
163  
164          # Start database session
165          self.database = Session(self.connection)
166  
167          # Initialize pgvector extension
168          self.sqldialect(text("CREATE EXTENSION IF NOT EXISTS vector"))
169  
170      def schema(self):
171          """
172          Sets the database schema, if available.
173          """
174  
175          # Set default schema, if necessary
176          schema = self.setting("schema")
177          if schema:
178              with self.engine.begin():
179                  self.sqldialect(CreateSchema(schema, if_not_exists=True))
180  
181              self.sqldialect(text("SET search_path TO :schema,public"), {"schema": schema})
182  
183      def settings(self):
184          """
185          Returns settings for this index.
186  
187          Returns:
188              dict
189          """
190  
191          return {"m": self.setting("m", 16), "ef_construction": self.setting("efconstruction", 200)}
192  
193      def sqldialect(self, sql, parameters=None):
194          """
195          Executes a SQL statement based on the current SQL dialect.
196  
197          Args:
198              sql: SQL to execute
199              parameters: optional bind parameters
200          """
201  
202          args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (text("SELECT 1"),)
203          self.database.execute(*args)
204  
205      def defaulttable(self):
206          """
207          Returns the default table name.
208  
209          Returns:
210              default table name
211          """
212  
213          return "vectors"
214  
215      def url(self):
216          """
217          Reads the database url parameter.
218  
219          Returns:
220              database url
221          """
222  
223          return self.setting("url", os.environ.get("ANN_URL"))
224  
225      def column(self):
226          """
227          Gets embedding column for the current settings.
228  
229          Returns:
230              embedding column definition
231          """
232  
233          if self.qbits:
234              # If quantization is set, always return BIT vectors
235              return BIT(self.config["dimensions"] * 8)
236  
237          if self.setting("precision") == "half":
238              # 16-bit HALF precision vectors
239              return HALFVEC(self.config["dimensions"])
240  
241          # Default is full 32-bit FULL precision vectors
242          return VECTOR(self.config["dimensions"])
243  
244      def operation(self):
245          """
246          Gets the index operation for the current settings.
247  
248          Returns:
249              index operation
250          """
251  
252          if self.qbits:
253              # If quantization is set, always return BIT vectors
254              return "bit_hamming_ops"
255  
256          if self.setting("precision") == "half":
257              # 16-bit HALF precision vectors
258              return "halfvec_ip_ops"
259  
260          # Default is full 32-bit FULL precision vectors
261          return "vector_ip_ops"
262  
263      def prepare(self, data):
264          """
265          Prepares data for the embeddings column. This method returns a bit string for bit vectors and
266          the input data unmodified for float vectors.
267  
268          Args:
269              data: input data
270  
271          Returns:
272              data ready for the embeddings column
273          """
274  
275          # Transform to a bit string when vector quantization is enabled
276          if self.qbits:
277              return "".join(np.where(np.unpackbits(data), "1", "0"))
278  
279          # Return original data
280          return data
281  
282      def query(self, query):
283          """
284          Creates a query statement from an input query. This method uses hamming distance for bit vectors and
285          the max_inner_product for float vectors.
286  
287          Args:
288              query: input query
289  
290          Returns:
291              query statement
292          """
293  
294          # Prepare query embeddings
295          query = self.prepare(query)
296  
297          # Bit vector query
298          if self.qbits:
299              return self.table.c["embedding"].hamming_distance(query).label("score")
300  
301          # Float vector query
302          return self.table.c["embedding"].max_inner_product(query).label("score")
303  
304      def score(self, score):
305          """
306          Calculates the index score from the input score. This method returns the hamming score
307          (1.0 - (hamming distance / total number of bits)) for bit vectors and the -score for
308          float vectors.
309  
310          Args:
311              score: input score
312  
313          Returns:
314              index score
315          """
316  
317          # Calculate hamming score as 1.0 - (hamming distance / total number of bits)
318          # Bound score from 0 to 1
319          if self.qbits:
320              return min(max(0.0, 1.0 - (score / (self.config["dimensions"] * 8))), 1.0)
321  
322          # pgvector returns negative inner product since Postgres only supports ASC order index scans on operators
323          return -score