/ src / python / txtai / app / base.py
base.py
  1  """
  2  Application module
  3  """
  4  
  5  import os
  6  
  7  from multiprocessing.pool import ThreadPool
  8  from threading import RLock
  9  
 10  import yaml
 11  
 12  from ..agent import Agent
 13  from ..embeddings import Documents, Embeddings
 14  from ..pipeline import PipelineFactory
 15  from ..workflow import WorkflowFactory
 16  
 17  
 18  # pylint: disable=R0904
 19  class Application:
 20      """
 21      Builds YAML-configured txtai applications.
 22      """
 23  
 24      @staticmethod
 25      def read(data):
 26          """
 27          Reads a YAML configuration file.
 28  
 29          Args:
 30              data: input data
 31  
 32          Returns:
 33              yaml
 34          """
 35  
 36          if isinstance(data, str):
 37              if os.path.exists(data):
 38                  # Read yaml from file
 39                  with open(data, "r", encoding="utf-8") as f:
 40                      # Read configuration
 41                      return yaml.safe_load(f)
 42  
 43              # Attempt to read yaml from input
 44              data = yaml.safe_load(data)
 45              if not isinstance(data, str):
 46                  return data
 47  
 48              # File not found and input is not yaml, raise error
 49              raise FileNotFoundError(f"Unable to load file '{data}'")
 50  
 51          # Return unmodified
 52          return data
 53  
 54      def __init__(self, config, loaddata=True):
 55          """
 56          Creates an Application instance, which encapsulates embeddings, pipelines and workflows.
 57  
 58          Args:
 59              config: index configuration
 60              loaddata: If True (default), load existing index data, if available. Otherwise, only load models.
 61          """
 62  
 63          # Initialize member variables
 64          self.config, self.documents, self.embeddings = Application.read(config), None, None
 65  
 66          # Write lock - allows only a single thread to update embeddings
 67          self.lock = RLock()
 68  
 69          # ThreadPool - runs scheduled workflows
 70          self.pool = None
 71  
 72          # Create pipelines
 73          self.createpipelines()
 74  
 75          # Create workflows
 76          self.createworkflows()
 77  
 78          # Create agents
 79          self.createagents()
 80  
 81          # Create embeddings index
 82          self.indexes(loaddata)
 83  
 84      def __del__(self):
 85          """
 86          Close threadpool when this object is garbage collected.
 87          """
 88  
 89          if hasattr(self, "pool") and self.pool:
 90              self.pool.close()
 91              self.pool = None
 92  
 93      def createpipelines(self):
 94          """
 95          Create pipelines.
 96          """
 97  
 98          # Pipeline definitions
 99          self.pipelines = {}
