/ test / python / testdatabase / testrdbms.py
testrdbms.py
  1  """
  2  Common file database module tests
  3  """
  4  
  5  import contextlib
  6  import io
  7  import os
  8  import tempfile
  9  import unittest
 10  
 11  from unittest.mock import patch
 12  
 13  from txtai.embeddings import Embeddings, IndexNotFoundError
 14  from txtai.database import Embedded, RDBMS, SQLError
 15  
 16  
 17  class Common:
 18      """
 19      Wraps common file database tests to prevent unit test discovery for this class.
 20      """
 21  
 22      # pylint: disable=R0904
 23      class TestRDBMS(unittest.TestCase):
 24          """
 25          Embeddings with content stored in a file database tests.
 26          """
 27  
 28          @classmethod
 29          def setUpClass(cls):
 30              """
 31              Initialize test data.
 32              """
 33  
 34              cls.data = [
 35                  "US tops 5 million confirmed virus cases",
 36                  "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg",
 37                  "Beijing mobilises invasion craft along coast as Taiwan tensions escalate",
 38                  "The National Park Service warns against sacrificing slower friends in a bear attack",
 39                  "Maine man wins $1M from $25 lottery ticket",
 40                  "Make huge profits without work, earn up to $100,000 a day",
 41              ]
 42  
 43              # Content backend
 44              cls.backend = None
 45  
 46              # Create embeddings model, backed by sentence-transformers & transformers
 47              cls.embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": cls.backend})
 48  
 49          @classmethod
 50          def tearDownClass(cls):
 51              """
 52              Cleanup data.
 53              """
 54  
 55              if cls.embeddings:
 56                  cls.embeddings.close()
 57  
 58          def testArchive(self):
 59              """
 60              Test embeddings index archiving
 61              """
 62  
 63              for extension in ["tar.bz2", "tar.gz", "tar.xz", "zip"]:
 64                  # Create an index for the list of text
 65                  self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
 66  
 67                  # Generate temp file path
 68                  index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.{extension}")
 69  
 70                  self.embeddings.save(index)
 71                  self.embeddings.load(index)
 72  
 73                  # Search for best match
 74                  result = self.embeddings.search("feel good story", 1)[0]
 75  
 76                  self.assertEqual(result["text"], self.data[4])
 77  
 78                  # Test offsets still work after save/load
 79                  self.embeddings.upsert([(0, "Looking out into the dreadful abyss", None)])
 80                  self.assertEqual(self.embeddings.count(), len(self.data))
 81  
 82          def testAutoId(self):
 83              """
 84              Test auto id generation
 85              """
 86  
 87              # Default sequence id
 88              embeddings = Embeddings(path="sentence-transformers/nli-mpnet-base-v2", content=self.backend)
 89              embeddings.index(self.data)
 90  
 91              result = embeddings.search("feel good story", 1)[0]
 92              self.assertEqual(result["text"], self.data[4])
 93  
 94              # UUID
 95              embeddings.config["autoid"] = "uuid4"
 96              embeddings.index(self.data)
 97  
 98              result = embeddings.search(self.data[4], 1)[0]
 99              self.assertEqual(len(result["id"]), 36)
