evaluate.py
  1  import argparse
  2  import json
  3  import multiprocessing
  4  import os
  5  import pickle
  6  import threading
  7  import time
  8  from collections import Counter, defaultdict
  9  from concurrent.futures import ProcessPoolExecutor, as_completed
 10  from datetime import datetime
 11  from typing import Any, Dict, List, Tuple
 12  from warnings import warn
 13  
 14  import numpy as np
 15  from termcolor import cprint
 16  from tqdm import tqdm
 17  
 18  from data import (
 19      get_bigcodebench,
 20      get_bigcodebench_hash,
 21      load_solutions,
 22  )
 23  from data.utils import CACHE_DIR
 24  from eval import (
 25      PASS,
 26      compatible_eval_result,
 27      estimate_pass_at_k,
 28      untrusted_check,
 29  )
 30  from gen.util import trusted_check
 31  
 32  # 1st item: the status
 33  # 2nd item (optional): the detailed pass/fail boolean for each input
 34  Result = Tuple[str, List[bool]]
 35  
 36  
 37  def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit):
 38      cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl")
 39      if os.path.exists(cache_file):
 40          if check_gt_only:
 41              os.remove(cache_file)
 42          else:
 43              print(f"Load from ground-truth from {cache_file}")
 44              with open(cache_file, "rb") as f:
 45                  return pickle.load(f)
 46  
 47      os.makedirs(CACHE_DIR, exist_ok=True)
 48      print("\nAsserting the groundtruth...")
 49      tbegin = time.time()
 50  
 51      with ProcessPoolExecutor(max_workers=n_workers) as executor:
 52          futures = []
 53          n_samples = 0
 54          expected_time = dict()
 55  
 56          for problem in problems.values():
 57              args = (
 58                  problem["complete_prompt"] + "\n" + problem["canonical_solution"],
 59                  problem["test"],
 60                  problem["task_id"],
 61                  max_as_limit,
 62                  max_data_limit,
 63                  max_stack_limit,
 64              )
 65  
 66              futures.append(executor.submit(trusted_check, *args))
 67              n_samples += 1
 68  
 69          for future in tqdm(as_completed(futures), total=n_samples):
 70              result = future.result()
 71              expected_time[result["task_id"]] = result["time"]
 72  
 73      print(f"Expected outputs computed in {time.time() - tbegin:.2f}s")
 74  
 75      if any(expected_time.values()):
 76          with open(cache_file, "wb") as f:
 77              pickle.dump(expected_time, f)
 78  
 79      return expected_time
 80  
 81  
 82  def check_correctness(
 83      completion_id: int,
 84      problem: Dict[str, Any],
 85      solution: str,
 86      max_as_limit: float,
 87      max_data_limit: float,
 88      max_stack_limit: float,
 89      identifier=None,
 90      min_time_limit: float = 0.1,
 91      gt_time_limit: float = 2.0,
 92  ) -> Dict[str, Result]:  # {...}, "base" | "plus" -> (status, details)
 93      ret = {
 94          "completion_id": completion_id,
 95          "task_id": problem["task_id"],
 96          "_identifier": identifier,
 97          "solution": solution,
 98      }
 99      ret["base"] = untrusted_check(
100          solution,
101          problem["test"],
102          problem["entry_point"],
103          max_as_limit,
104          max_data_limit,
105          max_stack_limit,
106          min_time_limit,
107          gt_time_limit,
108      )
109      return ret
110  
111  
112  def evaluate(flags):
113      if flags.parallel is None:
114          n_workers = max(1, multiprocessing.cpu_count() // 2)
115      else:
116          n_workers = flags.parallel
117  
118      if flags.check_gt_only:
119          # bypass the samples
120          flags.samples = "__dummy__.jsonl"
121  
122      extra = flags.subset + "_" if flags.subset != "full" else ""
123      if os.path.isdir(flags.samples):
124          result_path = os.path.join(flags.samples, f"{extra}eval_results.json")
125      else:
126          assert flags.samples.endswith(".jsonl")
127          result_path = flags.samples.replace(".jsonl", f"_{extra}eval_results.json")
128  
129      problems = get_bigcodebench(subset=flags.subset)
130      dataset_hash = get_bigcodebench_hash(subset=flags.subset)
131  
132      if not flags.no_gt:
133          expected_time = get_groundtruth(
134              n_workers, problems, dataset_hash, flags.check_gt_only, flags.max_as_limit, flags.max_data_limit, flags.max_stack_limit
135          )
136      else:
137          expected_time = {task_id: None for task_id in problems}
138  
139      gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems])
140  
141      if os.path.isfile(result_path):
142          print(f"Load from previous results from {result_path}")
143          with open(result_path, "r") as f:
144              results = json.load(f)
145  
146          results = compatible_eval_result(results)
147      else:
148          if flags.check_gt_only:
149  
150              if gt_pass_rate > 0.99:
151                  cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}", "green")
152              else:
153                  cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}\nPlease be cautious!", "red")
154              return
155  
156          results = {
157              "date": datetime.now().strftime("%Y-%m-%d %H:%M"),
158              "eval": {},
159          }
160  
161          with ProcessPoolExecutor(max_workers=n_workers) as executor:
162              futures = []
163              completion_id = Counter()
164              n_samples = 0
165              eval_results = defaultdict(list)  # task_id ->
166              remainings = set()
167  
168              print("Reading samples...")
169              for sample in tqdm(load_solutions(flags.samples)):
170                  task_id = sample["task_id"]
171  
172                  if task_id not in problems:
173                      warn(f"Task {task_id} is found in the samples but not found in the dataset")
174                      continue
175                  solution = sample["solution"] if "solution" in sample else problems[task_id]["complete_prompt"] + sample["completion"]
176                  if "sanitized-calibrated" in flags.samples:
177                      solution = problems[task_id]["code_prompt"] + "\n    pass\n" + solution
178                  remainings.add(sample["_identifier"])
179                  args = (
180                      completion_id[task_id],
181                      problems[task_id],
182                      solution,
183                      flags.max_as_limit,
184                      flags.max_data_limit,
185                      flags.max_stack_limit,
186                      sample["_identifier"],
187                      flags.min_time_limit,
188                      expected_time[task_id] if expected_time[task_id] else 20,
189                  )
190                  futures.append(executor.submit(check_correctness, *args))
191                  completion_id[task_id] += 1
192                  n_samples += 1
193  
194              assert n_samples == len(remainings), "Missing problems in unfinished"
195              assert len(completion_id) == len(problems), "Missing problems in samples"
196  
197              def stucking_checker():
198                  while remainings:
199                      last_size = len(remainings)
200                      time.sleep(30)
201                      if last_size != len(remainings) or len(remainings) == 0:
202                          continue
203                      # Potential stucking
204                      warn("No samples had finished testing in the last 30s")
205                      warn(f"{len(remainings)} samples to be tested: {remainings}")
206  
207              threading.Thread(target=stucking_checker).start()
208  
209              for future in tqdm(as_completed(futures), total=n_samples):
210                  result = future.result()
211                  remainings.remove(result["_identifier"])
212                  eval_results[result["task_id"]].append(result)
213  
214          # sort the results for each problem by completion_id
215          for task_id, task_results in eval_results.items():
216              task_results.sort(key=lambda x: x["completion_id"])
217              results["eval"][task_id] = []
218              for res in task_results:
219                  stat, details = res["base"]
220                  results["eval"][task_id].append(
221                      {
222                          "task_id": task_id,
223                          "solution": res["solution"],
224                          "status": stat,
225                          "details": details,
226                      }
227                  )
228  
229      # Calculate pass@k.
230      total = np.array([len(r) for k, r in results["eval"].items() if k in problems])
231      base_correct = []
232  
233      for key, res in results["eval"].items():
234          if key not in problems:
235              continue
236          bc = sum([r["status"] == PASS for r in res])
237          base_correct.append(bc)
238  
239      base_correct = np.array(base_correct)
240  
241      pass_at_k = {f"pass@{k}": 100 * estimate_pass_at_k(total, base_correct, k).mean() for k in [1, 5, 10, 25, 100] if total.min() >= k}
242  
243      mode = "-calibrated" if "sanitized-calibrated" in flags.samples else ""
244      extra = flags.subset.capitalize()
245      flags.split = flags.split.capitalize()
246      cprint(f"BigCodeBench-{flags.split}{mode} ({extra})", "green")
247  
248      if flags.no_gt:
249          cprint(f"Groundtruth is not checked", "yellow")
250      else:
251          if gt_pass_rate > 0.99:
252              cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}", "green")
253          else:
254              cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}\nPlease be cautious!", "red")
255  
256      for k, v in pass_at_k.items():
257          cprint(f"{k}:\t{v:.3f}", "green")
258  
259      # save results
260      with open(result_path, "w") as f:
261          json.dump(results, f, indent=2)
262  
263      pass_at_k_path = os.path.join(os.path.dirname(result_path), "results.json")
264      pass_at_k["model"] = os.path.basename(flags.samples).split("--bigcodebench-")[0]
265      pass_at_k["calibrated"] = "sanitized-calibrated" in flags.samples
266      pass_at_k["subset"] = flags.subset
267  
268      def save_pass_at_k():
269          with open(pass_at_k_path, "w") as f:
270              json.dump(pass_at_k, f, indent=2)
271  
272      save_pass_at_k()
273  
274  
275  def main():
276      parser = argparse.ArgumentParser()
277      parser.add_argument("--split", required=True, type=str, choices=["complete", "instruct"])
278      parser.add_argument("--subset", default="full", type=str, choices=["full", "hard"])
279      parser.add_argument("--samples", required=True, type=str)
280      parser.add_argument("--parallel", default=None, type=int)
281      parser.add_argument("--min-time-limit", default=1, type=float)
282      parser.add_argument("--max-as-limit", default=128 * 1024, type=int)
283      parser.add_argument("--max-data-limit", default=4 * 1024, type=int)
284      parser.add_argument("--max-stack-limit", default=5, type=int)
285      parser.add_argument("--check-gt-only", action="store_true", help="Check the groundtruth")
286      parser.add_argument("--no-gt", action="store_true", help="Check the groundtruth")
287      args = parser.parse_args()
288  
289      evaluate(args)
290  
291  
292  if __name__ == "__main__":
293      main()