evaluate.py
1 import argparse 2 import json 3 import multiprocessing 4 import os 5 import pickle 6 import threading 7 import time 8 from collections import Counter, defaultdict 9 from concurrent.futures import ProcessPoolExecutor, as_completed 10 from datetime import datetime 11 from typing import Any, Dict, List, Tuple 12 from warnings import warn 13 14 import numpy as np 15 from termcolor import cprint 16 from tqdm import tqdm 17 18 from evalplus.data import ( 19 get_human_eval_plus, 20 get_human_eval_plus_hash, 21 get_mbpp_plus, 22 get_mbpp_plus_hash, 23 load_solutions, 24 ) 25 from evalplus.data.mbpp import mbpp_serialize_inputs 26 27 # from evalplus.data.utils import CACHE_DIR 28 CACHE_DIR = "/home/data/xixi.yjx/eval_cache/evalplus" 29 from evalplus.eval import ( 30 PASS, 31 compatible_eval_result, 32 estimate_pass_at_k, 33 untrusted_check, 34 ) 35 from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS 36 from evalplus.gen.util import trusted_exec 37 38 # 1st item: the status 39 # 2nd item (optional): the detailed pass/fail boolean for each input 40 Result = Tuple[str, List[bool]] 41 42 43 def get_groundtruth(problems, hashcode, tasks_only_output_not_none): 44 cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl") 45 if os.path.exists(cache_file): 46 print(f"Load from ground-truth from {cache_file}") 47 with open(cache_file, "rb") as f: 48 return pickle.load(f) 49 50 os.makedirs(CACHE_DIR, exist_ok=True) 51 print("Computing expected output...") 52 tbegin = time.time() 53 expected_output = {} 54 for task_id, problem in problems.items(): 55 oracle = {} 56 oracle["base"], oracle["base_time"] = trusted_exec( 57 problem["prompt"] + problem["canonical_solution"], 58 problem["base_input"], 59 problem["entry_point"], 60 record_time=True, 61 output_not_none=problem["entry_point"] in tasks_only_output_not_none, 62 ) 63 64 oracle["plus"], oracle["plus_time"] = trusted_exec( 65 problem["prompt"] + problem["canonical_solution"], 66 problem["plus_input"], 67 problem["entry_point"], 68 record_time=True, 69 output_not_none=problem["entry_point"] in tasks_only_output_not_none, 70 ) 71 expected_output[task_id] = oracle 72 print(f"Expected outputs computed in {time.time() - tbegin:.2f}s") 73 74 with open(cache_file, "wb") as f: 75 pickle.dump(expected_output, f) 76 77 return expected_output 78 79 80 def check_correctness( 81 dataset: str, 82 completion_id: int, 83 problem: Dict[str, Any], 84 solution: str, 85 expected_output: Dict[str, List], 86 base_only=False, 87 fast_check=False, 88 identifier=None, 89 min_time_limit: float = 0.1, 90 gt_time_limit_factor: float = 2.0, 91 ) -> Dict[str, Result]: # {...}, "base" | "plus" -> (status, details) 92 ret = { 93 "completion_id": completion_id, 94 "task_id": problem["task_id"], 95 "_identifier": identifier, 96 "solution": solution, 97 } 98 ret["base"] = untrusted_check( 99 dataset, 100 solution, 101 problem["base_input"], 102 problem["entry_point"], 103 expected=expected_output["base"], 104 atol=problem["atol"], 105 ref_time=expected_output["base_time"], 106 fast_check=fast_check, 107 min_time_limit=min_time_limit, 108 gt_time_limit_factor=gt_time_limit_factor, 109 ) 110 111 if not base_only: 112 ret["plus"] = untrusted_check( 113 dataset, 114 solution, 115 problem["plus_input"], 116 problem["entry_point"], 117 expected=expected_output["plus"], 118 atol=problem["atol"], 119 ref_time=expected_output["plus_time"], 120 fast_check=fast_check, 121 min_time_limit=min_time_limit, 122 gt_time_limit_factor=gt_time_limit_factor, 123 ) 124 125 return ret 126 127 128 def evaluate(flags): 129 if flags.parallel is None: 130 n_workers = max(1, multiprocessing.cpu_count() // 2) 131 else: 132 n_workers = flags.parallel 133 134 if os.path.isdir(flags.samples): 135 result_path = os.path.join(flags.samples, "eval_results.json") 136 else: 137 assert flags.samples.endswith(".jsonl") 138 result_path = flags.samples.replace(".jsonl", "_eval_results.json") 139 140 if os.path.isfile(result_path) and not flags.i_just_wanna_run: 141 print(f"Load from previous results from {result_path}") 142 with open(result_path, "r") as f: 143 results = json.load(f) 144 145 results = compatible_eval_result(results) 146 else: 147 if flags.dataset == "humaneval": 148 problems = get_human_eval_plus(mini=flags.mini, noextreme=flags.noextreme, version=flags.version) 149 dataset_hash = get_human_eval_plus_hash(mini=flags.mini, noextreme=flags.noextreme, version=flags.version) 150 expected_output = get_groundtruth(problems, dataset_hash, []) 151 elif flags.dataset == "mbpp": 152 problems = get_mbpp_plus(mini=flags.mini, noextreme=flags.noextreme, version=flags.version) 153 dataset_hash = get_mbpp_plus_hash(mini=flags.mini, noextreme=flags.noextreme, version=flags.version) 154 expected_output = get_groundtruth( 155 problems, 156 dataset_hash, 157 MBPP_OUTPUT_NOT_NONE_TASKS, 158 ) 159 160 results = { 161 "date": datetime.now().strftime("%Y-%m-%d %H:%M"), 162 "hash": dataset_hash, 163 "eval": {}, 164 } 165 166 with ProcessPoolExecutor(max_workers=n_workers) as executor: 167 futures = [] 168 completion_id = Counter() 169 n_samples = 0 170 eval_results = defaultdict(list) # task_id -> 171 remainings = set() 172 173 print("Reading samples...") 174 for sample in tqdm(load_solutions(flags.samples)): 175 task_id = sample["task_id"] 176 if task_id not in problems: 177 warn(f"Task {task_id} is found in the samples but not found in the dataset") 178 continue 179 solution = sample["solution"] if "solution" in sample else problems[task_id]["prompt"] + sample["completion"] 180 remainings.add(sample["_identifier"]) 181 args = ( 182 flags.dataset, 183 completion_id[task_id], 184 problems[task_id], 185 solution, 186 expected_output[task_id], 187 flags.base_only, 188 not flags.test_details, # fast_check 189 sample["_identifier"], 190 flags.min_time_limit, 191 flags.gt_time_limit_factor, 192 ) 193 futures.append(executor.submit(check_correctness, *args)) 194 completion_id[task_id] += 1 195 n_samples += 1 196 197 assert n_samples == len(remainings), "Missing problems in unfinished" 198 assert len(completion_id) == len(problems), "Missing problems in samples" 199 200 def stucking_checker(): 201 while remainings: 202 last_size = len(remainings) 203 time.sleep(20) 204 if last_size != len(remainings) or len(remainings) == 0: 205 continue 206 # Potential stucking 207 warn("No samples had finished testing in the last 20s") 208 warn(f"{len(remainings)} samples to be tested: {remainings}") 209 210 threading.Thread(target=stucking_checker).start() 211 212 for future in tqdm(as_completed(futures), total=n_samples): 213 result = future.result() 214 remainings.remove(result["_identifier"]) 215 eval_results[result["task_id"]].append(result) 216 217 # sort the results for each problem by completion_id 218 for task_id, task_results in eval_results.items(): 219 task_results.sort(key=lambda x: x["completion_id"]) 220 results["eval"][task_id] = [] 221 for res in task_results: 222 223 def get_failed_tests(stat, details, inputs) -> List[Any]: 224 if stat == PASS or not details: 225 return [] 226 227 if flags.test_details: 228 return [inputs[i] for i in range(len(details)) if not details[i]] 229 230 # else => simply return the only and the last fail test 231 return [inputs[len(details) - 1]] 232 233 base_stat, base_details = res["base"] 234 base_fail_tests = get_failed_tests(base_stat, base_details, problems[task_id]["base_input"]) 235 236 # initialize plus tests 237 plus_stat = None 238 plus_fail_tests = [] 239 240 # with plus tests 241 if not flags.base_only: 242 plus_stat, plus_details = res["plus"] 243 plus_fail_tests = get_failed_tests(plus_stat, plus_details, problems[task_id]["plus_input"]) 244 245 if flags.dataset == "mbpp": 246 base_fail_tests = mbpp_serialize_inputs(task_id, base_fail_tests) 247 plus_fail_tests = mbpp_serialize_inputs(task_id, plus_fail_tests) 248 249 results["eval"][task_id].append( 250 { 251 "task_id": task_id, 252 "solution": res["solution"], 253 "base_status": base_stat, 254 "plus_status": plus_stat, 255 "base_fail_tests": base_fail_tests, 256 "plus_fail_tests": plus_fail_tests, 257 } 258 ) 259 260 # Calculate pass@k. 261 total = np.array([len(r) for r in results["eval"].values()]) 262 base_correct = [] 263 new_correct = [] 264 265 for res in results["eval"].values(): 266 bc = sum([r["base_status"] == PASS for r in res]) 267 base_correct.append(bc) 268 if not flags.base_only: 269 new_correct.append(sum([res[i]["base_status"] == res[i]["plus_status"] == PASS for i in range(len(res))])) 270 base_correct = np.array(base_correct) 271 272 pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean() for k in [1, 10, 100] if total.min() >= k} 273 cprint(f"{flags.dataset} (base tests)", "red") 274 for k, v in pass_at_k.items(): 275 cprint(f"{k}:\t{v:.3f}", "red") 276 277 if new_correct: 278 cprint(f"{flags.dataset}+ (base + extra tests)", "green") 279 pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, np.array(new_correct), k).mean() for k in [1, 10, 100] if (total >= k).all()} 280 for k, v in pass_at_k.items(): 281 cprint(f"{k}:\t{v:.3f}", "green") 282 283 # save results 284 if os.path.isfile(result_path) and flags.i_just_wanna_run: 285 decision = "" 286 while decision.lower() not in ["y", "n"]: 287 print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...") 288 decision = input() 289 290 if decision.lower() == "y": 291 # mv the file to a backup 292 new_path = result_path + ".bak" 293 while os.path.isfile(new_path): 294 new_path += ".bak" 295 os.rename(result_path, new_path) 296 print(f"Backup {result_path} to {new_path}") 297 298 if not os.path.isfile(result_path): 299 with open(result_path, "w") as f: 300 json.dump(results, f) 301 302 303 def main(): 304 parser = argparse.ArgumentParser() 305 parser.add_argument("--dataset", required=True, type=str, choices=["humaneval", "mbpp"]) 306 parser.add_argument("--samples", required=True, type=str) 307 parser.add_argument("--base-only", action="store_true") 308 parser.add_argument("--parallel", default=None, type=int) 309 parser.add_argument("--i-just-wanna-run", action="store_true") 310 parser.add_argument("--test-details", action="store_true") 311 parser.add_argument("--min-time-limit", default=1, type=float) 312 parser.add_argument("--gt-time-limit-factor", default=4.0, type=float) 313 parser.add_argument("--mini", action="store_true") 314 parser.add_argument("--noextreme", action="store_true", help="Omit extreme test inputs") 315 parser.add_argument("--version", default="default", type=str, help="Version of the dataset") 316 args = parser.parse_args() 317 318 evaluate(args) 319 320 321 if __name__ == "__main__": 322 main()