100  
101          def testCheckpoint(self):
102              """
103              Test embeddings index checkpoints
104              """
105  
106              # Checkpoint directory
107              checkpoint = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.checkpoint")
108  
109              # Save embeddings checkpoint
110              self.embeddings.index(self.data, checkpoint=checkpoint)
111  
112              # Reindex with checkpoint
113              self.embeddings.index(self.data, checkpoint=checkpoint)
114  
115              # Search for best match
116              result = self.embeddings.search("feel good story", 1)[0]
117              self.assertEqual(result["text"], self.data[4])
118  
119          def testColumns(self):
120              """
121              Test custom text/object columns
122              """
123  
124              embeddings = Embeddings({"keyword": True, "content": self.backend, "columns": {"text": "value"}})
125              data = [{"value": x} for x in self.data]
126              embeddings.index([(uid, text, None) for uid, text in enumerate(data)])
127  
128              # Run search
129              result = embeddings.search("lottery", 1)[0]
130              self.assertEqual(result["text"], self.data[4])
131  
132          def testClose(self):
133              """
134              Test embeddings close
135              """
136  
137              embeddings = None
138  
139              # Create index twice to test open/close and ensure resources are freed
140              for _ in range(2):
141                  embeddings = Embeddings(
142                      {"path": "sentence-transformers/nli-mpnet-base-v2", "scoring": {"method": "bm25", "terms": True}, "content": self.backend}
143                  )
144  
145                  # Add record to index
146                  embeddings.index([(0, "Close test", None)])
147  
148                  # Save index
149                  index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.close")
150                  embeddings.save(index)
151  
152                  # Close index
153                  embeddings.close()
154  
155              # Test embeddings is empty
156              self.assertIsNone(embeddings.ann)
157              self.assertIsNone(embeddings.database)
158  
159          def testData(self):
160              """
161              Test content storage and retrieval
162              """
163  
164              data = self.data + [{"date": "2021-01-01", "text": "Baby panda", "flag": 1}]
165  
166              # Create an index for the list of text
167              self.embeddings.index([(uid, text, None) for uid, text in enumerate(data)])
168  
169              # Search for best match
170              result = self.embeddings.search("feel good story", 1)[0]
171              self.assertEqual(result["text"], data[-1]["text"])
172  
173          def testDelete(self):
174              """
175              Test delete
176              """
177  
178              # Create an index for the list of text
179              self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
180  
181              # Delete best match
182              self.embeddings.delete([4])
183  
184              # Search for best match
185              result = self.embeddings.search("feel good story", 1)[0]
186  
187              self.assertEqual(self.embeddings.count(), 5)
188              self.assertEqual(result["text"], self.data[5])
189  
190          def testEmpty(self):
191              """
192              Test empty index
193              """
194  
195              # Test search against empty index
196              embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": self.backend})
197              self.assertEqual(embeddings.search("test"), [])
198  
199              # Test index with no data
200              embeddings.index([])
201              self.assertIsNone(embeddings.ann)
202  
203              # Test upsert with no data
204              embeddings.index([(0, "this is a test", None)])
205              embeddings.upsert([])
206              self.assertIsNotNone(embeddings.ann)
207  
208          def testEmptyString(self):
209              """
210              Test empty string indexing
211              """
212  
213              # Test empty string
214              self.embeddings.index([(0, "", None)])
215              self.assertTrue(self.embeddings.search("test"))
216  
217              # Test empty string with dict
218              self.embeddings.index([(0, {"text": ""}, None)])
219              self.assertTrue(self.embeddings.search("test"))
220  
221          def testExplain(self):
222              """
223              Test query explain
224              """
225  
226              # Test explain with similarity
227              result = self.embeddings.explain("feel good story", self.data)[0]
228              self.assertEqual(result["text"], self.data[4])
229              self.assertEqual(len(result.get("tokens")), 8)
230  
231          def testExplainBatch(self):
232              """
233              Test query explain batch
234              """
235  
236              # Test explain with query
237              self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
238  
239              result = self.embeddings.batchexplain(["feel good story"], limit=1)[0][0]
240              self.assertEqual(result["text"], self.data[4])
241              self.assertEqual(len(result.get("tokens")), 8)
242  
243          def testExplainEmpty(self):
244              """
245              Test query explain with no filtering criteria
246              """
247  
248              self.assertEqual(self.embeddings.explain("select * from txtai limit 1")[0]["id"], "0")
249  
250          def testExpressions(self):
251              """
252              Test expressions
253              """
254  
255              # Test indexed expressions
256              embeddings = Embeddings(
257                  path="sentence-transformers/nli-mpnet-base-v2",
258                  content=self.backend,
259                  expressions=[{"name": "textlength", "expression": "length(text)", "index": True}],
260              )
261              embeddings.index(self.data)
262  
263              result = embeddings.search("SELECT textlength FROM txtai WHERE id = 0", 1)[0]
264              self.assertEqual(result["textlength"], len(self.data[0]))
265  
266          def testGenerator(self):
267              """
268              Test index with a generator
269              """
270  
271              def documents():
272                  for uid, text in enumerate(self.data):
273                      yield (uid, text, None)
274  
275              # Create an index for the list of text
276              self.embeddings.index(documents())
277  
278              # Search for best match
279              result = self.embeddings.search("feel good story", 1)[0]
280  
281              self.assertEqual(result["text"], self.data[4])
282  
283          def testHybrid(self):
284              """
285              Test hybrid search
286              """
287  
288              # Build data array
289              data = [(uid, text, None) for uid, text in enumerate(self.data)]
290  
291              # Index data with sparse + dense vectors.
292              embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "hybrid": True, "content": self.backend})
293              embeddings.index(data)
294  
295              # Run search
296              result = embeddings.search("feel good story", 1)[0]
297              self.assertEqual(result["text"], data[4][1])
298  
299              # Generate temp file path
300              index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.hybrid")
301  
302              # Test load/save
303              embeddings.save(index)
304              embeddings.load(index)
305  
306              # Run search
307              result = embeddings.search("feel good story", 1)[0]
308              self.assertEqual(result["text"], data[4][1])
309  
310              # Index data with sparse + dense vectors and unnormalized scores.
311              embeddings.config["scoring"]["normalize"] = False
312              embeddings.index(data)
313  
314              # Run search
315              result = embeddings.search("feel good story", 1)[0]
316              self.assertEqual(result["text"], data[4][1])
317  
318              # Index data with sparse + dense vectors and bb25 normalized scores
319              embeddings.config["scoring"]["normalize"] = "bb25"
320              embeddings.index(data)
321  
322              # Run search
323              result = embeddings.search("canada intact iceberg a", 1)[0]
324              self.assertEqual(result["text"], data[1][1])
325  
326              # Test upsert
327              data[0] = (0, "Feel good story: baby panda born", None)
328              embeddings.upsert([data[0]])
329  
330              result = embeddings.search("feel good story", 1)[0]
331              self.assertEqual(result["text"], data[0][1])
332  
333          def testIndex(self):
334              """
335              Test index
336              """
337  
338              # Create an index for the list of text
339              self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
340  
341              # Search for best match
342              result = self.embeddings.search("feel good story", 1)[0]
343  
344              self.assertEqual(result["text"], self.data[4])
345  
346          def testIndexTokens(self):
347              """
348              Test index with tokens
349              """
350  
351              # Create an index for the list of text
352              self.embeddings.index([(uid, text.split(), None) for uid, text in enumerate(self.data)])
353  
354              # Search for best match
355              result = self.embeddings.search("feel good story", 1)[0]
356  
357              self.assertEqual(result["text"], self.data[4])
358  
359          def testInfo(self):
360              """
361              Test info
362              """
363  
364              # Create an index for the list of text
365              self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
366  
367              output = io.StringIO()
368              with contextlib.redirect_stdout(output):
369                  self.embeddings.info()
370  
371              self.assertIn("txtai", output.getvalue())
372  
373          def testInstructions(self):
374              """
375              Test indexing with instruction prefixes.
376              """
377  
378              embeddings = Embeddings(
379                  {
380                      "path": "sentence-transformers/nli-mpnet-base-v2",
381                      "content": self.backend,
382                      "instructions": {"query": "query: ", "data": "passage: "},
383                  }
384              )
385  
386              embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
387  
388              # Search for best match
389              result = embeddings.search("feel good story", 1)[0]
390  
391              self.assertEqual(result["text"], self.data[4])
392  
393          def testInvalidData(self):
394              """
395              Test invalid JSON data
396              """
397  
398              # Test invalid JSON value
399              with self.assertRaises(ValueError):
400                  self.embeddings.index([(0, {"text": "This is a test", "flag": float("NaN")}, None)])
401  
402          def testKeyword(self):
403              """
404              Test keyword only (sparse) search
405              """
406  
407              # Build data array
408              data = [(uid, text, None) for uid, text in enumerate(self.data)]
409  
410              # Index data with sparse keyword vectors
411              embeddings = Embeddings({"keyword": True, "content": self.backend})
412              embeddings.index(data)
413  
414              # Run search
415              result = embeddings.search("lottery ticket", 1)[0]
416              self.assertEqual(result["text"], data[4][1])
417  
418              # Test count method
419              self.assertEqual(embeddings.count(), len(data))
420  
421              # Generate temp file path
422              index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.keyword")
423  
424              # Test load/save
425              embeddings.save(index)
426              embeddings.load(index)
427  
428              # Run search
429              result = embeddings.search("lottery ticket", 1)[0]
430              self.assertEqual(result["text"], data[4][1])
431  
432              # Update data
433              data[0] = (0, "Feel good story: baby panda born", None)
434              embeddings.upsert([data[0]])
435  
436              # Search for best match
437              result = embeddings.search("feel good story", 1)[0]
438              self.assertEqual(result["text"], data[0][1])
439  
440          def testMultiData(self):
441              """
442              Test indexing with multiple data types (text, documents)
443              """
444  
445              embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": self.backend, "batch": len(self.data)})
446  
447              # Create an index using mixed data (text and documents)
448              data = []
449              for uid, text in enumerate(self.data):
450                  data.append((uid, text, None))
451                  data.append((uid, {"content": text}, None))
452  
453              embeddings.index(data)
454  
455              # Search for best match
456              result = embeddings.search("feel good story", 1)[0]
457  
458              self.assertEqual(result["text"], self.data[4])
459  
460          def testMultiSave(self):
461              """
462              Test multiple successive saves
463              """
464  
465              # Create an index for the list of text
466              self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
467  
468              # Save original index
469              index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.insert")
470              self.embeddings.save(index)
471  
472              # Modify index
473              self.embeddings.upsert([(0, "Looking out into the dreadful abyss", None)])
474  
475              # Save to a different location
476              indexupdate = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.update")
477              self.embeddings.save(indexupdate)
478  
479              # Save to same location
480              self.embeddings.save(index)
481  
482              # Test all indexes match
483              result = self.embeddings.search("feel good story", 1)[0]
484              self.assertEqual(result["text"], self.data[4])
485  
486              self.embeddings.load(index)
487              result = self.embeddings.search("feel good story", 1)[0]
488              self.assertEqual(result["text"], self.data[4])
489  
490              self.embeddings.load(indexupdate)
491              result = self.embeddings.search("feel good story", 1)[0]
492              self.assertEqual(result["text"], self.data[4])
493  
494          def testNoIndex(self):
495              """
496              Test an embeddings instance with no available indexes
497              """
498  
499              # Disable top-level indexing
500              embeddings = Embeddings(
501                  {
502                      "content": self.backend,
503                      "defaults": False,
504                  }
505              )
506              embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
507  
508              with self.assertRaises(IndexNotFoundError):
509                  embeddings.search("select id, text, score from txtai where similar('feel good story')")
510  
511          def testNotImplemented(self):
512              """
513              Test exceptions for non-implemented methods
514              """
515  
516              db = RDBMS({})
517  
518              self.assertRaises(NotImplementedError, db.connect, None)
519              self.assertRaises(NotImplementedError, db.getcursor)
520              self.assertRaises(NotImplementedError, db.jsonprefix)
521              self.assertRaises(NotImplementedError, db.jsoncolumn, None)
522              self.assertRaises(NotImplementedError, db.rows)
523              self.assertRaises(NotImplementedError, db.addfunctions)
524  
525              db = Embedded({})
526              self.assertRaises(NotImplementedError, db.copy, None)
527  
528          def testObject(self):
529              """
530              Test object field
531              """
532  
533              # Encode object
534              embeddings = Embeddings({"defaults": False, "content": self.backend, "objects": True})
535              embeddings.index([{"object": "binary data".encode("utf-8")}])
536  
537              # Decode and test extracted object
538              obj = embeddings.search("select object from txtai where id = 0")[0]["object"]
539              self.assertEqual(str(obj.getvalue(), "utf-8"), "binary data")
540  
541          @patch.dict(os.environ, {"ALLOW_PICKLE": "True"})
542          def testPickle(self):
543              """
544              Test pickle configuration
545              """
546  
547              embeddings = Embeddings(
548                  {
549                      "format": "pickle",
550                      "path": "sentence-transformers/nli-mpnet-base-v2",
551                      "content": self.backend,
552                  }
553              )
554  
555              embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
556  
557              # Generate temp file path
558              index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.pickle")
559  
560              embeddings.save(index)
561  
562              # Check that config exists
563              self.assertTrue(os.path.exists(os.path.join(index, "config")))
564  
565              # Check that index can be reloaded
566              embeddings.load(index)
567              self.assertEqual(embeddings.count(), 6)
568  
569          def testQuantize(self):
570              """
571              Test scalar quantization
572              """
573  
574              # Index data with 1-bit scalar quantization
575              embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "quantize": 1, "content": self.backend})
576              embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
577  
578              # Search for best match
579              result = self.embeddings.search("feel good story", 1)[0]
580              self.assertEqual(result["text"], self.data[4])
581  
582          def testQueryModel(self):
583              """
584              Test index
585              """
586  
587              embeddings = Embeddings(
588                  {"path": "sentence-transformers/nli-mpnet-base-v2", "content": self.backend, "query": {"path": "neuml/t5-small-txtsql"}}
589              )
590  
591              # Create an index for the list of text
592              embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
593  
594              # Search for best match
595              result = embeddings.search("feel good story with win in text", 1)[0]
596  
597              self.assertEqual(result["text"], self.data[4])
598  
599          def testReindex(self):
600              """
601              Test reindex
602              """
603  
604              # Create an index for the list of text
605              self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
606  
607              # Delete records to test indexids still match
608              self.embeddings.delete(([0, 1]))
609  
610              # Reindex
611              self.embeddings.reindex({"path": "sentence-transformers/nli-mpnet-base-v2"})
612  
613              # Search for best match
614              result = self.embeddings.search("feel good story", 1)[0]
615  
616              self.assertEqual(result["text"], self.data[4])
617  
618          def testSave(self):
619              """
620              Test save
621              """
622  
623              # Create an index for the list of text
624              self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
625  
626              # Generate temp file path
627              index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}")
628  
629              self.embeddings.save(index)
630              self.embeddings.load(index)
631  
632              # Search for best match
633              result = self.embeddings.search("feel good story", 1)[0]
634  
635              self.assertEqual(result["text"], self.data[4])
636  
637              # Test offsets still work after save/load
638              self.embeddings.upsert([(0, "Looking out into the dreadful abyss", None)])
639              self.assertEqual(self.embeddings.count(), len(self.data))
640  
641          def testSettings(self):
642              """
643              Test custom SQLite settings
644              """
645  
646              # Index with write-ahead logging enabled
647              embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": self.backend, "sqlite": {"wal": True}})
648  
649              # Create an index for the list of text
650              embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
651  
652              # Search for best match
653              result = embeddings.search("feel good story", 1)[0]
654  
655              self.assertEqual(result["text"], self.data[4])
656  
657          def testSQL(self):
658              """
659              Test running a SQL query
660              """
661  
662              # Create an index for the list of text
663              self.embeddings.index([(uid, {"text": text, "length": len(text), "attribute": f"ID{uid}"}, None) for uid, text in enumerate(self.data)])
664  
665              # Test similar
666              result = self.embeddings.search(
667                  "select text, score from txtai where similar('feel good story') group by text, score having count(*) > 0 order by score desc", 1
668              )[0]
669              self.assertEqual(result["text"], self.data[4])
670  
671              # Test similar with limits
672              result = self.embeddings.search("select * from txtai where similar('feel good story', 1) limit 1")[0]
673              self.assertEqual(result["text"], self.data[4])
674  
675              # Test similar with offset
676              result = self.embeddings.search("select * from txtai where similar('feel good story') offset 1")[0]
677              self.assertEqual(result["text"], self.data[5])
678  
679              # Test where
680              result = self.embeddings.search("select * from txtai where text like '%iceberg%'", 1)[0]
681              self.assertEqual(result["text"], self.data[1])
682  
683              # Test count
684              result = self.embeddings.search("select count(*) from txtai")[0]
685              self.assertEqual(list(result.values())[0], len(self.data))
686  
687              # Test columns
688              result = self.embeddings.search("select id, text, length, data, entry from txtai")[0]
689              self.assertEqual(sorted(result.keys()), ["data", "entry", "id", "length", "text"])
690  
691              # Test column filtering
692              result = self.embeddings.search("select text from txtai where attribute = 'ID4'", 1)[0]
693              self.assertEqual(result["text"], self.data[4])
694  
695              # Test SQL parse error
696              with self.assertRaises(SQLError):
697                  self.embeddings.search("select * from txtai where bad,query")
698  
699          def testSQLBind(self):
700              """
701              Test SQL statements with bind parameters
702              """
703  
704              # Create an index for the list of text
705              self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
706  
707              # Test similar clause bind parameters
708              result = self.embeddings.search("select id, text, score from txtai where similar(:x)", parameters={"x": "feel good story"})[0]
709              self.assertEqual(result["text"], self.data[4])
710  
711              # Test similar clause bind and non-bind parameters
712              result = self.embeddings.search("select id, text, score from txtai where similar(:x, 0.5)", parameters={"x": "feel good story"})[0]
713              self.assertEqual(result["text"], self.data[4])
714  
715              # Test where filtering with bind parameters
716              result = self.embeddings.search("select * from txtai where text like :x", parameters={"x": "%iceberg%"})[0]
717              self.assertEqual(result["text"], self.data[1])
718  
719          def testSparse(self):
720              """
721              Test sparse vector search
722              """
723  
724              # Build data array
725              data = [(uid, text, None) for uid, text in enumerate(self.data)]
726  
727              # Index data with sparse vectors
728              embeddings = Embeddings({"sparse": "sparse-encoder-testing/splade-bert-tiny-nq", "content": self.backend})
729              embeddings.index(data)
730  
731              # Run search
732              result = embeddings.search("lottery ticket", 1)[0]
733              self.assertEqual(result["text"], data[4][1])
734  
735              # Test count method
736              self.assertEqual(embeddings.count(), len(data))
737  
738              # Generate temp file path
739              index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.sparse")
740  
741              # Test load/save
742              embeddings.save(index)
743              embeddings.load(index)
744  
745              # Run search
746              result = embeddings.search("lottery ticket", 1)[0]
747              self.assertEqual(result["text"], data[4][1])
748  
749              # Update data
750              data[0] = (0, "Feel good story: baby panda born", None)
751              embeddings.upsert([data[0]])
752  
753              # Search for best match
754              result = embeddings.search("feel good story", 1)[0]
755              self.assertEqual(result["text"], data[0][1])
756  
757          def testSubindex(self):
758              """
759              Test subindex
760              """
761  
762              # Build data array
763              data = [(uid, text, None) for uid, text in enumerate(self.data)]
764  
765              # Disable top-level indexing and create subindex
766              embeddings = Embeddings(
767                  {"content": self.backend, "defaults": False, "indexes": {"index1": {"path": "sentence-transformers/nli-mpnet-base-v2"}}}
768              )
769              embeddings.index(data)
770  
771              # Test transform
772              self.assertEqual(embeddings.transform("feel good story").shape, (768,))
773  
774              # Run search
775              result = embeddings.search("feel good story", 1)[0]
776              self.assertEqual(result["text"], data[4][1])
777  
778              # Run SQL search
779              result = embeddings.search("select id, text, score from txtai where similar('feel good story', 10, 0.5)")[0]
780              self.assertEqual(result["text"], data[4][1])
781  
782              # Test missing index
783              with self.assertRaises(IndexNotFoundError):
784                  embeddings.search("select id, text, score from txtai where similar('feel good story', 'notindex')")
785  
786              # Generate temp file path
787              index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.subindex")
788  
789              # Test load/save
790              embeddings.save(index)
791              embeddings.load(index)
792  
793              # Run search
794              result = embeddings.search("feel good story", 1)[0]
795              self.assertEqual(result["text"], data[4][1])
796  
797              # Update data
798              data[0] = (0, "Feel good story: baby panda born", None)
799              embeddings.upsert([data[0]])
800  
801              # Search for best match
802              result = embeddings.search("feel good story", 1)[0]
803              self.assertEqual(result["text"], data[0][1])
804  
805              # Check missing text is set to id when top-level indexing is disabled
806              embeddings.upsert([(embeddings.count(), {"content": "empty text"}, None)])
807              result = embeddings.search(f"{embeddings.count() - 1}", 1)[0]
808              self.assertEqual(result["text"], str(embeddings.count() - 1))
809  
810              # Close embeddings
811              embeddings.close()
812  
813          def testSubindexEmpty(self):
814              """
815              Test loading an empty subindex
816              """
817  
818              # Build data array
819              data = [(uid, {"column1": text}, None) for uid, text in enumerate(self.data)]
820  
821              # Disable top-level indexing and create subindexes
822              embeddings = Embeddings(
823                  {
824                      "content": self.backend,
825                      "defaults": False,
826                      "indexes": {
827                          "index1": {"path": "sentence-transformers/nli-mpnet-base-v2", "columns": {"text": "column1"}},
828                          "index2": {"path": "sentence-transformers/nli-mpnet-base-v2", "columns": {"text": "column2"}},
829                      },
830                  }
831              )
832              embeddings.index(data)
833  
834              # Generate temp file path
835              index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.subindexempty")
836  
837              # Save index
838              embeddings.save(index)
839  
840              # Test exists
841              self.assertTrue(embeddings.exists(index))
842  
843              # Load index
844              embeddings.load(index)
845  
846              # Test search
847              result = embeddings.search("feel good story", 1)[0]
848              self.assertEqual(result["text"], data[4][1]["text"])
849  
850          def testTerms(self):
851              """
852              Test extracting keyword terms from query
853              """
854  
855              result = self.embeddings.terms("select * from txtai where similar('keyword terms')")
856              self.assertEqual(result, "keyword terms")
857  
858          def testTruncate(self):
859              """
860              Test dimensionality truncation
861              """
862  
863              # Truncate vectors to a specified number of dimensions
864              embeddings = Embeddings(
865                  {"path": "sentence-transformers/nli-mpnet-base-v2", "dimensionality": 750, "content": self.backend, "vectors": {"revision": "main"}}
866              )
867              embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
868  
869              # Search for best match
870              result = self.embeddings.search("feel good story", 1)[0]
871              self.assertEqual(result["text"], self.data[4])
872  
873          def testUpsert(self):
874              """
875              Test upsert
876              """
877  
878              # Build data array
879              data = [(uid, text, None) for uid, text in enumerate(self.data)]
880  
881              # Reset embeddings for test
882              self.embeddings.ann = None
883              self.embeddings.database = None
884  
885              # Create an index for the list of text
886              self.embeddings.upsert(data)
887  
888              # Update data
889              data[0] = (0, "Feel good story: baby panda born", None)
890              self.embeddings.upsert([data[0]])
891  
892              # Search for best match
893              result = self.embeddings.search("feel good story", 1)[0]
894              self.assertEqual(result["text"], data[0][1])
895  
896          def testUpsertBatch(self):
897              """
898              Test upsert batch
899              """
900  
901              try:
902                  # Build data array
903                  data = [(uid, text, None) for uid, text in enumerate(self.data)]
904  
905                  # Reset embeddings for test
906                  self.embeddings.ann = None
907                  self.embeddings.database = None
908  
909                  # Create an index for the list of text
910                  self.embeddings.upsert(data)
911  
912                  # Set batch size to 1
913                  self.embeddings.config["batch"] = 1
914  
915                  # Update data
916                  data[0] = (0, "Feel good story: baby panda born", None)
917                  data[1] = (0, "Not good news", None)
918                  self.embeddings.upsert([data[0], data[1]])
919  
920                  # Search for best match
921                  result = self.embeddings.search("feel good story", 1)[0]
922  
923                  self.assertEqual(result["text"], data[0][1])
924              finally:
925                  del self.embeddings.config["batch"]
926  
927          def category(self):
928              """
929              Content backend category.
930  
931              Returns:
932                  category
933              """
934  
935              return self.__class__.__name__.lower().replace("test", "")