/ src / python / txtai / embeddings / base.py
base.py
   1  """
   2  Embeddings module
   3  """
   4  
   5  import json
   6  import os
   7  import tempfile
   8  
   9  from ..ann import ANNFactory
  10  from ..archive import ArchiveFactory
  11  from ..cloud import CloudFactory
  12  from ..database import DatabaseFactory
  13  from ..graph import GraphFactory
  14  from ..scoring import ScoringFactory
  15  from ..vectors import VectorsFactory
  16  
  17  from .index import Action, Configuration, Functions, Indexes, IndexIds, Reducer, Stream, Transform
  18  from .search import Explain, Ids, Query, Search, Terms
  19  
  20  
  21  # pylint: disable=C0302,R0904
  22  class Embeddings:
  23      """
  24      Embeddings databases are the engine that delivers semantic search. Data is transformed into embeddings vectors where similar concepts
  25      will produce similar vectors. Indexes both large and small are built with these vectors. The indexes are used to find results
  26      that have the same meaning, not necessarily the same keywords.
  27      """
  28  
  29      # pylint: disable=W0231
  30      def __init__(self, config=None, models=None, **kwargs):
  31          """
  32          Creates a new embeddings index. Embeddings indexes are thread-safe for read operations but writes must be synchronized.
  33  
  34          Args:
  35              config: embeddings configuration
  36              models: models cache, used for model sharing between embeddings
  37              kwargs: additional configuration as keyword args
  38          """
  39  
  40          # Index configuration
  41          self.config = None
  42  
  43          # Dimensionality reduction - word vectors only
  44          self.reducer = None
  45  
  46          # Dense vector model - transforms data into similarity vectors
  47          self.model = None
  48  
  49          # Approximate nearest neighbor index
  50          self.ann = None
  51  
  52          # Index ids when content is disabled
  53          self.ids = None
  54  
  55          # Document database
  56          self.database = None
  57  
  58          # Resolvable functions
  59          self.functions = None
  60  
  61          # Graph network
  62          self.graph = None
  63  
  64          # Sparse vectors
  65          self.scoring = None
  66  
  67          # Query model
  68          self.query = None
  69  
  70          # Index archive
  71          self.archive = None
  72  
  73          # Subindexes for this embeddings instance
  74          self.indexes = None
  75  
  76          # Models cache
  77          self.models = models
  78  
  79          # Merge configuration into single dictionary
  80          config = {**config, **kwargs} if config and kwargs else kwargs if kwargs else config
  81  
  82          # Set initial configuration
  83          self.configure(config)
  84  
  85      def __enter__(self):
  86          return self
  87  
  88      def __exit__(self, *args):
  89          self.close()
  90  
  91      def score(self, documents):
  92          """
  93          Builds a term weighting scoring index. Only used by word vectors models.
  94  
  95          Args:
  96              documents: iterable of (id, data, tags), (id, data) or data
  97          """
  98  
  99          # Build scoring index for word vectors term weighting
 100          if self.isweighted():
 101              self.scoring.index(Stream(self)(documents))
 102  
 103      def index(self, documents, reindex=False, checkpoint=None):
 104          """
 105          Builds an embeddings index. This method overwrites an existing index.
 106  
 107          Args:
 108              documents: iterable of (id, data, tags), (id, data) or data
 109              reindex: if this is a reindex operation in which case database creation is skipped, defaults to False
 110              checkpoint: optional checkpoint directory, enables indexing restart
 111          """
 112  
 113          # Initialize index
 114          self.initindex(reindex)
 115  
 116          # Create transform and stream
 117          transform = Transform(self, Action.REINDEX if reindex else Action.INDEX, checkpoint)
 118          stream = Stream(self, Action.REINDEX if reindex else Action.INDEX)
 119  
 120          with tempfile.NamedTemporaryFile(mode="wb", suffix=".npy") as buffer:
 121              # Load documents into database and transform to vectors
 122              ids, dimensions, embeddings = transform(stream(documents), buffer)
 123              if embeddings is not None:
 124                  # Build LSA model (if enabled). Remove principal components from embeddings.
 125                  if self.config.get("pca"):
 126                      self.reducer = Reducer(embeddings, self.config["pca"])
 127                      self.reducer(embeddings)
 128  
 129                  # Save index dimensions
 130                  self.config["dimensions"] = dimensions
 131  
 132                  # Create approximate nearest neighbor index
 133                  self.ann = self.createann()
 134  
 135                  # Add embeddings to the index
 136                  self.ann.index(embeddings)
 137  
 138              # Save indexids-ids mapping for indexes with no database, except when this is a reindex
 139              if ids and not reindex and not self.database:
 140                  self.ids = self.createids(ids)
 141  
 142          # Index scoring, if necessary
 143          # This must occur before graph index in order to be available to the graph
 144          if self.issparse():
 145              self.scoring.index()
 146  
 147          # Index subindexes, if necessary
 148          if self.indexes:
 149              self.indexes.index()
 150  
 151          # Index graph, if necessary
 152          if self.graph:
 153              self.graph.index(Search(self, indexonly=True), Ids(self), self.batchsimilarity)
 154  
 155      def upsert(self, documents, checkpoint=None):
 156          """
 157          Runs an embeddings upsert operation. If the index exists, new data is
 158          appended to the index, existing data is updated. If the index doesn't exist,
 159          this method runs a standard index operation.
 160  
 161          Args:
 162              documents: iterable of (id, data, tags), (id, data) or data
 163              checkpoint: optional checkpoint directory, enables indexing restart
 164          """
 165  
 166          # Run standard insert if index doesn't exist or it has no records
 167          if not self.count():
 168              self.index(documents, checkpoint=checkpoint)
 169              return
 170  
 171          # Create transform and stream
 172          transform = Transform(self, Action.UPSERT, checkpoint=checkpoint)
 173          stream = Stream(self, Action.UPSERT)
 174  
 175          with tempfile.NamedTemporaryFile(mode="wb", suffix=".npy") as buffer:
 176              # Load documents into database and transform to vectors
 177              ids, _, embeddings = transform(stream(documents), buffer)
 178              if embeddings is not None:
 179                  # Remove principal components from embeddings, if necessary
 180                  if self.reducer:
 181                      self.reducer(embeddings)
 182  
 183                  # Append embeddings to the index
 184                  self.ann.append(embeddings)
 185  
 186              # Save indexids-ids mapping for indexes with no database
 187              if ids and not self.database:
 188                  self.ids = self.createids(self.ids + ids)
 189  
 190          # Scoring upsert, if necessary
 191          # This must occur before graph upsert in order to be available to the graph
 192          if self.issparse():
 193              self.scoring.upsert()
 194  
 195          # Subindexes upsert, if necessary
 196          if self.indexes:
 197              self.indexes.upsert()
 198  
 199          # Graph upsert, if necessary
 200          if self.graph:
 201              self.graph.upsert(Search(self, indexonly=True), Ids(self), self.batchsimilarity)
 202  
 203      def delete(self, ids):
 204          """
 205          Deletes from an embeddings index. Returns list of ids deleted.
 206  
 207          Args:
 208              ids: list of ids to delete
 209  
 210          Returns:
 211              list of ids deleted
 212          """
 213  
 214          # List of internal indices for each candidate id to delete
 215          indices = []
 216  
 217          # List of deleted ids
 218          deletes = []
 219  
 220          if self.database:
 221              # Retrieve indexid-id mappings from database
 222              ids = self.database.ids(ids)
 223  
 224              # Parse out indices and ids to delete
 225              indices = [i for i, _ in ids]
 226              deletes = sorted(set(uid for _, uid in ids))
 227  
 228              # Delete ids from database
 229              self.database.delete(deletes)
 230          elif self.ann or self.scoring:
 231              # Find existing ids
 232              for uid in ids:
 233                  indices.extend([index for index, value in enumerate(self.ids) if uid == value])
 234  
 235              # Clear embeddings ids
 236              for index in indices:
 237                  deletes.append(self.ids[index])
 238                  self.ids[index] = None
 239  
 240          # Delete indices for all indexes and data stores
 241          if indices:
 242              # Delete ids from ann
 243              if self.isdense():
 244                  self.ann.delete(indices)
 245  
 246              # Delete ids from scoring
 247              if self.issparse():
 248                  self.scoring.delete(indices)
 249  
 250              # Delete ids from subindexes
 251              if self.indexes:
 252                  self.indexes.delete(indices)
 253  
 254              # Delete ids from graph
 255              if self.graph:
 256                  self.graph.delete(indices)
 257  
 258          return deletes
 259  
 260      def reindex(self, config=None, function=None, **kwargs):
 261          """
 262          Recreates embeddings index using config. This method only works if document content storage is enabled.
 263  
 264          Args:
 265              config: new config
 266              function: optional function to prepare content for indexing
 267              kwargs: additional configuration as keyword args
 268          """
 269  
 270          if self.database:
 271              # Merge configuration into single dictionary
 272              config = {**config, **kwargs} if config and kwargs else config if config else kwargs
 273  
 274              # Keep content and objects parameters to ensure database is preserved
 275              config["content"] = self.config["content"]
 276              if "objects" in self.config:
 277                  config["objects"] = self.config["objects"]
 278  
 279              # Reset configuration
 280              self.configure(config)
 281  
 282              # Reset function references
 283              if self.functions:
 284                  self.functions.reset()
 285  
 286              # Reindex
 287              if function:
 288                  self.index(function(self.database.reindex(self.config)), True)
 289              else:
 290                  self.index(self.database.reindex(self.config), True)
 291  
 292      def transform(self, document, category=None, index=None):
 293          """
 294          Transforms document into an embeddings vector.
 295  
 296          Args:
 297              documents: iterable of (id, data, tags), (id, data) or data
 298              category: category for instruction-based embeddings
 299              index: index name, if applicable
 300  
 301          Returns:
 302              embeddings vector
 303          """
 304  
 305          return self.batchtransform([document], category, index)[0]
 306  
 307      def batchtransform(self, documents, category=None, index=None):
 308          """
 309          Transforms documents into embeddings vectors.
 310  
 311          Args:
 312              documents: iterable of (id, data, tags), (id, data) or data
 313              category: category for instruction-based embeddings
 314              index: index name, if applicable
 315  
 316          Returns:
 317              embeddings vectors
 318          """
 319  
 320          # Initialize default parameters, if necessary
 321          self.defaults()
 322  
 323          # Get vector model
 324          model = self.findmodel(index)
 325  
 326          # Convert documents into embeddings
 327          embeddings = model.batchtransform(Stream(self)(documents), category)
 328  
 329          # Reduce the dimensionality of the embeddings. Scale the embeddings using this
 330          # model to reduce the noise of common but less relevant terms.
 331          if self.reducer:
 332              self.reducer(embeddings)
 333  
 334          return embeddings
 335  
 336      def count(self):
 337          """
 338          Total number of elements in this embeddings index.
 339  
 340          Returns:
 341              number of elements in this embeddings index
 342          """
 343  
 344          if self.ann:
 345              return self.ann.count()
 346          if self.scoring:
 347              return self.scoring.count()
 348          if self.database:
 349              return self.database.count()
 350          if self.ids:
 351              return len([uid for uid in self.ids if uid is not None])
 352  
 353          # Default to 0 when no suitable method found
 354          return 0
 355  
 356      def search(self, query, limit=None, weights=None, index=None, parameters=None, graph=False):
 357          """
 358          Finds documents most similar to the input query. This method runs an index search, index + database search
 359          or a graph search, depending on the embeddings configuration and query.
 360  
 361          Args:
 362              query: input query
 363              limit: maximum results
 364              weights: hybrid score weights, if applicable
 365              index: index name, if applicable
 366              parameters: dict of named parameters to bind to placeholders
 367              graph: return graph results if True
 368  
 369          Returns:
 370              list of (id, score) for index search
 371              list of dict for an index + database search
 372              graph when graph is set to True
 373          """
 374  
 375          results = self.batchsearch([query], limit, weights, index, [parameters], graph)
 376          return results[0] if results else results
 377  
 378      def batchsearch(self, queries, limit=None, weights=None, index=None, parameters=None, graph=False):
 379          """
 380          Finds documents most similar to the input query. This method runs an index search, index + database search
 381          or a graph search, depending on the embeddings configuration and query.
 382  
 383          Args:
 384              queries: input queries
 385              limit: maximum results
 386              weights: hybrid score weights, if applicable
 387              index: index name, if applicable
 388              parameters: list of dicts of named parameters to bind to placeholders
 389              graph: return graph results if True
 390  
 391          Returns:
 392              list of (id, score) per query for index search
 393              list of dict per query for an index + database search
 394              list of graph per query when graph is set to True
 395          """
 396  
 397          # Determine if graphs should be returned
 398          graph = graph if self.graph else False
 399  
 400          # Execute search
 401          results = Search(self, indexids=graph)(queries, limit, weights, index, parameters)
 402  
 403          # Create subgraphs using results, if necessary
 404          return [self.graph.filter(x) if isinstance(x, list) else x for x in results] if graph else results
 405  
 406      def similarity(self, query, data):
 407          """
 408          Computes the similarity between query and list of data. Returns a list of
 409          (id, score) sorted by highest score, where id is the index in data.
 410  
 411          Args:
 412              query: input query
 413              data: list of data
 414  
 415          Returns:
 416              list of (id, score)
 417          """
 418  
 419          return self.batchsimilarity([query], data)[0]
 420  
 421      def batchsimilarity(self, queries, data):
 422          """
 423          Computes the similarity between list of queries and list of data. Returns a list
 424          of (id, score) sorted by highest score per query, where id is the index in data.
 425  
 426          Args:
 427              queries: input queries
 428              data: list of data
 429  
 430          Returns:
 431              list of (id, score) per query
 432          """
 433  
 434          # Convert queries to embedding vectors
 435          queries = self.batchtransform(((None, query, None) for query in queries), "query")
 436          data = self.batchtransform(((None, row, None) for row in data), "data")
 437  
 438          # Get vector model
 439          model = self.findmodel()
 440  
 441          # Dot product on normalized vectors is equal to cosine similarity
 442          scores = model.dot(queries, data)
 443  
 444          # Add index and sort desc based on score
 445          return [sorted(enumerate(score), key=lambda x: x[1], reverse=True) for score in scores]
 446  
 447      def explain(self, query, texts=None, limit=None):
 448          """
 449          Explains the importance of each input token in text for a query. This method requires either content to be enabled
 450          or texts to be provided.
 451  
 452          Args:
 453              query: input query
 454              texts: optional list of (text|list of tokens), otherwise runs search query
 455              limit: optional limit if texts is None
 456  
 457          Returns:
 458              list of dict per input text where a higher token scores represents higher importance relative to the query
 459          """
 460  
 461          results = self.batchexplain([query], texts, limit)
 462          return results[0] if results else results
 463  
 464      def batchexplain(self, queries, texts=None, limit=None):
 465          """
 466          Explains the importance of each input token in text for a list of queries. This method requires either content to be enabled
 467          or texts to be provided.
 468  
 469          Args:
 470              queries: input queries
 471              texts: optional list of (text|list of tokens), otherwise runs search queries
 472              limit: optional limit if texts is None
 473  
 474          Returns:
 475              list of dict per input text per query where a higher token scores represents higher importance relative to the query
 476          """
 477  
 478          return Explain(self)(queries, texts, limit)
 479  
 480      def terms(self, query):
 481          """
 482          Extracts keyword terms from a query.
 483  
 484          Args:
 485              query: input query
 486  
 487          Returns:
 488              query reduced down to keyword terms
 489          """
 490  
 491          return self.batchterms([query])[0]
 492  
 493      def batchterms(self, queries):
 494          """
 495          Extracts keyword terms from a list of queries.
 496  
 497          Args:
 498              queries: list of queries
 499  
 500          Returns:
 501              list of queries reduced down to keyword term strings
 502          """
 503  
 504          return Terms(self)(queries)
 505  
 506      def exists(self, path=None, cloud=None, **kwargs):
 507          """
 508          Checks if an index exists at path.
 509  
 510          Args:
 511              path: input path
 512              cloud: cloud storage configuration
 513              kwargs: additional configuration as keyword args
 514  
 515          Returns:
 516              True if index exists, False otherwise
 517          """
 518  
 519          # Check if this exists in a cloud instance
 520          cloud = self.createcloud(cloud=cloud, **kwargs)
 521          if cloud:
 522              return cloud.exists(path)
 523  
 524          # Check if this is an archive file and exists
 525          path, apath = self.checkarchive(path)
 526          if apath:
 527              return os.path.exists(apath)
 528  
 529          # Return true if path has a config.json or config file with an offset set
 530          return path and (os.path.exists(f"{path}/config.json") or os.path.exists(f"{path}/config")) and "offset" in Configuration().load(path)
 531  
 532      def load(self, path=None, cloud=None, config=None, **kwargs):
 533          """
 534          Loads an existing index from path.
 535  
 536          Args:
 537              path: input path
 538              cloud: cloud storage configuration
 539              config: configuration overrides
 540              kwargs: additional configuration as keyword args
 541  
 542          Returns:
 543              Embeddings
 544          """
 545  
 546          # Load from cloud, if configured
 547          cloud = self.createcloud(cloud=cloud, **kwargs)
 548          if cloud:
 549              path = cloud.load(path)
 550  
 551          # Check if this is an archive file and extract
 552          path, apath = self.checkarchive(path)
 553          if apath:
 554              self.archive.load(apath)
 555  
 556          # Load index configuration
 557          self.config = Configuration().load(path)
 558  
 559          # Apply config overrides
 560          self.config = {**self.config, **config} if config else self.config
 561  
 562          # Approximate nearest neighbor index - stores dense vectors
 563          self.ann = self.createann()
 564          if self.ann:
 565              self.ann.load(f"{path}/embeddings")
 566  
 567          # Dimensionality reduction model - word vectors only
 568          if self.config.get("pca"):
 569              self.reducer = Reducer()
 570              self.reducer.load(f"{path}/lsa")
 571  
 572          # Index ids when content is disabled
 573          self.ids = self.createids()
 574          if self.ids:
 575              self.ids.load(f"{path}/ids")
 576  
 577          # Document database - stores document content
 578          self.database = self.createdatabase()
 579          if self.database:
 580              self.database.load(f"{path}/documents")
 581  
 582          # Sparse vectors - stores term sparse arrays
 583          self.scoring = self.createscoring()
 584          if self.scoring:
 585              self.scoring.load(f"{path}/scoring")
 586  
 587          # Subindexes
 588          self.indexes = self.createindexes()
 589          if self.indexes:
 590              self.indexes.load(f"{path}/indexes")
 591  
 592          # Graph network - stores relationships
 593          self.graph = self.creategraph()
 594          if self.graph:
 595              self.graph.load(f"{path}/graph")
 596  
 597          # Dense vectors - transforms data to embeddings vectors
 598          self.model = self.loadvectors()
 599  
 600          # Query model
 601          self.query = self.loadquery()
 602  
 603          return self
 604  
 605      def save(self, path, cloud=None, **kwargs):
 606          """
 607          Saves an index in a directory at path unless path ends with tar.gz, tar.bz2, tar.xz or zip.
 608          In those cases, the index is stored as a compressed file.
 609  
 610          Args:
 611              path: output path
 612              cloud: cloud storage configuration
 613              kwargs: additional configuration as keyword args
 614          """
 615  
 616          if self.config:
 617              # Check if this is an archive file
 618              path, apath = self.checkarchive(path)
 619  
 620              # Create output directory, if necessary
 621              os.makedirs(path, exist_ok=True)
 622  
 623              # Save index configuration
 624              Configuration().save(self.config, path)
 625  
 626              # Save approximate nearest neighbor index
 627              if self.ann:
 628                  self.ann.save(f"{path}/embeddings")
 629  
 630              # Save dimensionality reduction model (word vectors only)
 631              if self.reducer:
 632                  self.reducer.save(f"{path}/lsa")
 633  
 634              # Save index ids
 635              if self.ids:
 636                  self.ids.save(f"{path}/ids")
 637  
 638              # Save document database
 639              if self.database:
 640                  self.database.save(f"{path}/documents")
 641  
 642              # Save scoring index
 643              if self.scoring:
 644                  self.scoring.save(f"{path}/scoring")
 645  
 646              # Save subindexes
 647              if self.indexes:
 648                  self.indexes.save(f"{path}/indexes")
 649  
 650              # Save graph
 651              if self.graph:
 652                  self.graph.save(f"{path}/graph")
 653  
 654              # If this is an archive, save it
 655              if apath:
 656                  self.archive.save(apath)
 657  
 658              # Save to cloud, if configured
 659              cloud = self.createcloud(cloud=cloud, **kwargs)
 660              if cloud:
 661                  cloud.save(apath if apath else path)
 662  
 663      def close(self):
 664          """
 665          Closes this embeddings index and frees all resources.
 666          """
 667  
 668          self.config, self.archive = None, None
 669          self.reducer, self.query = None, None
 670          self.ids = None
 671  
 672          # Close ANN
 673          if self.ann:
 674              self.ann.close()
 675              self.ann = None
 676  
 677          # Close database
 678          if self.database:
 679              self.database.close()
 680              self.database, self.functions = None, None
 681  
 682          # Close scoring
 683          if self.scoring:
 684              self.scoring.close()
 685              self.scoring = None
 686  
 687          # Close graph
 688          if self.graph:
 689              self.graph.close()
 690              self.graph = None
 691  
 692          # Close indexes
 693          if self.indexes:
 694              self.indexes.close()
 695              self.indexes = None
 696  
 697          # Close vectors model
 698          if self.model:
 699              self.model.close()
 700              self.model = None
 701  
 702          self.models = None
 703  
 704      def info(self):
 705          """
 706          Prints the current embeddings index configuration.
 707          """
 708  
 709          if self.config:
 710              # Print configuration
 711              print(json.dumps(self.config, sort_keys=True, default=str, indent=2))
 712  
 713      def issparse(self):
 714          """
 715          Checks if this instance has an associated sparse keyword or sparse vectors scoring index.
 716  
 717          Returns:
 718              True if scoring has an associated sparse keyword/vector index, False otherwise
 719          """
 720  
 721          return self.scoring and self.scoring.issparse()
 722  
 723      def isdense(self):
 724          """
 725          Checks if this instance has an associated ANN instance.
 726  
 727          Returns:
 728              True if this instance has an associated ANN, False otherwise
 729          """
 730  
 731          return self.ann is not None
 732  
 733      def isweighted(self):
 734          """
 735          Checks if this instance has an associated scoring instance with term weighting enabled.
 736  
 737          Returns:
 738              True if term weighting is enabled, False otherwise
 739          """
 740  
 741          return self.scoring and self.scoring.isweighted()
 742  
 743      def findmodel(self, index=None):
 744          """
 745          Finds the primary vector model used by this instance.
 746  
 747          Returns:
 748              Vectors
 749          """
 750  
 751          return (
 752              self.indexes.findmodel(index)
 753              if index and self.indexes
 754              else (
 755                  self.model
 756                  if self.model
 757                  else self.scoring.findmodel() if self.scoring and self.scoring.findmodel() else self.indexes.findmodel() if self.indexes else None
 758              )
 759          )
 760  
 761      def configure(self, config):
 762          """
 763          Sets the configuration for this embeddings index and loads config-driven models.
 764  
 765          Args:
 766              config: embeddings configuration
 767          """
 768  
 769          # Configuration
 770          self.config = config
 771  
 772          # Dimensionality reduction model
 773          self.reducer = None
 774  
 775          # Create scoring instance for word vectors term weighting
 776          scoring = self.config.get("scoring") if self.config else None
 777          self.scoring = self.createscoring() if scoring and not self.hassparse() else None
 778  
 779          # Dense vectors - transforms data to embeddings vectors
 780          self.model = self.loadvectors() if self.config else None
 781  
 782          # Query model
 783          self.query = self.loadquery() if self.config else None
 784  
 785      def initindex(self, reindex):
 786          """
 787          Initialize new index.
 788  
 789          Args:
 790              reindex: if this is a reindex operation in which case database creation is skipped, defaults to False
 791          """
 792  
 793          # Initialize default parameters, if necessary
 794          self.defaults()
 795  
 796          # Initialize index ids, only created when content is disabled
 797          self.ids = None
 798  
 799          # Create document database, if necessary
 800          if not reindex:
 801              self.database = self.createdatabase()
 802  
 803              # Reset archive since this is a new index
 804              self.archive = None
 805  
 806          # Close existing ANN, if necessary
 807          if self.ann:
 808              self.ann.close()
 809  
 810          # Initialize ANN, will be created after index transformations complete
 811          self.ann = None
 812  
 813          # Create scoring only if the scoring config is for a sparse index
 814          if self.hassparse():
 815              self.scoring = self.createscoring()
 816  
 817          # Create subindexes, if necessary
 818          self.indexes = self.createindexes()
 819  
 820          # Create graph, if necessary
 821          self.graph = self.creategraph()
 822  
 823      def defaults(self):
 824          """
 825          Apply default parameters to current configuration.
 826  
 827          Returns:
 828              configuration with default parameters set
 829          """
 830  
 831          self.config = self.config if self.config else {}
 832  
 833          # Expand sparse index shortcuts
 834          if not self.config.get("scoring") and any(self.config.get(key) for key in ["keyword", "sparse", "hybrid"]):
 835              self.defaultsparse()
 836  
 837          # Expand graph shortcuts
 838          if self.config.get("graph") is True:
 839              self.config["graph"] = {}
 840  
 841          # Check if default model should be loaded
 842          if not self.model and (self.defaultallowed() or self.config.get("dense")):
 843              self.config["path"] = "sentence-transformers/all-MiniLM-L6-v2"
 844  
 845              # Load dense vectors model
 846              self.model = self.loadvectors()
 847  
 848      def defaultsparse(self):
 849          """
 850          Logic to derive default sparse index configuration.
 851          """
 852  
 853          # Check for keyword and hybrid parameters
 854          method = None
 855          for x in ["keyword", "hybrid"]:
 856              value = self.config.get(x)
 857              if value:
 858                  method = value if isinstance(value, str) else "bm25"
 859  
 860                  # Enable dense index when hybrid enabled
 861                  if x == "hybrid":
 862                      self.config["dense"] = True
 863  
 864          sparse = self.config.get("sparse", {})
 865          if sparse or method == "sparse":
 866              # Sparse vector configuration
 867              sparse = {"path": self.config.get("sparse")} if isinstance(sparse, str) else {} if isinstance(sparse, bool) else sparse
 868              sparse["path"] = sparse.get("path", "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-mini")
 869  
 870              # Merge in sparse parameters
 871              self.config["scoring"] = {**{"method": "sparse"}, **sparse}
 872  
 873          elif method:
 874              # Sparse keyword configuration
 875              self.config["scoring"] = {"method": method, "terms": True, "normalize": True}
 876  
 877      def defaultallowed(self):
 878          """
 879          Tests if this embeddings instance can use a default model if not otherwise provided.
 880  
 881          Returns:
 882              True if a default model is allowed, False otherwise
 883          """
 884  
 885          params = [("keyword", False), ("sparse", False), ("defaults", True)]
 886          return all(self.config.get(key, default) == default for key, default in params)
 887  
 888      def loadvectors(self):
 889          """
 890          Loads a vector model set in config.
 891  
 892          Returns:
 893              vector model
 894          """
 895  
 896          # Create model cache if subindexes are enabled
 897          if "indexes" in self.config and self.models is None:
 898              self.models = {}
 899  
 900          # Support path via dense parameter
 901          dense = self.config.get("dense")
 902          if not self.config.get("path") and dense and isinstance(dense, str):
 903              self.config["path"] = dense
 904  
 905          # Load vector model
 906          return VectorsFactory.create(self.config, self.scoring, self.models)
 907  
 908      def loadquery(self):
 909          """
 910          Loads a query model set in config.
 911  
 912          Returns:
 913              query model
 914          """
 915  
 916          if "query" in self.config:
 917              return Query(**self.config["query"])
 918  
 919          return None
 920  
 921      def checkarchive(self, path):
 922          """
 923          Checks if path is an archive file.
 924  
 925          Args:
 926              path: path to check
 927  
 928          Returns:
 929              (working directory, current path) if this is an archive, original path otherwise
 930          """
 931  
 932          # Create archive instance, if necessary
 933          self.archive = ArchiveFactory.create()
 934  
 935          # Check if path is an archive file
 936          if self.archive.isarchive(path):
 937              # Return temporary archive working directory and original path
 938              return self.archive.path(), path
 939  
 940          return path, None
 941  
 942      def createcloud(self, **cloud):
 943          """
 944          Creates a cloud instance from config.
 945  
 946          Args:
 947              cloud: cloud configuration
 948          """
 949  
 950          # Merge keyword args and keys under the cloud parameter
 951          config = cloud
 952          if "cloud" in config and config["cloud"]:
 953              config.update(config.pop("cloud"))
 954  
 955          # Create cloud instance from config and return
 956          return CloudFactory.create(config) if config else None
 957  
 958      def createann(self):
 959          """
 960          Creates an ANN from config.
 961  
 962          Returns:
 963              new ANN, if enabled in config
 964          """
 965  
 966          # Free existing resources
 967          if self.ann:
 968              self.ann.close()
 969  
 970          return ANNFactory.create(self.config) if self.config.get("path") or self.defaultallowed() else None
 971  
 972      def createdatabase(self):
 973          """
 974          Creates a database from config. This method will also close any existing database connection.
 975  
 976          Returns:
 977              new database, if enabled in config
 978          """
 979  
 980          # Free existing resources
 981          if self.database:
 982              self.database.close()
 983  
 984          config = self.config.copy()
 985  
 986          # Create references to callable functions
 987          self.functions = Functions(self) if "functions" in config else None
 988          if self.functions:
 989              config["functions"] = self.functions(config)
 990  
 991          # Create database from config and return
 992          return DatabaseFactory.create(config)
 993  
 994      def creategraph(self):
 995          """
 996          Creates a graph from config.
 997  
 998          Returns:
 999              new graph, if enabled in config
1000          """
1001  
1002          # Free existing resources
1003          if self.graph:
1004              self.graph.close()
1005  
1006          if "graph" in self.config:
1007              # Get or create graph configuration
1008              config = self.config["graph"] if "graph" in self.config else {}
1009  
1010              # Create configuration with custom columns, if necessary
1011              config = self.columns(config)
1012              return GraphFactory.create(config)
1013  
1014          return None
1015  
1016      def createids(self, ids=None):
1017          """
1018          Creates indexids when content is disabled.
1019  
1020          Args:
1021              ids: optional ids to add
1022  
1023          Returns:
1024              new indexids, if content disabled
1025          """
1026  
1027          # Load index ids when content is disabled
1028          return IndexIds(self, ids) if not self.config.get("content") else None
1029  
1030      def createindexes(self):
1031          """
1032          Creates subindexes from config.
1033  
1034          Returns:
1035              list of subindexes
1036          """
1037  
1038          # Free existing resources
1039          if self.indexes:
1040              self.indexes.close()
1041  
1042          # Load subindexes
1043          if "indexes" in self.config:
1044              indexes = {}
1045              for index, config in self.config["indexes"].items():
1046                  # Create index with shared model cache
1047                  indexes[index] = Embeddings(config, models=self.models)
1048  
1049              # Wrap as Indexes object
1050              return Indexes(self, indexes)
1051  
1052          return None
1053  
1054      def createscoring(self):
1055          """
1056          Creates a scoring from config.
1057  
1058          Returns:
1059              new scoring, if enabled in config
1060          """
1061  
1062          # Free existing resources
1063          if self.scoring:
1064              self.scoring.close()
1065  
1066          if "scoring" in self.config:
1067              # Expand scoring to a dictionary, if necessary
1068              config = self.config["scoring"]
1069              config = config if isinstance(config, dict) else {"method": config}
1070  
1071              # Create configuration with custom columns, if necessary
1072              config = self.columns(config)
1073              return ScoringFactory.create(config, self.models)
1074  
1075          return None
1076  
1077      def hassparse(self):
1078          """
1079          Checks is this embeddings database has an associated sparse index.
1080  
1081          Returns:
1082              True if this embeddings has an associated scoring index
1083          """
1084  
1085          # Create scoring only if scoring is a sparse keyword/vector index
1086          return ScoringFactory.issparse(self.config.get("scoring"))
1087  
1088      def columns(self, config):
1089          """
1090          Adds custom text/object column information if it's provided.
1091  
1092          Args:
1093              config: input configuration
1094  
1095          Returns:
1096              config with column information added
1097          """
1098  
1099          # Add text/object columns if custom
1100          if "columns" in self.config:
1101              # Work on copy of configuration
1102              config = config.copy()
1103  
1104              # Copy columns to config
1105              config["columns"] = self.config["columns"]
1106  
1107          return config