/ test / python / testvectors / testdense / testlitellm.py
testlitellm.py
 1  """
 2  LiteLLM module tests
 3  """
 4  
 5  import json
 6  import os
 7  import unittest
 8  
 9  from http.server import HTTPServer, BaseHTTPRequestHandler
10  from threading import Thread
11  
12  import numpy as np
13  
14  from txtai.vectors import VectorsFactory
15  
16  
17  class RequestHandler(BaseHTTPRequestHandler):
18      """
19      Test HTTP handler.
20      """
21  
22      def do_POST(self):
23          """
24          POST request handler.
25          """
26  
27          # Generate mock response
28          response = [[0.0] * 768]
29          response = json.dumps(response).encode("utf-8")
30  
31          self.send_response(200)
32          self.send_header("content-type", "application/json")
33          self.send_header("content-length", len(response))
34          self.end_headers()
35  
36          self.wfile.write(response)
37          self.wfile.flush()
38  
39  
40  class TestLiteLLM(unittest.TestCase):
41      """
42      LiteLLM vectors tests
43      """
44  
45      @classmethod
46      def setUpClass(cls):
47          """
48          Create mock http server.
49          """
50  
51          cls.httpd = HTTPServer(("127.0.0.1", 8004), RequestHandler)
52  
53          server = Thread(target=cls.httpd.serve_forever, daemon=True)
54          server.start()
55  
56      @classmethod
57      def tearDownClass(cls):
58          """
59          Shutdown mock http server.
60          """
61  
62          cls.httpd.shutdown()
63  
64      def testIndex(self):
65          """
66          Test indexing with LiteLLM vectors
67          """
68  
69          # LiteLLM vectors instance
70          model = VectorsFactory.create(
71              {"path": "huggingface/sentence-transformers/all-MiniLM-L6-v2", "vectors": {"api_base": "http://127.0.0.1:8004"}}, None
72          )
73  
74          ids, dimension, batches, stream = model.index([(0, "test", None)])
75  
76          self.assertEqual(len(ids), 1)
77          self.assertEqual(dimension, 768)
78          self.assertEqual(batches, 1)
79          self.assertIsNotNone(os.path.exists(stream))
80  
81          # Test shape of serialized embeddings
82          with open(stream, "rb") as queue:
83              self.assertEqual(np.load(queue).shape, (1, 768))