/ src / python / txtai / scoring / tfidf.py
tfidf.py
  1  """
  2  TFIDF module
  3  """
  4  
  5  import math
  6  import os
  7  
  8  from collections import Counter
  9  from multiprocessing.pool import ThreadPool
 10  
 11  import numpy as np
 12  
 13  from ..pipeline import Tokenizer
 14  from ..serialize import Serializer
 15  
 16  from .base import Scoring
 17  from .normalize import Normalize
 18  from .terms import Terms
 19  
 20  
 21  class TFIDF(Scoring):
 22      """
 23      Term frequency-inverse document frequency (TF-IDF) scoring.
 24      """
 25  
 26      def __init__(self, config=None):
 27          super().__init__(config)
 28  
 29          # Document stats
 30          self.total = 0
 31          self.tokens = 0
 32          self.avgdl = 0
 33  
 34          # Word frequency
 35          self.docfreq = Counter()
 36          self.wordfreq = Counter()
 37          self.avgfreq = 0
 38  
 39          # IDF index
 40          self.idf = {}
 41          self.avgidf = 0
 42  
 43          # Tag boosting
 44          self.tags = Counter()
 45  
 46          # Tokenizer, lazily loaded as needed
 47          self.tokenizer = None
 48  
 49          # Term index
 50          self.terms = Terms(self.config["terms"], self.score, self.idf) if self.config.get("terms") else None
 51  
 52          # Document data
 53          self.documents = {} if self.config.get("content") else None
 54  
 55          # Normalize scores
 56          self.normalize = self.config.get("normalize")
 57          self.normalizer = Normalize(self.normalize) if self.normalize else None
 58          self.avgscore = None
 59  
 60      def insert(self, documents, index=None, checkpoint=None):
 61          # Insert documents, calculate word frequency, total tokens and total documents
 62          for uid, document, tags in documents:
 63              # Extract text, if necessary
 64              if isinstance(document, dict):
 65                  document = document.get(self.text, document.get(self.object))
 66  
 67              if document is not None:
 68                  # If index is passed, use indexid, otherwise use id
 69                  uid = index if index is not None else uid
 70  
 71                  # Add entry to index if the data type is accepted
 72                  if isinstance(document, (str, list)):
 73                      # Store content
 74                      if self.documents is not None:
 75                          self.documents[uid] = document
 76  
 77                      # Convert to tokens, if necessary
 78                      tokens = self.tokenize(document) if isinstance(document, str) else document
 79  
 80                      # Add tokens for id to term index
 81                      if self.terms is not None:
 82                          self.terms.insert(uid, tokens)
 83  
 84                      # Add tokens and tags to stats
 85                      self.addstats(tokens, tags)
 86  
 87                  # Increment index
 88                  index = index + 1 if index is not None else None
 89  
 90      def delete(self, ids):
 91          # Delete from terms index
 92          if self.terms:
 93              self.terms.delete(ids)
 94  
 95          # Delete content
 96          if self.documents:
 97              for uid in ids:
 98                  self.documents.pop(uid)
 99  
