/ src / python / txtai / ann / dense / sqlite.py
sqlite.py
  1  """
  2  SQLite module
  3  """
  4  
  5  import os
  6  import sqlite3
  7  
  8  # Conditional import
  9  try:
 10      import sqlite_vec
 11  
 12      SQLITEVEC = True
 13  except ImportError:
 14      SQLITEVEC = False
 15  
 16  from ..base import ANN
 17  
 18  
 19  class SQLite(ANN):
 20      """
 21      Builds an ANN index backed by a SQLite database.
 22      """
 23  
 24      def __init__(self, config):
 25          super().__init__(config)
 26  
 27          if not SQLITEVEC:
 28              raise ImportError('sqlite-vec is not available - install "ann" extra to enable')
 29  
 30          # Database parameters
 31          self.connection, self.cursor, self.path = None, None, ""
 32  
 33          # Quantization setting
 34          self.quantize = self.setting("quantize")
 35          self.quantize = 8 if isinstance(self.quantize, bool) else int(self.quantize) if self.quantize else None
 36  
 37      def load(self, path):
 38          self.path = path
 39  
 40      def index(self, embeddings):
 41          # Initialize tables
 42          self.initialize(recreate=True)
 43  
 44          # Add vectors
 45          self.database().executemany(self.insertsql(), enumerate(embeddings))
 46  
 47          # Add id offset and index build metadata
 48          self.config["offset"] = embeddings.shape[0]
 49          self.metadata(self.settings())
 50  
 51      def append(self, embeddings):
 52          self.database().executemany(self.insertsql(), [(x + self.config["offset"], row) for x, row in enumerate(embeddings)])
 53  
 54          self.config["offset"] += embeddings.shape[0]
 55          self.metadata()
 56  
 57      def delete(self, ids):
 58          self.database().executemany(self.deletesql(), [(x,) for x in ids])
 59  
 60      def search(self, queries, limit):
 61          results = []
 62          for query in queries:
 63              # Execute query
 64              self.database().execute(self.searchsql(), [query, limit])
 65  
 66              # Add query results
 67              results.append(list(self.database()))
 68  
 69          return results
 70  
 71      def count(self):
 72          self.database().execute(self.countsql())
 73          return self.cursor.fetchone()[0]
 74  
 75      def save(self, path):
 76          # Temporary database
 77          if not self.path:
 78              # Save temporary database
 79              self.connection.commit()
 80  
 81              # Copy data from current to new
 82              connection = self.copy(path)
 83  
 84              # Close temporary database
 85              self.connection.close()
 86  
 87              # Point connection to new connection
 88              self.connection = connection
 89              self.cursor = self.connection.cursor()
 90              self.path = path
 91  
 92          # Paths are equal, commit changes
 93          elif self.path == path:
 94              self.connection.commit()
 95  
 96          # New path is different from current path, copy data and continue using current connection
 97          else:
 98              self.copy(path).close()
 99  
