/ examples / rag / rag_bm.py
rag_bm.py
  1  # -*- coding: utf-8 -*-
  2  """RAG benchmark pipeline"""
  3  
  4  import asyncio
  5  
  6  from llama_index.core.node_parser import SentenceSplitter
  7  from llama_index.core.schema import NodeWithScore
  8  
  9  from metagpt.const import DATA_PATH, EXAMPLE_BENCHMARK_PATH, EXAMPLE_DATA_PATH
 10  from metagpt.logs import logger
 11  from metagpt.rag.benchmark import RAGBenchmark
 12  from metagpt.rag.engines import SimpleEngine
 13  from metagpt.rag.factories import get_rag_embedding, get_rag_llm
 14  from metagpt.rag.schema import (
 15      BM25RetrieverConfig,
 16      CohereRerankConfig,
 17      ColbertRerankConfig,
 18      FAISSIndexConfig,
 19      FAISSRetrieverConfig,
 20  )
 21  from metagpt.utils.common import write_json_file
 22  
 23  DOC_PATH = EXAMPLE_DATA_PATH / "rag_bm/summary_writer.txt"
 24  QUESTION = "2023年7月20日,应急管理部、财政部联合下发《因灾倒塌、损坏住房恢复重建救助工作规范》的通知,规范倒损住房恢复重建救助相关工作。"
 25  
 26  TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag_bm/documents.txt"
 27  TRAVEL_QUESTION = "国家卫生健康委在2023年7月28日开展的“启明行动”是为了防控哪个群体的哪种健康问题,并请列出活动发布的指导性文件名称。"
 28  
 29  DATASET_PATH = EXAMPLE_DATA_PATH / "rag_bm/test.json"
 30  SAVE_PATH = EXAMPLE_DATA_PATH / "rag_bm/result.json"
 31  GROUND_TRUTH_TRANVEL = "2023-07-28 10:14:27作者:白剑峰来源:人民日报    ,正文:为在全社会形成重视儿童眼健康的良好氛围,持续推进综合防控儿童青少年近视工作落实,国家卫生健康委决定在全国持续开展“启明行动”——防控儿童青少年近视健康促进活动,并发布了《防控儿童青少年近视核心知识十条》。本次活动的主题为:重视儿童眼保健,守护孩子明眸“视”界。强调预防为主,推动关口前移,倡导和推动家庭及全社会共同行动起来,营造爱眼护眼的视觉友好环境,共同呵护好孩子的眼睛,让他们拥有一个光明的未来。国家卫生健康委要求,开展社会宣传和健康教育。充分利用网络、广播电视、报纸杂志、海报墙报、培训讲座等多种形式,广泛开展宣传倡导,向社会公众传播开展儿童眼保健、保护儿童视力健康的重要意义,以《防控儿童青少年近视核心知识十条》为重点普及预防近视科学知识。创新健康教育方式和载体,开发制作群众喜闻乐见的健康教育科普作品,利用互联网媒体扩大传播效果,提高健康教育的针对性、精准性和实效性。指导相关医疗机构将儿童眼保健和近视防控等科学知识纳入孕妇学校、家长课堂内容。开展儿童眼保健及视力检查咨询指导。医疗机构要以儿童家长和养育人为重点,结合眼保健和眼科临床服务,开展个性化咨询指导。要针对儿童常见眼病和近视防控等重点问题,通过面对面咨询指导,引导儿童家长树立近视防控意识,改变不良生活方式,加强户外活动,养成爱眼护眼健康行为习惯。提高儿童眼保健专科服务能力。各地要积极推进儿童眼保健专科建设,扎实组织好妇幼健康职业技能竞赛“儿童眼保健”项目,推动各层级开展比武练兵,提升业务能力。"
 32  GROUND_TRUTH_ANSWER = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。"
 33  
 34  LLM_TIP = "If you not sure, just answer I don't know."
 35  LLM_ERROR = "Retrieve failed due to LLM wasn't follow instruction"
 36  EMPTY_ERROR = "Empty Response"
 37  
 38  
 39  class RAGExample:
 40      """Show how to use RAG for evaluation."""
 41  
 42      def __init__(self):
 43          self.benchmark = RAGBenchmark()
 44          self.embedding = get_rag_embedding()
 45          self.llm = get_rag_llm()
 46  
 47      async def rag_evaluate_pipeline(self, dataset_name: list[str] = ["all"]):
 48          dataset_config = self.benchmark.load_dataset(dataset_name)
 49  
 50          for dataset in dataset_config.datasets:
 51              if "all" in dataset_name or dataset.name in dataset_name:
 52                  output_dir = DATA_PATH / f"{dataset.name}"
 53  
 54                  if output_dir.exists():
 55                      logger.info("Loading Existed index!")
 56                      logger.info(f"Index Path:{output_dir}")
 57                      self.engine = SimpleEngine.from_index(
 58                          index_config=FAISSIndexConfig(persist_path=output_dir),
 59                          ranker_configs=[ColbertRerankConfig()],
 60                          retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
 61                      )
 62                  else:
 63                      logger.info("Loading index from documents!")
 64                      self.engine = SimpleEngine.from_docs(
 65                          input_files=dataset.document_files,
 66                          retriever_configs=[FAISSRetrieverConfig()],
 67                          ranker_configs=[CohereRerankConfig()],
 68                          transformations=[SentenceSplitter(chunk_size=1024, chunk_overlap=0)],
 69                      )
 70                  results = []
 71                  for gt_info in dataset.gt_info:
 72                      result = await self.rag_evaluate_single(
 73                          question=gt_info["question"],
 74                          reference=gt_info["gt_reference"],
 75                          ground_truth=gt_info["gt_answer"],
 76                      )
 77                      results.append(result)
 78                  logger.info(f"=====The {dataset.name} Benchmark dataset assessment is complete!=====")
 79                  self._print_bm_result(results)
 80  
 81                  write_json_file((EXAMPLE_BENCHMARK_PATH / dataset.name / "bm_result.json").as_posix(), results, "utf-8")
 82  
 83      async def rag_evaluate_single(self, question, reference, ground_truth, print_title=True):
 84          """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
 85  
 86          Retrieve Result:
 87          0. Productivi..., 10.0
 88          1. I wrote cu..., 7.0
 89          2. I highly r..., 5.0
 90  
 91          Query Result:
 92          Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer.
 93  
 94          RAG BenchMark result:
 95          {
 96          'metrics':
 97           {
 98           'bleu-avg': 0.48318624982047,
 99           'bleu-1': 0.5609756097560976,
100           'bleu-2': 0.5,
101           'bleu-3': 0.46153846153846156,
102           'bleu-4': 0.42105263157894735,
103           'rouge-L': 0.6865671641791045,
104           'semantic similarity': 0.9487444961487591,
105           'length': 74
106           },
107           'log': {
108           'generated_text':
109           '国家卫生健康委在2023年7月28日开展的“启明行动”是为了防控儿童青少年的近视问题。活动发布的指导性文件名称为《防控儿童青少年近视核心知识十条》。',
110           'ground_truth_text':
111           '“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。'
112              }
113           }
114          """
115          if print_title:
116              self._print_title("RAG Pipeline")
117          try:
118              nodes = await self.engine.aretrieve(question)
119              self._print_result(nodes, state="Retrieve")
120  
121              answer = await self.engine.aquery(question)
122              self._print_result(answer, state="Query")
123  
124          except Exception as e:
125              logger.error(e)
126              return self.benchmark.set_metrics(
127                  generated_text=LLM_ERROR, ground_truth_text=ground_truth, question=question
128              )
129  
130          result = await self.evaluate_result(answer.response, ground_truth, nodes, reference, question)
131  
132          logger.info("==========RAG BenchMark result demo as follows==========")
133          logger.info(result)
134  
135          return result
136  
137      async def rag_faissdb(self):
138          """This example show how to use FAISS. how to save and load index. will print something like:
139  
140          Query Result:
141          Bob likes traveling.
142          """
143          self._print_title("RAG FAISS")
144  
145          # save index
146          output_dir = DATA_PATH / "rag_faiss"
147  
148          SimpleEngine.from_docs(
149              input_files=[TRAVEL_DOC_PATH],
150              retriever_configs=[FAISSRetrieverConfig(dimensions=512, persist_path=output_dir)],
151          )
152  
153          # load index
154          engine = SimpleEngine.from_index(
155              index_config=FAISSIndexConfig(persist_path=output_dir),
156          )
157  
158          # query
159          nodes = engine.retrieve(QUESTION)
160          self._print_result(nodes, state="Retrieve")
161  
162          answer = engine.query(TRAVEL_QUESTION)
163          self._print_result(answer, state="Query")
164  
165      async def evaluate_result(
166          self,
167          response: str = None,
168          reference: str = None,
169          nodes: list[NodeWithScore] = None,
170          reference_doc: list[str] = None,
171          question: str = None,
172      ):
173          result = await self.benchmark.compute_metric(response, reference, nodes, reference_doc, question)
174  
175          return result
176  
177      @staticmethod
178      def _print_title(title):
179          logger.info(f"{'#'*30} {title} {'#'*30}")
180  
181      @staticmethod
182      def _print_result(result, state="Retrieve"):
183          """print retrieve or query result"""
184          logger.info(f"{state} Result:")
185  
186          if state == "Retrieve":
187              for i, node in enumerate(result):
188                  logger.info(f"{i}. {node.text[:10]}..., {node.score}")
189              logger.info("======Retrieve Finished======")
190              return
191  
192          logger.info(f"{result}\n")
193  
194      @staticmethod
195      def _print_bm_result(result):
196          import pandas as pd
197  
198          metrics = [
199              item["metrics"]
200              for item in result
201              if item["log"]["generated_text"] != LLM_ERROR and item["log"]["generated_text"] != EMPTY_ERROR
202          ]
203  
204          data = pd.DataFrame(metrics)
205          logger.info(f"\n {data.mean()}")
206  
207          llm_errors = [item for item in result if item["log"]["generated_text"] == LLM_ERROR]
208          retrieve_errors = [item for item in result if item["log"]["generated_text"] == EMPTY_ERROR]
209          logger.info(
210              f"Percentage of retrieval failures due to incorrect LLM instruction following:"
211              f" {100.0 * len(llm_errors) / len(result)}%"
212          )
213          logger.info(
214              f"Percentage of retrieval failures due to retriever not recalling any documents is:"
215              f" {100.0 * len(retrieve_errors) / len(result)}%"
216          )
217  
218      async def _retrieve_and_print(self, question):
219          nodes = await self.engine.aretrieve(question)
220          self._print_result(nodes, state="Retrieve")
221          return nodes
222  
223  
224  async def main():
225      """RAG pipeline"""
226      e = RAGExample()
227      await e.rag_evaluate_pipeline()
228  
229  
230  if __name__ == "__main__":
231      asyncio.run(main())