evaluation.py
  1  import os
  2  import sys
  3  import fire
  4  import json
  5  import gzip
  6  import regex
  7  import numpy as np
  8  import itertools
  9  
 10  from typing import *
 11  from tqdm.auto import tqdm
 12  from collections import defaultdict
 13  from concurrent.futures import ThreadPoolExecutor, as_completed
 14  from .data import stream_jsonl
 15  from .execution import check_correctness
 16  
 17  IMPORT_HELPER = {
 18      "python": [
 19          "import math",
 20          "import re",
 21          "import sys",
 22          "import copy",
 23          "import datetime",
 24          "import itertools",
 25          "import collections",
 26          "import heapq",
 27          "import functools",
 28          "import hashlib",
 29          "import numpy",
 30          "import numpy as np",
 31          "import string",
 32          "from typing import *",
 33          "from collections import *",
 34      ],
 35      "go": [
 36          "math",
 37          "strings",
 38          "fmt",
 39          "strconv",
 40          "time",
 41          "bytes",
 42          "regexp",
 43          "sort",
 44          "math/rand",
 45          "crypto/md5",
 46      ],
 47      "cpp": [
 48          "#include<stdlib.h>",
 49          "#include<algorithm>",
 50          "#include<math.h>",
 51          "#include<stdio.h>",
 52          "#include<vector>",
 53          "#include<string>",
 54          "#include<climits>",
 55          "#include<cstring>",
 56          "#include<iostream>",
 57          "#include<cassert>",
 58      ],
 59      "cs": [
 60          "using System.Numerics;",
 61          "using System.Diagnostics;",
 62          "using System.Collections.Generic;",
 63          "using System.Linq;",
 64          "using System.Text;",
 65          "using System.Security.Cryptography;",
 66          "using System.Collections.Generic;",
 67      ],
 68  }
 69  
 70  
 71  LANGUAGE_NAME = {
 72      "cpp": "CPP",
 73      "go": "Go",
 74      "java": "Java",
 75      "js": "JavaScript",
 76      "python": "Python",
 77  }
 78  
 79  
 80  def read_dataset(
 81      data_file: str = None,
 82      dataset_type: str = "humaneval",
 83      num_shot=None,
 84  ) -> Dict:
 85      """
 86      Reads a dataset and returns a dictionary of tasks.
 87      """
 88      if num_shot is not None:
 89          print(f"{num_shot}-shot setting...")
 90      if "humaneval" in dataset_type.lower():
 91          if data_file is None:
 92              current_path = os.path.dirname(os.path.abspath(__file__))
 93              data_file = os.path.join(current_path, "..", "humaneval-x", "python", "data", "humaneval_python.jsonl.gz")
 94          dataset = {task["task_id"]: task for task in stream_jsonl(data_file)}
 95      else:
 96          raise f"Dataset: {dataset_type} not supported."
 97  
 98      return dataset
 99  
