evaluate.py
1 import argparse 2 import json 3 import multiprocessing 4 import os 5 import pickle 6 import threading 7 import time 8 from collections import Counter, defaultdict 9 from concurrent.futures import ProcessPoolExecutor, as_completed 10 from datetime import datetime 11 from typing import Any, Dict, List, Tuple 12 from warnings import warn 13 14 import numpy as np 15 from termcolor import cprint 16 from tqdm import tqdm 17 18 from data import ( 19 get_bigcodebench, 20 get_bigcodebench_hash, 21 load_solutions, 22 ) 23 from data.utils import CACHE_DIR 24 from eval import ( 25 PASS, 26 compatible_eval_result, 27 estimate_pass_at_k, 28 untrusted_check, 29 ) 30 from gen.util import trusted_check 31 32 # 1st item: the status 33 # 2nd item (optional): the detailed pass/fail boolean for each input 34 Result = Tuple[str, List[bool]] 35 36 37 def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit): 38 cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl") 39 if os.path.exists(cache_file): 40 if check_gt_only: 41 os.remove(cache_file) 42 else: 43 print(f"Load from ground-truth from {cache_file}") 44 with open(cache_file, "rb") as f: 45 return pickle.load(f) 46 47 os.makedirs(CACHE_DIR, exist_ok=True) 48 print("\nAsserting the groundtruth...") 49 tbegin = time.time() 50 51 with ProcessPoolExecutor(max_workers=n_workers) as executor: 52 futures = [] 53 n_samples = 0 54 expected_time = dict() 55 56 for problem in problems.values(): 57 args = ( 58 problem["complete_prompt"] + "\n" + problem["canonical_solution"], 59 problem["test"], 60 problem["task_id"], 61 max_as_limit, 62 max_data_limit, 63 max_stack_limit, 64 ) 65 66 futures.append(executor.submit(trusted_check, *args)) 67 n_samples += 1 68 69 for future in tqdm(as_completed(futures), total=n_samples): 70 result = future.result() 71 expected_time[result["task_id"]] = result["time"] 72 73 print(f"Expected outputs computed in {time.time() - tbegin:.2f}s") 74 75 if any(expected_time.values()): 76 with open(cache_file, "wb") as f: 77 pickle.dump(expected_time, f) 78 79 return expected_time 80 81 82 def check_correctness( 83 completion_id: int, 84 problem: Dict[str, Any], 85 solution: str, 86 max_as_limit: float, 87 max_data_limit: float, 88 max_stack_limit: float, 89 identifier=None, 90 min_time_limit: float = 0.1, 91 gt_time_limit: float = 2.0, 92 ) -> Dict[str, Result]: # {...}, "base" | "plus" -> (status, details) 93 ret = { 94 "completion_id": completion_id, 95 "task_id": problem["task_id"], 96 "_identifier": identifier, 97 "solution": solution, 98 } 99 ret["base"] = untrusted_check( 100 solution, 101 problem["test"], 102 problem["entry_point"], 103 max_as_limit, 104 max_data_limit, 105 max_stack_limit, 106 min_time_limit, 107 gt_time_limit, 108 ) 109 return ret 110 111 112 def evaluate(flags): 113 if flags.parallel is None: 114 n_workers = max(1, multiprocessing.cpu_count() // 2) 115 else: 116 n_workers = flags.parallel 117 118 if flags.check_gt_only: 119 # bypass the samples 120 flags.samples = "__dummy__.jsonl" 121 122 extra = flags.subset + "_" if flags.subset != "full" else "" 123 if os.path.isdir(flags.samples): 124 result_path = os.path.join(flags.samples, f"{extra}eval_results.json") 125 else: 126 assert flags.samples.endswith(".jsonl") 127 result_path = flags.samples.replace(".jsonl", f"_{extra}eval_results.json") 128 129 problems = get_bigcodebench(subset=flags.subset) 130 dataset_hash = get_bigcodebench_hash(subset=flags.subset) 131 132 if not flags.no_gt: 133 expected_time = get_groundtruth( 134 n_workers, problems, dataset_hash, flags.check_gt_only, flags.max_as_limit, flags.max_data_limit, flags.max_stack_limit 135 ) 136 else: 137 expected_time = {task_id: None for task_id in problems} 138 139 gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems]) 140 141 if os.path.isfile(result_path): 142 print(f"Load from previous results from {result_path}") 143 with open(result_path, "r") as f: 144 results = json.load(f) 145 146 results = compatible_eval_result(results) 147 else: 148 if flags.check_gt_only: 149 150 if gt_pass_rate > 0.99: 151 cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}", "green") 152 else: 153 cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}\nPlease be cautious!", "red") 154 return 155 156 results = { 157 "date": datetime.now().strftime("%Y-%m-%d %H:%M"), 158 "eval": {}, 159 } 160 161 with ProcessPoolExecutor(max_workers=n_workers) as executor: 162 futures = [] 163 completion_id = Counter() 164 n_samples = 0 165 eval_results = defaultdict(list) # task_id -> 166 remainings = set() 167 168 print("Reading samples...") 169 for sample in tqdm(load_solutions(flags.samples)): 170 task_id = sample["task_id"] 171 172 if task_id not in problems: 173 warn(f"Task {task_id} is found in the samples but not found in the dataset") 174 continue 175 solution = sample["solution"] if "solution" in sample else problems[task_id]["complete_prompt"] + sample["completion"] 176 if "sanitized-calibrated" in flags.samples: 177 solution = problems[task_id]["code_prompt"] + "\n pass\n" + solution 178 remainings.add(sample["_identifier"]) 179 args = ( 180 completion_id[task_id], 181 problems[task_id], 182 solution, 183 flags.max_as_limit, 184 flags.max_data_limit, 185 flags.max_stack_limit, 186 sample["_identifier"], 187 flags.min_time_limit, 188 expected_time[task_id] if expected_time[task_id] else 20, 189 ) 190 futures.append(executor.submit(check_correctness, *args)) 191 completion_id[task_id] += 1 192 n_samples += 1 193 194 assert n_samples == len(remainings), "Missing problems in unfinished" 195 assert len(completion_id) == len(problems), "Missing problems in samples" 196 197 def stucking_checker(): 198 while remainings: 199 last_size = len(remainings) 200 time.sleep(30) 201 if last_size != len(remainings) or len(remainings) == 0: 202 continue 203 # Potential stucking 204 warn("No samples had finished testing in the last 30s") 205 warn(f"{len(remainings)} samples to be tested: {remainings}") 206 207 threading.Thread(target=stucking_checker).start() 208 209 for future in tqdm(as_completed(futures), total=n_samples): 210 result = future.result() 211 remainings.remove(result["_identifier"]) 212 eval_results[result["task_id"]].append(result) 213 214 # sort the results for each problem by completion_id 215 for task_id, task_results in eval_results.items(): 216 task_results.sort(key=lambda x: x["completion_id"]) 217 results["eval"][task_id] = [] 218 for res in task_results: 219 stat, details = res["base"] 220 results["eval"][task_id].append( 221 { 222 "task_id": task_id, 223 "solution": res["solution"], 224 "status": stat, 225 "details": details, 226 } 227 ) 228 229 # Calculate pass@k. 230 total = np.array([len(r) for k, r in results["eval"].items() if k in problems]) 231 base_correct = [] 232 233 for key, res in results["eval"].items(): 234 if key not in problems: 235 continue 236 bc = sum([r["status"] == PASS for r in res]) 237 base_correct.append(bc) 238 239 base_correct = np.array(base_correct) 240 241 pass_at_k = {f"pass@{k}": 100 * estimate_pass_at_k(total, base_correct, k).mean() for k in [1, 5, 10, 25, 100] if total.min() >= k} 242 243 mode = "-calibrated" if "sanitized-calibrated" in flags.samples else "" 244 extra = flags.subset.capitalize() 245 flags.split = flags.split.capitalize() 246 cprint(f"BigCodeBench-{flags.split}{mode} ({extra})", "green") 247 248 if flags.no_gt: 249 cprint(f"Groundtruth is not checked", "yellow") 250 else: 251 if gt_pass_rate > 0.99: 252 cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}", "green") 253 else: 254 cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}\nPlease be cautious!", "red") 255 256 for k, v in pass_at_k.items(): 257 cprint(f"{k}:\t{v:.3f}", "green") 258 259 # save results 260 with open(result_path, "w") as f: 261 json.dump(results, f, indent=2) 262 263 pass_at_k_path = os.path.join(os.path.dirname(result_path), "results.json") 264 pass_at_k["model"] = os.path.basename(flags.samples).split("--bigcodebench-")[0] 265 pass_at_k["calibrated"] = "sanitized-calibrated" in flags.samples 266 pass_at_k["subset"] = flags.subset 267 268 def save_pass_at_k(): 269 with open(pass_at_k_path, "w") as f: 270 json.dump(pass_at_k, f, indent=2) 271 272 save_pass_at_k() 273 274 275 def main(): 276 parser = argparse.ArgumentParser() 277 parser.add_argument("--split", required=True, type=str, choices=["complete", "instruct"]) 278 parser.add_argument("--subset", default="full", type=str, choices=["full", "hard"]) 279 parser.add_argument("--samples", required=True, type=str) 280 parser.add_argument("--parallel", default=None, type=int) 281 parser.add_argument("--min-time-limit", default=1, type=float) 282 parser.add_argument("--max-as-limit", default=128 * 1024, type=int) 283 parser.add_argument("--max-data-limit", default=4 * 1024, type=int) 284 parser.add_argument("--max-stack-limit", default=5, type=int) 285 parser.add_argument("--check-gt-only", action="store_true", help="Check the groundtruth") 286 parser.add_argument("--no-gt", action="store_true", help="Check the groundtruth") 287 args = parser.parse_args() 288 289 evaluate(args) 290 291 292 if __name__ == "__main__": 293 main()