pgvector.py
1 """ 2 PGVector module 3 """ 4 5 import os 6 7 import numpy as np 8 9 # Conditional import 10 try: 11 from pgvector.sqlalchemy import BIT, HALFVEC, VECTOR 12 13 from sqlalchemy import create_engine, delete, func, text, Column, Index, Integer, MetaData, StaticPool, Table 14 from sqlalchemy.orm import Session 15 from sqlalchemy.schema import CreateSchema 16 17 PGVECTOR = True 18 except ImportError: 19 PGVECTOR = False 20 21 from ..base import ANN 22 23 24 # pylint: disable=R0904 25 class PGVector(ANN): 26 """ 27 Builds an ANN index backed by a Postgres database. 28 """ 29 30 def __init__(self, config): 31 super().__init__(config) 32 33 if not PGVECTOR: 34 raise ImportError('PGVector is not available - install "ann" extra to enable') 35 36 # Database connection 37 self.engine, self.database, self.connection, self.table = None, None, None, None 38 39 # Scalar quantization 40 quantize = self.config.get("quantize") 41 self.qbits = quantize if quantize and isinstance(quantize, int) and not isinstance(quantize, bool) else None 42 43 def load(self, path): 44 # Initialize tables 45 self.initialize() 46 47 def index(self, embeddings): 48 # Initialize tables 49 self.initialize(recreate=True) 50 51 # Prepare embeddings and insert rows 52 self.database.execute(self.table.insert(), [{"indexid": x, "embedding": self.prepare(row)} for x, row in enumerate(embeddings)]) 53 54 # Create index 55 self.createindex() 56 57 # Add id offset and index build metadata 58 self.config["offset"] = embeddings.shape[0] 59 self.metadata(self.settings()) 60 61 def append(self, embeddings): 62 # Prepare embeddings and insert rows 63 self.database.execute( 64 self.table.insert(), [{"indexid": x + self.config["offset"], "embedding": self.prepare(row)} for x, row in enumerate(embeddings)] 65 ) 66 67 # Update id offset and index metadata 68 self.config["offset"] += embeddings.shape[0] 69 self.metadata() 70 71 def delete(self, ids): 72 self.database.execute(delete(self.table).where(self.table.c["indexid"].in_(ids))) 73 74 def search(self, queries, limit): 75 results = [] 76 for query in queries: 77 # Run query 78 query = self.database.query(self.table.c["indexid"], self.query(query)).order_by("score").limit(limit) 79 80 # Calculate and collect scores 81 results.append([(indexid, self.score(score)) for indexid, score in query]) 82 83 return results 84 85 def count(self): 86 # pylint: disable=E1102 87 return self.database.query(func.count(self.table.c["indexid"])).scalar() 88 89 def save(self, path): 90 # Commit session and connection 91 self.database.commit() 92 self.connection.commit() 93 94 def close(self): 95 # Parent logic 96 super().close() 97 98 # Close database connection 99 if self.database: 100 self.database.close() 101 self.engine.dispose() 102 103 def initialize(self, recreate=False): 104 """ 105 Initializes a new database session. 106 107 Args: 108 recreate: Recreates the database tables if True 109 """ 110 111 # Connect to database 112 self.connect() 113 114 # Set the database schema 115 self.schema() 116 117 # Table name 118 table = self.setting("table", self.defaulttable()) 119 120 # Create vectors table object 121 self.table = Table(table, MetaData(), Column("indexid", Integer, primary_key=True, autoincrement=False), Column("embedding", self.column())) 122 123 # Drop table, if necessary 124 if recreate: 125 self.table.drop(self.connection, checkfirst=True) 126 127 # Create table, if necessary 128 self.table.create(self.connection, checkfirst=True) 129 130 def createindex(self): 131 """ 132 Creates a index with the current settings. 133 """ 134 135 # Table name 136 table = self.setting("table", self.defaulttable()) 137 138 # Create ANN index - inner product is equal to cosine similarity on normalized vectors 139 index = Index( 140 f"{table}-index", 141 self.table.c["embedding"], 142 postgresql_using="hnsw", 143 postgresql_with=self.settings(), 144 postgresql_ops={"embedding": self.operation()}, 145 ) 146 147 # Create or recreate index 148 index.drop(self.connection, checkfirst=True) 149 index.create(self.connection, checkfirst=True) 150 151 def connect(self): 152 """ 153 Establishes a database connection. Cleans up any existing database connection first. 154 """ 155 156 # Close existing connection 157 if self.database: 158 self.close() 159 160 # Create engine 161 self.engine = create_engine(self.url(), poolclass=StaticPool, echo=False) 162 self.connection = self.engine.connect() 163 164 # Start database session 165 self.database = Session(self.connection) 166 167 # Initialize pgvector extension 168 self.sqldialect(text("CREATE EXTENSION IF NOT EXISTS vector")) 169 170 def schema(self): 171 """ 172 Sets the database schema, if available. 173 """ 174 175 # Set default schema, if necessary 176 schema = self.setting("schema") 177 if schema: 178 with self.engine.begin(): 179 self.sqldialect(CreateSchema(schema, if_not_exists=True)) 180 181 self.sqldialect(text("SET search_path TO :schema,public"), {"schema": schema}) 182 183 def settings(self): 184 """ 185 Returns settings for this index. 186 187 Returns: 188 dict 189 """ 190 191 return {"m": self.setting("m", 16), "ef_construction": self.setting("efconstruction", 200)} 192 193 def sqldialect(self, sql, parameters=None): 194 """ 195 Executes a SQL statement based on the current SQL dialect. 196 197 Args: 198 sql: SQL to execute 199 parameters: optional bind parameters 200 """ 201 202 args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (text("SELECT 1"),) 203 self.database.execute(*args) 204 205 def defaulttable(self): 206 """ 207 Returns the default table name. 208 209 Returns: 210 default table name 211 """ 212 213 return "vectors" 214 215 def url(self): 216 """ 217 Reads the database url parameter. 218 219 Returns: 220 database url 221 """ 222 223 return self.setting("url", os.environ.get("ANN_URL")) 224 225 def column(self): 226 """ 227 Gets embedding column for the current settings. 228 229 Returns: 230 embedding column definition 231 """ 232 233 if self.qbits: 234 # If quantization is set, always return BIT vectors 235 return BIT(self.config["dimensions"] * 8) 236 237 if self.setting("precision") == "half": 238 # 16-bit HALF precision vectors 239 return HALFVEC(self.config["dimensions"]) 240 241 # Default is full 32-bit FULL precision vectors 242 return VECTOR(self.config["dimensions"]) 243 244 def operation(self): 245 """ 246 Gets the index operation for the current settings. 247 248 Returns: 249 index operation 250 """ 251 252 if self.qbits: 253 # If quantization is set, always return BIT vectors 254 return "bit_hamming_ops" 255 256 if self.setting("precision") == "half": 257 # 16-bit HALF precision vectors 258 return "halfvec_ip_ops" 259 260 # Default is full 32-bit FULL precision vectors 261 return "vector_ip_ops" 262 263 def prepare(self, data): 264 """ 265 Prepares data for the embeddings column. This method returns a bit string for bit vectors and 266 the input data unmodified for float vectors. 267 268 Args: 269 data: input data 270 271 Returns: 272 data ready for the embeddings column 273 """ 274 275 # Transform to a bit string when vector quantization is enabled 276 if self.qbits: 277 return "".join(np.where(np.unpackbits(data), "1", "0")) 278 279 # Return original data 280 return data 281 282 def query(self, query): 283 """ 284 Creates a query statement from an input query. This method uses hamming distance for bit vectors and 285 the max_inner_product for float vectors. 286 287 Args: 288 query: input query 289 290 Returns: 291 query statement 292 """ 293 294 # Prepare query embeddings 295 query = self.prepare(query) 296 297 # Bit vector query 298 if self.qbits: 299 return self.table.c["embedding"].hamming_distance(query).label("score") 300 301 # Float vector query 302 return self.table.c["embedding"].max_inner_product(query).label("score") 303 304 def score(self, score): 305 """ 306 Calculates the index score from the input score. This method returns the hamming score 307 (1.0 - (hamming distance / total number of bits)) for bit vectors and the -score for 308 float vectors. 309 310 Args: 311 score: input score 312 313 Returns: 314 index score 315 """ 316 317 # Calculate hamming score as 1.0 - (hamming distance / total number of bits) 318 # Bound score from 0 to 1 319 if self.qbits: 320 return min(max(0.0, 1.0 - (score / (self.config["dimensions"] * 8))), 1.0) 321 322 # pgvector returns negative inner product since Postgres only supports ASC order index scans on operators 323 return -score