100  
101  def estimate_pass_at_k(num_samples: Union[int, List[int], np.ndarray], num_correct: Union[List[int], np.ndarray], k: int) -> np.ndarray:
102      """
103      Estimates pass@k of each problem and returns them in an array.
104      """
105  
106      def estimator(n: int, c: int, k: int) -> float:
107          """
108          Calculates 1 - comb(n - c, k) / comb(n, k).
109          """
110          if n - c < k:
111              return 1.0
112          return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
113  
114      if isinstance(num_samples, int):
115          num_samples_it = itertools.repeat(num_samples, len(num_correct))
116      else:
117          assert len(num_samples) == len(num_correct)
118          num_samples_it = iter(num_samples)
119  
120      return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
121  
122  
123  def process_humaneval_test(sample, problems, example_test=False, is_mbpp=False, language="python"):
124      """
125      Processes a sample for evaluation.
126      """
127      task_id = sample["task_id"]
128      if is_mbpp:
129          return sample["generation"] + "\n" + "\n".join(problems[task_id]["test"])
130  
131      prompt = sample["prompt"]
132      if example_test and "example_test" in problems[task_id] and problems[task_id]["example_test"] != "":
133          test = problems[task_id]["example_test"]
134      else:
135          test = problems[task_id]["test"]
136      code = sample["generation"]
137  
138      # Pre-process for different languages
139      if language == "python":
140          test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n"
141          test_string = test_setup + code + "\n" + test + "\n"
142      elif language == "cpp":
143          test_set_up = ""
144          for s in IMPORT_HELPER["cpp"]:
145              if s not in prompt:
146                  test_set_up += s + "\n"
147          test_string = test_set_up + "\n" + code + "\n" + test
148      elif language == "java":
149          test_string = code + "\n" + test
150      elif language == "cs":
151          test_set_up = ""
152          for s in IMPORT_HELPER["cs"]:
153              test_set_up += s + "\n"
154          test_string = test_set_up + "\n" + code + "\n" + test
155      elif language in ["js", "javascript", "ts", "sh", "go"]:
156          test_string = code + "\n" + test
157      elif language == "go232":
158          import_string = problems[task_id]["import"]
159          prompt = prompt.replace(import_string, "")
160          if example_test and "example_test" in problems[task_id]:
161              test = problems[task_id]["example_test"]
162          else:
163              test = problems[task_id]["test"]
164          test_setup = problems[task_id]["test_setup"]
165          other_pkgs = []
166          for pkg in IMPORT_HELPER["go"]:
167              if pkg not in test_setup:
168                  p = pkg.split("/")[-1]
169                  if p + "." in code:
170                      other_pkgs.append(f'"{pkg}"')
171          if other_pkgs:
172              import_other_pkgs = "import (\n" + "    ".join([p + "\n" for p in other_pkgs]) + ")"
173              test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test
174          else:
175              test_string = test_setup + "\n" + prompt + code + "\n" + test
176      elif language == "rs":
177          main = "\nfn main(){ \n } \n"
178          declaration = problems[task_id]["declaration"]
179          test_string = main + declaration + prompt + code + test
180      elif language == "php":
181          if code[:5] != "<?php":
182              code = "<?php\n" + code
183          test_string = code + "\n" + test + "?>"
184      return test_string
185  
186  
187  def stream_jsonl_all(filename: str) -> Iterable[Dict]:
188      """
189      Streams a JSONL file.
190      """
191      results = []
192      print(f"fp: openfile {filename}")
193      fp = open(filename, "r")
194  
195      for line in fp:
196          if any(not x.isspace() for x in line):
197              results.append(json.loads(line))
198      fp.close()
199  
200      return results
201  
202  
203  def evaluate_functional_correctness(
204      input_file: str = None,
205      tmp_dir: str = "./",
206      n_workers: int = 32,
207      timeout: float = 10.0,
208      problem_file: str = "../data/humaneval_python.jsonl.gz",
209      out_dir: str = None,
210      k: List[int] = [1, 10, 100],
211      test_groundtruth: bool = False,
212      example_test: bool = False,
213      is_mbpp: bool = False,
214      language: str = "python",
215  ):
216      """
217      Evaluates the functional correctness of a model.
218      """
219      if example_test:
220          print("Example test...")
221  
222      problems = read_dataset(problem_file, dataset_type="humaneval")
223      sample_jsonl = stream_jsonl_all(input_file)
224  
225      with ThreadPoolExecutor(max_workers=n_workers) as executor:
226  
227          futures = []
228          completion_id = Counter()
229          n_samples = 0
230          results = defaultdict(list)
231  
232          if test_groundtruth:
233              print("Testing ground truth...")
234              for sample in tqdm(problems.values()):
235                  task_id = sample["task_id"]
236                  lang = task_id.split("/")[0].lower()
237                  if lang == "javascript":
238                      lang = "js"
239                  tmp_dir_ = os.path.join(tmp_dir, "evaluation", lang)
240                  sample["generation"] = sample["canonical_solution"]
241                  sample["test_code"] = process_humaneval_test(sample, problems, example_test, language)
242                  if sample["test_code"] is None:
243                      continue
244                  args = (task_id, sample, lang, timeout, tmp_dir_, completion_id[task_id])
245                  future = executor.submit(check_correctness, *args)
246                  futures.append(future)
247                  completion_id[task_id] += 1
248                  n_samples += 1
249          else:
250              print("Reading samples...")
251              for sample in sample_jsonl:
252                  task_id = sample["task_id"]
253                  if not is_mbpp:
254                      lang = language
255                  if not is_mbpp and lang == "javascript":
256                      lang = "js"
257                  if is_mbpp:
258                      lang = "python"
259                  tmp_dir_ = os.path.join(tmp_dir, "evaluation", lang)
260                  sample["task_id"] = task_id
261                  sample["test_code"] = process_humaneval_test(sample, problems, example_test, is_mbpp, language)
262                  if sample["test_code"] is None:
263                      continue
264                  if "completion_id" in sample:
265                      completion_id_ = sample["completion_id"]
266                  else:
267                      completion_id_ = completion_id[task_id]
268                  args = (task_id, sample, lang, timeout, tmp_dir_, completion_id_)
269                  future = executor.submit(check_correctness, *args)
270                  futures.append(future)
271                  completion_id[task_id] += 1
272                  n_samples += 1
273  
274          if len(completion_id) == len(problems):
275              evaluate_pass_at_k = True
276          else:
277              evaluate_pass_at_k = False
278  
279          print("Running test suites...")
280          for future in tqdm(as_completed(futures), total=len(futures)):
281              result = future.result()
282              results[result["task_id"]].append((result["completion_id"], result))
283  
284      # Calculate pass@k.
285      total, correct = [], []
286      for result in results.values():
287          passed = [r[1]["passed"] for r in result]
288          total.append(len(passed))
289          correct.append(sum(passed))
290      total = np.array(total)
291      correct = np.array(correct)
292      if evaluate_pass_at_k:
293          ks = k
294          pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()}
295          print(pass_at_k)
296      else:
297          print("Total:", np.sum(total))
298          print("Correct:", np.sum(correct))
299      return pass_at_k, results