evaluate.py
  1  import argparse
  2  import json
  3  import multiprocessing
  4  import os
  5  import pickle
  6  import threading
  7  import time
  8  from collections import Counter, defaultdict
  9  from concurrent.futures import ProcessPoolExecutor, as_completed
 10  from datetime import datetime
 11  from typing import Any, Dict, List, Tuple
 12  from warnings import warn
 13  
 14  import numpy as np
 15  from termcolor import cprint
 16  from tqdm import tqdm
 17  
 18  from evalplus.data import (
 19      get_human_eval_plus,
 20      get_human_eval_plus_hash,
 21      get_mbpp_plus,
 22      get_mbpp_plus_hash,
 23      load_solutions,
 24  )
 25  from evalplus.data.mbpp import mbpp_serialize_inputs
 26  
 27  # from evalplus.data.utils import CACHE_DIR
 28  CACHE_DIR = "/home/data/xixi.yjx/eval_cache/evalplus"
 29  from evalplus.eval import (
 30      PASS,
 31      compatible_eval_result,
 32      estimate_pass_at_k,
 33      untrusted_check,
 34  )
 35  from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS
 36  from evalplus.gen.util import trusted_exec
 37  
 38  # 1st item: the status
 39  # 2nd item (optional): the detailed pass/fail boolean for each input
 40  Result = Tuple[str, List[bool]]
 41  
 42  
 43  def get_groundtruth(problems, hashcode, tasks_only_output_not_none):
 44      cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl")
 45      if os.path.exists(cache_file):
 46          print(f"Load from ground-truth from {cache_file}")
 47          with open(cache_file, "rb") as f:
 48              return pickle.load(f)
 49  
 50      os.makedirs(CACHE_DIR, exist_ok=True)
 51      print("Computing expected output...")
 52      tbegin = time.time()
 53      expected_output = {}
 54      for task_id, problem in problems.items():
 55          oracle = {}
 56          oracle["base"], oracle["base_time"] = trusted_exec(
 57              problem["prompt"] + problem["canonical_solution"],
 58              problem["base_input"],
 59              problem["entry_point"],
 60              record_time=True,
 61              output_not_none=problem["entry_point"] in tasks_only_output_not_none,
 62          )
 63  
 64          oracle["plus"], oracle["plus_time"] = trusted_exec(
 65              problem["prompt"] + problem["canonical_solution"],
 66              problem["plus_input"],
 67              problem["entry_point"],
 68              record_time=True,
 69              output_not_none=problem["entry_point"] in tasks_only_output_not_none,
 70          )
 71          expected_output[task_id] = oracle
 72      print(f"Expected outputs computed in {time.time() - tbegin:.2f}s")
 73  
 74      with open(cache_file, "wb") as f:
 75          pickle.dump(expected_output, f)
 76  
 77      return expected_output
 78  
 79  
 80  def check_correctness(
 81      dataset: str,
 82      completion_id: int,
 83      problem: Dict[str, Any],
 84      solution: str,
 85      expected_output: Dict[str, List],
 86      base_only=False,
 87      fast_check=False,
 88      identifier=None,
 89      min_time_limit: float = 0.1,
 90      gt_time_limit_factor: float = 2.0,
 91  ) -> Dict[str, Result]:  # {...}, "base" | "plus" -> (status, details)
 92      ret = {
 93          "completion_id": completion_id,
 94          "task_id": problem["task_id"],
 95          "_identifier": identifier,
 96          "solution": solution,
 97      }
 98      ret["base"] = untrusted_check(
 99          dataset,
100          solution,
101          problem["base_input"],
102          problem["entry_point"],
103          expected=expected_output["base"],
104          atol=problem["atol"],
105          ref_time=expected_output["base_time"],
106          fast_check=fast_check,
107          min_time_limit=min_time_limit,
108          gt_time_limit_factor=gt_time_limit_factor,
109      )
110  
111      if not base_only:
112          ret["plus"] = untrusted_check(
113              dataset,
114              solution,
115              problem["plus_input"],
116              problem["entry_point"],
117              expected=expected_output["plus"],
118              atol=problem["atol"],
119              ref_time=expected_output["plus_time"],
120              fast_check=fast_check,
121              min_time_limit=min_time_limit,
122              gt_time_limit_factor=gt_time_limit_factor,
123          )
124  
125      return ret
126  
127  
128  def evaluate(flags):
129      if flags.parallel is None:
130          n_workers = max(1, multiprocessing.cpu_count() // 2)
131      else:
132          n_workers = flags.parallel
133  
134      if os.path.isdir(flags.samples):
135          result_path = os.path.join(flags.samples, "eval_results.json")
136      else:
137          assert flags.samples.endswith(".jsonl")
138          result_path = flags.samples.replace(".jsonl", "_eval_results.json")
139  
140      if os.path.isfile(result_path) and not flags.i_just_wanna_run:
141          print(f"Load from previous results from {result_path}")
142          with open(result_path, "r") as f:
143              results = json.load(f)
144  
145          results = compatible_eval_result(results)
146      else:
147          if flags.dataset == "humaneval":
148              problems = get_human_eval_plus(mini=flags.mini, noextreme=flags.noextreme, version=flags.version)
149              dataset_hash = get_human_eval_plus_hash(mini=flags.mini, noextreme=flags.noextreme, version=flags.version)
150              expected_output = get_groundtruth(problems, dataset_hash, [])
151          elif flags.dataset == "mbpp":
152              problems = get_mbpp_plus(mini=flags.mini, noextreme=flags.noextreme, version=flags.version)
153              dataset_hash = get_mbpp_plus_hash(mini=flags.mini, noextreme=flags.noextreme, version=flags.version)
154              expected_output = get_groundtruth(
155                  problems,
156                  dataset_hash,
157                  MBPP_OUTPUT_NOT_NONE_TASKS,
158              )
159  
160          results = {
161              "date": datetime.now().strftime("%Y-%m-%d %H:%M"),
162              "hash": dataset_hash,
163              "eval": {},
164          }
165  
166          with ProcessPoolExecutor(max_workers=n_workers) as executor:
167              futures = []
168              completion_id = Counter()
169              n_samples = 0
170              eval_results = defaultdict(list)  # task_id ->
171              remainings = set()
172  
173              print("Reading samples...")
174              for sample in tqdm(load_solutions(flags.samples)):
175                  task_id = sample["task_id"]
176                  if task_id not in problems:
177                      warn(f"Task {task_id} is found in the samples but not found in the dataset")
178                      continue
179                  solution = sample["solution"] if "solution" in sample else problems[task_id]["prompt"] + sample["completion"]
180                  remainings.add(sample["_identifier"])
181                  args = (
182                      flags.dataset,
183                      completion_id[task_id],
184                      problems[task_id],
185                      solution,
186                      expected_output[task_id],
187                      flags.base_only,
188                      not flags.test_details,  # fast_check
189                      sample["_identifier"],
190                      flags.min_time_limit,
191                      flags.gt_time_limit_factor,
192                  )
193                  futures.append(executor.submit(check_correctness, *args))
194                  completion_id[task_id] += 1
195                  n_samples += 1
196  
197              assert n_samples == len(remainings), "Missing problems in unfinished"
198              assert len(completion_id) == len(problems), "Missing problems in samples"
199  
200              def stucking_checker():
201                  while remainings:
202                      last_size = len(remainings)
203                      time.sleep(20)
204                      if last_size != len(remainings) or len(remainings) == 0:
205                          continue
206                      # Potential stucking
207                      warn("No samples had finished testing in the last 20s")
208                      warn(f"{len(remainings)} samples to be tested: {remainings}")
209  
210              threading.Thread(target=stucking_checker).start()
211  
212              for future in tqdm(as_completed(futures), total=n_samples):
213                  result = future.result()
214                  remainings.remove(result["_identifier"])
215                  eval_results[result["task_id"]].append(result)
216  
217          # sort the results for each problem by completion_id
218          for task_id, task_results in eval_results.items():
219              task_results.sort(key=lambda x: x["completion_id"])
220              results["eval"][task_id] = []
221              for res in task_results:
222  
223                  def get_failed_tests(stat, details, inputs) -> List[Any]:
224                      if stat == PASS or not details:
225                          return []
226  
227                      if flags.test_details:
228                          return [inputs[i] for i in range(len(details)) if not details[i]]
229  
230                      # else => simply return the only and the last fail test
231                      return [inputs[len(details) - 1]]
232  
233                  base_stat, base_details = res["base"]
234                  base_fail_tests = get_failed_tests(base_stat, base_details, problems[task_id]["base_input"])
235  
236                  # initialize plus tests
237                  plus_stat = None
238                  plus_fail_tests = []
239  
240                  # with plus tests
241                  if not flags.base_only:
242                      plus_stat, plus_details = res["plus"]
243                      plus_fail_tests = get_failed_tests(plus_stat, plus_details, problems[task_id]["plus_input"])
244  
245                  if flags.dataset == "mbpp":
246                      base_fail_tests = mbpp_serialize_inputs(task_id, base_fail_tests)
247                      plus_fail_tests = mbpp_serialize_inputs(task_id, plus_fail_tests)
248  
249                  results["eval"][task_id].append(
250                      {
251                          "task_id": task_id,
252                          "solution": res["solution"],
253                          "base_status": base_stat,
254                          "plus_status": plus_stat,
255                          "base_fail_tests": base_fail_tests,
256                          "plus_fail_tests": plus_fail_tests,
257                      }
258                  )
259  
260      # Calculate pass@k.
261      total = np.array([len(r) for r in results["eval"].values()])
262      base_correct = []
263      new_correct = []
264  
265      for res in results["eval"].values():
266          bc = sum([r["base_status"] == PASS for r in res])
267          base_correct.append(bc)
268          if not flags.base_only:
269              new_correct.append(sum([res[i]["base_status"] == res[i]["plus_status"] == PASS for i in range(len(res))]))
270      base_correct = np.array(base_correct)
271  
272      pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean() for k in [1, 10, 100] if total.min() >= k}
273      cprint(f"{flags.dataset} (base tests)", "red")
274      for k, v in pass_at_k.items():
275          cprint(f"{k}:\t{v:.3f}", "red")
276  
277      if new_correct:
278          cprint(f"{flags.dataset}+ (base + extra tests)", "green")
279          pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, np.array(new_correct), k).mean() for k in [1, 10, 100] if (total >= k).all()}
280          for k, v in pass_at_k.items():
281              cprint(f"{k}:\t{v:.3f}", "green")
282  
283      # save results
284      if os.path.isfile(result_path) and flags.i_just_wanna_run:
285          decision = ""
286          while decision.lower() not in ["y", "n"]:
287              print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...")
288              decision = input()
289  
290          if decision.lower() == "y":
291              # mv the file to a backup
292              new_path = result_path + ".bak"
293              while os.path.isfile(new_path):
294                  new_path += ".bak"
295              os.rename(result_path, new_path)
296              print(f"Backup {result_path} to {new_path}")
297  
298      if not os.path.isfile(result_path):
299          with open(result_path, "w") as f:
300              json.dump(results, f)
301  
302  
303  def main():
304      parser = argparse.ArgumentParser()
305      parser.add_argument("--dataset", required=True, type=str, choices=["humaneval", "mbpp"])
306      parser.add_argument("--samples", required=True, type=str)
307      parser.add_argument("--base-only", action="store_true")
308      parser.add_argument("--parallel", default=None, type=int)
309      parser.add_argument("--i-just-wanna-run", action="store_true")
310      parser.add_argument("--test-details", action="store_true")
311      parser.add_argument("--min-time-limit", default=1, type=float)
312      parser.add_argument("--gt-time-limit-factor", default=4.0, type=float)
313      parser.add_argument("--mini", action="store_true")
314      parser.add_argument("--noextreme", action="store_true", help="Omit extreme test inputs")
315      parser.add_argument("--version", default="default", type=str, help="Version of the dataset")
316      args = parser.parse_args()
317  
318      evaluate(args)
319  
320  
321  if __name__ == "__main__":
322      main()