/ test / python / testapi / testcluster.py
testcluster.py
  1  """
  2  Cluster API module tests
  3  """
  4  
  5  import json
  6  import os
  7  import tempfile
  8  import unittest
  9  import urllib.parse
 10  
 11  from http.server import HTTPServer, BaseHTTPRequestHandler
 12  from threading import Thread
 13  from unittest.mock import patch
 14  
 15  from fastapi.testclient import TestClient
 16  
 17  from txtai.api import application
 18  
 19  # Configuration for an embeddings cluster
 20  CLUSTER = """
 21  cluster:
 22      shards:
 23          - http://127.0.0.1:8002
 24          - http://127.0.0.1:8003
 25  """
 26  
 27  
 28  class RequestHandler(BaseHTTPRequestHandler):
 29      """
 30      Test HTTP handler.
 31      """
 32  
 33      def do_GET(self):
 34          """
 35          GET request handler.
 36          """
 37  
 38          if self.path == "/count":
 39              response = 26
 40          elif self.path.startswith("/search?query=select"):
 41              if "group+by+id" in self.path:
 42                  response = [{"count(*)": 26}]
 43              elif "group+by+text" in self.path:
 44                  response = [{"count(*)": 12, "text": "This is a test"}, {"count(*)": 14, "text": "And another test"}]
 45              elif "group+by+txt" in self.path:
 46                  response = [{"count(*)": 12, "txt": "This is a test"}, {"count(*)": 14, "txt": "And another test"}]
 47              else:
 48                  if self.server.server_port == 8002:
 49                      response = [{"count(*)": 12, "min(indexid)": 0, "max(indexid)": 11, "avg(indexid)": 6.3}]
 50                  else:
 51                      response = [{"count(*)": 16, "min(indexid)": 2, "max(indexid)": 14, "avg(indexid)": 6.7}]
 52          elif self.path.startswith("/search"):
 53              response = [{"id": 4, "score": 0.40}]
 54          else:
 55              response = {"result": "ok"}
 56  
 57          # Convert response to string
 58          response = json.dumps(response).encode("utf-8")
 59  
 60          self.send_response(200)
 61          self.send_header("content-type", "application/json")
 62          self.send_header("content-length", len(response))
 63          self.end_headers()
 64  
 65          self.wfile.write(response)
 66          self.wfile.flush()
 67  
 68      def do_POST(self):
 69          """
 70          POST request handler.
 71          """
 72  
 73          if self.path.startswith("/batchsearch"):
 74              response = [[{"id": 4, "score": 0.40}], [{"id": 1, "score": 0.40}]]
 75          elif self.path.startswith("/delete"):
 76              if self.server.server_port == 8002:
 77                  response = [0]
 78              else:
 79                  response = []
 80          else:
 81              response = {"result": "ok"}
 82  
 83          response = json.dumps(response).encode("utf-8")
 84  
 85          self.send_response(200)
 86          self.send_header("content-type", "application/json")
 87          self.send_header("content-length", len(response))
 88          self.end_headers()
 89  
 90          self.wfile.write(response)
 91          self.wfile.flush()
 92  
 93  
 94  @unittest.skipIf(os.name == "nt", "TestCluster skipped on Windows")
 95  class TestCluster(unittest.TestCase):
 96      """
 97      API tests for embeddings clusters
 98      """
 99  
