/ qwencoder-eval / base / aggr_results.py
aggr_results.py
  1  import json
  2  from pathlib import Path
  3  import argparse
  4  from rich.console import Console
  5  from rich.table import Table
  6  
  7  
  8  file_cache = {}
  9  
 10  
 11  def ret(file, *fields):
 12      if file not in file_cache:
 13          with Path(file).open("r") as f:
 14              content = json.load(f)
 15          file_cache[file] = content
 16  
 17      cur = file_cache[file]
 18      for field in fields:
 19          cur = cur[field]
 20  
 21      return cur
 22  
 23  
 24  def err_payload():
 25      return {"-": "Err"}
 26  
 27  
 28  def try_collect_evalplus(path):
 29      try:
 30          payload = {
 31              "HumanEval": ret(path / "humaneval" / "result.json", "humaneval", "pass@1"),
 32              "HumanEval+": ret(path / "humaneval" / "result.json", "humaneval+", "pass@1"),
 33              "MBPP": ret(path / "mbpp" / "result.json", "mbpp", "pass@1"),
 34              "MBPP+": ret(path / "mbpp" / "result.json", "mbpp+", "pass@1"),
 35          }
 36      except:
 37          payload = err_payload()
 38  
 39      return {"EvalPlus": payload}
 40  
 41  
 42  def try_collect_multipl_e(path):
 43      try:
 44          lang_mapper = {"java": "Java", "cpp": "C++", "js": "JavaScript", "cs": "C#", "php": "php", "sh": "Shell", "ts": "TypeScript"}
 45          d = ret(path / "results.json")
 46          payload = {lang_mapper[k]: v for k, v in d.items()}
 47      except Exception as e:
 48          payload = err_payload()
 49  
 50      return {"MultiPL-E": payload}
 51  
 52  
 53  def try_collect_cruxeval(path):
 54      try:
 55          payload = {
 56              "Input-CoT": ret(path / "input-cot" / "results.json", "pass@1"),
 57              "Output-CoT": ret(path / "output-cot" / "results.json", "pass@1"),
 58          }
 59      except:
 60          payload = err_payload()
 61  
 62      return {"CRUX-Eval": payload}
 63  
 64  
 65  def try_collect_bigcodebench(path):
 66      try:
 67          payload = {
 68              "full": ret(path / "full" / "results.json", "pass@1"),
 69              "hard": ret(path / "hard" / "results.json", "pass@1"),
 70          }
 71      except:
 72          payload = err_payload()
 73  
 74      return {"BigCodeBench": payload}
 75  
 76  
 77  if __name__ == "__main__":
 78      parser = argparse.ArgumentParser()
 79      parser.add_argument("folder")
 80      args = parser.parse_args()
 81  
 82      folder = Path(args.folder)
 83      assert folder.is_dir(), f"{folder} is not dir"
 84  
 85      collected = {}
 86      collected.update(try_collect_evalplus(folder / "evalplus"))
 87      collected.update(try_collect_multipl_e(folder / "multipl-e"))
 88      collected.update(try_collect_cruxeval(folder / "cruxeval"))
 89      collected.update(try_collect_bigcodebench(folder / "bigcodebench"))
 90  
 91      table = Table(title=f"[b][u][red]{args.folder}[/red][/u][/b]")
 92      table.add_column("Benchmark", justify="left", no_wrap=True)
 93      table.add_column("Metric", justify="left", no_wrap=True)
 94      table.add_column("Score", justify="right", no_wrap=True)
 95  
 96      for bench, details in collected.items():
 97          center = (len(details) - 1) // 2
 98          for idx, (k, v) in enumerate(details.items()):
 99              if not isinstance(v, str):
100                  v = f"{float(v):5.1f}"
101              table.add_row(f"[b]{bench}[/b]" if idx == center else "", k, v)
102          table.add_section()
103  
104      print()
105      Console().print(table)
106      print()
107  
108      saved = folder.joinpath("all_results.json")
109      with Path(saved).open("w") as f:
110          json.dump(collected, f, indent=2, ensure_ascii=False)
111  
112      print(f"All results => {saved}")