testlitellm.py
1 """ 2 LiteLLM module tests 3 """ 4 5 import json 6 import os 7 import unittest 8 9 from http.server import HTTPServer, BaseHTTPRequestHandler 10 from threading import Thread 11 12 import numpy as np 13 14 from txtai.vectors import VectorsFactory 15 16 17 class RequestHandler(BaseHTTPRequestHandler): 18 """ 19 Test HTTP handler. 20 """ 21 22 def do_POST(self): 23 """ 24 POST request handler. 25 """ 26 27 # Generate mock response 28 response = [[0.0] * 768] 29 response = json.dumps(response).encode("utf-8") 30 31 self.send_response(200) 32 self.send_header("content-type", "application/json") 33 self.send_header("content-length", len(response)) 34 self.end_headers() 35 36 self.wfile.write(response) 37 self.wfile.flush() 38 39 40 class TestLiteLLM(unittest.TestCase): 41 """ 42 LiteLLM vectors tests 43 """ 44 45 @classmethod 46 def setUpClass(cls): 47 """ 48 Create mock http server. 49 """ 50 51 cls.httpd = HTTPServer(("127.0.0.1", 8004), RequestHandler) 52 53 server = Thread(target=cls.httpd.serve_forever, daemon=True) 54 server.start() 55 56 @classmethod 57 def tearDownClass(cls): 58 """ 59 Shutdown mock http server. 60 """ 61 62 cls.httpd.shutdown() 63 64 def testIndex(self): 65 """ 66 Test indexing with LiteLLM vectors 67 """ 68 69 # LiteLLM vectors instance 70 model = VectorsFactory.create( 71 {"path": "huggingface/sentence-transformers/all-MiniLM-L6-v2", "vectors": {"api_base": "http://127.0.0.1:8004"}}, None 72 ) 73 74 ids, dimension, batches, stream = model.index([(0, "test", None)]) 75 76 self.assertEqual(len(ids), 1) 77 self.assertEqual(dimension, 768) 78 self.assertEqual(batches, 1) 79 self.assertIsNotNone(os.path.exists(stream)) 80 81 # Test shape of serialized embeddings 82 with open(stream, "rb") as queue: 83 self.assertEqual(np.load(queue).shape, (1, 768))