/ src / python / txtai / api / cluster.py
cluster.py
  1  """
  2  Cluster module
  3  """
  4  
  5  import asyncio
  6  import json
  7  import random
  8  import urllib.parse
  9  import zlib
 10  
 11  import aiohttp
 12  
 13  from ..database.sql import Aggregate
 14  
 15  
 16  class Cluster:
 17      """
 18      Aggregates multiple embeddings shards into a single logical embeddings instance.
 19      """
 20  
 21      # pylint: disable=W0231
 22      def __init__(self, config=None):
 23          """
 24          Creates a new Cluster.
 25  
 26          Args:
 27              config: cluster configuration
 28          """
 29  
 30          # Configuration
 31          self.config = config
 32  
 33          # Embeddings shard urls
 34          self.shards = None
 35          if "shards" in self.config:
 36              self.shards = self.config["shards"]
 37  
 38          # Query aggregator
 39          self.aggregate = Aggregate()
 40  
 41      def search(self, query, limit=None, weights=None, index=None, parameters=None, graph=False):
 42          """
 43          Finds documents most similar to the input query. This method will run either an index search
 44          or an index + database search depending on if a database is available.
 45  
 46          Args:
 47              query: input query
 48              limit: maximum results
 49              weights: hybrid score weights, if applicable
 50              index: index name, if applicable
 51              parameters: dict of named parameters to bind to placeholders
 52              graph: return graph results if True
 53  
 54          Returns:
 55              list of {id: value, score: value} for index search, list of dict for an index + database search
 56          """
 57  
 58          # Build URL
 59          action = f"search?query={urllib.parse.quote_plus(query)}"
 60          if limit:
 61              action += f"&limit={limit}"
 62          if weights:
 63              action += f"&weights={weights}"
 64          if index:
 65              action += f"&index={index}"
 66          if parameters:
 67              action += f"&parameters={json.dumps(parameters) if isinstance(parameters, dict) else parameters}"
 68          if graph is not None:
 69              action += f"&graph={graph}"
 70  
 71          # Run query and flatten results into single results list
 72          results = []
 73          for result in self.execute("get", action):
 74              results.extend(result)
 75  
 76          # Combine aggregate functions and sort
 77          results = self.aggregate(query, results)
 78  
 79          # Limit results
 80          return results[: (limit if limit else 10)]
 81  
 82      def batchsearch(self, queries, limit=None, weights=None, index=None, parameters=None, graph=False):
 83          """
 84          Finds documents most similar to the input queries. This method will run either an index search
 85          or an index + database search depending on if a database is available.
 86  
 87          Args:
 88              queries: input queries
 89              limit: maximum results
 90              weights: hybrid score weights, if applicable
 91              index: index name, if applicable
 92              parameters: list of dicts of named parameters to bind to placeholders
 93              graph: return graph results if True
 94  
 95          Returns:
 96              list of {id: value, score: value} per query for index search, list of dict per query for an index + database search
 97          """
 98  
 99          # POST parameters
