aggr_results.py
1 import json 2 from pathlib import Path 3 import argparse 4 from rich.console import Console 5 from rich.table import Table 6 7 8 file_cache = {} 9 10 11 def ret(file, *fields): 12 if file not in file_cache: 13 with Path(file).open("r") as f: 14 content = json.load(f) 15 file_cache[file] = content 16 17 cur = file_cache[file] 18 for field in fields: 19 cur = cur[field] 20 21 return cur 22 23 24 def err_payload(): 25 return {"-": "Err"} 26 27 28 def try_collect_evalplus(path): 29 try: 30 payload = { 31 "HumanEval": ret(path / "humaneval" / "result.json", "humaneval", "pass@1"), 32 "HumanEval+": ret(path / "humaneval" / "result.json", "humaneval+", "pass@1"), 33 "MBPP": ret(path / "mbpp" / "result.json", "mbpp", "pass@1"), 34 "MBPP+": ret(path / "mbpp" / "result.json", "mbpp+", "pass@1"), 35 } 36 except: 37 payload = err_payload() 38 39 return {"EvalPlus": payload} 40 41 42 def try_collect_multipl_e(path): 43 try: 44 lang_mapper = {"java": "Java", "cpp": "C++", "js": "JavaScript", "cs": "C#", "php": "php", "sh": "Shell", "ts": "TypeScript"} 45 d = ret(path / "results.json") 46 payload = {lang_mapper[k]: v for k, v in d.items()} 47 except Exception as e: 48 payload = err_payload() 49 50 return {"MultiPL-E": payload} 51 52 53 def try_collect_cruxeval(path): 54 try: 55 payload = { 56 "Input-CoT": ret(path / "input-cot" / "results.json", "pass@1"), 57 "Output-CoT": ret(path / "output-cot" / "results.json", "pass@1"), 58 } 59 except: 60 payload = err_payload() 61 62 return {"CRUX-Eval": payload} 63 64 65 def try_collect_bigcodebench(path): 66 try: 67 payload = { 68 "full": ret(path / "full" / "results.json", "pass@1"), 69 "hard": ret(path / "hard" / "results.json", "pass@1"), 70 } 71 except: 72 payload = err_payload() 73 74 return {"BigCodeBench": payload} 75 76 77 if __name__ == "__main__": 78 parser = argparse.ArgumentParser() 79 parser.add_argument("folder") 80 args = parser.parse_args() 81 82 folder = Path(args.folder) 83 assert folder.is_dir(), f"{folder} is not dir" 84 85 collected = {} 86 collected.update(try_collect_evalplus(folder / "evalplus")) 87 collected.update(try_collect_multipl_e(folder / "multipl-e")) 88 collected.update(try_collect_cruxeval(folder / "cruxeval")) 89 collected.update(try_collect_bigcodebench(folder / "bigcodebench")) 90 91 table = Table(title=f"[b][u][red]{args.folder}[/red][/u][/b]") 92 table.add_column("Benchmark", justify="left", no_wrap=True) 93 table.add_column("Metric", justify="left", no_wrap=True) 94 table.add_column("Score", justify="right", no_wrap=True) 95 96 for bench, details in collected.items(): 97 center = (len(details) - 1) // 2 98 for idx, (k, v) in enumerate(details.items()): 99 if not isinstance(v, str): 100 v = f"{float(v):5.1f}" 101 table.add_row(f"[b]{bench}[/b]" if idx == center else "", k, v) 102 table.add_section() 103 104 print() 105 Console().print(table) 106 print() 107 108 saved = folder.joinpath("all_results.json") 109 with Path(saved).open("w") as f: 110 json.dump(collected, f, indent=2, ensure_ascii=False) 111 112 print(f"All results => {saved}")