100      @staticmethod
101      @patch.dict(os.environ, {"CONFIG": os.path.join(tempfile.gettempdir(), "testapi.yml"), "API_CLASS": "txtai.api.API"})
102      def start():
103          """
104          Starts a mock FastAPI client.
105          """
106  
107          config = os.path.join(tempfile.gettempdir(), "testapi.yml")
108  
109          with open(config, "w", encoding="utf-8") as output:
110              output.write(CLUSTER)
111  
112          # Create new application and set on client
113          application.app = application.create()
114          client = TestClient(application.app)
115          application.start()
116  
117          return client
118  
119      @classmethod
120      def setUpClass(cls):
121          """
122          Create API client on creation of class.
123          """
124  
125          cls.client = TestCluster.start()
126  
127          cls.httpd1 = HTTPServer(("127.0.0.1", 8002), RequestHandler)
128  
129          server1 = Thread(target=cls.httpd1.serve_forever, daemon=True)
130          server1.start()
131  
132          cls.httpd2 = HTTPServer(("127.0.0.1", 8003), RequestHandler)
133  
134          server2 = Thread(target=cls.httpd2.serve_forever, daemon=True)
135          server2.start()
136  
137          # Index data
138          cls.client.post("add", json=[{"id": 0, "text": "test"}])
139          cls.client.get("index")
140  
141      @classmethod
142      def tearDownClass(cls):
143          """
144          Shutdown mock http server.
145          """
146  
147          cls.httpd1.shutdown()
148          cls.httpd2.shutdown()
149  
150      def testCount(self):
151          """
152          Test cluster count
153          """
154  
155          self.assertEqual(self.client.get("count").json(), 52)
156  
157      def testDelete(self):
158          """
159          Test cluster delete
160          """
161  
162          self.assertEqual(self.client.post("delete", json=[0]).json(), [0])
163  
164      def testDeleteString(self):
165          """
166          Test cluster delete with string id
167          """
168  
169          self.assertEqual(self.client.post("delete", json=["0"]).json(), [0])
170  
171      def testIds(self):
172          """
173          Test id configurations
174          """
175  
176          # String ids
177          self.client.post("add", json=[{"id": "0", "text": "test"}])
178          self.assertEqual(self.client.get("index").status_code, 200)
179  
180          # Auto ids
181          self.client.post("add", json=[{"text": "test"}])
182          self.assertEqual(self.client.get("index").status_code, 200)
183  
184      def testReindex(self):
185          """
186          Test cluster reindex
187          """
188  
189          self.assertEqual(self.client.post("reindex", json={"config": {"path": "sentence-transformers/nli-mpnet-base-v2"}}).status_code, 200)
190  
191      def testSearch(self):
192          """
193          Test cluster search
194          """
195  
196          # Encode parameters
197          params = json.dumps({"x": 1})
198  
199          query = urllib.parse.quote("feel good story")
200          uid = self.client.get(f"search?query={query}&limit=1&weights=0.5&index=default&parameters={params}&graph=False").json()[0]["id"]
201          self.assertEqual(uid, 4)
202  
203      def testSearchBatch(self):
204          """
205          Test cluster batch search
206          """
207  
208          results = self.client.post(
209              "batchsearch",
210              json={
211                  "queries": ["feel good story", "climate change"],
212                  "limit": 1,
213                  "weights": 0.5,
214                  "index": "default",
215                  "parameters": [{"x": 1}, {"x": 2}],
216                  "graph": False,
217              },
218          ).json()
219  
220          uids = [result[0]["id"] for result in results]
221          self.assertEqual(uids, [4, 1])
222  
223      def testSQL(self):
224          """
225          Test cluster SQL statement
226          """
227  
228          query = urllib.parse.quote("select count(*), min(indexid), max(indexid), avg(indexid) from txtai where text='This is a test'")
229          self.assertEqual(
230              self.client.get(f"search?query={query}").json(), [{"count(*)": 28, "min(indexid)": 0, "max(indexid)": 14, "avg(indexid)": 6.5}]
231          )
232  
233          query = urllib.parse.quote("select count(*), text txt from txtai group by txt order by count(*) desc")
234          self.assertEqual(
235              self.client.get(f"search?query={query}").json(),
236              [{"count(*)": 28, "txt": "And another test"}, {"count(*)": 24, "txt": "This is a test"}],
237          )
238  
239          query = urllib.parse.quote("select count(*), text from txtai group by text order by count(*) asc")
240          self.assertEqual(
241              self.client.get(f"search?query={query}").json(),
242              [{"count(*)": 24, "text": "This is a test"}, {"count(*)": 28, "text": "And another test"}],
243          )
244  
245          query = urllib.parse.quote("select count(*) from txtai group by id order by count(*)")
246          self.assertEqual(self.client.get(f"search?query={query}").json(), [{"count(*)": 52}])
247  
248      def testUpsert(self):
249          """
250          Test cluster upsert
251          """
252  
253          # Update data
254          self.client.post("add", json=[{"id": 4, "text": "Feel good story: baby panda born"}])
255          self.client.get("upsert")
256  
257          # Search for best match
258          query = "feel good story"
259          uid = self.client.get(f"search?query={query}&limit=1").json()[0]["id"]
260  
261          self.assertEqual(uid, 4)