100  
101          # Default pipelines
102          pipelines = list(PipelineFactory.list().keys())
103  
104          # Add custom pipelines
105          for key in self.config:
106              if "." in key:
107                  pipelines.append(key)
108  
109          # Move dependent pipelines to end of list
110          dependent = ["similarity", "extractor", "rag", "reranker"]
111          pipelines = sorted(pipelines, key=lambda x: dependent.index(x) + 1 if x in dependent else 0)
112  
113          # Create pipelines
114          for pipeline in pipelines:
115              if pipeline in self.config:
116                  config = self.config[pipeline] if self.config[pipeline] else {}
117  
118                  # Add application reference, if requested
119                  if "application" in config:
120                      config["application"] = self
121  
122                  # Custom pipeline parameters
123                  if pipeline in ["extractor", "rag"]:
124                      if "similarity" not in config:
125                          # Add placeholder, will be set to embeddings index once initialized
126                          config["similarity"] = None
127  
128                      # Resolve reference pipelines
129                      if config.get("similarity") in self.pipelines:
130                          config["similarity"] = self.pipelines[config["similarity"]]
131  
132                      if config.get("path") in self.pipelines:
133                          config["path"] = self.pipelines[config["path"]]
134  
135                  elif pipeline == "similarity" and "path" not in config and "labels" in self.pipelines:
136                      config["model"] = self.pipelines["labels"]
137  
138                  elif pipeline == "reranker":
139                      config["embeddings"] = None
140                      config["similarity"] = self.pipelines["similarity"]
141  
142                  elif pipeline == "textractor":
143                      # Default to safeopen enabled
144                      config["safeopen"] = config.get("safeopen", True)
145  
146                  self.pipelines[pipeline] = PipelineFactory.create(config, pipeline)
147  
148      def createworkflows(self):
149          """
150          Create workflows.
151          """
152  
153          # Workflow definitions
154          self.workflows = {}
155  
156          # Create workflows
157          if "workflow" in self.config:
158              for workflow, config in self.config["workflow"].items():
159                  # Create copy of config
160                  config = config.copy()
161  
162                  # Resolve callable functions
163                  config["tasks"] = [self.resolvetask(task) for task in config["tasks"]]
164  
165                  # Resolve stream functions
166                  if "stream" in config:
167                      config["stream"] = self.resolvetask(config["stream"])
168  
169                  # Get scheduler config
170                  schedule = config.pop("schedule", None)
171  
172                  # Create workflow
173                  self.workflows[workflow] = WorkflowFactory.create(config, workflow)
174  
175                  # Schedule job if necessary
176                  if schedule:
177                      # Create pool if necessary
178                      if not self.pool:
179                          self.pool = ThreadPool()
180  
181                      self.pool.apply_async(self.workflows[workflow].schedule, kwds=schedule)
182  
183      def createagents(self):
184          """
185          Create agents.
186          """
187  
188          # Agent definitions
189          self.agents = {}
190  
191          # Create agents
192          if "agent" in self.config:
193              for agent, config in self.config["agent"].items():
194                  # Create copy of config
195                  config = config.copy()
196  
197                  # Resolve LLM
198                  config["llm"] = self.function("llm")
199  
200                  # Resolve tools
201                  for tool in config.get("tools", []):
202                      if isinstance(tool, dict) and "target" in tool:
203                          tool["target"] = self.function(tool["target"])
204  
205                  # Create agent
206                  self.agents[agent] = Agent(**config)
207  
208      def indexes(self, loaddata):
209          """
210          Initialize an embeddings index.
211  
212          Args:
213              loaddata: If True (default), load existing index data, if available. Otherwise, only load models.
214          """
215  
216          # Get embeddings configuration
217          config = self.config.get("embeddings")
218          if config:
219              # Resolve application functions in embeddings config
220              config = self.resolveconfig(config.copy())
221  
222          # Load embeddings index if loaddata and index exists
223          if loaddata and Embeddings().exists(self.config.get("path"), self.config.get("cloud")):
224              # Initialize empty embeddings
225              self.embeddings = Embeddings()
226  
227              # Pass path and cloud settings. Set application functions as config overrides.
228              self.embeddings.load(
229                  self.config.get("path"),
230                  self.config.get("cloud"),
231                  {key: config[key] for key in ["functions", "transform"] if key in config} if config else None,
232              )
233  
234          elif "embeddings" in self.config:
235              # Create new embeddings with config
236              self.embeddings = Embeddings(config)
237  
238          # If an extractor pipeline is defined and the similarity attribute is None, set to embeddings index
239          for key in ["extractor", "rag"]:
240              pipeline = self.pipelines.get(key)
241              config = self.config.get(key)
242  
243              if pipeline and config is not None and config["similarity"] is None:
244                  pipeline.similarity = self.embeddings
245  
246          # Attach embeddings to reranker
247          if "reranker" in self.pipelines:
248              self.pipelines["reranker"].embeddings = self.embeddings
249  
250      def resolvetask(self, task):
251          """
252          Resolves callable functions for a task.
253  
254          Args:
255              task: input task config
256          """
257  
258          # Check for task shorthand syntax
259          task = {"action": task} if isinstance(task, (str, list)) else task
260  
261          if "action" in task:
262              action = task["action"]
263              values = [action] if not isinstance(action, list) else action
264  
265              actions = []
266              for a in values:
267                  if a in ["index", "upsert"]:
268                      # Add queue action to buffer documents to index
269                      actions.append(self.add)
270  
271                      # Override and disable unpacking for indexing actions
272                      task["unpack"] = False
273  
274                      # Add finalize to trigger indexing
275                      task["finalize"] = self.upsert if a == "upsert" else self.index
276                  elif a == "search":
277                      actions.append(self.batchsearch)
278                  elif a == "transform":
279                      # Transform vectors
280                      actions.append(self.batchtransform)
281  
282                      # Override and disable one-to-many transformations
283                      task["onetomany"] = False
284                  else:
285                      # Resolve action to callable function
286                      actions.append(self.function(a))
287  
288              # Save resolved action(s)
289              task["action"] = actions[0] if not isinstance(action, list) else actions
290  
291          # Resolve initializer
292          if "initialize" in task and isinstance(task["initialize"], str):
293              task["initialize"] = self.function(task["initialize"])
294  
295          # Resolve finalizer
296          if "finalize" in task and isinstance(task["finalize"], str):
297              task["finalize"] = self.function(task["finalize"])
298  
299          return task
300  
301      def resolveconfig(self, config):
302          """
303          Resolves callable functions stored in embeddings configuration.
304  
305          Args:
306              config: embeddings config
307  
308          Returns:
309              resolved config
310          """
311  
312          if "functions" in config:
313              # Resolve callable functions
314              functions = []
315              for fn in config["functions"]:
316                  original = fn
317                  try:
318                      if isinstance(fn, dict):
319                          fn = fn.copy()
320                          fn["function"] = self.function(fn["function"])
321                      else:
322                          fn = self.function(fn)
323  
324                  # pylint: disable=W0703
325                  except Exception:
326                      # Not a resolvable function, pipeline or workflow - further resolution will happen in embeddings
327                      fn = original
328  
329                  functions.append(fn)
330  
331              config["functions"] = functions
332  
333          if "transform" in config:
334              # Resolve transform function
335              config["transform"] = self.function(config["transform"])
336  
337          return config
338  
339      def function(self, function):
340          """
341          Get a handle to a callable function.
342  
343          Args:
344              function: function name
345  
346          Returns:
347              resolved function
348          """
349  
350          # Check if function is a pipeline
351          if function in self.pipelines:
352              return self.pipelines[function]
353  
354          # Check if function is a workflow
355          if function in self.workflows:
356              return self.workflows[function]
357  
358          # Attempt to resolve action as a callable function
359          return PipelineFactory.create({}, function)
360  
361      def search(self, query, limit=10, weights=None, index=None, parameters=None, graph=False):
362          """
363          Finds documents most similar to the input query. This method will run either an index search
364          or an index + database search depending on if a database is available.
365  
366          Args:
367              query: input query
368              limit: maximum results
369              weights: hybrid score weights, if applicable
370              index: index name, if applicable
371              parameters: dict of named parameters to bind to placeholders
372              graph: return graph results if True
373  
374          Returns:
375              list of {id: value, score: value} for index search, list of dict for an index + database search
376          """
377  
378          if self.embeddings:
379              with self.lock:
380                  results = self.embeddings.search(query, limit, weights, index, parameters, graph)
381  
382              # Unpack (id, score) tuple, if necessary. Otherwise, results are dictionaries.
383              return results if graph else [{"id": r[0], "score": float(r[1])} if isinstance(r, tuple) else r for r in results]
384  
385          return None
386  
387      def batchsearch(self, queries, limit=10, weights=None, index=None, parameters=None, graph=False):
388          """
389          Finds documents most similar to the input queries. This method will run either an index search
390          or an index + database search depending on if a database is available.
391  
392          Args:
393              queries: input queries
394              limit: maximum results
395              weights: hybrid score weights, if applicable
396              index: index name, if applicable
397              parameters: list of dicts of named parameters to bind to placeholders
398              graph: return graph results if True
399  
400          Returns:
401              list of {id: value, score: value} per query for index search, list of dict per query for an index + database search
402          """
403  
404          if self.embeddings:
405              with self.lock:
406                  search = self.embeddings.batchsearch(queries, limit, weights, index, parameters, graph)
407  
408              results = []
409              for result in search:
410                  # Unpack (id, score) tuple, if necessary. Otherwise, results are dictionaries.
411                  results.append(result if graph else [{"id": r[0], "score": float(r[1])} if isinstance(r, tuple) else r for r in result])
412              return results
413  
414          return None
415  
416      def add(self, documents):
417          """
418          Adds a batch of documents for indexing.
419  
420          Args:
421              documents: list of {id: value, data: value, tags: value}
422  
423          Returns:
424              unmodified input documents
425          """
426  
427          # Raise error if index is not writable
428          if not self.config.get("writable"):
429              raise ReadOnlyError("Attempting to add documents to a read-only index (writable != True)")
430  
431          if self.embeddings:
432              with self.lock:
433                  # Create documents file if not already open
434                  if not self.documents:
435                      self.documents = Documents()
436  
437                  # Add documents
438                  self.documents.add(list(documents))
439  
440          # Return unmodified input documents
441          return documents
442  
443      def addobject(self, data, uid, field):
444          """
445          Helper method that builds a batch of object documents.
446  
447          Args:
448              data: object content
449              uid: optional list of corresponding uids
450              field: optional field to set
451  
452          Returns:
453              documents
454          """
455  
456          # Raise error if index is not writable
457          if not self.config.get("writable"):
458              raise ReadOnlyError("Attempting to add documents to a read-only index (writable != True)")
459  
460          documents = []
461          for x, content in enumerate(data):
462              if field:
463                  row = {"id": uid[x], field: content} if uid else {field: content}
464              elif uid:
465                  row = (uid[x], content)
466              else:
467                  row = content
468  
469              documents.append(row)
470  
471          return self.add(documents)
472  
473      def index(self):
474          """
475          Builds an embeddings index for previously batched documents.
476          """
477  
478          # Raise error if index is not writable
479          if not self.config.get("writable"):
480              raise ReadOnlyError("Attempting to index a read-only index (writable != True)")
481  
482          if self.embeddings and self.documents:
483              with self.lock:
484                  # Reset index
485                  self.indexes(False)
486  
487                  # Build scoring index if term weighting is enabled
488                  if self.embeddings.isweighted():
489                      self.embeddings.score(self.documents)
490  
491                  # Build embeddings index
492                  self.embeddings.index(self.documents)
493  
494                  # Save index if path available, otherwise this is an memory-only index
495                  if self.config.get("path"):
496                      self.embeddings.save(self.config["path"], self.config.get("cloud"))
497  
498                  # Reset document stream
499                  self.documents.close()
500                  self.documents = None
501  
502      def upsert(self):
503          """
504          Runs an embeddings upsert operation for previously batched documents.
505          """
506  
507          # Raise error if index is not writable
508          if not self.config.get("writable"):
509              raise ReadOnlyError("Attempting to upsert a read-only index (writable != True)")
510  
511          if self.embeddings and self.documents:
512              with self.lock:
513                  # Run upsert
514                  self.embeddings.upsert(self.documents)
515  
516                  # Save index if path available, otherwise this is an memory-only index
517                  if self.config.get("path"):
518                      self.embeddings.save(self.config["path"], self.config.get("cloud"))
519  
520                  # Reset document stream
521                  self.documents.close()
522                  self.documents = None
523  
524      def delete(self, ids):
525          """
526          Deletes from an embeddings index. Returns list of ids deleted.
527  
528          Args:
529              ids: list of ids to delete
530  
531          Returns:
532              ids deleted
533          """
534  
535          # Raise error if index is not writable
536          if not self.config.get("writable"):
537              raise ReadOnlyError("Attempting to delete from a read-only index (writable != True)")
538  
539          if self.embeddings:
540              with self.lock:
541                  # Run delete operation
542                  deleted = self.embeddings.delete(ids)
543  
544                  # Save index if path available, otherwise this is an memory-only index
545                  if self.config.get("path"):
546                      self.embeddings.save(self.config["path"], self.config.get("cloud"))
547  
548                  # Return deleted ids
549                  return deleted
550  
551          return None
552  
553      def reindex(self, config, function=None):
554          """
555          Recreates embeddings index using config. This method only works if document content storage is enabled.
556  
557          Args:
558              config: new config
559              function: optional function to prepare content for indexing
560          """
561  
562          # Raise error if index is not writable
563          if not self.config.get("writable"):
564              raise ReadOnlyError("Attempting to reindex a read-only index (writable != True)")
565  
566          if self.embeddings:
567              with self.lock:
568                  # Resolve function, if necessary
569                  function = self.function(function) if function and isinstance(function, str) else function
570  
571                  # Reindex
572                  self.embeddings.reindex(config, function)
573  
574                  # Save index if path available, otherwise this is an memory-only index
575                  if self.config.get("path"):
576                      self.embeddings.save(self.config["path"], self.config.get("cloud"))
577  
578      def count(self):
579          """
580          Total number of elements in this embeddings index.
581  
582          Returns:
583              number of elements in embeddings index
584          """
585  
586          if self.embeddings:
587              return self.embeddings.count()
588  
589          return None
590  
591      def similarity(self, query, texts):
592          """
593          Computes the similarity between query and list of text. Returns a list of
594          {id: value, score: value} sorted by highest score, where id is the index
595          in texts.
596  
597          Args:
598              query: query text
599              texts: list of text
600  
601          Returns:
602              list of {id: value, score: value}
603          """
604  
605          # Use similarity instance if available otherwise fall back to embeddings model
606          if "similarity" in self.pipelines:
607              return [{"id": uid, "score": float(score)} for uid, score in self.pipelines["similarity"](query, texts)]
608          if self.embeddings:
609              return [{"id": uid, "score": float(score)} for uid, score in self.embeddings.similarity(query, texts)]
610  
611          return None
612  
613      def batchsimilarity(self, queries, texts):
614          """
615          Computes the similarity between list of queries and list of text. Returns a list
616          of {id: value, score: value} sorted by highest score per query, where id is the
617          index in texts.
618  
619          Args:
620              queries: queries text
621              texts: list of text
622  
623          Returns:
624              list of {id: value, score: value} per query
625          """
626  
627          # Use similarity instance if available otherwise fall back to embeddings model
628          if "similarity" in self.pipelines:
629              return [[{"id": uid, "score": float(score)} for uid, score in r] for r in self.pipelines["similarity"](queries, texts)]
630          if self.embeddings:
631              return [[{"id": uid, "score": float(score)} for uid, score in r] for r in self.embeddings.batchsimilarity(queries, texts)]
632  
633          return None
634  
635      def explain(self, query, texts=None, limit=10):
636          """
637          Explains the importance of each input token in text for a query.
638  
639          Args:
640              query: query text
641              texts: optional list of text, otherwise runs search query
642              limit: optional limit if texts is None
643  
644          Returns:
645              list of dict per input text where a higher token scores represents higher importance relative to the query
646          """
647  
648          if self.embeddings:
649              with self.lock:
650                  return self.embeddings.explain(query, texts, limit)
651  
652          return None
653  
654      def batchexplain(self, queries, texts=None, limit=10):
655          """
656          Explains the importance of each input token in text for a list of queries.
657  
658          Args:
659              query: queries text
660              texts: optional list of text, otherwise runs search queries
661              limit: optional limit if texts is None
662  
663          Returns:
664              list of dict per input text per query where a higher token scores represents higher importance relative to the query
665          """
666  
667          if self.embeddings:
668              with self.lock:
669                  return self.embeddings.batchexplain(queries, texts, limit)
670  
671          return None
672  
673      def transform(self, text, category=None, index=None):
674          """
675          Transforms text into embeddings arrays.
676  
677          Args:
678              text: input text
679              category: category for instruction-based embeddings
680              index: index name, if applicable
681  
682          Returns:
683              embeddings array
684          """
685  
686          if self.embeddings:
687              return [float(x) for x in self.embeddings.transform(text, category, index)]
688  
689          return None
690  
691      def batchtransform(self, texts, category=None, index=None):
692          """
693          Transforms list of text into embeddings arrays.
694  
695          Args:
696              texts: list of text
697              category: category for instruction-based embeddings
698              index: index name, if applicable
699  
700          Returns:
701              embeddings arrays
702          """
703  
704          if self.embeddings:
705              return [[float(x) for x in result] for result in self.embeddings.batchtransform(texts, category, index)]
706  
707          return None
708  
709      def extract(self, queue, texts=None):
710          """
711          Extracts answers to input questions.
712  
713          Args:
714              queue: list of {name: value, query: value, question: value, snippet: value}
715              texts: optional list of text
716  
717          Returns:
718              list of {name: value, answer: value}
719          """
720  
721          if self.embeddings and "extractor" in self.pipelines:
722              # Get extractor instance
723              extractor = self.pipelines["extractor"]
724  
725              # Run extractor and return results as dicts
726              return extractor(queue, texts)
727  
728          return None
729  
730      def label(self, text, labels):
731          """
732          Applies a zero shot classifier to text using a list of labels. Returns a list of
733          {id: value, score: value} sorted by highest score, where id is the index in labels.
734  
735          Args:
736              text: text|list
737              labels: list of labels
738  
739          Returns:
740              list of {id: value, score: value} per text element
741          """
742  
743          if "labels" in self.pipelines:
744              # Text is a string
745              if isinstance(text, str):
746                  return [{"id": uid, "score": float(score)} for uid, score in self.pipelines["labels"](text, labels)]
747  
748              # Text is a list
749              return [[{"id": uid, "score": float(score)} for uid, score in result] for result in self.pipelines["labels"](text, labels)]
750  
751          return None
752  
753      def pipeline(self, name, *args, **kwargs):
754          """
755          Generic pipeline execution method.
756  
757          Args:
758              name: pipeline name
759              args: pipeline positional arguments
760              kwargs: pipeline keyword arguments
761  
762          Returns:
763              pipeline results
764          """
765  
766          # Backwards compatible with previous pipeline function arguments
767          args = args[0] if args and len(args) == 1 and isinstance(args[0], tuple) else args
768  
769          if name in self.pipelines:
770              return self.pipelines[name](*args, **kwargs)
771  
772          return None
773  
774      def workflow(self, name, elements):
775          """
776          Executes a workflow.
777  
778          Args:
779              name: workflow name
780              elements: elements to process
781  
782          Returns:
783              processed elements
784          """
785  
786          if hasattr(elements, "__len__") and hasattr(elements, "__getitem__"):
787              # Convert to tuples and return as a list since input is sized
788              elements = [tuple(element) if isinstance(element, list) else element for element in elements]
789          else:
790              # Convert to tuples and return as a generator since input is not sized
791              elements = (tuple(element) if isinstance(element, list) else element for element in elements)
792  
793          # Execute workflow
794          return self.workflows[name](elements)
795  
796      def agent(self, name, *args, **kwargs):
797          """
798          Executes an agent.
799  
800          Args:
801              name: agent name
802              args: agent positional arguments
803              kwargs: agent keyword arguments
804          """
805  
806          if name in self.agents:
807              return self.agents[name](*args, **kwargs)
808  
809          return None
810  
811      def wait(self):
812          """
813          Closes threadpool and waits for completion.
814          """
815  
816          if self.pool:
817              self.pool.close()
818              self.pool.join()
819              self.pool = None
820  
821  
822  class ReadOnlyError(Exception):
823      """
824      Error raised when trying to modify a read-only index
825      """