/ examples / di / InfiAgent-DABench / DABench.py
DABench.py
  1  import asyncio
  2  import json
  3  import re
  4  from pathlib import Path
  5  from typing import Any, Dict, List, Tuple, Union
  6  
  7  import nest_asyncio
  8  
  9  from examples.di.requirements_prompt import DABENCH
 10  from metagpt.const import DABENCH_PATH
 11  from metagpt.logs import logger
 12  from metagpt.utils.exceptions import handle_exception
 13  
 14  
 15  def evaluate_accuracy_by_question(results: dict) -> float:
 16      """
 17      Calculate the accuracy of results based on complete correctness of each question.
 18      This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
 19      This function checks whether each result is entirely correct, meaning all sub-questions
 20      within that result are answered correctly. It computes the proportion of correct results
 21      by dividing the number of fully correct results by the total number of results.
 22  
 23      Args:
 24          results (dict): A collection of results where each result may contain a 'correctness' field.
 25  
 26      Returns:
 27          float: The proportion of correct results, rounded to four decimal places.
 28                 Returns 0 if there are no results.
 29      """
 30      correct = sum("correctness" in result and all(result["correctness"].values()) for result in results)
 31      total = len(results)
 32      return round(correct / total, 4) if total > 0 else 0
 33  
 34  
 35  def evaluate_accuracy_by_sub_question(results: dict) -> float:
 36      """
 37      Evaluate the correctness of all sub-questions across the results.
 38      This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
 39      This function calculates the total number of correct sub-questions and the overall
 40      number of sub-questions present in all results. It returns the ratio of correct
 41      sub-questions to the total number of sub-questions.
 42  
 43      Args:
 44          results (dict): A collection of results where each result may contain a 'correctness' field.
 45  
 46      Returns:
 47          float: The ratio of correct sub-questions, rounded to four decimal places.
 48                 Returns 0 if there are no sub-questions.
 49      """
 50      correct = sum(sum(result["correctness"].values()) for result in results if "correctness" in result)
 51      total = sum(len(result["correctness"]) for result in results if "correctness" in result)
 52      return round(correct / total, 4) if total > 0 else 0
 53  
 54  
 55  def evaluate_accuracy_proportional_by_sub_question_adjusted(results: dict) -> float:
 56      """
 57      Adjust the score based on the number of sub-questions in each result.
 58      This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/eval_closed_form.py
 59      This function calculates a score for each result by considering the number of sub-questions
 60      it contains. Each sub-question is assigned a score of 1 divided by the number of sub-questions.
 61      The total score for each result is computed as the sum of all correct sub-questions multiplied
 62      by the score per sub-question. Finally, it returns the average score across all results.
 63  
 64      Args:
 65          results (dict): A collection of results where each result may contain a 'correctness' field.
 66  
 67      Returns:
 68          float: The average score across all results, rounded to four decimal places.
 69                 Returns 0 if there are no results.
 70      """
 71      total_score = 0
 72      for result in results:
 73          if "correctness" in result:
 74              sub_question_count = len(result["correctness"])
 75              score_per_sub_question = 1 / sub_question_count if sub_question_count > 0 else 0
 76              question_score = sum(result["correctness"].values()) * score_per_sub_question
 77              total_score += question_score
 78      return round(total_score / len(results), 4) if results else 0
 79  
 80  
 81  async def reformat(question: str, format: str, response: str) -> str:
 82      """
 83      Asynchronously reformats a given response based on specified formatting requirements.
 84      This function is referenced from https://github.com/InfiAgent/InfiAgent/blob/main/examples/DA-Agent/reformat.py
 85      This function constructs a prompt for the LLM (Large Language Model) to reformat
 86      the provided response according to the specified format. It includes a system prompt
 87      to guide the LLM's behavior and a template that outlines the expected output structure.
 88  
 89      Args:
 90          question (str): The original question posed by the user.
 91          format (str): The specific formatting requirements that the response must adhere to.
 92          response (str): The initial response from the LLM that needs to be reformatted.
 93  
 94      Returns:
 95          str: The reformatted response generated by the LLM based on the provided question
 96               and formatting requirements.
 97      """
 98      system_prompt = "You are a helpful assistant."
 99      demons = """\Format{{
100          @shapiro_wilk_statistic[test_statistic]
101          @shapiro_wilk_p_value[p_value]
102          where "test_statistic" is a number between 0 and 1 representing the Shapiro-Wilk test statistic. Rounding off the answer to two decimal places.
103          where "p_value" is a number between 0 and 1 representing the p-value from the Shapiro-Wilk test. Rounding off the answer to four decimal places.
104          }}
105          \Answer{{
106          @shapiro_wilk_statistic[0.56]
107          @shapiro_wilk_p_value[0.0002]   
108          }}
109  
110          \Format{{
111          @total_votes_outliers_num[outlier_num]
112          where "outlier_num" is an integer representing the number of values considered outliers in the 'total_votes' column.
113          }}
114          \Answer{{
115          @total_votes_outliers[10]   
116          }}
117          """
118      reformat_template = """You should strictly follow the output requirements in the Format part. Here're some examples: {demons}. 
119      Your answer should contain all the \"@answer_name[answer]\" in the order mentioned, each \"answer\" should be in the range of value as required. You need to keep the original numbers and text, just reformat without making any changes.
120      The format requirements of this question is:
121      {format}. You need to keep the original numbers and text, just reformat without making any changes. Please give your answer:"""
122      messages = [
123          {"role": "user", "content": question},
124          {"role": "assistant", "content": response},
125          {"role": "user", "content": reformat_template.format(demons=demons, format=format)},
126      ]
127      rsp = await ask(messages, system_prompt)
128      return rsp
129  
130  
131  def load_jsonl(file_path: Union[Path, str]) -> List[Dict[str, Any]]:
132      """
133      Load data from a JSONL file into a list of dictionaries.
134  
135      Args:
136          file_path (Union[Path, str]): The path to the JSONL file to be loaded.
137  
138      Returns:
139          List[Dict[str, Any]]: A list of dictionaries containing the data from the JSONL file.
140      """
141      # Convert file_path to Path if it's a string
142      if isinstance(file_path, str):
143          file_path = Path(file_path)
144  
145      data = []
146      with open(file_path, "r", encoding="utf-8") as file:
147          for line in file:
148              data.append(json.loads(line))
149      return data
150  
151  
152  def compare_predictions(pred_dict: dict, true_label: list) -> bool:
153      """
154      Compares each prediction against the corresponding true label.
155  
156      This function checks whether the predicted values match the true values for each
157      metric. It sorts the true labels to ensure the comparison is made in the correct
158      order. The function returns True if all predictions are accurate within a small
159      tolerance for numerical values, or if string values match case-insensitively.
160  
161      Args:
162          pred_dict (dict): A dictionary of predicted metrics and their values.
163          true_label (list): A list of tuples containing true metrics and their values.
164  
165      Returns:
166          bool: True if all predictions match the true labels, False otherwise.
167      """
168      sorted_true_label = sorted(true_label, key=lambda x: x[0])  # Sort true labels by metric name
169  
170      for metric, true_value in sorted_true_label:
171          try:
172              true_value = float(true_value)  # Attempt to convert the true value to float
173          except ValueError:
174              true_value = true_value.replace(",", "")  # Clean the true value if conversion fails
175  
176          # Check if the true value is numeric and compare with the prediction
177          if isinstance(true_value, (int, float)) and (
178              metric not in pred_dict or abs(pred_dict[metric] - true_value) > 1e-6
179          ):
180              return False  # Return False if the prediction is inaccurate
181  
182          # Check if the true value is a string and compare with the prediction
183          if isinstance(true_value, str) and (
184              metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
185          ):
186              return False  # Return False if the string prediction does not match
187  
188      return True  # Return True if all predictions are accurate
189  
190  
191  async def ask(question: str, system_prompt: str) -> str:
192      """
193      Asynchronously sends a question to the LLM (Large Language Model) and retrieves the response.
194  
195      This function initializes an instance of the LLM and uses it to ask a question
196      along with a system prompt. The response from the LLM is awaited and returned.
197  
198      Args:
199          question (str): The question to be asked to the LLM.
200          system_prompt (str): A prompt that provides context or instructions to the LLM.
201  
202      Returns:
203          str: The response from the LLM based on the provided question and system prompt.
204      """
205      from metagpt.llm import LLM  # Importing the LLM class from the metagpt module
206  
207      llm = LLM()  # Create an instance of the LLM
208      rsp = await llm.aask(question, system_msgs=[system_prompt])  # Await the response from the LLM
209      return rsp  # Return the response
210  
211  
212  def parse_prediction(prediction: str) -> dict:
213      """
214      Parses a prediction string into a dictionary of metric-value pairs.
215  
216      This function takes a formatted string containing metrics and their corresponding
217      values, separated by the "@" symbol. Each metric may be enclosed in brackets and
218      may include commas. The function processes the input to extract and clean the
219      metrics and their values, returning them in a structured dictionary format.
220  
221      Args:
222          prediction (str): A string representation of metrics and their values.
223  
224      Returns:
225          dict: A dictionary where each key is a metric name and each value is the
226                corresponding value, either as a float or a string.
227      """
228      pred_dict = {}
229      for pred in prediction.split("@"):
230          if pred == "":
231              continue  # Skip any empty segments resulting from the split
232          temp = re.split(r"[\[\]]", pred.strip())  # Split the string by brackets
233          temp = [s.replace(",", "") for s in temp]  # Remove commas from the segments
234          parts = [s for s in temp if s]  # Filter out any empty strings
235          metric = parts[0].strip().replace(",", "")  # Extract and clean the metric name
236          value = parts[-1].replace(",", "").replace(":", "")  # Extract and clean the value
237  
238          try:
239              value = float(value)  # Attempt to convert the value to a float
240          except ValueError:
241              pass  # If conversion fails, retain the value as a string
242  
243          pred_dict[metric] = value  # Store the metric-value pair in the dictionary
244      return pred_dict
245  
246  
247  class DABench:
248      def __init__(
249          self,
250          questions_file: Path = Path(DABENCH_PATH) / "da-dev-questions.jsonl",
251          answers_file: Path = Path(DABENCH_PATH) / "da-dev-labels.jsonl",
252          template: str = "",
253      ):
254          """
255          Initializes the DABench instance with questions and answers.
256  
257          This constructor loads questions and answers from specified JSONL files.
258          It also sets a template for formatting prompts. If no template is provided,
259          a default template is used.
260  
261          Args:
262              questions_file (Path): The path to the JSONL file containing questions.
263              answers_file (Path): The path to the JSONL file containing answers.
264              template (str): A string template for formatting prompts.
265          """
266  
267          self.questions = {
268              int(line["id"]): line for line in load_jsonl(questions_file)
269          }  # Load questions from the specified file
270          self.answers = {
271              int(line["id"]): line for line in load_jsonl(answers_file)
272          }  # Load answers from the specified file
273          self.template = template if template else DABENCH  # Set the template, defaulting if necessary
274  
275      def get_question(self, question_id: str) -> dict:
276          """
277          Retrieve the question associated with the given ID.
278  
279          This method looks up a question by its unique identifier. If the question
280          is found, it returns the question data; otherwise, it returns a message
281          indicating that the question was not found.
282  
283          Args:
284              question_id (str): The unique identifier for the question.
285  
286          Returns:
287              dict: The question data if found, otherwise a "Question not found." message.
288          """
289          return self.questions.get(question_id, "Question not found.")  # Return the question or an error message
290  
291      def generate_formatted_prompt(self, question_id: str) -> str:
292          """
293          Generate a formatted prompt for the specified question ID.
294  
295          This method retrieves the question data and formats it using the specified
296          template. The formatted prompt includes the question, constraints, format,
297          file name, and level, allowing for a structured output.
298  
299          Args:
300              question_id (str): The unique identifier for the question.
301  
302          Returns:
303              str: A formatted prompt string based on the question data.
304          """
305          temp = self.get_question(question_id)  # Retrieve the question data
306          return self.template.format(
307              question=temp["question"],
308              constraints=temp["constraints"],
309              format=temp["format"],
310              file_name=str(DABENCH_PATH) + "/da-dev-tables/" + temp["file_name"],
311              level=temp["level"],
312          )  # Format and return the prompt
313  
314      def get_answer(self, answer_id: str) -> list:
315          """
316          Retrieve the answer list associated with the given ID.
317  
318          This method looks up an answer by its unique identifier. If the answer
319          is found, it returns the answer data; otherwise, it returns a message
320          indicating that the answer was not found.
321  
322          Args:
323              answer_id (str): The unique identifier for the answer.
324  
325          Returns:
326              list: The answer data if found, otherwise an "Answer not found." message.
327          """
328          return self.answers.get(answer_id, "Answer not found.")  # Return the answer or an error message
329  
330      @handle_exception(exception_msg="Error parsing cleaned prediction", default_return=(None, False))
331      def parse_cleaned_prediction(self, cleaned_prediction: str, true_label: Any) -> Tuple[str, bool]:
332          """
333          Parse the cleaned prediction and compare it with the true label.
334  
335          Args:
336              cleaned_prediction (str): The cleaned prediction string.
337              true_label (Any): The true label to compare against.
338  
339          Returns:
340              Tuple[str, bool]: A tuple containing the cleaned prediction and a boolean indicating
341                                whether it matches the true label.
342          """
343          if cleaned_prediction:  # Ensure the cleaned prediction is not empty
344              pred_dict = parse_prediction(cleaned_prediction)  # Parse the prediction
345              if pred_dict is not None and compare_predictions(pred_dict, true_label):
346                  return cleaned_prediction, True  # Return if the prediction matches the true label
347          return cleaned_prediction, False  # Return the cleaned prediction with a False match
348  
349      @handle_exception(exception_msg="Error during async reformat", default_return=(None, False))
350      def async_reformat_prediction(self, id: str, result: str) -> str:
351          """
352          Reformat the prediction asynchronously and extract the answer.
353  
354          Args:
355              id (str): The identifier for the question.
356              result (str): The original prediction result.
357  
358          Returns:
359              str: The reformatted prediction or the original prediction if extraction fails.
360          """
361          question = self.get_question(id)["question"]  # Retrieve the question based on the ID
362          question_format = self.get_question(id)["format"]  # Get the format of the question
363          prediction = asyncio.run(reformat(question, question_format, result))  # Asynchronously reformat the prediction
364  
365          # Attempt to extract the answer from the reformatted prediction
366          answer_part = prediction.split("Answer{{") if "Answer{{" in prediction else []
367          if len(answer_part) > 1:
368              return answer_part[1].split("}}")[0].strip()  # Return the extracted answer
369  
370          return prediction  # If extraction fails, return the original prediction
371  
372      def eval(self, id: str, result: str) -> Tuple[str, bool]:
373          """
374          Evaluate the prediction against the true label.
375  
376          Args:
377              id (str): The identifier for the question.
378              result (str): The original prediction result.
379  
380          Returns:
381              Tuple[str, bool]: A tuple containing the final prediction and a boolean indicating
382                                whether it matches the true label.
383          """
384          true_label = self.get_answer(id)["common_answers"]  # Retrieve the true label for comparison
385          nest_asyncio.apply()  # Apply nested asyncio to allow for async calls
386          result = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])[-1]["result"].strip()
387          cleaned_prediction = result.replace("{", "").replace("}", "").replace("'", "")  # Clean the prediction string
388  
389          # Use the decorated function to handle exceptions while parsing the cleaned prediction
390          parsed_result = self.parse_cleaned_prediction(cleaned_prediction, true_label)
391          if parsed_result[1]:  # If the parsed prediction is valid
392              return parsed_result  # Return the valid prediction
393  
394          # If the cleaned prediction is not valid, attempt to asynchronously reformat it
395          prediction = self.async_reformat_prediction(id, result)
396  
397          pred_dict = parse_prediction(prediction)  # Parse the reformatted prediction
398          if pred_dict is not None and compare_predictions(pred_dict, true_label):
399              return prediction, True  # Return if the reformatted prediction matches the true label
400  
401          return prediction, False  # Return the final prediction with a False match
402  
403      @handle_exception(exception_msg="Error evaluating single prediction", default_return={})
404      def single_eval(self, id: str, prediction: str) -> dict:
405          """
406          Evaluate the prediction against the true label for a single question.
407          just using in eval_all
408  
409          Args:
410              id (str): The identifier for the question.
411              prediction (str): The prediction string to evaluate.
412  
413          Returns:
414              dict: A dictionary indicating the correctness of each metric.
415          """
416          true_label = self.get_answer(id)["common_answers"]  # Retrieve the true label for the question
417          prediction = prediction.replace("{", "").replace("}", "").replace("'", "")  # Clean the prediction string
418          pred_dict = parse_prediction(prediction)  # Parse the prediction into a dictionary
419  
420          # Initialize the correctness dictionary with False values for each metric
421          correctness = {metric: False for metric, _ in true_label}
422  
423          # Check each metric's prediction against the true label
424          for metric, true_value in true_label:
425              try:
426                  true_value = float(true_value)  # Attempt to convert the true value to float
427              except ValueError:
428                  true_value = true_value.replace(",", "")  # Handle non-numeric values
429  
430              if metric in pred_dict:
431                  # Consider the prediction correct if it's within a small tolerance
432                  if (
433                      isinstance(true_value, (int, float))
434                      and isinstance(pred_dict[metric], (int, float))
435                      and abs(pred_dict[metric] - true_value) < 1e-6
436                  ):
437                      correctness[metric] = True  # Mark as correct if within tolerance
438  
439                  if isinstance(true_value, str) and (
440                      metric not in pred_dict or str(pred_dict[metric]).lower() != str(true_value).lower()
441                  ):
442                      correctness[metric] = True  # Mark as correct for string comparison
443  
444          return correctness  # Return the correctness dictionary
445  
446      def eval_all(self, id_list: list, predictions: list) -> dict:
447          """
448          Evaluate all predictions and calculate accuracy rates.
449  
450          Args:
451              id_list (list): A list of question identifiers.
452              predictions (list): A list of prediction strings corresponding to the questions.
453  
454          Returns:
455              dict: A dictionary containing accuracy rates by question and sub-question.
456          """
457          results = []  # Initialize a list to store results for each question
458  
459          # Evaluate each prediction against its corresponding question ID
460          for id, prediction in zip(id_list, predictions):
461              correct = self.single_eval(id, prediction)  # Evaluate the single prediction
462              results.append({"id": id, "correctness": correct})  # Append the result to the list
463  
464          # Calculate the three accuracy rates based on the results
465          accuracy_by_question = evaluate_accuracy_by_question(results)
466          accuracy_by_sub_question = evaluate_accuracy_by_sub_question(results)
467          proportional_accuracy_by_sub_question = evaluate_accuracy_proportional_by_sub_question_adjusted(results)
468  
469          return {
470              "accuracy_by_question": accuracy_by_question,
471              "accuracy_by_sub_question": accuracy_by_sub_question,
472              "proportional_accuracy_by_sub_question": proportional_accuracy_by_sub_question,
473          }
474  
475  
476  if __name__ == "__main__":
477      bench = DABench()
478      id = 0
479      prediction = "@mean_fare[34.65]"
480      logger.info(bench.eval(id, prediction))
481      ids = [0, 5, 6]
482      predictions = [
483          "@mean_fare[34.89]",
484          "@correlation_coefficient[0.21]",
485          "@mean_fare_child[31.09], @mean_fare_teenager[31.98], @mean_fare_adult[35.17], @mean_fare_elderly[43.47]",
486      ]
487      logger.info(bench.eval_all(ids, predictions))