100      def index(self, documents=None):
101          # Call base method
102          super().index(documents)
103  
104          # Build index if tokens parsed
105          if self.wordfreq:
106              # Calculate total token frequency
107              self.tokens = sum(self.wordfreq.values())
108  
109              # Calculate average frequency per token
110              self.avgfreq = self.tokens / len(self.wordfreq.values())
111  
112              # Calculate average document length in tokens
113              self.avgdl = self.tokens / self.total
114  
115              # Compute IDF scores
116              idfs = self.computeidf(np.array(list(self.docfreq.values())))
117              for x, word in enumerate(self.docfreq):
118                  self.idf[word] = float(idfs[x])
119  
120              # Average IDF score per token
121              self.avgidf = float(np.mean(idfs))
122  
123              # Calculate average score across index
124              self.avgscore = self.score(self.avgfreq, self.avgidf, self.avgdl)
125  
126              # Filter for tags that appear in at least 1% of the documents
127              self.tags = Counter({tag: number for tag, number in self.tags.items() if number >= self.total * 0.005})
128  
129          # Index terms, if available
130          if self.terms:
131              self.terms.index()
132  
133      def weights(self, tokens):
134          # Document length
135          length = len(tokens)
136  
137          # Calculate token counts
138          freq = self.computefreq(tokens)
139          freq = np.array([freq[token] for token in tokens])
140  
141          # Get idf scores
142          idf = np.array([self.idf[token] if token in self.idf else self.avgidf for token in tokens])
143  
144          # Calculate score for each token, use as weight
145          weights = self.score(freq, idf, length).tolist()
146  
147          # Boost weights of tag tokens to match the largest weight in the list
148          if self.tags:
149              tags = {token: self.tags[token] for token in tokens if token in self.tags}
150              if tags:
151                  maxWeight = max(weights)
152                  maxTag = max(tags.values())
153  
154                  weights = [max(maxWeight * (tags[tokens[x]] / maxTag), weight) if tokens[x] in tags else weight for x, weight in enumerate(weights)]
155  
156          return weights
157  
158      def search(self, query, limit=3):
159          # Check if term index available
160          if self.terms:
161              # Escape query operators
162              query = self.terms.escape(query) if isinstance(query, str) else [self.terms.escape(q) for q in query]
163  
164              # Parse query into terms
165              query = self.tokenize(query) if isinstance(query, str) else query
166  
167              # Get topn term query matches
168              scores = self.terms.search(query, limit)
169  
170              # Normalize scores, if enabled
171              if self.normalizer and scores:
172                  scores = self.normalizer(scores, self.avgscore)
173  
174              # Add content, if available
175              return self.results(scores)
176  
177          return None
178  
179      def batchsearch(self, queries, limit=3, threads=True):
180          # Calculate number of threads using a thread per 25k records in index
181          threads = math.ceil(self.count() / 25000) if isinstance(threads, bool) and threads else int(threads)
182          threads = min(max(threads, 1), os.cpu_count())
183  
184          # This method is able to run as multiple threads due to a number of regex and numpy method calls that drop the GIL.
185          results = []
186          with ThreadPool(threads) as pool:
187              for result in pool.starmap(self.search, [(x, limit) for x in queries]):
188                  results.append(result)
189  
190          return results
191  
192      def count(self):
193          return self.terms.count() if self.terms else self.total
194  
195      def load(self, path):
196          # Load scoring
197          state = Serializer.load(path)
198  
199          # Convert to Counter instances
200          for key in ["docfreq", "wordfreq", "tags"]:
201              state[key] = Counter(state[key])
202  
203          # Convert documents to dict
204          state["documents"] = dict(state["documents"]) if state["documents"] else state["documents"]
205  
206          # Set parameters on this object
207          self.__dict__.update(state)
208  
209          # Recreate normalizer
210          self.normalizer = Normalize(self.normalize) if self.normalize else None
211  
212          # Load terms
213          if self.config.get("terms"):
214              self.terms = Terms(self.config["terms"], self.score, self.idf)
215              self.terms.load(path + ".terms")
216  
217      def save(self, path):
218          # Don't serialize following fields
219          skipfields = ("config", "terms", "tokenizer", "normalizer")
220  
221          # Get object state
222          state = {key: value for key, value in self.__dict__.items() if key not in skipfields}
223  
224          # Update documents to tuples
225          state["documents"] = list(state["documents"].items()) if state["documents"] else state["documents"]
226  
227          # Save scoring
228          Serializer.save(state, path)
229  
230          # Save terms
231          if self.terms:
232              self.terms.save(path + ".terms")
233  
234      def close(self):
235          if self.terms:
236              self.terms.close()
237  
238      def issparse(self):
239          return self.terms is not None
240  
241      def isnormalized(self):
242          return self.normalize
243  
244      def isbayes(self):
245          return self.normalizer is not None and self.normalizer.isbayes()
246  
247      def computefreq(self, tokens):
248          """
249          Computes token frequency. Used for token weighting.
250  
251          Args:
252              tokens: input tokens
253  
254          Returns:
255              {token: count}
256          """
257  
258          return Counter(tokens)
259  
260      def computeidf(self, freq):
261          """
262          Computes an idf score for word frequency.
263  
264          Args:
265              freq: word frequency
266  
267          Returns:
268              idf score
269          """
270  
271          return np.log((self.total + 1) / (freq + 1)) + 1
272  
273      # pylint: disable=W0613
274      def score(self, freq, idf, length):
275          """
276          Calculates a score for each token.
277  
278          Args:
279              freq: token frequency
280              idf: token idf score
281              length: total number of tokens in source document
282  
283          Returns:
284              token score
285          """
286  
287          return idf * np.sqrt(freq) * (1 / np.sqrt(length))
288  
289      def addstats(self, tokens, tags):
290          """
291          Add tokens and tags to stats.
292  
293          Args:
294              tokens: list of tokens
295              tags: list of tags
296          """
297  
298          # Total number of times token appears, count all tokens
299          self.wordfreq.update(tokens)
300  
301          # Total number of documents a token is in, count unique tokens
302          self.docfreq.update(set(tokens))
303  
304          # Get list of unique tags
305          if tags:
306              self.tags.update(tags.split())
307  
308          # Total document count
309          self.total += 1
310  
311      def tokenize(self, text):
312          """
313          Tokenizes text using default tokenizer.
314  
315          Args:
316              text: input text
317  
318          Returns:
319              tokens
320          """
321  
322          # Load tokenizer
323          if not self.tokenizer:
324              self.tokenizer = self.loadtokenizer()
325  
326          return self.tokenizer(text)
327  
328      def loadtokenizer(self):
329          """
330          Load default tokenizer.
331  
332          Returns:
333              tokenize method
334          """
335  
336          # Custom tokenizer settings
337          if self.config.get("tokenizer"):
338              return Tokenizer(**self.config.get("tokenizer"))
339  
340          # Terms index use a standard tokenizer
341          if self.config.get("terms"):
342              return Tokenizer()
343  
344          # Standard scoring index without a terms index uses backwards compatible static tokenize method
345          return Tokenizer.tokenize
346  
347      def results(self, scores):
348          """
349          Resolves a list of (id, score) with document content, if available. Otherwise, the original input is returned.
350  
351          Args:
352              scores: list of (id, score)
353  
354          Returns:
355              resolved results
356          """
357  
358          # Convert to Python values
359          scores = [(x, float(score)) for x, score in scores]
360  
361          if self.documents:
362              return [{"id": x, "text": self.documents[x], "score": score} for x, score in scores]
363  
364          return scores