base.py
1 """ 2 Application module 3 """ 4 5 import os 6 7 from multiprocessing.pool import ThreadPool 8 from threading import RLock 9 10 import yaml 11 12 from ..agent import Agent 13 from ..embeddings import Documents, Embeddings 14 from ..pipeline import PipelineFactory 15 from ..workflow import WorkflowFactory 16 17 18 # pylint: disable=R0904 19 class Application: 20 """ 21 Builds YAML-configured txtai applications. 22 """ 23 24 @staticmethod 25 def read(data): 26 """ 27 Reads a YAML configuration file. 28 29 Args: 30 data: input data 31 32 Returns: 33 yaml 34 """ 35 36 if isinstance(data, str): 37 if os.path.exists(data): 38 # Read yaml from file 39 with open(data, "r", encoding="utf-8") as f: 40 # Read configuration 41 return yaml.safe_load(f) 42 43 # Attempt to read yaml from input 44 data = yaml.safe_load(data) 45 if not isinstance(data, str): 46 return data 47 48 # File not found and input is not yaml, raise error 49 raise FileNotFoundError(f"Unable to load file '{data}'") 50 51 # Return unmodified 52 return data 53 54 def __init__(self, config, loaddata=True): 55 """ 56 Creates an Application instance, which encapsulates embeddings, pipelines and workflows. 57 58 Args: 59 config: index configuration 60 loaddata: If True (default), load existing index data, if available. Otherwise, only load models. 61 """ 62 63 # Initialize member variables 64 self.config, self.documents, self.embeddings = Application.read(config), None, None 65 66 # Write lock - allows only a single thread to update embeddings 67 self.lock = RLock() 68 69 # ThreadPool - runs scheduled workflows 70 self.pool = None 71 72 # Create pipelines 73 self.createpipelines() 74 75 # Create workflows 76 self.createworkflows() 77 78 # Create agents 79 self.createagents() 80 81 # Create embeddings index 82 self.indexes(loaddata) 83 84 def __del__(self): 85 """ 86 Close threadpool when this object is garbage collected. 87 """ 88 89 if hasattr(self, "pool") and self.pool: 90 self.pool.close() 91 self.pool = None 92 93 def createpipelines(self): 94 """ 95 Create pipelines. 96 """ 97 98 # Pipeline definitions 99 self.pipelines = {} 100 101 # Default pipelines 102 pipelines = list(PipelineFactory.list().keys()) 103 104 # Add custom pipelines 105 for key in self.config: 106 if "." in key: 107 pipelines.append(key) 108 109 # Move dependent pipelines to end of list 110 dependent = ["similarity", "extractor", "rag", "reranker"] 111 pipelines = sorted(pipelines, key=lambda x: dependent.index(x) + 1 if x in dependent else 0) 112 113 # Create pipelines 114 for pipeline in pipelines: 115 if pipeline in self.config: 116 config = self.config[pipeline] if self.config[pipeline] else {} 117 118 # Add application reference, if requested 119 if "application" in config: 120 config["application"] = self 121 122 # Custom pipeline parameters 123 if pipeline in ["extractor", "rag"]: 124 if "similarity" not in config: 125 # Add placeholder, will be set to embeddings index once initialized 126 config["similarity"] = None 127 128 # Resolve reference pipelines 129 if config.get("similarity") in self.pipelines: 130 config["similarity"] = self.pipelines[config["similarity"]] 131 132 if config.get("path") in self.pipelines: 133 config["path"] = self.pipelines[config["path"]] 134 135 elif pipeline == "similarity" and "path" not in config and "labels" in self.pipelines: 136 config["model"] = self.pipelines["labels"] 137 138 elif pipeline == "reranker": 139 config["embeddings"] = None 140 config["similarity"] = self.pipelines["similarity"] 141 142 elif pipeline == "textractor": 143 # Default to safeopen enabled 144 config["safeopen"] = config.get("safeopen", True) 145 146 self.pipelines[pipeline] = PipelineFactory.create(config, pipeline) 147 148 def createworkflows(self): 149 """ 150 Create workflows. 151 """ 152 153 # Workflow definitions 154 self.workflows = {} 155 156 # Create workflows 157 if "workflow" in self.config: 158 for workflow, config in self.config["workflow"].items(): 159 # Create copy of config 160 config = config.copy() 161 162 # Resolve callable functions 163 config["tasks"] = [self.resolvetask(task) for task in config["tasks"]] 164 165 # Resolve stream functions 166 if "stream" in config: 167 config["stream"] = self.resolvetask(config["stream"]) 168 169 # Get scheduler config 170 schedule = config.pop("schedule", None) 171 172 # Create workflow 173 self.workflows[workflow] = WorkflowFactory.create(config, workflow) 174 175 # Schedule job if necessary 176 if schedule: 177 # Create pool if necessary 178 if not self.pool: 179 self.pool = ThreadPool() 180 181 self.pool.apply_async(self.workflows[workflow].schedule, kwds=schedule) 182 183 def createagents(self): 184 """ 185 Create agents. 186 """ 187 188 # Agent definitions 189 self.agents = {} 190 191 # Create agents 192 if "agent" in self.config: 193 for agent, config in self.config["agent"].items(): 194 # Create copy of config 195 config = config.copy() 196 197 # Resolve LLM 198 config["llm"] = self.function("llm") 199 200 # Resolve tools 201 for tool in config.get("tools", []): 202 if isinstance(tool, dict) and "target" in tool: 203 tool["target"] = self.function(tool["target"]) 204 205 # Create agent 206 self.agents[agent] = Agent(**config) 207 208 def indexes(self, loaddata): 209 """ 210 Initialize an embeddings index. 211 212 Args: 213 loaddata: If True (default), load existing index data, if available. Otherwise, only load models. 214 """ 215 216 # Get embeddings configuration 217 config = self.config.get("embeddings") 218 if config: 219 # Resolve application functions in embeddings config 220 config = self.resolveconfig(config.copy()) 221 222 # Load embeddings index if loaddata and index exists 223 if loaddata and Embeddings().exists(self.config.get("path"), self.config.get("cloud")): 224 # Initialize empty embeddings 225 self.embeddings = Embeddings() 226 227 # Pass path and cloud settings. Set application functions as config overrides. 228 self.embeddings.load( 229 self.config.get("path"), 230 self.config.get("cloud"), 231 {key: config[key] for key in ["functions", "transform"] if key in config} if config else None, 232 ) 233 234 elif "embeddings" in self.config: 235 # Create new embeddings with config 236 self.embeddings = Embeddings(config) 237 238 # If an extractor pipeline is defined and the similarity attribute is None, set to embeddings index 239 for key in ["extractor", "rag"]: 240 pipeline = self.pipelines.get(key) 241 config = self.config.get(key) 242 243 if pipeline and config is not None and config["similarity"] is None: 244 pipeline.similarity = self.embeddings 245 246 # Attach embeddings to reranker 247 if "reranker" in self.pipelines: 248 self.pipelines["reranker"].embeddings = self.embeddings 249 250 def resolvetask(self, task): 251 """ 252 Resolves callable functions for a task. 253 254 Args: 255 task: input task config 256 """ 257 258 # Check for task shorthand syntax 259 task = {"action": task} if isinstance(task, (str, list)) else task 260 261 if "action" in task: 262 action = task["action"] 263 values = [action] if not isinstance(action, list) else action 264 265 actions = [] 266 for a in values: 267 if a in ["index", "upsert"]: 268 # Add queue action to buffer documents to index 269 actions.append(self.add) 270 271 # Override and disable unpacking for indexing actions 272 task["unpack"] = False 273 274 # Add finalize to trigger indexing 275 task["finalize"] = self.upsert if a == "upsert" else self.index 276 elif a == "search": 277 actions.append(self.batchsearch) 278 elif a == "transform": 279 # Transform vectors 280 actions.append(self.batchtransform) 281 282 # Override and disable one-to-many transformations 283 task["onetomany"] = False 284 else: 285 # Resolve action to callable function 286 actions.append(self.function(a)) 287 288 # Save resolved action(s) 289 task["action"] = actions[0] if not isinstance(action, list) else actions 290 291 # Resolve initializer 292 if "initialize" in task and isinstance(task["initialize"], str): 293 task["initialize"] = self.function(task["initialize"]) 294 295 # Resolve finalizer 296 if "finalize" in task and isinstance(task["finalize"], str): 297 task["finalize"] = self.function(task["finalize"]) 298 299 return task 300 301 def resolveconfig(self, config): 302 """ 303 Resolves callable functions stored in embeddings configuration. 304 305 Args: 306 config: embeddings config 307 308 Returns: 309 resolved config 310 """ 311 312 if "functions" in config: 313 # Resolve callable functions 314 functions = [] 315 for fn in config["functions"]: 316 original = fn 317 try: 318 if isinstance(fn, dict): 319 fn = fn.copy() 320 fn["function"] = self.function(fn["function"]) 321 else: 322 fn = self.function(fn) 323 324 # pylint: disable=W0703 325 except Exception: 326 # Not a resolvable function, pipeline or workflow - further resolution will happen in embeddings 327 fn = original 328 329 functions.append(fn) 330 331 config["functions"] = functions 332 333 if "transform" in config: 334 # Resolve transform function 335 config["transform"] = self.function(config["transform"]) 336 337 return config 338 339 def function(self, function): 340 """ 341 Get a handle to a callable function. 342 343 Args: 344 function: function name 345 346 Returns: 347 resolved function 348 """ 349 350 # Check if function is a pipeline 351 if function in self.pipelines: 352 return self.pipelines[function] 353 354 # Check if function is a workflow 355 if function in self.workflows: 356 return self.workflows[function] 357 358 # Attempt to resolve action as a callable function 359 return PipelineFactory.create({}, function) 360 361 def search(self, query, limit=10, weights=None, index=None, parameters=None, graph=False): 362 """ 363 Finds documents most similar to the input query. This method will run either an index search 364 or an index + database search depending on if a database is available. 365 366 Args: 367 query: input query 368 limit: maximum results 369 weights: hybrid score weights, if applicable 370 index: index name, if applicable 371 parameters: dict of named parameters to bind to placeholders 372 graph: return graph results if True 373 374 Returns: 375 list of {id: value, score: value} for index search, list of dict for an index + database search 376 """ 377 378 if self.embeddings: 379 with self.lock: 380 results = self.embeddings.search(query, limit, weights, index, parameters, graph) 381 382 # Unpack (id, score) tuple, if necessary. Otherwise, results are dictionaries. 383 return results if graph else [{"id": r[0], "score": float(r[1])} if isinstance(r, tuple) else r for r in results] 384 385 return None 386 387 def batchsearch(self, queries, limit=10, weights=None, index=None, parameters=None, graph=False): 388 """ 389 Finds documents most similar to the input queries. This method will run either an index search 390 or an index + database search depending on if a database is available. 391 392 Args: 393 queries: input queries 394 limit: maximum results 395 weights: hybrid score weights, if applicable 396 index: index name, if applicable 397 parameters: list of dicts of named parameters to bind to placeholders 398 graph: return graph results if True 399 400 Returns: 401 list of {id: value, score: value} per query for index search, list of dict per query for an index + database search 402 """ 403 404 if self.embeddings: 405 with self.lock: 406 search = self.embeddings.batchsearch(queries, limit, weights, index, parameters, graph) 407 408 results = [] 409 for result in search: 410 # Unpack (id, score) tuple, if necessary. Otherwise, results are dictionaries. 411 results.append(result if graph else [{"id": r[0], "score": float(r[1])} if isinstance(r, tuple) else r for r in result]) 412 return results 413 414 return None 415 416 def add(self, documents): 417 """ 418 Adds a batch of documents for indexing. 419 420 Args: 421 documents: list of {id: value, data: value, tags: value} 422 423 Returns: 424 unmodified input documents 425 """ 426 427 # Raise error if index is not writable 428 if not self.config.get("writable"): 429 raise ReadOnlyError("Attempting to add documents to a read-only index (writable != True)") 430 431 if self.embeddings: 432 with self.lock: 433 # Create documents file if not already open 434 if not self.documents: 435 self.documents = Documents() 436 437 # Add documents 438 self.documents.add(list(documents)) 439 440 # Return unmodified input documents 441 return documents 442 443 def addobject(self, data, uid, field): 444 """ 445 Helper method that builds a batch of object documents. 446 447 Args: 448 data: object content 449 uid: optional list of corresponding uids 450 field: optional field to set 451 452 Returns: 453 documents 454 """ 455 456 # Raise error if index is not writable 457 if not self.config.get("writable"): 458 raise ReadOnlyError("Attempting to add documents to a read-only index (writable != True)") 459 460 documents = [] 461 for x, content in enumerate(data): 462 if field: 463 row = {"id": uid[x], field: content} if uid else {field: content} 464 elif uid: 465 row = (uid[x], content) 466 else: 467 row = content 468 469 documents.append(row) 470 471 return self.add(documents) 472 473 def index(self): 474 """ 475 Builds an embeddings index for previously batched documents. 476 """ 477 478 # Raise error if index is not writable 479 if not self.config.get("writable"): 480 raise ReadOnlyError("Attempting to index a read-only index (writable != True)") 481 482 if self.embeddings and self.documents: 483 with self.lock: 484 # Reset index 485 self.indexes(False) 486 487 # Build scoring index if term weighting is enabled 488 if self.embeddings.isweighted(): 489 self.embeddings.score(self.documents) 490 491 # Build embeddings index 492 self.embeddings.index(self.documents) 493 494 # Save index if path available, otherwise this is an memory-only index 495 if self.config.get("path"): 496 self.embeddings.save(self.config["path"], self.config.get("cloud")) 497 498 # Reset document stream 499 self.documents.close() 500 self.documents = None 501 502 def upsert(self): 503 """ 504 Runs an embeddings upsert operation for previously batched documents. 505 """ 506 507 # Raise error if index is not writable 508 if not self.config.get("writable"): 509 raise ReadOnlyError("Attempting to upsert a read-only index (writable != True)") 510 511 if self.embeddings and self.documents: 512 with self.lock: 513 # Run upsert 514 self.embeddings.upsert(self.documents) 515 516 # Save index if path available, otherwise this is an memory-only index 517 if self.config.get("path"): 518 self.embeddings.save(self.config["path"], self.config.get("cloud")) 519 520 # Reset document stream 521 self.documents.close() 522 self.documents = None 523 524 def delete(self, ids): 525 """ 526 Deletes from an embeddings index. Returns list of ids deleted. 527 528 Args: 529 ids: list of ids to delete 530 531 Returns: 532 ids deleted 533 """ 534 535 # Raise error if index is not writable 536 if not self.config.get("writable"): 537 raise ReadOnlyError("Attempting to delete from a read-only index (writable != True)") 538 539 if self.embeddings: 540 with self.lock: 541 # Run delete operation 542 deleted = self.embeddings.delete(ids) 543 544 # Save index if path available, otherwise this is an memory-only index 545 if self.config.get("path"): 546 self.embeddings.save(self.config["path"], self.config.get("cloud")) 547 548 # Return deleted ids 549 return deleted 550 551 return None 552 553 def reindex(self, config, function=None): 554 """ 555 Recreates embeddings index using config. This method only works if document content storage is enabled. 556 557 Args: 558 config: new config 559 function: optional function to prepare content for indexing 560 """ 561 562 # Raise error if index is not writable 563 if not self.config.get("writable"): 564 raise ReadOnlyError("Attempting to reindex a read-only index (writable != True)") 565 566 if self.embeddings: 567 with self.lock: 568 # Resolve function, if necessary 569 function = self.function(function) if function and isinstance(function, str) else function 570 571 # Reindex 572 self.embeddings.reindex(config, function) 573 574 # Save index if path available, otherwise this is an memory-only index 575 if self.config.get("path"): 576 self.embeddings.save(self.config["path"], self.config.get("cloud")) 577 578 def count(self): 579 """ 580 Total number of elements in this embeddings index. 581 582 Returns: 583 number of elements in embeddings index 584 """ 585 586 if self.embeddings: 587 return self.embeddings.count() 588 589 return None 590 591 def similarity(self, query, texts): 592 """ 593 Computes the similarity between query and list of text. Returns a list of 594 {id: value, score: value} sorted by highest score, where id is the index 595 in texts. 596 597 Args: 598 query: query text 599 texts: list of text 600 601 Returns: 602 list of {id: value, score: value} 603 """ 604 605 # Use similarity instance if available otherwise fall back to embeddings model 606 if "similarity" in self.pipelines: 607 return [{"id": uid, "score": float(score)} for uid, score in self.pipelines["similarity"](query, texts)] 608 if self.embeddings: 609 return [{"id": uid, "score": float(score)} for uid, score in self.embeddings.similarity(query, texts)] 610 611 return None 612 613 def batchsimilarity(self, queries, texts): 614 """ 615 Computes the similarity between list of queries and list of text. Returns a list 616 of {id: value, score: value} sorted by highest score per query, where id is the 617 index in texts. 618 619 Args: 620 queries: queries text 621 texts: list of text 622 623 Returns: 624 list of {id: value, score: value} per query 625 """ 626 627 # Use similarity instance if available otherwise fall back to embeddings model 628 if "similarity" in self.pipelines: 629 return [[{"id": uid, "score": float(score)} for uid, score in r] for r in self.pipelines["similarity"](queries, texts)] 630 if self.embeddings: 631 return [[{"id": uid, "score": float(score)} for uid, score in r] for r in self.embeddings.batchsimilarity(queries, texts)] 632 633 return None 634 635 def explain(self, query, texts=None, limit=10): 636 """ 637 Explains the importance of each input token in text for a query. 638 639 Args: 640 query: query text 641 texts: optional list of text, otherwise runs search query 642 limit: optional limit if texts is None 643 644 Returns: 645 list of dict per input text where a higher token scores represents higher importance relative to the query 646 """ 647 648 if self.embeddings: 649 with self.lock: 650 return self.embeddings.explain(query, texts, limit) 651 652 return None 653 654 def batchexplain(self, queries, texts=None, limit=10): 655 """ 656 Explains the importance of each input token in text for a list of queries. 657 658 Args: 659 query: queries text 660 texts: optional list of text, otherwise runs search queries 661 limit: optional limit if texts is None 662 663 Returns: 664 list of dict per input text per query where a higher token scores represents higher importance relative to the query 665 """ 666 667 if self.embeddings: 668 with self.lock: 669 return self.embeddings.batchexplain(queries, texts, limit) 670 671 return None 672 673 def transform(self, text, category=None, index=None): 674 """ 675 Transforms text into embeddings arrays. 676 677 Args: 678 text: input text 679 category: category for instruction-based embeddings 680 index: index name, if applicable 681 682 Returns: 683 embeddings array 684 """ 685 686 if self.embeddings: 687 return [float(x) for x in self.embeddings.transform(text, category, index)] 688 689 return None 690 691 def batchtransform(self, texts, category=None, index=None): 692 """ 693 Transforms list of text into embeddings arrays. 694 695 Args: 696 texts: list of text 697 category: category for instruction-based embeddings 698 index: index name, if applicable 699 700 Returns: 701 embeddings arrays 702 """ 703 704 if self.embeddings: 705 return [[float(x) for x in result] for result in self.embeddings.batchtransform(texts, category, index)] 706 707 return None 708 709 def extract(self, queue, texts=None): 710 """ 711 Extracts answers to input questions. 712 713 Args: 714 queue: list of {name: value, query: value, question: value, snippet: value} 715 texts: optional list of text 716 717 Returns: 718 list of {name: value, answer: value} 719 """ 720 721 if self.embeddings and "extractor" in self.pipelines: 722 # Get extractor instance 723 extractor = self.pipelines["extractor"] 724 725 # Run extractor and return results as dicts 726 return extractor(queue, texts) 727 728 return None 729 730 def label(self, text, labels): 731 """ 732 Applies a zero shot classifier to text using a list of labels. Returns a list of 733 {id: value, score: value} sorted by highest score, where id is the index in labels. 734 735 Args: 736 text: text|list 737 labels: list of labels 738 739 Returns: 740 list of {id: value, score: value} per text element 741 """ 742 743 if "labels" in self.pipelines: 744 # Text is a string 745 if isinstance(text, str): 746 return [{"id": uid, "score": float(score)} for uid, score in self.pipelines["labels"](text, labels)] 747 748 # Text is a list 749 return [[{"id": uid, "score": float(score)} for uid, score in result] for result in self.pipelines["labels"](text, labels)] 750 751 return None 752 753 def pipeline(self, name, *args, **kwargs): 754 """ 755 Generic pipeline execution method. 756 757 Args: 758 name: pipeline name 759 args: pipeline positional arguments 760 kwargs: pipeline keyword arguments 761 762 Returns: 763 pipeline results 764 """ 765 766 # Backwards compatible with previous pipeline function arguments 767 args = args[0] if args and len(args) == 1 and isinstance(args[0], tuple) else args 768 769 if name in self.pipelines: 770 return self.pipelines[name](*args, **kwargs) 771 772 return None 773 774 def workflow(self, name, elements): 775 """ 776 Executes a workflow. 777 778 Args: 779 name: workflow name 780 elements: elements to process 781 782 Returns: 783 processed elements 784 """ 785 786 if hasattr(elements, "__len__") and hasattr(elements, "__getitem__"): 787 # Convert to tuples and return as a list since input is sized 788 elements = [tuple(element) if isinstance(element, list) else element for element in elements] 789 else: 790 # Convert to tuples and return as a generator since input is not sized 791 elements = (tuple(element) if isinstance(element, list) else element for element in elements) 792 793 # Execute workflow 794 return self.workflows[name](elements) 795 796 def agent(self, name, *args, **kwargs): 797 """ 798 Executes an agent. 799 800 Args: 801 name: agent name 802 args: agent positional arguments 803 kwargs: agent keyword arguments 804 """ 805 806 if name in self.agents: 807 return self.agents[name](*args, **kwargs) 808 809 return None 810 811 def wait(self): 812 """ 813 Closes threadpool and waits for completion. 814 """ 815 816 if self.pool: 817 self.pool.close() 818 self.pool.join() 819 self.pool = None 820 821 822 class ReadOnlyError(Exception): 823 """ 824 Error raised when trying to modify a read-only index 825 """