100          params = {"queries": queries}
101          if limit:
102              params["limit"] = limit
103          if weights:
104              params["weights"] = weights
105          if index:
106              params["index"] = index
107          if parameters:
108              params["parameters"] = parameters
109          if graph is not None:
110              params["graph"] = graph
111  
112          # Run query
113          batch = self.execute("post", "batchsearch", [params] * len(self.shards))
114  
115          # Combine results per query
116          results = []
117          for x, query in enumerate(queries):
118              result = []
119              for section in batch:
120                  result.extend(section[x])
121  
122              # Aggregate, sort and limit results
123              results.append(self.aggregate(query, result)[: (limit if limit else 10)])
124  
125          return results
126  
127      def add(self, documents):
128          """
129          Adds a batch of documents for indexing.
130  
131          Args:
132              documents: list of {id: value, text: value}
133          """
134  
135          self.execute("post", "add", self.shard(documents))
136  
137      def index(self):
138          """
139          Builds an embeddings index for previously batched documents.
140          """
141  
142          self.execute("get", "index")
143  
144      def upsert(self):
145          """
146          Runs an embeddings upsert operation for previously batched documents.
147          """
148  
149          self.execute("get", "upsert")
150  
151      def delete(self, ids):
152          """
153          Deletes from an embeddings cluster. Returns list of ids deleted.
154  
155          Args:
156              ids: list of ids to delete
157  
158          Returns:
159              ids deleted
160          """
161  
162          return [uid for ids in self.execute("post", "delete", [ids] * len(self.shards)) for uid in ids]
163  
164      def reindex(self, config, function=None):
165          """
166          Recreates this embeddings index using config. This method only works if document content storage is enabled.
167  
168          Args:
169              config: new config
170              function: optional function to prepare content for indexing
171          """
172  
173          self.execute("post", "reindex", [{"config": config, "function": function}] * len(self.shards))
174  
175      def count(self):
176          """
177          Total number of elements in this embeddings cluster.
178  
179          Returns:
180              number of elements in embeddings cluster
181          """
182  
183          return sum(self.execute("get", "count"))
184  
185      def shard(self, documents):
186          """
187          Splits documents into equal sized shards.
188  
189          Args:
190              documents: input documents
191  
192          Returns:
193              list of evenly sized shards with the last shard having the remaining elements
194          """
195  
196          shards = [[] for _ in range(len(self.shards))]
197          for document in documents:
198              uid = document.get("id") if isinstance(document, dict) else document
199              if uid and isinstance(uid, str):
200                  # Quick int hash of string to help derive shard id
201                  uid = zlib.adler32(uid.encode("utf-8"))
202              elif uid is None:
203                  # Get random shard id when uid isn't set
204                  uid = random.randint(0, len(shards) - 1)
205  
206              shards[uid % len(self.shards)].append(document)
207  
208          return shards
209  
210      def execute(self, method, action, data=None):
211          """
212          Executes a HTTP action asynchronously.
213  
214          Args:
215              method: get or post
216              action: url action to perform
217              data: post parameters
218  
219          Returns:
220              json results if any
221          """
222  
223          # Get urls
224          urls = [f"{shard}/{action}" for shard in self.shards]
225          close = False
226  
227          # Use existing loop if available, otherwise create one
228          try:
229              loop = asyncio.get_event_loop()
230          except RuntimeError:
231              loop = asyncio.new_event_loop()
232              close = True
233  
234          try:
235              return loop.run_until_complete(self.run(urls, method, data))
236          finally:
237              # Close loop if it was created in this method
238              if close:
239                  loop.close()
240  
241      async def run(self, urls, method, data):
242          """
243          Runs an async action.
244  
245          Args:
246              urls: run against this list of urls
247              method: get or post
248              data: list of data for each url or None
249  
250          Returns:
251              json results if any
252          """
253  
254          async with aiohttp.ClientSession(raise_for_status=True) as session:
255              tasks = []
256  
257              for x, url in enumerate(urls):
258                  if method == "post":
259                      if not data or data[x]:
260                          tasks.append(asyncio.ensure_future(self.post(session, url, data[x] if data else None)))
261                  else:
262                      tasks.append(asyncio.ensure_future(self.get(session, url)))
263  
264              return await asyncio.gather(*tasks)
265  
266      async def get(self, session, url):
267          """
268          Runs an async HTTP GET request.
269  
270          Args:
271              session: ClientSession
272              url: url
273  
274          Returns:
275              json results if any
276          """
277  
278          async with session.get(url) as resp:
279              return await resp.json()
280  
281      async def post(self, session, url, data):
282          """
283          Runs an async HTTP POST request.
284  
285          Args:
286              session: ClientSession
287              url: url
288              data: data to POST
289  
290          Returns:
291              json results if any
292          """
293  
294          async with session.post(url, json=data) as resp:
295              return await resp.json()