testcluster.py
1 """ 2 Cluster API module tests 3 """ 4 5 import json 6 import os 7 import tempfile 8 import unittest 9 import urllib.parse 10 11 from http.server import HTTPServer, BaseHTTPRequestHandler 12 from threading import Thread 13 from unittest.mock import patch 14 15 from fastapi.testclient import TestClient 16 17 from txtai.api import application 18 19 # Configuration for an embeddings cluster 20 CLUSTER = """ 21 cluster: 22 shards: 23 - http://127.0.0.1:8002 24 - http://127.0.0.1:8003 25 """ 26 27 28 class RequestHandler(BaseHTTPRequestHandler): 29 """ 30 Test HTTP handler. 31 """ 32 33 def do_GET(self): 34 """ 35 GET request handler. 36 """ 37 38 if self.path == "/count": 39 response = 26 40 elif self.path.startswith("/search?query=select"): 41 if "group+by+id" in self.path: 42 response = [{"count(*)": 26}] 43 elif "group+by+text" in self.path: 44 response = [{"count(*)": 12, "text": "This is a test"}, {"count(*)": 14, "text": "And another test"}] 45 elif "group+by+txt" in self.path: 46 response = [{"count(*)": 12, "txt": "This is a test"}, {"count(*)": 14, "txt": "And another test"}] 47 else: 48 if self.server.server_port == 8002: 49 response = [{"count(*)": 12, "min(indexid)": 0, "max(indexid)": 11, "avg(indexid)": 6.3}] 50 else: 51 response = [{"count(*)": 16, "min(indexid)": 2, "max(indexid)": 14, "avg(indexid)": 6.7}] 52 elif self.path.startswith("/search"): 53 response = [{"id": 4, "score": 0.40}] 54 else: 55 response = {"result": "ok"} 56 57 # Convert response to string 58 response = json.dumps(response).encode("utf-8") 59 60 self.send_response(200) 61 self.send_header("content-type", "application/json") 62 self.send_header("content-length", len(response)) 63 self.end_headers() 64 65 self.wfile.write(response) 66 self.wfile.flush() 67 68 def do_POST(self): 69 """ 70 POST request handler. 71 """ 72 73 if self.path.startswith("/batchsearch"): 74 response = [[{"id": 4, "score": 0.40}], [{"id": 1, "score": 0.40}]] 75 elif self.path.startswith("/delete"): 76 if self.server.server_port == 8002: 77 response = [0] 78 else: 79 response = [] 80 else: 81 response = {"result": "ok"} 82 83 response = json.dumps(response).encode("utf-8") 84 85 self.send_response(200) 86 self.send_header("content-type", "application/json") 87 self.send_header("content-length", len(response)) 88 self.end_headers() 89 90 self.wfile.write(response) 91 self.wfile.flush() 92 93 94 @unittest.skipIf(os.name == "nt", "TestCluster skipped on Windows") 95 class TestCluster(unittest.TestCase): 96 """ 97 API tests for embeddings clusters 98 """ 99 100 @staticmethod 101 @patch.dict(os.environ, {"CONFIG": os.path.join(tempfile.gettempdir(), "testapi.yml"), "API_CLASS": "txtai.api.API"}) 102 def start(): 103 """ 104 Starts a mock FastAPI client. 105 """ 106 107 config = os.path.join(tempfile.gettempdir(), "testapi.yml") 108 109 with open(config, "w", encoding="utf-8") as output: 110 output.write(CLUSTER) 111 112 # Create new application and set on client 113 application.app = application.create() 114 client = TestClient(application.app) 115 application.start() 116 117 return client 118 119 @classmethod 120 def setUpClass(cls): 121 """ 122 Create API client on creation of class. 123 """ 124 125 cls.client = TestCluster.start() 126 127 cls.httpd1 = HTTPServer(("127.0.0.1", 8002), RequestHandler) 128 129 server1 = Thread(target=cls.httpd1.serve_forever, daemon=True) 130 server1.start() 131 132 cls.httpd2 = HTTPServer(("127.0.0.1", 8003), RequestHandler) 133 134 server2 = Thread(target=cls.httpd2.serve_forever, daemon=True) 135 server2.start() 136 137 # Index data 138 cls.client.post("add", json=[{"id": 0, "text": "test"}]) 139 cls.client.get("index") 140 141 @classmethod 142 def tearDownClass(cls): 143 """ 144 Shutdown mock http server. 145 """ 146 147 cls.httpd1.shutdown() 148 cls.httpd2.shutdown() 149 150 def testCount(self): 151 """ 152 Test cluster count 153 """ 154 155 self.assertEqual(self.client.get("count").json(), 52) 156 157 def testDelete(self): 158 """ 159 Test cluster delete 160 """ 161 162 self.assertEqual(self.client.post("delete", json=[0]).json(), [0]) 163 164 def testDeleteString(self): 165 """ 166 Test cluster delete with string id 167 """ 168 169 self.assertEqual(self.client.post("delete", json=["0"]).json(), [0]) 170 171 def testIds(self): 172 """ 173 Test id configurations 174 """ 175 176 # String ids 177 self.client.post("add", json=[{"id": "0", "text": "test"}]) 178 self.assertEqual(self.client.get("index").status_code, 200) 179 180 # Auto ids 181 self.client.post("add", json=[{"text": "test"}]) 182 self.assertEqual(self.client.get("index").status_code, 200) 183 184 def testReindex(self): 185 """ 186 Test cluster reindex 187 """ 188 189 self.assertEqual(self.client.post("reindex", json={"config": {"path": "sentence-transformers/nli-mpnet-base-v2"}}).status_code, 200) 190 191 def testSearch(self): 192 """ 193 Test cluster search 194 """ 195 196 # Encode parameters 197 params = json.dumps({"x": 1}) 198 199 query = urllib.parse.quote("feel good story") 200 uid = self.client.get(f"search?query={query}&limit=1&weights=0.5&index=default¶meters={params}&graph=False").json()[0]["id"] 201 self.assertEqual(uid, 4) 202 203 def testSearchBatch(self): 204 """ 205 Test cluster batch search 206 """ 207 208 results = self.client.post( 209 "batchsearch", 210 json={ 211 "queries": ["feel good story", "climate change"], 212 "limit": 1, 213 "weights": 0.5, 214 "index": "default", 215 "parameters": [{"x": 1}, {"x": 2}], 216 "graph": False, 217 }, 218 ).json() 219 220 uids = [result[0]["id"] for result in results] 221 self.assertEqual(uids, [4, 1]) 222 223 def testSQL(self): 224 """ 225 Test cluster SQL statement 226 """ 227 228 query = urllib.parse.quote("select count(*), min(indexid), max(indexid), avg(indexid) from txtai where text='This is a test'") 229 self.assertEqual( 230 self.client.get(f"search?query={query}").json(), [{"count(*)": 28, "min(indexid)": 0, "max(indexid)": 14, "avg(indexid)": 6.5}] 231 ) 232 233 query = urllib.parse.quote("select count(*), text txt from txtai group by txt order by count(*) desc") 234 self.assertEqual( 235 self.client.get(f"search?query={query}").json(), 236 [{"count(*)": 28, "txt": "And another test"}, {"count(*)": 24, "txt": "This is a test"}], 237 ) 238 239 query = urllib.parse.quote("select count(*), text from txtai group by text order by count(*) asc") 240 self.assertEqual( 241 self.client.get(f"search?query={query}").json(), 242 [{"count(*)": 24, "text": "This is a test"}, {"count(*)": 28, "text": "And another test"}], 243 ) 244 245 query = urllib.parse.quote("select count(*) from txtai group by id order by count(*)") 246 self.assertEqual(self.client.get(f"search?query={query}").json(), [{"count(*)": 52}]) 247 248 def testUpsert(self): 249 """ 250 Test cluster upsert 251 """ 252 253 # Update data 254 self.client.post("add", json=[{"id": 4, "text": "Feel good story: baby panda born"}]) 255 self.client.get("upsert") 256 257 # Search for best match 258 query = "feel good story" 259 uid = self.client.get(f"search?query={query}&limit=1").json()[0]["id"] 260 261 self.assertEqual(uid, 4)