rag_bm.py
1 # -*- coding: utf-8 -*- 2 """RAG benchmark pipeline""" 3 4 import asyncio 5 6 from llama_index.core.node_parser import SentenceSplitter 7 from llama_index.core.schema import NodeWithScore 8 9 from metagpt.const import DATA_PATH, EXAMPLE_BENCHMARK_PATH, EXAMPLE_DATA_PATH 10 from metagpt.logs import logger 11 from metagpt.rag.benchmark import RAGBenchmark 12 from metagpt.rag.engines import SimpleEngine 13 from metagpt.rag.factories import get_rag_embedding, get_rag_llm 14 from metagpt.rag.schema import ( 15 BM25RetrieverConfig, 16 CohereRerankConfig, 17 ColbertRerankConfig, 18 FAISSIndexConfig, 19 FAISSRetrieverConfig, 20 ) 21 from metagpt.utils.common import write_json_file 22 23 DOC_PATH = EXAMPLE_DATA_PATH / "rag_bm/summary_writer.txt" 24 QUESTION = "2023年7月20日,应急管理部、财政部联合下发《因灾倒塌、损坏住房恢复重建救助工作规范》的通知,规范倒损住房恢复重建救助相关工作。" 25 26 TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag_bm/documents.txt" 27 TRAVEL_QUESTION = "国家卫生健康委在2023年7月28日开展的“启明行动”是为了防控哪个群体的哪种健康问题,并请列出活动发布的指导性文件名称。" 28 29 DATASET_PATH = EXAMPLE_DATA_PATH / "rag_bm/test.json" 30 SAVE_PATH = EXAMPLE_DATA_PATH / "rag_bm/result.json" 31 GROUND_TRUTH_TRANVEL = "2023-07-28 10:14:27作者:白剑峰来源:人民日报 ,正文:为在全社会形成重视儿童眼健康的良好氛围,持续推进综合防控儿童青少年近视工作落实,国家卫生健康委决定在全国持续开展“启明行动”——防控儿童青少年近视健康促进活动,并发布了《防控儿童青少年近视核心知识十条》。本次活动的主题为:重视儿童眼保健,守护孩子明眸“视”界。强调预防为主,推动关口前移,倡导和推动家庭及全社会共同行动起来,营造爱眼护眼的视觉友好环境,共同呵护好孩子的眼睛,让他们拥有一个光明的未来。国家卫生健康委要求,开展社会宣传和健康教育。充分利用网络、广播电视、报纸杂志、海报墙报、培训讲座等多种形式,广泛开展宣传倡导,向社会公众传播开展儿童眼保健、保护儿童视力健康的重要意义,以《防控儿童青少年近视核心知识十条》为重点普及预防近视科学知识。创新健康教育方式和载体,开发制作群众喜闻乐见的健康教育科普作品,利用互联网媒体扩大传播效果,提高健康教育的针对性、精准性和实效性。指导相关医疗机构将儿童眼保健和近视防控等科学知识纳入孕妇学校、家长课堂内容。开展儿童眼保健及视力检查咨询指导。医疗机构要以儿童家长和养育人为重点,结合眼保健和眼科临床服务,开展个性化咨询指导。要针对儿童常见眼病和近视防控等重点问题,通过面对面咨询指导,引导儿童家长树立近视防控意识,改变不良生活方式,加强户外活动,养成爱眼护眼健康行为习惯。提高儿童眼保健专科服务能力。各地要积极推进儿童眼保健专科建设,扎实组织好妇幼健康职业技能竞赛“儿童眼保健”项目,推动各层级开展比武练兵,提升业务能力。" 32 GROUND_TRUTH_ANSWER = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。" 33 34 LLM_TIP = "If you not sure, just answer I don't know." 35 LLM_ERROR = "Retrieve failed due to LLM wasn't follow instruction" 36 EMPTY_ERROR = "Empty Response" 37 38 39 class RAGExample: 40 """Show how to use RAG for evaluation.""" 41 42 def __init__(self): 43 self.benchmark = RAGBenchmark() 44 self.embedding = get_rag_embedding() 45 self.llm = get_rag_llm() 46 47 async def rag_evaluate_pipeline(self, dataset_name: list[str] = ["all"]): 48 dataset_config = self.benchmark.load_dataset(dataset_name) 49 50 for dataset in dataset_config.datasets: 51 if "all" in dataset_name or dataset.name in dataset_name: 52 output_dir = DATA_PATH / f"{dataset.name}" 53 54 if output_dir.exists(): 55 logger.info("Loading Existed index!") 56 logger.info(f"Index Path:{output_dir}") 57 self.engine = SimpleEngine.from_index( 58 index_config=FAISSIndexConfig(persist_path=output_dir), 59 ranker_configs=[ColbertRerankConfig()], 60 retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], 61 ) 62 else: 63 logger.info("Loading index from documents!") 64 self.engine = SimpleEngine.from_docs( 65 input_files=dataset.document_files, 66 retriever_configs=[FAISSRetrieverConfig()], 67 ranker_configs=[CohereRerankConfig()], 68 transformations=[SentenceSplitter(chunk_size=1024, chunk_overlap=0)], 69 ) 70 results = [] 71 for gt_info in dataset.gt_info: 72 result = await self.rag_evaluate_single( 73 question=gt_info["question"], 74 reference=gt_info["gt_reference"], 75 ground_truth=gt_info["gt_answer"], 76 ) 77 results.append(result) 78 logger.info(f"=====The {dataset.name} Benchmark dataset assessment is complete!=====") 79 self._print_bm_result(results) 80 81 write_json_file((EXAMPLE_BENCHMARK_PATH / dataset.name / "bm_result.json").as_posix(), results, "utf-8") 82 83 async def rag_evaluate_single(self, question, reference, ground_truth, print_title=True): 84 """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like: 85 86 Retrieve Result: 87 0. Productivi..., 10.0 88 1. I wrote cu..., 7.0 89 2. I highly r..., 5.0 90 91 Query Result: 92 Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer. 93 94 RAG BenchMark result: 95 { 96 'metrics': 97 { 98 'bleu-avg': 0.48318624982047, 99 'bleu-1': 0.5609756097560976, 100 'bleu-2': 0.5, 101 'bleu-3': 0.46153846153846156, 102 'bleu-4': 0.42105263157894735, 103 'rouge-L': 0.6865671641791045, 104 'semantic similarity': 0.9487444961487591, 105 'length': 74 106 }, 107 'log': { 108 'generated_text': 109 '国家卫生健康委在2023年7月28日开展的“启明行动”是为了防控儿童青少年的近视问题。活动发布的指导性文件名称为《防控儿童青少年近视核心知识十条》。', 110 'ground_truth_text': 111 '“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。' 112 } 113 } 114 """ 115 if print_title: 116 self._print_title("RAG Pipeline") 117 try: 118 nodes = await self.engine.aretrieve(question) 119 self._print_result(nodes, state="Retrieve") 120 121 answer = await self.engine.aquery(question) 122 self._print_result(answer, state="Query") 123 124 except Exception as e: 125 logger.error(e) 126 return self.benchmark.set_metrics( 127 generated_text=LLM_ERROR, ground_truth_text=ground_truth, question=question 128 ) 129 130 result = await self.evaluate_result(answer.response, ground_truth, nodes, reference, question) 131 132 logger.info("==========RAG BenchMark result demo as follows==========") 133 logger.info(result) 134 135 return result 136 137 async def rag_faissdb(self): 138 """This example show how to use FAISS. how to save and load index. will print something like: 139 140 Query Result: 141 Bob likes traveling. 142 """ 143 self._print_title("RAG FAISS") 144 145 # save index 146 output_dir = DATA_PATH / "rag_faiss" 147 148 SimpleEngine.from_docs( 149 input_files=[TRAVEL_DOC_PATH], 150 retriever_configs=[FAISSRetrieverConfig(dimensions=512, persist_path=output_dir)], 151 ) 152 153 # load index 154 engine = SimpleEngine.from_index( 155 index_config=FAISSIndexConfig(persist_path=output_dir), 156 ) 157 158 # query 159 nodes = engine.retrieve(QUESTION) 160 self._print_result(nodes, state="Retrieve") 161 162 answer = engine.query(TRAVEL_QUESTION) 163 self._print_result(answer, state="Query") 164 165 async def evaluate_result( 166 self, 167 response: str = None, 168 reference: str = None, 169 nodes: list[NodeWithScore] = None, 170 reference_doc: list[str] = None, 171 question: str = None, 172 ): 173 result = await self.benchmark.compute_metric(response, reference, nodes, reference_doc, question) 174 175 return result 176 177 @staticmethod 178 def _print_title(title): 179 logger.info(f"{'#'*30} {title} {'#'*30}") 180 181 @staticmethod 182 def _print_result(result, state="Retrieve"): 183 """print retrieve or query result""" 184 logger.info(f"{state} Result:") 185 186 if state == "Retrieve": 187 for i, node in enumerate(result): 188 logger.info(f"{i}. {node.text[:10]}..., {node.score}") 189 logger.info("======Retrieve Finished======") 190 return 191 192 logger.info(f"{result}\n") 193 194 @staticmethod 195 def _print_bm_result(result): 196 import pandas as pd 197 198 metrics = [ 199 item["metrics"] 200 for item in result 201 if item["log"]["generated_text"] != LLM_ERROR and item["log"]["generated_text"] != EMPTY_ERROR 202 ] 203 204 data = pd.DataFrame(metrics) 205 logger.info(f"\n {data.mean()}") 206 207 llm_errors = [item for item in result if item["log"]["generated_text"] == LLM_ERROR] 208 retrieve_errors = [item for item in result if item["log"]["generated_text"] == EMPTY_ERROR] 209 logger.info( 210 f"Percentage of retrieval failures due to incorrect LLM instruction following:" 211 f" {100.0 * len(llm_errors) / len(result)}%" 212 ) 213 logger.info( 214 f"Percentage of retrieval failures due to retriever not recalling any documents is:" 215 f" {100.0 * len(retrieve_errors) / len(result)}%" 216 ) 217 218 async def _retrieve_and_print(self, question): 219 nodes = await self.engine.aretrieve(question) 220 self._print_result(nodes, state="Retrieve") 221 return nodes 222 223 224 async def main(): 225 """RAG pipeline""" 226 e = RAGExample() 227 await e.rag_evaluate_pipeline() 228 229 230 if __name__ == "__main__": 231 asyncio.run(main())