cluster.py
1 """ 2 Cluster module 3 """ 4 5 import asyncio 6 import json 7 import random 8 import urllib.parse 9 import zlib 10 11 import aiohttp 12 13 from ..database.sql import Aggregate 14 15 16 class Cluster: 17 """ 18 Aggregates multiple embeddings shards into a single logical embeddings instance. 19 """ 20 21 # pylint: disable=W0231 22 def __init__(self, config=None): 23 """ 24 Creates a new Cluster. 25 26 Args: 27 config: cluster configuration 28 """ 29 30 # Configuration 31 self.config = config 32 33 # Embeddings shard urls 34 self.shards = None 35 if "shards" in self.config: 36 self.shards = self.config["shards"] 37 38 # Query aggregator 39 self.aggregate = Aggregate() 40 41 def search(self, query, limit=None, weights=None, index=None, parameters=None, graph=False): 42 """ 43 Finds documents most similar to the input query. This method will run either an index search 44 or an index + database search depending on if a database is available. 45 46 Args: 47 query: input query 48 limit: maximum results 49 weights: hybrid score weights, if applicable 50 index: index name, if applicable 51 parameters: dict of named parameters to bind to placeholders 52 graph: return graph results if True 53 54 Returns: 55 list of {id: value, score: value} for index search, list of dict for an index + database search 56 """ 57 58 # Build URL 59 action = f"search?query={urllib.parse.quote_plus(query)}" 60 if limit: 61 action += f"&limit={limit}" 62 if weights: 63 action += f"&weights={weights}" 64 if index: 65 action += f"&index={index}" 66 if parameters: 67 action += f"¶meters={json.dumps(parameters) if isinstance(parameters, dict) else parameters}" 68 if graph is not None: 69 action += f"&graph={graph}" 70 71 # Run query and flatten results into single results list 72 results = [] 73 for result in self.execute("get", action): 74 results.extend(result) 75 76 # Combine aggregate functions and sort 77 results = self.aggregate(query, results) 78 79 # Limit results 80 return results[: (limit if limit else 10)] 81 82 def batchsearch(self, queries, limit=None, weights=None, index=None, parameters=None, graph=False): 83 """ 84 Finds documents most similar to the input queries. This method will run either an index search 85 or an index + database search depending on if a database is available. 86 87 Args: 88 queries: input queries 89 limit: maximum results 90 weights: hybrid score weights, if applicable 91 index: index name, if applicable 92 parameters: list of dicts of named parameters to bind to placeholders 93 graph: return graph results if True 94 95 Returns: 96 list of {id: value, score: value} per query for index search, list of dict per query for an index + database search 97 """ 98 99 # POST parameters 100 params = {"queries": queries} 101 if limit: 102 params["limit"] = limit 103 if weights: 104 params["weights"] = weights 105 if index: 106 params["index"] = index 107 if parameters: 108 params["parameters"] = parameters 109 if graph is not None: 110 params["graph"] = graph 111 112 # Run query 113 batch = self.execute("post", "batchsearch", [params] * len(self.shards)) 114 115 # Combine results per query 116 results = [] 117 for x, query in enumerate(queries): 118 result = [] 119 for section in batch: 120 result.extend(section[x]) 121 122 # Aggregate, sort and limit results 123 results.append(self.aggregate(query, result)[: (limit if limit else 10)]) 124 125 return results 126 127 def add(self, documents): 128 """ 129 Adds a batch of documents for indexing. 130 131 Args: 132 documents: list of {id: value, text: value} 133 """ 134 135 self.execute("post", "add", self.shard(documents)) 136 137 def index(self): 138 """ 139 Builds an embeddings index for previously batched documents. 140 """ 141 142 self.execute("get", "index") 143 144 def upsert(self): 145 """ 146 Runs an embeddings upsert operation for previously batched documents. 147 """ 148 149 self.execute("get", "upsert") 150 151 def delete(self, ids): 152 """ 153 Deletes from an embeddings cluster. Returns list of ids deleted. 154 155 Args: 156 ids: list of ids to delete 157 158 Returns: 159 ids deleted 160 """ 161 162 return [uid for ids in self.execute("post", "delete", [ids] * len(self.shards)) for uid in ids] 163 164 def reindex(self, config, function=None): 165 """ 166 Recreates this embeddings index using config. This method only works if document content storage is enabled. 167 168 Args: 169 config: new config 170 function: optional function to prepare content for indexing 171 """ 172 173 self.execute("post", "reindex", [{"config": config, "function": function}] * len(self.shards)) 174 175 def count(self): 176 """ 177 Total number of elements in this embeddings cluster. 178 179 Returns: 180 number of elements in embeddings cluster 181 """ 182 183 return sum(self.execute("get", "count")) 184 185 def shard(self, documents): 186 """ 187 Splits documents into equal sized shards. 188 189 Args: 190 documents: input documents 191 192 Returns: 193 list of evenly sized shards with the last shard having the remaining elements 194 """ 195 196 shards = [[] for _ in range(len(self.shards))] 197 for document in documents: 198 uid = document.get("id") if isinstance(document, dict) else document 199 if uid and isinstance(uid, str): 200 # Quick int hash of string to help derive shard id 201 uid = zlib.adler32(uid.encode("utf-8")) 202 elif uid is None: 203 # Get random shard id when uid isn't set 204 uid = random.randint(0, len(shards) - 1) 205 206 shards[uid % len(self.shards)].append(document) 207 208 return shards 209 210 def execute(self, method, action, data=None): 211 """ 212 Executes a HTTP action asynchronously. 213 214 Args: 215 method: get or post 216 action: url action to perform 217 data: post parameters 218 219 Returns: 220 json results if any 221 """ 222 223 # Get urls 224 urls = [f"{shard}/{action}" for shard in self.shards] 225 close = False 226 227 # Use existing loop if available, otherwise create one 228 try: 229 loop = asyncio.get_event_loop() 230 except RuntimeError: 231 loop = asyncio.new_event_loop() 232 close = True 233 234 try: 235 return loop.run_until_complete(self.run(urls, method, data)) 236 finally: 237 # Close loop if it was created in this method 238 if close: 239 loop.close() 240 241 async def run(self, urls, method, data): 242 """ 243 Runs an async action. 244 245 Args: 246 urls: run against this list of urls 247 method: get or post 248 data: list of data for each url or None 249 250 Returns: 251 json results if any 252 """ 253 254 async with aiohttp.ClientSession(raise_for_status=True) as session: 255 tasks = [] 256 257 for x, url in enumerate(urls): 258 if method == "post": 259 if not data or data[x]: 260 tasks.append(asyncio.ensure_future(self.post(session, url, data[x] if data else None))) 261 else: 262 tasks.append(asyncio.ensure_future(self.get(session, url))) 263 264 return await asyncio.gather(*tasks) 265 266 async def get(self, session, url): 267 """ 268 Runs an async HTTP GET request. 269 270 Args: 271 session: ClientSession 272 url: url 273 274 Returns: 275 json results if any 276 """ 277 278 async with session.get(url) as resp: 279 return await resp.json() 280 281 async def post(self, session, url, data): 282 """ 283 Runs an async HTTP POST request. 284 285 Args: 286 session: ClientSession 287 url: url 288 data: data to POST 289 290 Returns: 291 json results if any 292 """ 293 294 async with session.post(url, json=data) as resp: 295 return await resp.json()