/ test / python / testembeddings.py
testembeddings.py
  1  """
  2  Embeddings module tests
  3  """
  4  
  5  import json
  6  import os
  7  import tempfile
  8  import unittest
  9  
 10  from unittest.mock import patch
 11  
 12  import numpy as np
 13  
 14  from txtai.embeddings import Embeddings, Reducer
 15  from txtai.serialize import SerializeFactory
 16  
 17  
 18  # pylint: disable=R0904
 19  class TestEmbeddings(unittest.TestCase):
 20      """
 21      Embeddings tests.
 22      """
 23  
 24      @classmethod
 25      def setUpClass(cls):
 26          """
 27          Initialize test data.
 28          """
 29  
 30          cls.data = [
 31              "US tops 5 million confirmed virus cases",
 32              "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg",
 33              "Beijing mobilises invasion craft along coast as Taiwan tensions escalate",
 34              "The National Park Service warns against sacrificing slower friends in a bear attack",
 35              "Maine man wins $1M from $25 lottery ticket",
 36              "Make huge profits without work, earn up to $100,000 a day",
 37          ]
 38  
 39          # Create embeddings model, backed by sentence-transformers & transformers
 40          cls.embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2"})
 41  
 42      @classmethod
 43      def tearDownClass(cls):
 44          """
 45          Cleanup data.
 46          """
 47  
 48          if cls.embeddings:
 49              cls.embeddings.close()
 50  
 51      def testAutoId(self):
 52          """
 53          Test auto id generation
 54          """
 55  
 56          # Default sequence id
 57          embeddings = Embeddings()
 58          embeddings.index(self.data)
 59  
 60          uid = embeddings.search(self.data[4], 1)[0][0]
 61          self.assertEqual(uid, 4)
 62  
 63          # UUID
 64          embeddings = Embeddings(autoid="uuid4")
 65          embeddings.index(self.data)
 66  
 67          uid = embeddings.search(self.data[4], 1)[0][0]
 68          self.assertEqual(len(uid), 36)
 69  
 70      def testColumns(self):
 71          """
 72          Test custom text/object columns
 73          """
 74  
 75          embeddings = Embeddings({"keyword": True, "columns": {"text": "value"}})
 76          data = [{"value": x} for x in self.data]
 77          embeddings.index([(uid, text, None) for uid, text in enumerate(data)])
 78  
 79          # Run search
 80          uid = embeddings.search("lottery", 1)[0][0]
 81          self.assertEqual(uid, 4)
 82  
 83      def testContext(self):
 84          """
 85          Test embeddings context manager
 86          """
 87  
 88          # Generate temp file path
 89          index = os.path.join(tempfile.gettempdir(), "embeddings.context")
 90  
 91          with Embeddings() as embeddings:
 92              embeddings.index(self.data)
 93              embeddings.save(index)
 94  
 95          with Embeddings().load(index) as embeddings:
 96              uid = embeddings.search(self.data[4], 1)[0][0]
 97              self.assertEqual(uid, 4)
 98  
 99      def testDefaults(self):
