benchmarks.py
1 """ 2 Runs benchmark evaluations with the BEIR dataset. 3 4 Install txtai and the following dependencies to run: 5 pip install txtai pytrec_eval rank-bm25 bm25s elasticsearch psutil 6 """ 7 8 import argparse 9 import csv 10 import json 11 import os 12 import pickle 13 import sqlite3 14 import time 15 16 import psutil 17 import yaml 18 19 import numpy as np 20 21 from bm25s import BM25 as BM25Sparse 22 from elasticsearch import Elasticsearch 23 from elasticsearch.helpers import bulk 24 from pytrec_eval import RelevanceEvaluator 25 from rank_bm25 import BM25Okapi 26 from tqdm.auto import tqdm 27 28 from txtai.embeddings import Embeddings 29 from txtai.pipeline import LLM, RAG, Similarity, Tokenizer 30 from txtai.scoring import ScoringFactory 31 32 33 class Index: 34 """ 35 Base index definition. Defines methods to index and search a dataset. 36 """ 37 38 def __init__(self, path, config, output, refresh): 39 """ 40 Creates a new index. 41 42 Args: 43 path: path to dataset 44 config: path to config file 45 output: path to store index 46 refresh: overwrites existing index if True, otherwise existing index is loaded 47 """ 48 49 self.path = path 50 self.config = config 51 self.output = output 52 self.refresh = refresh 53 54 # Build and save index 55 self.backend = self.index() 56 57 def __call__(self, limit, filterscores=True): 58 """ 59 Main evaluation logic. Loads an index, runs the dataset queries and returns the results. 60 61 Args: 62 limit: maximum results 63 filterscores: if exact matches should be filtered out 64 65 Returns: 66 search results 67 """ 68 69 uids, queries = self.load() 70 71 # Run queries in batches 72 offset, results = 0, {} 73 for batch in self.batch(queries, 256): 74 for i, r in enumerate(self.search(batch, limit + 1)): 75 # Get result as list of (id, score) tuples 76 r = list(r) 77 r = [(x["id"], x["score"]) for x in r] if r and isinstance(r[0], dict) else r 78 79 if filterscores: 80 r = [(uid, score) for uid, score in r if uid != uids[offset + i]][:limit] 81 82 results[uids[offset + i]] = dict(r) 83 84 # Increment offset 85 offset += len(batch) 86 87 return results 88 89 def search(self, queries, limit): 90 """ 91 Runs a search for a set of queries. 92 93 Args: 94 queries: list of queries to run 95 limit: maximum results 96 97 Returns: 98 search results 99 """ 100 101 return self.backend.batchsearch(queries, limit) 102 103 def index(self): 104 """ 105 Indexes a dataset. 106 """ 107 108 raise NotImplementedError 109 110 def rows(self): 111 """ 112 Iterates over the dataset yielding a row at a time for indexing. 113 """ 114 115 # Data file 116 path = f"{self.path}/corpus.jsonl" 117 118 # Get total count 119 with open(path, encoding="utf-8") as f: 120 total = sum(1 for _ in f) 121 122 # Yield data 123 with open(path, encoding="utf-8") as f: 124 for line in tqdm(f, total=total): 125 row = json.loads(line) 126 text = f'{row["title"]}. {row["text"]}' if row.get("title") else row["text"] 127 if text: 128 yield (row["_id"], text, None) 129 130 def load(self): 131 """ 132 Loads queries for the dataset. Returns a list of expected result ids and input queries. 133 134 Returns: 135 (result ids, input queries) 136 """ 137 138 with open(f"{self.path}/queries.jsonl", encoding="utf-8") as f: 139 data = [json.loads(query) for query in f] 140 uids, queries = [x["_id"] for x in data], [x["text"] for x in data] 141 142 return uids, queries 143 144 def batch(self, data, size): 145 """ 146 Splits data into equal sized batches. 147 148 Args: 149 data: input data 150 size: batch size 151 152 Returns: 153 data split into equal size batches 154 """ 155 156 return [data[x : x + size] for x in range(0, len(data), size)] 157 158 def readconfig(self, key, default): 159 """ 160 Reads configuration from a config file. Returns default configuration 161 if config file is not found or config key isn't present. 162 163 Args: 164 key: configuration key to lookup 165 default: default configuration 166 167 Returns: 168 config if found, otherwise returns default config 169 """ 170 171 if self.config and os.path.exists(self.config): 172 # Read configuration 173 with open(self.config, "r", encoding="utf-8") as f: 174 # Check for config 175 config = yaml.safe_load(f) 176 if key in config: 177 return config[key] 178 179 return default 180 181 182 class Embed(Index): 183 """ 184 Embeddings index using txtai. 185 """ 186 187 def index(self): 188 if os.path.exists(self.output) and not self.refresh: 189 embeddings = Embeddings() 190 embeddings.load(self.output) 191 else: 192 # Read configuration 193 config = self.readconfig("embeddings", {"batch": 8192, "encodebatch": 128, "faiss": {"quantize": True, "sample": 0.05}}) 194 195 # Build index 196 embeddings = Embeddings(config) 197 embeddings.index(self.rows()) 198 embeddings.save(self.output) 199 200 return embeddings 201 202 203 class Hybrid(Index): 204 """ 205 Hybrid embeddings + BM25 index using txtai. 206 """ 207 208 def index(self): 209 if os.path.exists(self.output) and not self.refresh: 210 embeddings = Embeddings() 211 embeddings.load(self.output) 212 else: 213 # Read configuration 214 config = self.readconfig( 215 "hybrid", 216 { 217 "batch": 8192, 218 "encodebatch": 128, 219 "faiss": {"quantize": True, "sample": 0.05}, 220 "scoring": {"method": "bm25", "terms": True, "normalize": True}, 221 }, 222 ) 223 224 # Build index 225 embeddings = Embeddings(config) 226 embeddings.index(self.rows()) 227 embeddings.save(self.output) 228 229 return embeddings 230 231 232 class RetrievalAugmentedGeneration(Embed): 233 """ 234 Retrieval augmented generation (RAG) using txtai. 235 """ 236 237 def __init__(self, path, config, output, refresh): 238 # Parent logic 239 super().__init__(path, config, output, refresh) 240 241 # Read LLM configuration 242 llm = self.readconfig("llm", {}) 243 244 # Read RAG configuration 245 rag = self.readconfig("rag", {}) 246 247 # Load RAG pipeline 248 self.rag = RAG(self.backend, LLM(**llm), output="reference", **rag) 249 250 def search(self, queries, limit): 251 # Set context window size to limit and run 252 self.rag.context = limit 253 return [[(x["reference"], 1)] for x in self.rag(queries, maxlength=4096)] 254 255 256 class Score(Index): 257 """ 258 BM25 index using txtai. 259 """ 260 261 def index(self): 262 # Read configuration 263 config = self.readconfig("scoring", {"method": "bm25", "terms": True}) 264 265 # Create scoring instance 266 scoring = ScoringFactory.create(config) 267 268 output = os.path.join(self.output, "scoring") 269 if os.path.exists(output) and not self.refresh: 270 scoring.load(output) 271 else: 272 scoring.index(self.rows()) 273 scoring.save(output) 274 275 return scoring 276 277 278 class Similar(Index): 279 """ 280 Search data using a similarity pipeline. 281 """ 282 283 def index(self): 284 # Load similarity pipeline 285 model = Similarity(**self.readconfig("similar", {})) 286 287 # Get datasets 288 data = list(self.rows()) 289 ids = [x[0] for x in data] 290 texts = [x[1] for x in data] 291 292 # Encode texts 293 data = model.encode(texts, "data") 294 295 return (ids, data, model) 296 297 def search(self, queries, limit): 298 # Unpack backend 299 ids, data, model = self.backend 300 301 # Run model inference 302 results = [] 303 for result in model(queries, data, limit=limit): 304 results.append([(ids[x], score) for x, score in result]) 305 306 return results 307 308 309 class Rerank(Embed): 310 """ 311 Embeddings index using txtai combined with a similarity pipeline 312 """ 313 314 def index(self): 315 # Build embeddings index 316 embeddings = super().index() 317 318 # Combine similar pipeline with embeddings 319 model = Similar(self.path, self.config, self.output, self.refresh) 320 return model.index() + (embeddings,) 321 322 def search(self, queries, limit): 323 # Unpack backend 324 ids, data, model, embeddings = self.backend 325 326 # Run initial query 327 indices = [] 328 for r in embeddings.batchsearch(queries, limit * 10): 329 indices.append({x: ids.index(uid) for x, (uid, _) in enumerate(r)}) 330 331 # Run model inference 332 results = [] 333 for x, query in enumerate(queries): 334 queue = data[list(indices[x].values())] 335 if len(queue) > 0: 336 result = model(query, queue, limit=limit) 337 results.append([(ids[indices[x][i]], score) for i, score in result]) 338 339 return results 340 341 342 class RankBM25(Index): 343 """ 344 BM25 index using rank-bm25. 345 """ 346 347 def search(self, queries, limit): 348 ids, backend = self.backend 349 tokenizer, results = Tokenizer(), [] 350 for query in queries: 351 scores = backend.get_scores(tokenizer(query)) 352 topn = np.argsort(scores)[::-1][:limit] 353 results.append([(ids[x], scores[x]) for x in topn]) 354 355 return results 356 357 def index(self): 358 output = os.path.join(self.output, "rank") 359 if os.path.exists(output) and not self.refresh: 360 with open(output, "rb") as f: 361 ids, model = pickle.load(f) 362 else: 363 # Tokenize data 364 tokenizer, data = Tokenizer(), [] 365 for uid, text, _ in self.rows(): 366 data.append((uid, tokenizer(text))) 367 368 ids = [uid for uid, _ in data] 369 model = BM25Okapi([text for _, text in data]) 370 371 # Save model 372 with open(output, "wb") as out: 373 pickle.dump(model, out) 374 375 return ids, model 376 377 378 class BM25S(Index): 379 """ 380 BM25 as implemented by bm25s 381 """ 382 383 def __init__(self, path, config, output, refresh): 384 # Corpus ids 385 self.ids = None 386 387 # Parent logic 388 super().__init__(path, config, output, refresh) 389 390 def search(self, queries, limit): 391 tokenizer = Tokenizer() 392 results, scores = self.backend.retrieve([tokenizer(x) for x in queries], corpus=self.ids, k=limit) 393 394 # List of queries => list of matches (id, score) 395 x = [] 396 for a, b in zip(results, scores): 397 x.append([(str(c), float(d)) for c, d in zip(a, b)]) 398 399 return x 400 401 def index(self): 402 tokenizer = Tokenizer() 403 ids, texts = [], [] 404 405 for uid, text, _ in self.rows(): 406 ids.append(uid) 407 texts.append(text) 408 409 self.ids = ids 410 411 if os.path.exists(self.output) and not self.refresh: 412 model = BM25Sparse.load(self.output) 413 else: 414 model = BM25Sparse(method="lucene", k1=1.2, b=0.75) 415 model.index([tokenizer(x) for x in texts], leave_progress=False) 416 model.save(self.output) 417 418 return model 419 420 421 class SQLiteFTS(Index): 422 """ 423 BM25 index using SQLite's FTS extension. 424 """ 425 426 def search(self, queries, limit): 427 tokenizer, results = Tokenizer(), [] 428 for query in queries: 429 query = tokenizer(query) 430 query = " OR ".join([f'"{q}"' for q in query]) 431 432 self.backend.execute( 433 f"SELECT id, bm25(textindex) * -1 score FROM textindex WHERE text MATCH ? ORDER BY bm25(textindex) LIMIT {limit}", [query] 434 ) 435 436 results.append(list(self.backend)) 437 438 return results 439 440 def index(self): 441 if os.path.exists(self.output) and not self.refresh: 442 # Load existing database 443 connection = sqlite3.connect(self.output) 444 else: 445 # Delete existing database 446 if os.path.exists(self.output): 447 os.remove(self.output) 448 449 # Create new database 450 connection = sqlite3.connect(self.output) 451 452 # Tokenize data 453 tokenizer, data = Tokenizer(), [] 454 for uid, text, _ in self.rows(): 455 data.append((uid, " ".join(tokenizer(text)))) 456 457 # Create table 458 connection.execute("CREATE VIRTUAL TABLE textindex using fts5(id, text)") 459 460 # Load data and build index 461 connection.executemany("INSERT INTO textindex VALUES (?, ?)", data) 462 463 connection.commit() 464 465 return connection.cursor() 466 467 468 class Elastic(Index): 469 """ 470 BM25 index using Elasticsearch. 471 """ 472 473 def search(self, queries, limit): 474 # Generate bulk queries 475 request = [] 476 for query in queries: 477 req_head = {"index": "textindex", "search_type": "dfs_query_then_fetch"} 478 req_body = { 479 "_source": False, 480 "query": {"multi_match": {"query": query, "type": "best_fields", "fields": ["text"], "tie_breaker": 0.5}}, 481 "size": limit, 482 } 483 request.extend([req_head, req_body]) 484 485 # Run ES query 486 response = self.backend.msearch(body=request, request_timeout=600) 487 488 # Read responses 489 results = [] 490 for resp in response["responses"]: 491 result = resp["hits"]["hits"] 492 results.append([(r["_id"], r["_score"]) for r in result]) 493 494 return results 495 496 def index(self): 497 es = Elasticsearch("http://localhost:9200") 498 499 # Delete existing index 500 # pylint: disable=W0702 501 try: 502 es.indices.delete(index="textindex") 503 except: 504 pass 505 506 bulk(es, ({"_index": "textindex", "_id": uid, "text": text} for uid, text, _ in self.rows())) 507 es.indices.refresh(index="textindex") 508 509 return es 510 511 512 def relevance(path): 513 """ 514 Loads relevance data for evaluation. 515 516 Args: 517 path: path to dataset test file 518 519 Returns: 520 relevance data 521 """ 522 523 rel = {} 524 with open(f"{path}/qrels/test.tsv", encoding="utf-8") as f: 525 reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_MINIMAL) 526 next(reader) 527 528 for row in reader: 529 queryid, corpusid, score = row[0], row[1], int(row[2]) 530 if queryid not in rel: 531 rel[queryid] = {corpusid: score} 532 else: 533 rel[queryid][corpusid] = score 534 535 return rel 536 537 538 def create(method, path, config, output, refresh): 539 """ 540 Creates a new index. 541 542 Args: 543 method: indexing method 544 path: path to dataset 545 config: path to config file 546 output: path to store index 547 refresh: overwrites existing index if True, otherwise existing index is loaded 548 549 Returns: 550 Index 551 """ 552 553 if method == "hybrid": 554 return Hybrid(path, config, output, refresh) 555 if method == "rag": 556 return RetrievalAugmentedGeneration(path, config, output, refresh) 557 if method == "scoring": 558 return Score(path, config, output, refresh) 559 if method == "rank": 560 return RankBM25(path, config, output, refresh) 561 if method == "bm25s": 562 return BM25S(path, config, output, refresh) 563 if method == "sqlite": 564 return SQLiteFTS(path, config, output, refresh) 565 if method == "es": 566 return Elastic(path, config, output, refresh) 567 if method == "similar": 568 return Similar(path, config, output, refresh) 569 if method == "rerank": 570 return Rerank(path, config, output, refresh) 571 572 # Default 573 return Embed(path, config, output, refresh) 574 575 576 def compute(results): 577 """ 578 Computes metrics using the results from an evaluation run. 579 580 Args: 581 results: evaluation results 582 583 Returns: 584 metrics 585 """ 586 587 metrics = {} 588 for r in results: 589 for metric in results[r]: 590 if metric not in metrics: 591 metrics[metric] = [] 592 593 metrics[metric].append(results[r][metric]) 594 595 return {metric: round(np.mean(values), 5) for metric, values in metrics.items()} 596 597 598 def evaluate(methods, path, args): 599 """ 600 Runs an evaluation. 601 602 Args: 603 methods: list of indexing methods to test 604 path: path to dataset 605 args: command line arguments 606 607 Returns: 608 {calculated performance metrics} 609 """ 610 611 print(f"------ {os.path.basename(path)} ------") 612 613 # Performance stats 614 performance = {} 615 616 # Calculate stats for each model type 617 topk = args.topk 618 evaluator = RelevanceEvaluator(relevance(path), {f"ndcg_cut.{topk}", f"map_cut.{topk}", f"recall.{topk}", f"P.{topk}"}) 619 for method in methods: 620 # Stats for this source 621 stats = {} 622 performance[method] = stats 623 624 # Create index and get results 625 start = time.time() 626 output = args.output if args.output else f"{path}/{method}" 627 index = create(method, path, args.config, output, args.refresh) 628 629 # Add indexing metrics 630 stats["index"] = round(time.time() - start, 2) 631 stats["memory"] = int(psutil.Process().memory_info().rss / (1024 * 1024)) 632 stats["disk"] = int(sum(d.stat().st_size for d in os.scandir(output) if d.is_file()) / 1024) if os.path.isdir(output) else 0 633 634 print("INDEX TIME =", time.time() - start) 635 print(f"MEMORY USAGE = {stats['memory']} MB") 636 print(f"DISK USAGE = {stats['disk']} KB") 637 638 start = time.time() 639 results = index(topk) 640 641 # Add search metrics 642 stats["search"] = round(time.time() - start, 2) 643 print("SEARCH TIME =", time.time() - start) 644 645 # Calculate stats 646 metrics = compute(evaluator.evaluate(results)) 647 648 # Add accuracy metrics 649 for stat in [f"ndcg_cut_{topk}", f"map_cut_{topk}", f"recall_{topk}", f"P_{topk}"]: 650 stats[stat] = metrics[stat] 651 652 # Print model stats 653 print(f"------ {method} ------") 654 print(f"NDCG@{topk} =", metrics[f"ndcg_cut_{topk}"]) 655 print(f"MAP@{topk} =", metrics[f"map_cut_{topk}"]) 656 print(f"Recall@{topk} =", metrics[f"recall_{topk}"]) 657 print(f"P@{topk} =", metrics[f"P_{topk}"]) 658 659 print() 660 return performance 661 662 663 def benchmarks(args): 664 """ 665 Main benchmark execution method. 666 667 Args: 668 args: command line arguments 669 """ 670 671 # Directory where BEIR datasets are stored 672 directory = args.directory if args.directory else "beir" 673 674 if args.sources and args.methods: 675 sources, methods = args.sources.split(","), args.methods.split(",") 676 mode = "a" 677 else: 678 # Default sources and methods 679 sources = [ 680 "trec-covid", 681 "nfcorpus", 682 "nq", 683 "hotpotqa", 684 "fiqa", 685 "arguana", 686 "webis-touche2020", 687 "quora", 688 "dbpedia-entity", 689 "scidocs", 690 "fever", 691 "climate-fever", 692 "scifact", 693 ] 694 methods = ["embed", "hybrid", "rag", "scoring", "rank", "bm25s", "sqlite", "es", "similar", "rerank"] 695 mode = "w" 696 697 # Run and save benchmarks 698 with open("benchmarks.json", mode, encoding="utf-8") as f: 699 for source in sources: 700 # Run evaluations 701 results = evaluate(methods, f"{directory}/{source}", args) 702 703 # Save as JSON lines output 704 for method, stats in results.items(): 705 stats["source"] = source 706 stats["method"] = method 707 stats["name"] = args.name if args.name else method 708 709 json.dump(stats, f) 710 f.write("\n") 711 712 713 if __name__ == "__main__": 714 # Command line parser 715 parser = argparse.ArgumentParser(description="Benchmarks") 716 parser.add_argument("-c", "--config", help="path to config file", metavar="CONFIG") 717 parser.add_argument("-d", "--directory", help="root directory path with datasets", metavar="DIRECTORY") 718 parser.add_argument("-m", "--methods", help="comma separated list of methods", metavar="METHODS") 719 parser.add_argument("-n", "--name", help="name to assign to this run, defaults to method name", metavar="NAME") 720 parser.add_argument("-o", "--output", help="index output directory path", metavar="OUTPUT") 721 parser.add_argument( 722 "-r", 723 "--refresh", 724 help="refreshes index if set, otherwise uses existing index if available", 725 action="store_true", 726 ) 727 parser.add_argument("-s", "--sources", help="comma separated list of data sources", metavar="SOURCES") 728 parser.add_argument("-t", "--topk", help="top k results to use for the evaluation", metavar="TOPK", type=int, default=10) 729 730 # Calculate benchmarks 731 benchmarks(parser.parse_args())