tfidf.py
1 """ 2 TFIDF module 3 """ 4 5 import math 6 import os 7 8 from collections import Counter 9 from multiprocessing.pool import ThreadPool 10 11 import numpy as np 12 13 from ..pipeline import Tokenizer 14 from ..serialize import Serializer 15 16 from .base import Scoring 17 from .normalize import Normalize 18 from .terms import Terms 19 20 21 class TFIDF(Scoring): 22 """ 23 Term frequency-inverse document frequency (TF-IDF) scoring. 24 """ 25 26 def __init__(self, config=None): 27 super().__init__(config) 28 29 # Document stats 30 self.total = 0 31 self.tokens = 0 32 self.avgdl = 0 33 34 # Word frequency 35 self.docfreq = Counter() 36 self.wordfreq = Counter() 37 self.avgfreq = 0 38 39 # IDF index 40 self.idf = {} 41 self.avgidf = 0 42 43 # Tag boosting 44 self.tags = Counter() 45 46 # Tokenizer, lazily loaded as needed 47 self.tokenizer = None 48 49 # Term index 50 self.terms = Terms(self.config["terms"], self.score, self.idf) if self.config.get("terms") else None 51 52 # Document data 53 self.documents = {} if self.config.get("content") else None 54 55 # Normalize scores 56 self.normalize = self.config.get("normalize") 57 self.normalizer = Normalize(self.normalize) if self.normalize else None 58 self.avgscore = None 59 60 def insert(self, documents, index=None, checkpoint=None): 61 # Insert documents, calculate word frequency, total tokens and total documents 62 for uid, document, tags in documents: 63 # Extract text, if necessary 64 if isinstance(document, dict): 65 document = document.get(self.text, document.get(self.object)) 66 67 if document is not None: 68 # If index is passed, use indexid, otherwise use id 69 uid = index if index is not None else uid 70 71 # Add entry to index if the data type is accepted 72 if isinstance(document, (str, list)): 73 # Store content 74 if self.documents is not None: 75 self.documents[uid] = document 76 77 # Convert to tokens, if necessary 78 tokens = self.tokenize(document) if isinstance(document, str) else document 79 80 # Add tokens for id to term index 81 if self.terms is not None: 82 self.terms.insert(uid, tokens) 83 84 # Add tokens and tags to stats 85 self.addstats(tokens, tags) 86 87 # Increment index 88 index = index + 1 if index is not None else None 89 90 def delete(self, ids): 91 # Delete from terms index 92 if self.terms: 93 self.terms.delete(ids) 94 95 # Delete content 96 if self.documents: 97 for uid in ids: 98 self.documents.pop(uid) 99 100 def index(self, documents=None): 101 # Call base method 102 super().index(documents) 103 104 # Build index if tokens parsed 105 if self.wordfreq: 106 # Calculate total token frequency 107 self.tokens = sum(self.wordfreq.values()) 108 109 # Calculate average frequency per token 110 self.avgfreq = self.tokens / len(self.wordfreq.values()) 111 112 # Calculate average document length in tokens 113 self.avgdl = self.tokens / self.total 114 115 # Compute IDF scores 116 idfs = self.computeidf(np.array(list(self.docfreq.values()))) 117 for x, word in enumerate(self.docfreq): 118 self.idf[word] = float(idfs[x]) 119 120 # Average IDF score per token 121 self.avgidf = float(np.mean(idfs)) 122 123 # Calculate average score across index 124 self.avgscore = self.score(self.avgfreq, self.avgidf, self.avgdl) 125 126 # Filter for tags that appear in at least 1% of the documents 127 self.tags = Counter({tag: number for tag, number in self.tags.items() if number >= self.total * 0.005}) 128 129 # Index terms, if available 130 if self.terms: 131 self.terms.index() 132 133 def weights(self, tokens): 134 # Document length 135 length = len(tokens) 136 137 # Calculate token counts 138 freq = self.computefreq(tokens) 139 freq = np.array([freq[token] for token in tokens]) 140 141 # Get idf scores 142 idf = np.array([self.idf[token] if token in self.idf else self.avgidf for token in tokens]) 143 144 # Calculate score for each token, use as weight 145 weights = self.score(freq, idf, length).tolist() 146 147 # Boost weights of tag tokens to match the largest weight in the list 148 if self.tags: 149 tags = {token: self.tags[token] for token in tokens if token in self.tags} 150 if tags: 151 maxWeight = max(weights) 152 maxTag = max(tags.values()) 153 154 weights = [max(maxWeight * (tags[tokens[x]] / maxTag), weight) if tokens[x] in tags else weight for x, weight in enumerate(weights)] 155 156 return weights 157 158 def search(self, query, limit=3): 159 # Check if term index available 160 if self.terms: 161 # Escape query operators 162 query = self.terms.escape(query) if isinstance(query, str) else [self.terms.escape(q) for q in query] 163 164 # Parse query into terms 165 query = self.tokenize(query) if isinstance(query, str) else query 166 167 # Get topn term query matches 168 scores = self.terms.search(query, limit) 169 170 # Normalize scores, if enabled 171 if self.normalizer and scores: 172 scores = self.normalizer(scores, self.avgscore) 173 174 # Add content, if available 175 return self.results(scores) 176 177 return None 178 179 def batchsearch(self, queries, limit=3, threads=True): 180 # Calculate number of threads using a thread per 25k records in index 181 threads = math.ceil(self.count() / 25000) if isinstance(threads, bool) and threads else int(threads) 182 threads = min(max(threads, 1), os.cpu_count()) 183 184 # This method is able to run as multiple threads due to a number of regex and numpy method calls that drop the GIL. 185 results = [] 186 with ThreadPool(threads) as pool: 187 for result in pool.starmap(self.search, [(x, limit) for x in queries]): 188 results.append(result) 189 190 return results 191 192 def count(self): 193 return self.terms.count() if self.terms else self.total 194 195 def load(self, path): 196 # Load scoring 197 state = Serializer.load(path) 198 199 # Convert to Counter instances 200 for key in ["docfreq", "wordfreq", "tags"]: 201 state[key] = Counter(state[key]) 202 203 # Convert documents to dict 204 state["documents"] = dict(state["documents"]) if state["documents"] else state["documents"] 205 206 # Set parameters on this object 207 self.__dict__.update(state) 208 209 # Recreate normalizer 210 self.normalizer = Normalize(self.normalize) if self.normalize else None 211 212 # Load terms 213 if self.config.get("terms"): 214 self.terms = Terms(self.config["terms"], self.score, self.idf) 215 self.terms.load(path + ".terms") 216 217 def save(self, path): 218 # Don't serialize following fields 219 skipfields = ("config", "terms", "tokenizer", "normalizer") 220 221 # Get object state 222 state = {key: value for key, value in self.__dict__.items() if key not in skipfields} 223 224 # Update documents to tuples 225 state["documents"] = list(state["documents"].items()) if state["documents"] else state["documents"] 226 227 # Save scoring 228 Serializer.save(state, path) 229 230 # Save terms 231 if self.terms: 232 self.terms.save(path + ".terms") 233 234 def close(self): 235 if self.terms: 236 self.terms.close() 237 238 def issparse(self): 239 return self.terms is not None 240 241 def isnormalized(self): 242 return self.normalize 243 244 def isbayes(self): 245 return self.normalizer is not None and self.normalizer.isbayes() 246 247 def computefreq(self, tokens): 248 """ 249 Computes token frequency. Used for token weighting. 250 251 Args: 252 tokens: input tokens 253 254 Returns: 255 {token: count} 256 """ 257 258 return Counter(tokens) 259 260 def computeidf(self, freq): 261 """ 262 Computes an idf score for word frequency. 263 264 Args: 265 freq: word frequency 266 267 Returns: 268 idf score 269 """ 270 271 return np.log((self.total + 1) / (freq + 1)) + 1 272 273 # pylint: disable=W0613 274 def score(self, freq, idf, length): 275 """ 276 Calculates a score for each token. 277 278 Args: 279 freq: token frequency 280 idf: token idf score 281 length: total number of tokens in source document 282 283 Returns: 284 token score 285 """ 286 287 return idf * np.sqrt(freq) * (1 / np.sqrt(length)) 288 289 def addstats(self, tokens, tags): 290 """ 291 Add tokens and tags to stats. 292 293 Args: 294 tokens: list of tokens 295 tags: list of tags 296 """ 297 298 # Total number of times token appears, count all tokens 299 self.wordfreq.update(tokens) 300 301 # Total number of documents a token is in, count unique tokens 302 self.docfreq.update(set(tokens)) 303 304 # Get list of unique tags 305 if tags: 306 self.tags.update(tags.split()) 307 308 # Total document count 309 self.total += 1 310 311 def tokenize(self, text): 312 """ 313 Tokenizes text using default tokenizer. 314 315 Args: 316 text: input text 317 318 Returns: 319 tokens 320 """ 321 322 # Load tokenizer 323 if not self.tokenizer: 324 self.tokenizer = self.loadtokenizer() 325 326 return self.tokenizer(text) 327 328 def loadtokenizer(self): 329 """ 330 Load default tokenizer. 331 332 Returns: 333 tokenize method 334 """ 335 336 # Custom tokenizer settings 337 if self.config.get("tokenizer"): 338 return Tokenizer(**self.config.get("tokenizer")) 339 340 # Terms index use a standard tokenizer 341 if self.config.get("terms"): 342 return Tokenizer() 343 344 # Standard scoring index without a terms index uses backwards compatible static tokenize method 345 return Tokenizer.tokenize 346 347 def results(self, scores): 348 """ 349 Resolves a list of (id, score) with document content, if available. Otherwise, the original input is returned. 350 351 Args: 352 scores: list of (id, score) 353 354 Returns: 355 resolved results 356 """ 357 358 # Convert to Python values 359 scores = [(x, float(score)) for x, score in scores] 360 361 if self.documents: 362 return [{"id": x, "text": self.documents[x], "score": score} for x, score in scores] 363 364 return scores