100          """
101          Test default configuration
102          """
103  
104          # Run index with no config which will fall back to default configuration
105          embeddings = Embeddings()
106          embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
107  
108          self.assertEqual(embeddings.count(), 6)
109  
110      def testDelete(self):
111          """
112          Test delete
113          """
114  
115          # Create an index for the list of text
116          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
117  
118          # Delete best match
119          self.embeddings.delete([4])
120  
121          # Search for best match
122          uid = self.embeddings.search("feel good story", 1)[0][0]
123  
124          self.assertEqual(self.embeddings.count(), 5)
125          self.assertEqual(uid, 5)
126  
127      def testDense(self):
128          """
129          Test dense alias
130          """
131  
132          # Dense flag is an alias for path
133          embeddings = Embeddings(dense="sentence-transformers/nli-mpnet-base-v2")
134          embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
135  
136          self.assertEqual(embeddings.count(), 6)
137  
138      def testEmpty(self):
139          """
140          Test empty index
141          """
142  
143          # Test search against empty index
144          embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2"})
145          self.assertEqual(embeddings.search("test"), [])
146  
147          # Test index with no data
148          embeddings.index([])
149          self.assertIsNone(embeddings.ann)
150  
151          # Test upsert with no data
152          embeddings.index([(0, "this is a test", None)])
153          embeddings.upsert([])
154          self.assertIsNotNone(embeddings.ann)
155  
156      def testEmptyString(self):
157          """
158          Test empty string indexing
159          """
160  
161          # Test empty string
162          self.embeddings.index([(0, "", None)])
163          self.assertTrue(self.embeddings.search("test"))
164  
165          # Test empty string with dict
166          self.embeddings.index([(0, {"text": ""}, None)])
167          self.assertTrue(self.embeddings.search("test"))
168  
169      def testExternal(self):
170          """
171          Test embeddings backed by external vectors
172          """
173  
174          def transform(data):
175              embeddings = []
176              for text in data:
177                  # Create dummy embedding using sum and mean of character ordinals
178                  ordinals = [ord(c) for c in text]
179                  embeddings.append(np.array([sum(ordinals), np.mean(ordinals)]))
180  
181              return embeddings
182  
183          # Index data using simple embeddings transform method
184          embeddings = Embeddings({"method": "external", "transform": transform})
185          embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
186  
187          # Run search
188          uid = embeddings.search(self.data[4], 1)[0][0]
189          self.assertEqual(uid, 4)
190  
191      def testExternalPrecomputed(self):
192          """
193          Test embeddings backed by external pre-computed vectors
194          """
195  
196          # Test with no transform function
197          data = np.random.rand(5, 10).astype(np.float32)
198  
199          embeddings = Embeddings({"method": "external"})
200          embeddings.index([(uid, row, None) for uid, row in enumerate(data)])
201  
202          # Run search
203          uid = embeddings.search(data[4], 1)[0][0]
204          self.assertEqual(uid, 4)
205  
206      def testHybrid(self):
207          """
208          Test hybrid search
209          """
210  
211          # Build data array
212          data = [(uid, text, None) for uid, text in enumerate(self.data)]
213  
214          # Index data with sparse + dense vectors
215          embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "hybrid": True})
216          embeddings.index(data)
217  
218          # Run search
219          uid = embeddings.search("feel good story", 1)[0][0]
220          self.assertEqual(uid, 4)
221  
222          # Generate temp file path
223          index = os.path.join(tempfile.gettempdir(), "embeddings.hybrid")
224  
225          # Test load/save
226          embeddings.save(index)
227          embeddings.load(index)
228  
229          # Run search
230          uid = embeddings.search("feel good story", 1)[0][0]
231          self.assertEqual(uid, 4)
232  
233          # Index data with sparse + dense vectors and unnormalized scores
234          embeddings.config["scoring"]["normalize"] = False
235          embeddings.index(data)
236  
237          # Run search
238          uid = embeddings.search("feel good story", 1)[0][0]
239          self.assertEqual(uid, 4)
240  
241          # Index data with sparse + dense vectors and bb25 normalization
242          embeddings.config["scoring"]["normalize"] = "bb25"
243          embeddings.index(data)
244  
245          # Run search
246          uid = embeddings.search("canada intact iceberg a", 1)[0][0]
247          self.assertEqual(uid, 1)
248  
249          # Test upsert
250          data[0] = (0, "Feel good story: baby panda born", None)
251          embeddings.upsert([data[0]])
252  
253          uid = embeddings.search("feel good story", 1)[0][0]
254          self.assertEqual(uid, 0)
255  
256      def testIds(self):
257          """
258          Test legacy config ids loading
259          """
260  
261          # Create an index for the list of text
262          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
263  
264          # Generate temp file path
265          index = os.path.join(tempfile.gettempdir(), "embeddings.ids")
266  
267          # Save index
268          self.embeddings.save(index)
269  
270          # Set ids on config to simulate legacy ids format
271          with open(f"{index}/config.json", "r", encoding="utf-8") as handle:
272              config = json.load(handle)
273              config["ids"] = list(range(len(self.data)))
274  
275          with open(f"{index}/config.json", "w", encoding="utf-8") as handle:
276              json.dump(config, handle, default=str, indent=2)
277  
278          # Reload index
279          self.embeddings.load(index)
280  
281          # Run search
282          uid = self.embeddings.search("feel good story", 1)[0][0]
283          self.assertEqual(uid, 4)
284  
285          # Check that ids is not in config
286          self.assertTrue("ids" not in self.embeddings.config)
287  
288      @patch.dict(os.environ, {"ALLOW_PICKLE": "True"})
289      def testIdsPickle(self):
290          """
291          Test legacy pickle ids
292          """
293  
294          # Create an index for the list of text
295          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
296  
297          # Generate temp file path
298          index = os.path.join(tempfile.gettempdir(), "embeddings.idspickle")
299  
300          # Save index
301          self.embeddings.save(index)
302  
303          # Create ids as pickle
304          path = os.path.join(tempfile.gettempdir(), "embeddings.idspickle", "ids")
305          serializer = SerializeFactory.create("pickle", allowpickle=True)
306          serializer.save(self.embeddings.ids.ids, path)
307  
308          with self.assertWarns(RuntimeWarning):
309              self.embeddings.load(index)
310  
311          # Run search
312          uid = self.embeddings.search("feel good story", 1)[0][0]
313          self.assertEqual(uid, 4)
314  
315      def testIndex(self):
316          """
317          Test index
318          """
319  
320          # Create an index for the list of text
321          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
322  
323          # Search for best match
324          uid = self.embeddings.search("feel good story", 1)[0][0]
325  
326          self.assertEqual(uid, 4)
327  
328      def testKeyword(self):
329          """
330          Test keyword only (sparse) search
331          """
332  
333          # Build data array
334          data = [(uid, text, None) for uid, text in enumerate(self.data)]
335  
336          # Index data with sparse keyword vectors
337          embeddings = Embeddings({"keyword": True})
338          embeddings.index(data)
339  
340          # Run search
341          uid = embeddings.search("lottery ticket", 1)[0][0]
342          self.assertEqual(uid, 4)
343  
344          # Test count method
345          self.assertEqual(embeddings.count(), len(data))
346  
347          # Generate temp file path
348          index = os.path.join(tempfile.gettempdir(), "embeddings.keyword")
349  
350          # Test load/save
351          embeddings.save(index)
352          embeddings.load(index)
353  
354          # Run search
355          uid = embeddings.search("lottery ticket", 1)[0][0]
356          self.assertEqual(uid, 4)
357  
358          # Update data
359          data[0] = (0, "Feel good story: baby panda born", None)
360          embeddings.upsert([data[0]])
361  
362          # Search for best match
363          uid = embeddings.search("feel good story", 1)[0][0]
364          self.assertEqual(uid, 0)
365  
366      def testQuantize(self):
367          """
368          Test scalar quantization
369          """
370  
371          for ann in ["faiss", "numpy", "torch"]:
372              # Index data with 1-bit scalar quantization
373              embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "quantize": 1, "backend": ann})
374              embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
375  
376              # Search for best match
377              uid = embeddings.search("feel good story", 1)[0][0]
378              self.assertEqual(uid, 4)
379  
380      def testReducer(self):
381          """
382          Test reducer model
383          """
384  
385          # Test model with single PCA component
386          data = np.random.rand(5, 5).astype(np.float32)
387          reducer = Reducer(data, 1)
388  
389          # Generate query and keep original data to ensure it changes
390          query = np.random.rand(1, 5).astype(np.float32)
391          original = query.copy()
392  
393          # Run test
394          reducer(query)
395          self.assertFalse(np.array_equal(query, original))
396  
397          # Test model with multiple PCA components
398          reducer = Reducer(data, 3)
399  
400          # Generate query and keep original data to ensure it changes
401          query = np.random.rand(5).astype(np.float32)
402          original = query.copy()
403  
404          # Run test
405          reducer(query)
406          self.assertFalse(np.array_equal(query, original))
407  
408      @patch.dict(os.environ, {"ALLOW_PICKLE": "True"})
409      def testReducerLegacy(self):
410          """
411          Test reducer model with legacy model format
412          """
413  
414          # Test model with single PCA component
415          data = np.random.rand(5, 5).astype(np.float32)
416          reducer = Reducer(data, 1)
417  
418          # Save legacy format
419          path = os.path.join(tempfile.gettempdir(), "reducer")
420          serializer = SerializeFactory.create("pickle", allowpickle=True)
421          serializer.save(reducer.model, path)
422  
423          # Load legacy format
424          reducer = Reducer()
425          reducer.load(path)
426  
427          # Generate query and keep original data to ensure it changes
428          query = np.random.rand(1, 5).astype(np.float32)
429          original = query.copy()
430  
431          # Run test
432          reducer(query)
433          self.assertFalse(np.array_equal(query, original))
434  
435      def testSave(self):
436          """
437          Test save
438          """
439  
440          # Create an index for the list of text
441          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
442  
443          # Generate temp file path
444          index = os.path.join(tempfile.gettempdir(), "embeddings.base")
445  
446          self.embeddings.save(index)
447          self.embeddings.load(index)
448  
449          # Search for best match
450          uid = self.embeddings.search("feel good story", 1)[0][0]
451  
452          self.assertEqual(uid, 4)
453  
454          # Test offsets still work after save/load
455          self.embeddings.upsert([(0, "Looking out into the dreadful abyss", None)])
456          self.assertEqual(self.embeddings.count(), len(self.data))
457  
458      def testShortcuts(self):
459          """
460          Test embeddings creation shortcuts
461          """
462  
463          tests = [
464              ({"keyword": True}, ["scoring"]),
465              ({"keyword": "sif"}, ["scoring"]),
466              ({"sparse": True}, ["scoring"]),
467              ({"dense": True}, ["ann"]),
468              ({"hybrid": True}, ["ann", "scoring"]),
469              ({"hybrid": "tfidf"}, ["ann", "scoring"]),
470              ({"hybrid": "sparse"}, ["ann", "scoring"]),
471              ({"graph": True}, ["graph"]),
472          ]
473  
474          for config, checks in tests:
475              embeddings = Embeddings(config)
476              embeddings.index(["test"])
477  
478              for attr in checks:
479                  self.assertIsNotNone(getattr(embeddings, attr))
480  
481      def testSimilarity(self):
482          """
483          Test similarity
484          """
485  
486          # Get best matching id
487          uid = self.embeddings.similarity("feel good story", self.data)[0][0]
488  
489          self.assertEqual(uid, 4)
490  
491      def testSparse(self):
492          """
493          Test sparse vector search
494          """
495  
496          # Build data array
497          data = [(uid, text, None) for uid, text in enumerate(self.data)]
498  
499          # Index data with sparse vectors
500          embeddings = Embeddings({"sparse": "sparse-encoder-testing/splade-bert-tiny-nq"})
501          embeddings.index(data)
502  
503          # Run search
504          uid = embeddings.search("lottery ticket", 1)[0][0]
505          self.assertEqual(uid, 4)
506  
507          # Test count method
508          self.assertEqual(embeddings.count(), len(data))
509  
510          # Generate temp file path
511          index = os.path.join(tempfile.gettempdir(), "embeddings.sparse")
512  
513          # Test load/save
514          embeddings.save(index)
515          embeddings.load(index)
516  
517          # Run search
518          uid = embeddings.search("lottery ticket", 1)[0][0]
519          self.assertEqual(uid, 4)
520  
521          # Test similarity
522          uid = embeddings.similarity("lottery ticket", self.data)[0][0]
523          self.assertEqual(uid, 4)
524  
525          # Update data
526          data[0] = (0, "Feel good story: baby panda born", None)
527          embeddings.upsert([data[0]])
528  
529          # Search for best match
530          uid = embeddings.search("feel good story", 1)[0][0]
531          self.assertEqual(uid, 0)
532  
533      def testSubindex(self):
534          """
535          Test subindex
536          """
537  
538          # Build data array
539          data = [(uid, text, None) for uid, text in enumerate(self.data)]
540  
541          # Disable top-level indexing and create subindex
542          embeddings = Embeddings({"defaults": False, "indexes": {"index1": {"path": "sentence-transformers/nli-mpnet-base-v2"}}})
543          embeddings.index(data)
544  
545          # Test transform
546          self.assertEqual(embeddings.transform("feel good story").shape, (768,))
547          self.assertEqual(embeddings.transform("feel good story", index="index1").shape, (768,))
548          with self.assertRaises(KeyError):
549              embeddings.transform("feel good story", index="index2")
550  
551          # Run search
552          uid = embeddings.search("feel good story", 1)[0][0]
553          self.assertEqual(uid, 4)
554  
555          # Generate temp file path
556          index = os.path.join(tempfile.gettempdir(), "embeddings.subindex")
557  
558          # Test load/save
559          embeddings.save(index)
560          embeddings.load(index)
561  
562          # Run search
563          uid = embeddings.search("feel good story", 1)[0][0]
564          self.assertEqual(uid, 4)
565  
566          # Update data
567          data[0] = (0, "Feel good story: baby panda born", None)
568          embeddings.upsert([data[0]])
569  
570          # Search for best match
571          uid = embeddings.search("feel good story", 10)[0][0]
572          self.assertEqual(uid, 0)
573  
574          # Check missing text is set to id when top-level indexing is disabled
575          embeddings.upsert([(embeddings.count(), {"content": "empty text"}, None)])
576          uid = embeddings.search(f"{embeddings.count() - 1}", 1)[0][0]
577          self.assertEqual(uid, embeddings.count() - 1)
578  
579          # Close embeddings
580          embeddings.close()
581  
582      def testTruncate(self):
583          """
584          Test dimensionality truncation
585          """
586  
587          # Truncate vectors to a specified number of dimensions
588          embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "dimensionality": 750, "vectors": {"revision": "main"}})
589          embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
590  
591          # Search for best match
592          uid = embeddings.search("feel good story", 1)[0][0]
593          self.assertEqual(uid, 4)
594  
595      def testUpsert(self):
596          """
597          Test upsert
598          """
599  
600          # Build data array
601          data = [(uid, text, None) for uid, text in enumerate(self.data)]
602  
603          # Reset embeddings for test
604          self.embeddings.ann = None
605          self.embeddings.ids = None
606  
607          # Create an index for the list of text
608          self.embeddings.upsert(data)
609  
610          # Update data
611          data[0] = (0, "Feel good story: baby panda born", None)
612          self.embeddings.upsert([data[0]])
613  
614          # Search for best match
615          uid = self.embeddings.search("feel good story", 1)[0][0]
616  
617          self.assertEqual(uid, 0)
618  
619      @patch("os.cpu_count")
620      def testWords(self, cpucount):
621          """
622          Test embeddings backed by word vectors
623          """
624  
625          # Mock CPU count
626          cpucount.return_value = 1
627  
628          # Create dataset
629          data = [(x, row.split(), None) for x, row in enumerate(self.data)]
630  
631          # Create embeddings model, backed by word vectors
632          embeddings = Embeddings({"path": "neuml/glove-6B-quantized", "scoring": "bm25", "pca": 3, "quantize": True})
633  
634          # Call scoring and index methods
635          embeddings.score(data)
636          embeddings.index(data)
637  
638          # Test search
639          self.assertIsNotNone(embeddings.search("win", 1))
640  
641          # Generate temp file path
642          index = os.path.join(tempfile.gettempdir(), "embeddings.wordvectors")
643  
644          # Test save/load
645          embeddings.save(index)
646          embeddings.load(index)
647  
648          # Test search
649          self.assertIsNotNone(embeddings.search("win", 1))
650  
651      @patch("os.cpu_count")
652      def testWordsUpsert(self, cpucount):
653          """
654          Test embeddings backed by word vectors with upserts
655          """
656  
657          # Mock CPU count
658          cpucount.return_value = 1
659  
660          # Create dataset
661          data = [(x, row.split(), None) for x, row in enumerate(self.data)]
662  
663          # Create embeddings model, backed by word vectors
664          embeddings = Embeddings({"path": "neuml/glove-6B/model.sqlite", "scoring": "bm25", "pca": 3})
665  
666          # Call scoring and index methods
667          embeddings.score(data)
668          embeddings.index(data)
669  
670          # Now upsert and override record
671          data = [(0, "win win", None)]
672  
673          # Update scoring and run upsert
674          embeddings.score(data)
675          embeddings.upsert(data)
676  
677          # Test search after upsert
678          uid = embeddings.search("win", 1)[0][0]
679          self.assertEqual(uid, 0)