100      def close(self):
101          # Parent logic
102          super().close()
103  
104          # Close database connection
105          if self.connection:
106              self.connection.close()
107              self.connection = None
108  
109      def initialize(self, recreate=False):
110          """
111          Initializes a new database session.
112  
113          Args:
114              recreate: Recreates the database tables if True
115          """
116  
117          # Create table
118          self.database().execute(self.tablesql())
119  
120          # Clear data
121          if recreate:
122              self.database().execute(self.tosql("DELETE FROM {table}"))
123  
124      def settings(self):
125          """
126          Returns settings for this index.
127  
128          Returns:
129              dict
130          """
131  
132          sqlite, sqlitevec = self.database().execute("SELECT sqlite_version(), vec_version()").fetchone()
133  
134          return {"sqlite": sqlite, "sqlite-vec": sqlitevec}
135  
136      def database(self):
137          """
138          Gets the current database cursor. Creates a new connection
139          if there isn't one.
140  
141          Returns:
142              cursor
143          """
144  
145          if not self.connection:
146              self.connection = self.connect(self.path)
147              self.cursor = self.connection.cursor()
148  
149          return self.cursor
150  
151      def connect(self, path):
152          """
153          Creates a new database connection.
154  
155          Args:
156              path: path to database file
157  
158          Returns:
159              database connection
160          """
161  
162          # Create connection
163          connection = sqlite3.connect(path, check_same_thread=False)
164  
165          # Load sqlite-vec extension
166          connection.enable_load_extension(True)
167          sqlite_vec.load(connection)
168          connection.enable_load_extension(False)
169  
170          # Return connection and cursor
171          return connection
172  
173      def copy(self, path):
174          """
175          Copies content from the current database into target.
176  
177          Args:
178              path: target database path
179  
180          Returns:
181              new database connection
182          """
183  
184          # Delete existing file, if necessary
185          if os.path.exists(path):
186              os.remove(path)
187  
188          # Create new connection
189          connection = self.connect(path)
190  
191          if self.connection.in_transaction:
192              # Initialize connection
193              connection.execute(self.tablesql())
194  
195              # The backup call will hang if there are uncommitted changes, need to copy over
196              # with iterdump (which is much slower)
197              for sql in self.connection.iterdump():
198                  if self.tosql('insert into "{table}"') in sql.lower():
199                      connection.execute(sql)
200          else:
201              # Database is up to date, can do a more efficient copy with SQLite C API
202              self.connection.backup(connection)
203  
204          return connection
205  
206      def tablesql(self):
207          """
208          Builds a CREATE table statement for table.
209  
210          Returns:
211              CREATE TABLE
212          """
213  
214          # Binary quantization
215          if self.quantize == 1:
216              embedding = f"embedding BIT[{self.config['dimensions']}]"
217  
218          # INT8 quantization
219          elif self.quantize == 8:
220              embedding = f"embedding INT8[{self.config['dimensions']}] distance=cosine"
221  
222          # Standard FLOAT32
223          else:
224              embedding = f"embedding FLOAT[{self.config['dimensions']}] distance=cosine"
225  
226          # Return CREATE TABLE sql
227          return self.tosql(("CREATE VIRTUAL TABLE IF NOT EXISTS {table} USING vec0" "(indexid INTEGER PRIMARY KEY, " f"{embedding})"))
228  
229      def insertsql(self):
230          """
231          Creates an INSERT SQL statement.
232  
233          Returns:
234              INSERT
235          """
236  
237          return self.tosql(f"INSERT INTO {{table}}(indexid, embedding) VALUES (?, {self.embeddingsql()})")
238  
239      def deletesql(self):
240          """
241          Creates a DELETE SQL statement.
242  
243          Returns:
244              DELETE
245          """
246  
247          return self.tosql("DELETE FROM {table} WHERE indexid = ?")
248  
249      def searchsql(self):
250          """
251          Creates a SELECT SQL statement for search.
252  
253          Returns:
254              SELECT
255          """
256  
257          return self.tosql(("SELECT indexid, 1 - distance FROM {table} " f"WHERE embedding MATCH {self.embeddingsql()} AND k = ? ORDER BY distance"))
258  
259      def countsql(self):
260          """
261          Creates a SELECT COUNT statement.
262  
263          Returns:
264              SELECT COUNT
265          """
266  
267          return self.tosql("SELECT count(indexid) FROM {table}")
268  
269      def embeddingsql(self):
270          """
271          Creates an embeddings column SQL snippet.
272  
273          Returns:
274              embeddings column SQL
275          """
276  
277          # Binary quantization
278          if self.quantize == 1:
279              embedding = "vec_quantize_binary(?)"
280  
281          # INT8 quantization
282          elif self.quantize == 8:
283              embedding = "vec_quantize_int8(?, 'unit')"
284  
285          # Standard FLOAT32
286          else:
287              embedding = "?"
288  
289          return embedding
290  
291      def tosql(self, sql):
292          """
293          Creates a SQL statement substituting in the configured table name.
294  
295          Args:
296              sql: SQL statement with a {table} parameter
297  
298          Returns:
299              fully resolved SQL statement
300          """
301  
302          table = self.setting("table", "vectors")
303          return sql.format(table=table)