/ src / python / txtai / scoring / sparse.py
sparse.py
  1  """
  2  Sparse module
  3  """
  4  
  5  from queue import Queue
  6  from threading import Thread
  7  
  8  from ..ann import SparseANNFactory
  9  from ..vectors import SparseVectorsFactory
 10  
 11  from .base import Scoring
 12  from .normalize import Normalize
 13  
 14  
 15  class Sparse(Scoring):
 16      """
 17      Sparse vector scoring.
 18      """
 19  
 20      # End of stream message
 21      COMPLETE = 1
 22  
 23      def __init__(self, config=None, models=None):
 24          super().__init__(config)
 25  
 26          # Vector configuration
 27          mapping = {"vectormethod": "method", "vectornormalize": "normalize"}
 28          config = {k: v for k, v in config.items() if k not in mapping.values()}
 29          for k, v in mapping.items():
 30              if k in config:
 31                  config[v] = config[k]
 32  
 33          # Load the SparseVectors model
 34          self.model = SparseVectorsFactory.create(config, models)
 35  
 36          # Normalize search outputs if vectors are not normalized already
 37          # Supports: True (default linear, scale=30.0), float (custom scale),
 38          #           "bb25"/"bayes" (Bayesian sigmoid calibration), False (disabled)
 39          self.isnormalize = self.config.get("normalize", True)
 40  
 41          # Create Bayesian normalizer when a Bayesian method is configured
 42          self.normalizer = None
 43          if isinstance(self.isnormalize, (str, dict)):
 44              normalizer = Normalize(self.isnormalize)
 45              if normalizer.isbayes():
 46                  self.normalizer = normalizer
 47  
 48          # Sparse ANN
 49          self.ann = None
 50  
 51          # Encoding processing parameters
 52          self.batch = self.config.get("batch", 1024)
 53          self.thread, self.queue, self.data = None, None, None
 54  
 55      def insert(self, documents, index=None, checkpoint=None):
 56          # Start processing thread, if necessary
 57          self.start(checkpoint)
 58  
 59          data = []
 60          for uid, document, tags in documents:
 61              # Extract text, if necessary
 62              if isinstance(document, dict):
 63                  document = document.get(self.text, document.get(self.object))
 64  
 65              if document is not None:
 66                  # Add data
 67                  data.append((uid, " ".join(document) if isinstance(document, list) else document, tags))
 68  
 69          # Add batch of data
 70          self.queue.put(data)
 71  
 72      def delete(self, ids):
 73          self.ann.delete(ids)
 74  
 75      def index(self, documents=None):
 76          # Insert documents, if provided
 77          if documents:
 78              self.insert(documents)
 79  
 80          # Create ANN, if there is pending data
 81          embeddings = self.stop()
 82          if embeddings is not None:
 83              self.ann = SparseANNFactory.create(self.config)
 84              self.ann.index(embeddings)
 85  
 86      def upsert(self, documents=None):
 87          # Insert documents, if provided
 88          if documents:
 89              self.insert(documents)
 90  
 91          # Check for existing index and pending data
 92          if self.ann:
 93              embeddings = self.stop()
 94              if embeddings is not None:
 95                  self.ann.append(embeddings)
 96          else:
 97              self.index()
 98  
 99      def weights(self, tokens):
100          # Not supported
101          return None
102  
103      def search(self, query, limit=3):
104          return self.batchsearch([query], limit)[0]
105  
106      def batchsearch(self, queries, limit=3, threads=True):
107          # Convert queries to embedding vectors
108          embeddings = self.model.batchtransform((None, query, None) for query in queries)
109  
110          # Run ANN search
111          scores = self.ann.search(embeddings, limit)
112  
113          # Normalize scores if normalization IS enabled AND vector normalization IS NOT enabled
114          return self.normalize(embeddings, scores) if self.isnormalize and not self.model.isnormalize else scores
115  
116      def count(self):
117          return self.ann.count()
118  
119      def load(self, path):
120          self.ann = SparseANNFactory.create(self.config)
121          self.ann.load(path)
122  
123      def save(self, path):
124          # Save Sparse ANN
125          if self.ann:
126              self.ann.save(path)
127  
128      def close(self):
129          # Close Sparse ANN
130          if self.ann:
131              self.ann.close()
132  
133          # Clear parameters
134          self.model, self.ann, self.thread, self.queue = None, None, None, None
135  
136      def issparse(self):
137          return True
138  
139      def isnormalized(self):
140          return self.isnormalize or self.model.isnormalize
141  
142      def isbayes(self):
143          return self.normalizer is not None
144  
145      def start(self, checkpoint):
146          """
147          Starts an encoding processing thread.
148  
149          Args:
150              checkpoint: checkpoint directory
151          """
152  
153          if not self.thread:
154              self.queue = Queue(5)
155              self.thread = Thread(target=self.encode, args=(checkpoint,))
156              self.thread.start()
157  
158      def stop(self):
159          """
160          Stops an encoding processing thread. Return processed results.
161  
162          Returns:
163              results
164          """
165  
166          results = None
167          if self.thread:
168              # Send EOS message
169              self.queue.put(Sparse.COMPLETE)
170  
171              self.thread.join()
172              self.thread, self.queue = None, None
173  
174              # Get return value
175              results = self.data
176              self.data = None
177  
178          return results
179  
180      def encode(self, checkpoint):
181          """
182          Encodes streaming data.
183  
184          Args:
185              checkpoint: checkpoint directory
186          """
187  
188          # Streaming encoding of data
189          _, dimensions, self.data = self.model.vectors(self.stream(), self.batch, checkpoint)
190  
191          # Save number of dimensions
192          self.config["dimensions"] = dimensions
193  
194      def stream(self):
195          """
196          Streams data from an input queue until end of stream message received.
197          """
198  
199          batch = self.queue.get()
200          while batch != Sparse.COMPLETE:
201              yield from batch
202              batch = self.queue.get()
203  
204      def normalize(self, queries, scores):
205          """
206          Normalize query results.
207  
208          When Bayesian normalization is configured, applies sigmoid calibration
209          with per-query adaptive parameters (beta=median, alpha=1/std) to produce
210          calibrated probabilities in [0, 1]. Otherwise, applies linear normalization
211          using the max query score.
212  
213          Args:
214              queries: query vectors
215              scores: query results
216  
217          Returns:
218              normalized query results
219          """
220  
221          # Bayesian sigmoid calibration
222          if self.normalizer:
223              return [self.normalizer.bayes(result) if result else [] for result in scores]
224  
225          # Default linear normalization
226          scale = 30.0 if isinstance(self.isnormalize, bool) else self.isnormalize
227  
228          # Normalize scores using max scores
229          maxscores = self.model.dot(queries, queries)
230  
231          # Normalize results and return
232          results = []
233          for x, result in enumerate(scores):
234              maxscore = max(maxscores[x][x] / scale, scale)
235              maxscore = max(maxscore, result[0][1]) if result else maxscore
236  
237              results.append([(uid, score / maxscore) for uid, score in result])
238  
239          return results