/ examples / benchmarks.py
benchmarks.py
  1  """
  2  Runs benchmark evaluations with the BEIR dataset.
  3  
  4  Install txtai and the following dependencies to run:
  5      pip install txtai pytrec_eval rank-bm25 bm25s elasticsearch psutil
  6  """
  7  
  8  import argparse
  9  import csv
 10  import json
 11  import os
 12  import pickle
 13  import sqlite3
 14  import time
 15  
 16  import psutil
 17  import yaml
 18  
 19  import numpy as np
 20  
 21  from bm25s import BM25 as BM25Sparse
 22  from elasticsearch import Elasticsearch
 23  from elasticsearch.helpers import bulk
 24  from pytrec_eval import RelevanceEvaluator
 25  from rank_bm25 import BM25Okapi
 26  from tqdm.auto import tqdm
 27  
 28  from txtai.embeddings import Embeddings
 29  from txtai.pipeline import LLM, RAG, Similarity, Tokenizer
 30  from txtai.scoring import ScoringFactory
 31  
 32  
 33  class Index:
 34      """
 35      Base index definition. Defines methods to index and search a dataset.
 36      """
 37  
 38      def __init__(self, path, config, output, refresh):
 39          """
 40          Creates a new index.
 41  
 42          Args:
 43              path: path to dataset
 44              config: path to config file
 45              output: path to store index
 46              refresh: overwrites existing index if True, otherwise existing index is loaded
 47          """
 48  
 49          self.path = path
 50          self.config = config
 51          self.output = output
 52          self.refresh = refresh
 53  
 54          # Build and save index
 55          self.backend = self.index()
 56  
 57      def __call__(self, limit, filterscores=True):
 58          """
 59          Main evaluation logic. Loads an index, runs the dataset queries and returns the results.
 60  
 61          Args:
 62              limit: maximum results
 63              filterscores: if exact matches should be filtered out
 64  
 65          Returns:
 66              search results
 67          """
 68  
 69          uids, queries = self.load()
 70  
 71          # Run queries in batches
 72          offset, results = 0, {}
 73          for batch in self.batch(queries, 256):
 74              for i, r in enumerate(self.search(batch, limit + 1)):
 75                  # Get result as list of (id, score) tuples
 76                  r = list(r)
 77                  r = [(x["id"], x["score"]) for x in r] if r and isinstance(r[0], dict) else r
 78  
 79                  if filterscores:
 80                      r = [(uid, score) for uid, score in r if uid != uids[offset + i]][:limit]
 81  
 82                  results[uids[offset + i]] = dict(r)
 83  
 84              # Increment offset
 85              offset += len(batch)
 86  
 87          return results
 88  
 89      def search(self, queries, limit):
 90          """
 91          Runs a search for a set of queries.
 92  
 93          Args:
 94              queries: list of queries to run
 95              limit: maximum results
 96  
 97          Returns:
 98              search results
 99          """
