DABench.py
1 import asyncio 2 import json 3 import re 4 from pathlib import Path 5 from typing import Any, Dict, List, Tuple, Union 6 7 import nest_asyncio 8 9 from examples.di.requirements_prompt import DABENCH 10 from metagpt.const import DABENCH_PATH 11 from metagpt.logs import logger 12 from metagpt.utils.exceptions import handle_exception 13 14 15 def evaluate_accuracy_by_question(results: dict) -> float: 16 """ 17 Calculate the accuracy of results based on complete correctness of each question. 18 This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py 19 This function checks whether each result is entirely correct, meaning all sub-questions 20 within that result are answered correctly. It computes the proportion of correct results 21 by dividing the number of fully correct results by the total number of results. 22 23 Args: 24 results (dict): A collection of results where each result may contain a 'correctness' field. 25 26 Returns: 27 float: The proportion of correct results, rounded to four decimal places. 28 Returns 0 if there are no results. 29 """ 30 correct = sum("correctness" in result and all(result["correctness"].values()) for result in results) 31 total = len(results) 32 return round(correct / total, 4) if total > 0 else 0 33 34 35 def evaluate_accuracy_by_sub_question(results: dict) -> float: 36 """ 37 Evaluate the correctness of all sub-questions across the results. 38 This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py 39 This function calculates the total number of correct sub-questions and the overall 40 number of sub-questions present in all results. It returns the ratio of correct 41 sub-questions to the total number of sub-questions. 42 43 Args: 44 results (dict): A collection of results where each result may contain a 'correctness' field. 45 46 Returns: 47 float: The ratio of correct sub-questions, rounded to four decimal places. 48 Returns 0 if there are no sub-questions. 49 """ 50 correct = sum(sum(result["correctness"].values()) for result in results if "correctness" in result) 51 total = sum(len(result["correctness"]) for result in results if "correctness" in result) 52 return round(correct / total, 4) if total > 0 else 0 53 54 55 def evaluate_accuracy_proportional_by_sub_question_adjusted(results: dict) -> float: 56 """ 57 Adjust the score based on the number of sub-questions in each result. 58 This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py 59 This function calculates a score for each result by considering the number of sub-questions 60 it contains. Each sub-question is assigned a score of 1 divided by the number of sub-questions. 61 The total score for each result is computed as the sum of all correct sub-questions multiplied 62 by the score per sub-question. Finally, it returns the average score across all results. 63 64 Args: 65 results (dict): A collection of results where each result may contain a 'correctness' field. 66 67 Returns: 68 float: The average score across all results, rounded to four decimal places. 69 Returns 0 if there are no results. 70 """ 71 total_score = 0 72 for result in results: 73 if "correctness" in result: 74 sub_question_count = len(result["correctness"]) 75 score_per_sub_question = 1 / sub_question_count if sub_question_count > 0 else 0 76 question_score = sum(result["correctness"].values()) * score_per_sub_question 77 total_score += question_score 78 return round(total_score / len(results), 4) if results else 0 79 80 81 async def reformat(question: str, format: str, response: str) -> str: 82 """ 83 Asynchronously reformats a given response based on specified formatting requirements. 84 This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/reformat.py 85 This function constructs a prompt for the LLM (Large Language Model) to reformat 86 the provided response according to the specified format. It includes a system prompt 87 to guide the LLM's behavior and a template that outlines the expected output structure. 88 89 Args: 90 question (str): The original question posed by the user. 91 format (str): The specific formatting requirements that the response must adhere to. 92 response (str): The initial response from the LLM that needs to be reformatted. 93 94 Returns: 95 str: The reformatted response generated by the LLM based on the provided question 96 and formatting requirements. 97 """ 98 system_prompt = "You are a helpful assistant." 99 demons = """\Format{{ 100 @shapiro_wilk_statistic[test_statistic] 101 @shapiro_wilk_p_value[p_value] 102 where "test_statistic" is a number between 0 and 1 representing the Shapiro-Wilk test statistic. Rounding off the answer to two decimal places. 103 where "p_value" is a number between 0 and 1 representing the p-value from the Shapiro-Wilk test. Rounding off the answer to four decimal places. 104 }} 105 \Answer{{ 106 @shapiro_wilk_statistic[0.56] 107 @shapiro_wilk_p_value[0.0002] 108 }} 109 110 \Format{{ 111 @total_votes_outliers_num[outlier_num] 112 where "outlier_num" is an integer representing the number of values considered outliers in the 'total_votes' column. 113 }} 114 \Answer{{ 115 @total_votes_outliers[10] 116 }} 117 """ 118 reformat_template = """You should strictly follow the output requirements in the Format part. Here're some examples: {demons}. 119 Your answer should contain all the \"@answer_name[answer]\" in the order mentioned, each \"answer\" should be in the range of value as required. You need to keep the original numbers and text, just reformat without making any changes. 120 The format requirements of this question is: 121 {format}. You need to keep the original numbers and text, just reformat without making any changes. Please give your answer:""" 122 messages = [ 123 {"role": "user", "content": question}, 124 {"role": "assistant", "content": response}, 125 {"role": "user", "content": reformat_template.format(demons=demons, format=format)}, 126 ] 127 rsp = await ask(messages, system_prompt) 128 return rsp 129 130 131 def load_jsonl(file_path: Union[Path, str]) -> List[Dict[str, Any]]: 132 """ 133 Load data from a JSONL file into a list of dictionaries. 134 135 Args: 136 file_path (Union[Path, str]): The path to the JSONL file to be loaded. 137 138 Returns: 139 List[Dict[str, Any]]: A list of dictionaries containing the data from the JSONL file. 140 """ 141 # Convert file_path to Path if it's a string 142 if isinstance(file_path, str): 143 file_path = Path(file_path) 144 145 data = [] 146 with open(file_path, "r", encoding="utf-8") as file: 147 for line in file: 148 data.append(json.loads(line)) 149 return data 150 151 152 def compare_predictions(pred_dict: dict, true_label: list) -> bool: 153 """ 154 Compares each prediction against the corresponding true label. 155 156 This function checks whether the predicted values match the true values for each 157 metric. It sorts the true labels to ensure the comparison is made in the correct 158 order. The function returns True if all predictions are accurate within a small 159 tolerance for numerical values, or if string values match case-insensitively. 160 161 Args: 162 pred_dict (dict): A dictionary of predicted metrics and their values. 163 true_label (list): A list of tuples containing true metrics and their values. 164 165 Returns: 166 bool: True if all predictions match the true labels, False otherwise. 167 """ 168 sorted_true_label = sorted(true_label, key=lambda x: x[0]) # Sort true labels by metric name 169 170 for metric, true_value in sorted_true_label: 171 try: 172 true_value = float(true_value) # Attempt to convert the true value to float 173 except ValueError: 174 true_value = true_value.replace(",", "") # Clean the true value if conversion fails 175 176 # Check if the true value is numeric and compare with the prediction 177 if isinstance(true_value, (int, float)) and ( 178 metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6 179 ): 180 return False # Return False if the prediction is inaccurate 181 182 # Check if the true value is a string and compare with the prediction 183 if isinstance(true_value, str) and ( 184 metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower() 185 ): 186 return False # Return False if the string prediction does not match 187 188 return True # Return True if all predictions are accurate 189 190 191 async def ask(question: str, system_prompt: str) -> str: 192 """ 193 Asynchronously sends a question to the LLM (Large Language Model) and retrieves the response. 194 195 This function initializes an instance of the LLM and uses it to ask a question 196 along with a system prompt. The response from the LLM is awaited and returned. 197 198 Args: 199 question (str): The question to be asked to the LLM. 200 system_prompt (str): A prompt that provides context or instructions to the LLM. 201 202 Returns: 203 str: The response from the LLM based on the provided question and system prompt. 204 """ 205 from metagpt.llm import LLM # Importing the LLM class from the metagpt module 206 207 llm = LLM() # Create an instance of the LLM 208 rsp = await llm.aask(question, system_msgs=[system_prompt]) # Await the response from the LLM 209 return rsp # Return the response 210 211 212 def parse_prediction(prediction: str) -> dict: 213 """ 214 Parses a prediction string into a dictionary of metric-value pairs. 215 216 This function takes a formatted string containing metrics and their corresponding 217 values, separated by the "@" symbol. Each metric may be enclosed in brackets and 218 may include commas. The function processes the input to extract and clean the 219 metrics and their values, returning them in a structured dictionary format. 220 221 Args: 222 prediction (str): A string representation of metrics and their values. 223 224 Returns: 225 dict: A dictionary where each key is a metric name and each value is the 226 corresponding value, either as a float or a string. 227 """ 228 pred_dict = {} 229 for pred in prediction.split("@"): 230 if pred == "": 231 continue # Skip any empty segments resulting from the split 232 temp = re.split(r"[\[\]]", pred.strip()) # Split the string by brackets 233 temp = [s.replace(",", "") for s in temp] # Remove commas from the segments 234 parts = [s for s in temp if s] # Filter out any empty strings 235 metric = parts[0].strip().replace(",", "") # Extract and clean the metric name 236 value = parts[-1].replace(",", "").replace(":", "") # Extract and clean the value 237 238 try: 239 value = float(value) # Attempt to convert the value to a float 240 except ValueError: 241 pass # If conversion fails, retain the value as a string 242 243 pred_dict[metric] = value # Store the metric-value pair in the dictionary 244 return pred_dict 245 246 247 class DABench: 248 def __init__( 249 self, 250 questions_file: Path = Path(DABENCH_PATH) / "da-dev-questions.jsonl", 251 answers_file: Path = Path(DABENCH_PATH) / "da-dev-labels.jsonl", 252 template: str = "", 253 ): 254 """ 255 Initializes the DABench instance with questions and answers. 256 257 This constructor loads questions and answers from specified JSONL files. 258 It also sets a template for formatting prompts. If no template is provided, 259 a default template is used. 260 261 Args: 262 questions_file (Path): The path to the JSONL file containing questions. 263 answers_file (Path): The path to the JSONL file containing answers. 264 template (str): A string template for formatting prompts. 265 """ 266 267 self.questions = { 268 int(line["id"]): line for line in load_jsonl(questions_file) 269 } # Load questions from the specified file 270 self.answers = { 271 int(line["id"]): line for line in load_jsonl(answers_file) 272 } # Load answers from the specified file 273 self.template = template if template else DABENCH # Set the template, defaulting if necessary 274 275 def get_question(self, question_id: str) -> dict: 276 """ 277 Retrieve the question associated with the given ID. 278 279 This method looks up a question by its unique identifier. If the question 280 is found, it returns the question data; otherwise, it returns a message 281 indicating that the question was not found. 282 283 Args: 284 question_id (str): The unique identifier for the question. 285 286 Returns: 287 dict: The question data if found, otherwise a "Question not found." message. 288 """ 289 return self.questions.get(question_id, "Question not found.") # Return the question or an error message 290 291 def generate_formatted_prompt(self, question_id: str) -> str: 292 """ 293 Generate a formatted prompt for the specified question ID. 294 295 This method retrieves the question data and formats it using the specified 296 template. The formatted prompt includes the question, constraints, format, 297 file name, and level, allowing for a structured output. 298 299 Args: 300 question_id (str): The unique identifier for the question. 301 302 Returns: 303 str: A formatted prompt string based on the question data. 304 """ 305 temp = self.get_question(question_id) # Retrieve the question data 306 return self.template.format( 307 question=temp["question"], 308 constraints=temp["constraints"], 309 format=temp["format"], 310 file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"], 311 level=temp["level"], 312 ) # Format and return the prompt 313 314 def get_answer(self, answer_id: str) -> list: 315 """ 316 Retrieve the answer list associated with the given ID. 317 318 This method looks up an answer by its unique identifier. If the answer 319 is found, it returns the answer data; otherwise, it returns a message 320 indicating that the answer was not found. 321 322 Args: 323 answer_id (str): The unique identifier for the answer. 324 325 Returns: 326 list: The answer data if found, otherwise an "Answer not found." message. 327 """ 328 return self.answers.get(answer_id, "Answer not found.") # Return the answer or an error message 329 330 @handle_exception(exception_msg="Error parsing cleaned prediction", default_return=(None, False)) 331 def parse_cleaned_prediction(self, cleaned_prediction: str, true_label: Any) -> Tuple[str, bool]: 332 """ 333 Parse the cleaned prediction and compare it with the true label. 334 335 Args: 336 cleaned_prediction (str): The cleaned prediction string. 337 true_label (Any): The true label to compare against. 338 339 Returns: 340 Tuple[str, bool]: A tuple containing the cleaned prediction and a boolean indicating 341 whether it matches the true label. 342 """ 343 if cleaned_prediction: # Ensure the cleaned prediction is not empty 344 pred_dict = parse_prediction(cleaned_prediction) # Parse the prediction 345 if pred_dict is not None and compare_predictions(pred_dict, true_label): 346 return cleaned_prediction, True # Return if the prediction matches the true label 347 return cleaned_prediction, False # Return the cleaned prediction with a False match 348 349 @handle_exception(exception_msg="Error during async reformat", default_return=(None, False)) 350 def async_reformat_prediction(self, id: str, result: str) -> str: 351 """ 352 Reformat the prediction asynchronously and extract the answer. 353 354 Args: 355 id (str): The identifier for the question. 356 result (str): The original prediction result. 357 358 Returns: 359 str: The reformatted prediction or the original prediction if extraction fails. 360 """ 361 question = self.get_question(id)["question"] # Retrieve the question based on the ID 362 question_format = self.get_question(id)["format"] # Get the format of the question 363 prediction = asyncio.run(reformat(question, question_format, result)) # Asynchronously reformat the prediction 364 365 # Attempt to extract the answer from the reformatted prediction 366 answer_part = prediction.split("Answer{{") if "Answer{{" in prediction else [] 367 if len(answer_part) > 1: 368 return answer_part[1].split("}}")[0].strip() # Return the extracted answer 369 370 return prediction # If extraction fails, return the original prediction 371 372 def eval(self, id: str, result: str) -> Tuple[str, bool]: 373 """ 374 Evaluate the prediction against the true label. 375 376 Args: 377 id (str): The identifier for the question. 378 result (str): The original prediction result. 379 380 Returns: 381 Tuple[str, bool]: A tuple containing the final prediction and a boolean indicating 382 whether it matches the true label. 383 """ 384 true_label = self.get_answer(id)["common_answers"] # Retrieve the true label for comparison 385 nest_asyncio.apply() # Apply nested asyncio to allow for async calls 386 result = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"].strip() 387 cleaned_prediction = result.replace("{", "").replace("}", "").replace("'", "") # Clean the prediction string 388 389 # Use the decorated function to handle exceptions while parsing the cleaned prediction 390 parsed_result = self.parse_cleaned_prediction(cleaned_prediction, true_label) 391 if parsed_result[1]: # If the parsed prediction is valid 392 return parsed_result # Return the valid prediction 393 394 # If the cleaned prediction is not valid, attempt to asynchronously reformat it 395 prediction = self.async_reformat_prediction(id, result) 396 397 pred_dict = parse_prediction(prediction) # Parse the reformatted prediction 398 if pred_dict is not None and compare_predictions(pred_dict, true_label): 399 return prediction, True # Return if the reformatted prediction matches the true label 400 401 return prediction, False # Return the final prediction with a False match 402 403 @handle_exception(exception_msg="Error evaluating single prediction", default_return={}) 404 def single_eval(self, id: str, prediction: str) -> dict: 405 """ 406 Evaluate the prediction against the true label for a single question. 407 just using in eval_all 408 409 Args: 410 id (str): The identifier for the question. 411 prediction (str): The prediction string to evaluate. 412 413 Returns: 414 dict: A dictionary indicating the correctness of each metric. 415 """ 416 true_label = self.get_answer(id)["common_answers"] # Retrieve the true label for the question 417 prediction = prediction.replace("{", "").replace("}", "").replace("'", "") # Clean the prediction string 418 pred_dict = parse_prediction(prediction) # Parse the prediction into a dictionary 419 420 # Initialize the correctness dictionary with False values for each metric 421 correctness = {metric: False for metric, _ in true_label} 422 423 # Check each metric's prediction against the true label 424 for metric, true_value in true_label: 425 try: 426 true_value = float(true_value) # Attempt to convert the true value to float 427 except ValueError: 428 true_value = true_value.replace(",", "") # Handle non-numeric values 429 430 if metric in pred_dict: 431 # Consider the prediction correct if it's within a small tolerance 432 if ( 433 isinstance(true_value, (int, float)) 434 and isinstance(pred_dict[metric], (int, float)) 435 and abs(pred_dict[metric] - true_value) < 1e-6 436 ): 437 correctness[metric] = True # Mark as correct if within tolerance 438 439 if isinstance(true_value, str) and ( 440 metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower() 441 ): 442 correctness[metric] = True # Mark as correct for string comparison 443 444 return correctness # Return the correctness dictionary 445 446 def eval_all(self, id_list: list, predictions: list) -> dict: 447 """ 448 Evaluate all predictions and calculate accuracy rates. 449 450 Args: 451 id_list (list): A list of question identifiers. 452 predictions (list): A list of prediction strings corresponding to the questions. 453 454 Returns: 455 dict: A dictionary containing accuracy rates by question and sub-question. 456 """ 457 results = [] # Initialize a list to store results for each question 458 459 # Evaluate each prediction against its corresponding question ID 460 for id, prediction in zip(id_list, predictions): 461 correct = self.single_eval(id, prediction) # Evaluate the single prediction 462 results.append({"id": id, "correctness": correct}) # Append the result to the list 463 464 # Calculate the three accuracy rates based on the results 465 accuracy_by_question = evaluate_accuracy_by_question(results) 466 accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results) 467 proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results) 468 469 return { 470 "accuracy_by_question": accuracy_by_question, 471 "accuracy_by_sub_question": accuracy_by_sub_question, 472 "proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question, 473 } 474 475 476 if __name__ == "__main__": 477 bench = DABench() 478 id = 0 479 prediction = "@mean_fare[34.65]" 480 logger.info(bench.eval(id, prediction)) 481 ids = [0, 5, 6] 482 predictions = [ 483 "@mean_fare[34.89]", 484 "@correlation_coefficient[0.21]", 485 "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]", 486 ] 487 logger.info(bench.eval_all(ids, predictions))