sqlite.py
1 """ 2 SQLite module 3 """ 4 5 import os 6 import sqlite3 7 8 # Conditional import 9 try: 10 import sqlite_vec 11 12 SQLITEVEC = True 13 except ImportError: 14 SQLITEVEC = False 15 16 from ..base import ANN 17 18 19 class SQLite(ANN): 20 """ 21 Builds an ANN index backed by a SQLite database. 22 """ 23 24 def __init__(self, config): 25 super().__init__(config) 26 27 if not SQLITEVEC: 28 raise ImportError('sqlite-vec is not available - install "ann" extra to enable') 29 30 # Database parameters 31 self.connection, self.cursor, self.path = None, None, "" 32 33 # Quantization setting 34 self.quantize = self.setting("quantize") 35 self.quantize = 8 if isinstance(self.quantize, bool) else int(self.quantize) if self.quantize else None 36 37 def load(self, path): 38 self.path = path 39 40 def index(self, embeddings): 41 # Initialize tables 42 self.initialize(recreate=True) 43 44 # Add vectors 45 self.database().executemany(self.insertsql(), enumerate(embeddings)) 46 47 # Add id offset and index build metadata 48 self.config["offset"] = embeddings.shape[0] 49 self.metadata(self.settings()) 50 51 def append(self, embeddings): 52 self.database().executemany(self.insertsql(), [(x + self.config["offset"], row) for x, row in enumerate(embeddings)]) 53 54 self.config["offset"] += embeddings.shape[0] 55 self.metadata() 56 57 def delete(self, ids): 58 self.database().executemany(self.deletesql(), [(x,) for x in ids]) 59 60 def search(self, queries, limit): 61 results = [] 62 for query in queries: 63 # Execute query 64 self.database().execute(self.searchsql(), [query, limit]) 65 66 # Add query results 67 results.append(list(self.database())) 68 69 return results 70 71 def count(self): 72 self.database().execute(self.countsql()) 73 return self.cursor.fetchone()[0] 74 75 def save(self, path): 76 # Temporary database 77 if not self.path: 78 # Save temporary database 79 self.connection.commit() 80 81 # Copy data from current to new 82 connection = self.copy(path) 83 84 # Close temporary database 85 self.connection.close() 86 87 # Point connection to new connection 88 self.connection = connection 89 self.cursor = self.connection.cursor() 90 self.path = path 91 92 # Paths are equal, commit changes 93 elif self.path == path: 94 self.connection.commit() 95 96 # New path is different from current path, copy data and continue using current connection 97 else: 98 self.copy(path).close() 99 100 def close(self): 101 # Parent logic 102 super().close() 103 104 # Close database connection 105 if self.connection: 106 self.connection.close() 107 self.connection = None 108 109 def initialize(self, recreate=False): 110 """ 111 Initializes a new database session. 112 113 Args: 114 recreate: Recreates the database tables if True 115 """ 116 117 # Create table 118 self.database().execute(self.tablesql()) 119 120 # Clear data 121 if recreate: 122 self.database().execute(self.tosql("DELETE FROM {table}")) 123 124 def settings(self): 125 """ 126 Returns settings for this index. 127 128 Returns: 129 dict 130 """ 131 132 sqlite, sqlitevec = self.database().execute("SELECT sqlite_version(), vec_version()").fetchone() 133 134 return {"sqlite": sqlite, "sqlite-vec": sqlitevec} 135 136 def database(self): 137 """ 138 Gets the current database cursor. Creates a new connection 139 if there isn't one. 140 141 Returns: 142 cursor 143 """ 144 145 if not self.connection: 146 self.connection = self.connect(self.path) 147 self.cursor = self.connection.cursor() 148 149 return self.cursor 150 151 def connect(self, path): 152 """ 153 Creates a new database connection. 154 155 Args: 156 path: path to database file 157 158 Returns: 159 database connection 160 """ 161 162 # Create connection 163 connection = sqlite3.connect(path, check_same_thread=False) 164 165 # Load sqlite-vec extension 166 connection.enable_load_extension(True) 167 sqlite_vec.load(connection) 168 connection.enable_load_extension(False) 169 170 # Return connection and cursor 171 return connection 172 173 def copy(self, path): 174 """ 175 Copies content from the current database into target. 176 177 Args: 178 path: target database path 179 180 Returns: 181 new database connection 182 """ 183 184 # Delete existing file, if necessary 185 if os.path.exists(path): 186 os.remove(path) 187 188 # Create new connection 189 connection = self.connect(path) 190 191 if self.connection.in_transaction: 192 # Initialize connection 193 connection.execute(self.tablesql()) 194 195 # The backup call will hang if there are uncommitted changes, need to copy over 196 # with iterdump (which is much slower) 197 for sql in self.connection.iterdump(): 198 if self.tosql('insert into "{table}"') in sql.lower(): 199 connection.execute(sql) 200 else: 201 # Database is up to date, can do a more efficient copy with SQLite C API 202 self.connection.backup(connection) 203 204 return connection 205 206 def tablesql(self): 207 """ 208 Builds a CREATE table statement for table. 209 210 Returns: 211 CREATE TABLE 212 """ 213 214 # Binary quantization 215 if self.quantize == 1: 216 embedding = f"embedding BIT[{self.config['dimensions']}]" 217 218 # INT8 quantization 219 elif self.quantize == 8: 220 embedding = f"embedding INT8[{self.config['dimensions']}] distance=cosine" 221 222 # Standard FLOAT32 223 else: 224 embedding = f"embedding FLOAT[{self.config['dimensions']}] distance=cosine" 225 226 # Return CREATE TABLE sql 227 return self.tosql(("CREATE VIRTUAL TABLE IF NOT EXISTS {table} USING vec0" "(indexid INTEGER PRIMARY KEY, " f"{embedding})")) 228 229 def insertsql(self): 230 """ 231 Creates an INSERT SQL statement. 232 233 Returns: 234 INSERT 235 """ 236 237 return self.tosql(f"INSERT INTO {{table}}(indexid, embedding) VALUES (?, {self.embeddingsql()})") 238 239 def deletesql(self): 240 """ 241 Creates a DELETE SQL statement. 242 243 Returns: 244 DELETE 245 """ 246 247 return self.tosql("DELETE FROM {table} WHERE indexid = ?") 248 249 def searchsql(self): 250 """ 251 Creates a SELECT SQL statement for search. 252 253 Returns: 254 SELECT 255 """ 256 257 return self.tosql(("SELECT indexid, 1 - distance FROM {table} " f"WHERE embedding MATCH {self.embeddingsql()} AND k = ? ORDER BY distance")) 258 259 def countsql(self): 260 """ 261 Creates a SELECT COUNT statement. 262 263 Returns: 264 SELECT COUNT 265 """ 266 267 return self.tosql("SELECT count(indexid) FROM {table}") 268 269 def embeddingsql(self): 270 """ 271 Creates an embeddings column SQL snippet. 272 273 Returns: 274 embeddings column SQL 275 """ 276 277 # Binary quantization 278 if self.quantize == 1: 279 embedding = "vec_quantize_binary(?)" 280 281 # INT8 quantization 282 elif self.quantize == 8: 283 embedding = "vec_quantize_int8(?, 'unit')" 284 285 # Standard FLOAT32 286 else: 287 embedding = "?" 288 289 return embedding 290 291 def tosql(self, sql): 292 """ 293 Creates a SQL statement substituting in the configured table name. 294 295 Args: 296 sql: SQL statement with a {table} parameter 297 298 Returns: 299 fully resolved SQL statement 300 """ 301 302 table = self.setting("table", "vectors") 303 return sql.format(table=table)