100  
101          return self.backend.batchsearch(queries, limit)
102  
103      def index(self):
104          """
105          Indexes a dataset.
106          """
107  
108          raise NotImplementedError
109  
110      def rows(self):
111          """
112          Iterates over the dataset yielding a row at a time for indexing.
113          """
114  
115          # Data file
116          path = f"{self.path}/corpus.jsonl"
117  
118          # Get total count
119          with open(path, encoding="utf-8") as f:
120              total = sum(1 for _ in f)
121  
122          # Yield data
123          with open(path, encoding="utf-8") as f:
124              for line in tqdm(f, total=total):
125                  row = json.loads(line)
126                  text = f'{row["title"]}. {row["text"]}' if row.get("title") else row["text"]
127                  if text:
128                      yield (row["_id"], text, None)
129  
130      def load(self):
131          """
132          Loads queries for the dataset. Returns a list of expected result ids and input queries.
133  
134          Returns:
135              (result ids, input queries)
136          """
137  
138          with open(f"{self.path}/queries.jsonl", encoding="utf-8") as f:
139              data = [json.loads(query) for query in f]
140              uids, queries = [x["_id"] for x in data], [x["text"] for x in data]
141  
142          return uids, queries
143  
144      def batch(self, data, size):
145          """
146          Splits data into equal sized batches.
147  
148          Args:
149              data: input data
150              size: batch size
151  
152          Returns:
153              data split into equal size batches
154          """
155  
156          return [data[x : x + size] for x in range(0, len(data), size)]
157  
158      def readconfig(self, key, default):
159          """
160          Reads configuration from a config file. Returns default configuration
161          if config file is not found or config key isn't present.
162  
163          Args:
164              key: configuration key to lookup
165              default: default configuration
166  
167          Returns:
168              config if found, otherwise returns default config
169          """
170  
171          if self.config and os.path.exists(self.config):
172              # Read configuration
173              with open(self.config, "r", encoding="utf-8") as f:
174                  # Check for config
175                  config = yaml.safe_load(f)
176                  if key in config:
177                      return config[key]
178  
179          return default
180  
181  
182  class Embed(Index):
183      """
184      Embeddings index using txtai.
185      """
186  
187      def index(self):
188          if os.path.exists(self.output) and not self.refresh:
189              embeddings = Embeddings()
190              embeddings.load(self.output)
191          else:
192              # Read configuration
193              config = self.readconfig("embeddings", {"batch": 8192, "encodebatch": 128, "faiss": {"quantize": True, "sample": 0.05}})
194  
195              # Build index
196              embeddings = Embeddings(config)
197              embeddings.index(self.rows())
198              embeddings.save(self.output)
199  
200          return embeddings
201  
202  
203  class Hybrid(Index):
204      """
205      Hybrid embeddings + BM25 index using txtai.
206      """
207  
208      def index(self):
209          if os.path.exists(self.output) and not self.refresh:
210              embeddings = Embeddings()
211              embeddings.load(self.output)
212          else:
213              # Read configuration
214              config = self.readconfig(
215                  "hybrid",
216                  {
217                      "batch": 8192,
218                      "encodebatch": 128,
219                      "faiss": {"quantize": True, "sample": 0.05},
220                      "scoring": {"method": "bm25", "terms": True, "normalize": True},
221                  },
222              )
223  
224              # Build index
225              embeddings = Embeddings(config)
226              embeddings.index(self.rows())
227              embeddings.save(self.output)
228  
229          return embeddings
230  
231  
232  class RetrievalAugmentedGeneration(Embed):
233      """
234      Retrieval augmented generation (RAG) using txtai.
235      """
236  
237      def __init__(self, path, config, output, refresh):
238          # Parent logic
239          super().__init__(path, config, output, refresh)
240  
241          # Read LLM configuration
242          llm = self.readconfig("llm", {})
243  
244          # Read RAG configuration
245          rag = self.readconfig("rag", {})
246  
247          # Load RAG pipeline
248          self.rag = RAG(self.backend, LLM(**llm), output="reference", **rag)
249  
250      def search(self, queries, limit):
251          # Set context window size to limit and run
252          self.rag.context = limit
253          return [[(x["reference"], 1)] for x in self.rag(queries, maxlength=4096)]
254  
255  
256  class Score(Index):
257      """
258      BM25 index using txtai.
259      """
260  
261      def index(self):
262          # Read configuration
263          config = self.readconfig("scoring", {"method": "bm25", "terms": True})
264  
265          # Create scoring instance
266          scoring = ScoringFactory.create(config)
267  
268          output = os.path.join(self.output, "scoring")
269          if os.path.exists(output) and not self.refresh:
270              scoring.load(output)
271          else:
272              scoring.index(self.rows())
273              scoring.save(output)
274  
275          return scoring
276  
277  
278  class Similar(Index):
279      """
280      Search data using a similarity pipeline.
281      """
282  
283      def index(self):
284          # Load similarity pipeline
285          model = Similarity(**self.readconfig("similar", {}))
286  
287          # Get datasets
288          data = list(self.rows())
289          ids = [x[0] for x in data]
290          texts = [x[1] for x in data]
291  
292          # Encode texts
293          data = model.encode(texts, "data")
294  
295          return (ids, data, model)
296  
297      def search(self, queries, limit):
298          # Unpack backend
299          ids, data, model = self.backend
300  
301          # Run model inference
302          results = []
303          for result in model(queries, data, limit=limit):
304              results.append([(ids[x], score) for x, score in result])
305  
306          return results
307  
308  
309  class Rerank(Embed):
310      """
311      Embeddings index using txtai combined with a similarity pipeline
312      """
313  
314      def index(self):
315          # Build embeddings index
316          embeddings = super().index()
317  
318          # Combine similar pipeline with embeddings
319          model = Similar(self.path, self.config, self.output, self.refresh)
320          return model.index() + (embeddings,)
321  
322      def search(self, queries, limit):
323          # Unpack backend
324          ids, data, model, embeddings = self.backend
325  
326          # Run initial query
327          indices = []
328          for r in embeddings.batchsearch(queries, limit * 10):
329              indices.append({x: ids.index(uid) for x, (uid, _) in enumerate(r)})
330  
331          # Run model inference
332          results = []
333          for x, query in enumerate(queries):
334              queue = data[list(indices[x].values())]
335              if len(queue) > 0:
336                  result = model(query, queue, limit=limit)
337                  results.append([(ids[indices[x][i]], score) for i, score in result])
338  
339          return results
340  
341  
342  class RankBM25(Index):
343      """
344      BM25 index using rank-bm25.
345      """
346  
347      def search(self, queries, limit):
348          ids, backend = self.backend
349          tokenizer, results = Tokenizer(), []
350          for query in queries:
351              scores = backend.get_scores(tokenizer(query))
352              topn = np.argsort(scores)[::-1][:limit]
353              results.append([(ids[x], scores[x]) for x in topn])
354  
355          return results
356  
357      def index(self):
358          output = os.path.join(self.output, "rank")
359          if os.path.exists(output) and not self.refresh:
360              with open(output, "rb") as f:
361                  ids, model = pickle.load(f)
362          else:
363              # Tokenize data
364              tokenizer, data = Tokenizer(), []
365              for uid, text, _ in self.rows():
366                  data.append((uid, tokenizer(text)))
367  
368              ids = [uid for uid, _ in data]
369              model = BM25Okapi([text for _, text in data])
370  
371              # Save model
372              with open(output, "wb") as out:
373                  pickle.dump(model, out)
374  
375          return ids, model
376  
377  
378  class BM25S(Index):
379      """
380      BM25 as implemented by bm25s
381      """
382  
383      def __init__(self, path, config, output, refresh):
384          # Corpus ids
385          self.ids = None
386  
387          # Parent logic
388          super().__init__(path, config, output, refresh)
389  
390      def search(self, queries, limit):
391          tokenizer = Tokenizer()
392          results, scores = self.backend.retrieve([tokenizer(x) for x in queries], corpus=self.ids, k=limit)
393  
394          # List of queries => list of matches (id, score)
395          x = []
396          for a, b in zip(results, scores):
397              x.append([(str(c), float(d)) for c, d in zip(a, b)])
398  
399          return x
400  
401      def index(self):
402          tokenizer = Tokenizer()
403          ids, texts = [], []
404  
405          for uid, text, _ in self.rows():
406              ids.append(uid)
407              texts.append(text)
408  
409          self.ids = ids
410  
411          if os.path.exists(self.output) and not self.refresh:
412              model = BM25Sparse.load(self.output)
413          else:
414              model = BM25Sparse(method="lucene", k1=1.2, b=0.75)
415              model.index([tokenizer(x) for x in texts], leave_progress=False)
416              model.save(self.output)
417  
418          return model
419  
420  
421  class SQLiteFTS(Index):
422      """
423      BM25 index using SQLite's FTS extension.
424      """
425  
426      def search(self, queries, limit):
427          tokenizer, results = Tokenizer(), []
428          for query in queries:
429              query = tokenizer(query)
430              query = " OR ".join([f'"{q}"' for q in query])
431  
432              self.backend.execute(
433                  f"SELECT id, bm25(textindex) * -1 score FROM textindex WHERE text MATCH ? ORDER BY bm25(textindex) LIMIT {limit}", [query]
434              )
435  
436              results.append(list(self.backend))
437  
438          return results
439  
440      def index(self):
441          if os.path.exists(self.output) and not self.refresh:
442              # Load existing database
443              connection = sqlite3.connect(self.output)
444          else:
445              # Delete existing database
446              if os.path.exists(self.output):
447                  os.remove(self.output)
448  
449              # Create new database
450              connection = sqlite3.connect(self.output)
451  
452              # Tokenize data
453              tokenizer, data = Tokenizer(), []
454              for uid, text, _ in self.rows():
455                  data.append((uid, " ".join(tokenizer(text))))
456  
457              # Create table
458              connection.execute("CREATE VIRTUAL TABLE textindex using fts5(id, text)")
459  
460              # Load data and build index
461              connection.executemany("INSERT INTO textindex VALUES (?, ?)", data)
462  
463              connection.commit()
464  
465          return connection.cursor()
466  
467  
468  class Elastic(Index):
469      """
470      BM25 index using Elasticsearch.
471      """
472  
473      def search(self, queries, limit):
474          # Generate bulk queries
475          request = []
476          for query in queries:
477              req_head = {"index": "textindex", "search_type": "dfs_query_then_fetch"}
478              req_body = {
479                  "_source": False,
480                  "query": {"multi_match": {"query": query, "type": "best_fields", "fields": ["text"], "tie_breaker": 0.5}},
481                  "size": limit,
482              }
483              request.extend([req_head, req_body])
484  
485          # Run ES query
486          response = self.backend.msearch(body=request, request_timeout=600)
487  
488          # Read responses
489          results = []
490          for resp in response["responses"]:
491              result = resp["hits"]["hits"]
492              results.append([(r["_id"], r["_score"]) for r in result])
493  
494          return results
495  
496      def index(self):
497          es = Elasticsearch("http://localhost:9200")
498  
499          # Delete existing index
500          # pylint: disable=W0702
501          try:
502              es.indices.delete(index="textindex")
503          except:
504              pass
505  
506          bulk(es, ({"_index": "textindex", "_id": uid, "text": text} for uid, text, _ in self.rows()))
507          es.indices.refresh(index="textindex")
508  
509          return es
510  
511  
512  def relevance(path):
513      """
514      Loads relevance data for evaluation.
515  
516      Args:
517          path: path to dataset test file
518  
519      Returns:
520          relevance data
521      """
522  
523      rel = {}
524      with open(f"{path}/qrels/test.tsv", encoding="utf-8") as f:
525          reader = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
526          next(reader)
527  
528          for row in reader:
529              queryid, corpusid, score = row[0], row[1], int(row[2])
530              if queryid not in rel:
531                  rel[queryid] = {corpusid: score}
532              else:
533                  rel[queryid][corpusid] = score
534  
535      return rel
536  
537  
538  def create(method, path, config, output, refresh):
539      """
540      Creates a new index.
541  
542      Args:
543          method: indexing method
544          path: path to dataset
545          config: path to config file
546          output: path to store index
547          refresh: overwrites existing index if True, otherwise existing index is loaded
548  
549      Returns:
550          Index
551      """
552  
553      if method == "hybrid":
554          return Hybrid(path, config, output, refresh)
555      if method == "rag":
556          return RetrievalAugmentedGeneration(path, config, output, refresh)
557      if method == "scoring":
558          return Score(path, config, output, refresh)
559      if method == "rank":
560          return RankBM25(path, config, output, refresh)
561      if method == "bm25s":
562          return BM25S(path, config, output, refresh)
563      if method == "sqlite":
564          return SQLiteFTS(path, config, output, refresh)
565      if method == "es":
566          return Elastic(path, config, output, refresh)
567      if method == "similar":
568          return Similar(path, config, output, refresh)
569      if method == "rerank":
570          return Rerank(path, config, output, refresh)
571  
572      # Default
573      return Embed(path, config, output, refresh)
574  
575  
576  def compute(results):
577      """
578      Computes metrics using the results from an evaluation run.
579  
580      Args:
581          results: evaluation results
582  
583      Returns:
584          metrics
585      """
586  
587      metrics = {}
588      for r in results:
589          for metric in results[r]:
590              if metric not in metrics:
591                  metrics[metric] = []
592  
593              metrics[metric].append(results[r][metric])
594  
595      return {metric: round(np.mean(values), 5) for metric, values in metrics.items()}
596  
597  
598  def evaluate(methods, path, args):
599      """
600      Runs an evaluation.
601  
602      Args:
603          methods: list of indexing methods to test
604          path: path to dataset
605          args: command line arguments
606  
607      Returns:
608          {calculated performance metrics}
609      """
610  
611      print(f"------ {os.path.basename(path)} ------")
612  
613      # Performance stats
614      performance = {}
615  
616      # Calculate stats for each model type
617      topk = args.topk
618      evaluator = RelevanceEvaluator(relevance(path), {f"ndcg_cut.{topk}", f"map_cut.{topk}", f"recall.{topk}", f"P.{topk}"})
619      for method in methods:
620          # Stats for this source
621          stats = {}
622          performance[method] = stats
623  
624          # Create index and get results
625          start = time.time()
626          output = args.output if args.output else f"{path}/{method}"
627          index = create(method, path, args.config, output, args.refresh)
628  
629          # Add indexing metrics
630          stats["index"] = round(time.time() - start, 2)
631          stats["memory"] = int(psutil.Process().memory_info().rss / (1024 * 1024))
632          stats["disk"] = int(sum(d.stat().st_size for d in os.scandir(output) if d.is_file()) / 1024) if os.path.isdir(output) else 0
633  
634          print("INDEX TIME =", time.time() - start)
635          print(f"MEMORY USAGE = {stats['memory']} MB")
636          print(f"DISK USAGE = {stats['disk']} KB")
637  
638          start = time.time()
639          results = index(topk)
640  
641          # Add search metrics
642          stats["search"] = round(time.time() - start, 2)
643          print("SEARCH TIME =", time.time() - start)
644  
645          # Calculate stats
646          metrics = compute(evaluator.evaluate(results))
647  
648          # Add accuracy metrics
649          for stat in [f"ndcg_cut_{topk}", f"map_cut_{topk}", f"recall_{topk}", f"P_{topk}"]:
650              stats[stat] = metrics[stat]
651  
652          # Print model stats
653          print(f"------ {method} ------")
654          print(f"NDCG@{topk} =", metrics[f"ndcg_cut_{topk}"])
655          print(f"MAP@{topk} =", metrics[f"map_cut_{topk}"])
656          print(f"Recall@{topk} =", metrics[f"recall_{topk}"])
657          print(f"P@{topk} =", metrics[f"P_{topk}"])
658  
659      print()
660      return performance
661  
662  
663  def benchmarks(args):
664      """
665      Main benchmark execution method.
666  
667      Args:
668          args: command line arguments
669      """
670  
671      # Directory where BEIR datasets are stored
672      directory = args.directory if args.directory else "beir"
673  
674      if args.sources and args.methods:
675          sources, methods = args.sources.split(","), args.methods.split(",")
676          mode = "a"
677      else:
678          # Default sources and methods
679          sources = [
680              "trec-covid",
681              "nfcorpus",
682              "nq",
683              "hotpotqa",
684              "fiqa",
685              "arguana",
686              "webis-touche2020",
687              "quora",
688              "dbpedia-entity",
689              "scidocs",
690              "fever",
691              "climate-fever",
692              "scifact",
693          ]
694          methods = ["embed", "hybrid", "rag", "scoring", "rank", "bm25s", "sqlite", "es", "similar", "rerank"]
695          mode = "w"
696  
697      # Run and save benchmarks
698      with open("benchmarks.json", mode, encoding="utf-8") as f:
699          for source in sources:
700              # Run evaluations
701              results = evaluate(methods, f"{directory}/{source}", args)
702  
703              # Save as JSON lines output
704              for method, stats in results.items():
705                  stats["source"] = source
706                  stats["method"] = method
707                  stats["name"] = args.name if args.name else method
708  
709                  json.dump(stats, f)
710                  f.write("\n")
711  
712  
713  if __name__ == "__main__":
714      # Command line parser
715      parser = argparse.ArgumentParser(description="Benchmarks")
716      parser.add_argument("-c", "--config", help="path to config file", metavar="CONFIG")
717      parser.add_argument("-d", "--directory", help="root directory path with datasets", metavar="DIRECTORY")
718      parser.add_argument("-m", "--methods", help="comma separated list of methods", metavar="METHODS")
719      parser.add_argument("-n", "--name", help="name to assign to this run, defaults to method name", metavar="NAME")
720      parser.add_argument("-o", "--output", help="index output directory path", metavar="OUTPUT")
721      parser.add_argument(
722          "-r",
723          "--refresh",
724          help="refreshes index if set, otherwise uses existing index if available",
725          action="store_true",
726      )
727      parser.add_argument("-s", "--sources", help="comma separated list of data sources", metavar="SOURCES")
728      parser.add_argument("-t", "--topk", help="top k results to use for the evaluation", metavar="TOPK", type=int, default=10)
729  
730      # Calculate benchmarks
731      benchmarks(parser.parse_args())