__init__.py
  1  # The MIT License
  2  #
  3  # Copyright (c) OpenAI (https://openai.com)
  4  #
  5  # Permission is hereby granted, free of charge, to any person obtaining a copy
  6  # of this software and associated documentation files (the "Software"), to deal
  7  # in the Software without restriction, including without limitation the rights
  8  # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9  # copies of the Software, and to permit persons to whom the Software is
 10  # furnished to do so, subject to the following conditions:
 11  #
 12  # The above copyright notice and this permission notice shall be included in
 13  # all copies or substantial portions of the Software.
 14  #
 15  # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 16  # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 17  # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 18  # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 19  # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 20  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 21  # THE SOFTWARE.
 22  
 23  import itertools
 24  import multiprocessing
 25  import os
 26  import sys
 27  import time
 28  import types
 29  import unittest
 30  from multiprocessing import Array, Value, Manager
 31  from typing import Any, Dict, List, Tuple, Union
 32  
 33  import numpy as np
 34  
 35  from eval._special_oracle import (
 36      _poly,)
 37  from eval.utils import (
 38      create_tempdir,
 39      reliability_guard,
 40      swallow_io,
 41      time_limit,
 42      safe_environment,
 43      TIMEOUT_LIMIT,
 44  )
 45  
 46  
 47  def compatible_eval_result(results: Dict) -> Dict:
 48      # compatibility
 49      for task_results in results["eval"].values():
 50          # update the "files" field to "nfiles"
 51          if "files" in task_results and "nfiles" not in task_results:
 52              task_results["nfiles"] = len(task_results.pop("files"))
 53      return results
 54  
 55  
 56  # unbiased estimator from https://github.com/openai/human-eval
 57  def estimate_pass_at_k(
 58      num_samples: Union[int, List[int], np.ndarray],
 59      num_correct: Union[List[int], np.ndarray],
 60      k: int,
 61  ) -> np.ndarray:
 62      """
 63      Estimates pass@k of each problem and returns them in an array.
 64      """
 65  
 66      def estimator(n: int, c: int, k: int) -> float:
 67          """
 68          Calculates 1 - comb(n - c, k) / comb(n, k).
 69          """
 70          if n - c < k:
 71              return 1.0
 72          return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
 73  
 74      if isinstance(num_samples, int):
 75          num_samples_it = itertools.repeat(num_samples, len(num_correct))
 76      else:
 77          assert len(num_samples) == len(num_correct)
 78          num_samples_it = iter(num_samples)
 79  
 80      return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
 81  
 82  
 83  PASS = "pass"
 84  FAIL = "fail"
 85  TIMEOUT = "timeout"
 86  
 87  _SUCCESS = 0
 88  _FAILED = 1
 89  _TIMEOUT = 2
 90  _UNKNOWN = 3
 91  
 92  _mapping = {_SUCCESS: PASS, _FAILED: FAIL, _TIMEOUT: TIMEOUT, _UNKNOWN: None}
 93  
 94  
 95  def is_floats(x) -> bool:
 96      # check if it is float; List[float]; Tuple[float]
 97      if isinstance(x, float):
 98          return True
 99      if isinstance(x, (list, tuple)):
100          return all(isinstance(i, float) for i in x)
101      if isinstance(x, np.ndarray):
102          return x.dtype == np.float64 or x.dtype == np.float32
103      return False
104  
105  
106  def unsafe_execute(
107          entry_point: str,
108          code: str,
109          test_code: str,
110          timeout: float,
111          max_as_limit: float,
112          max_data_limit: float,
113          max_stack_limit: float,
114          stat,  # Value
115          details,  # Array
116  ):
117      with safe_environment(), create_tempdir():
118          # These system calls are needed when cleaning up tempdir.
119          import os
120          import shutil
121          import builtins
122  
123          rmtree = shutil.rmtree
124          rmdir = os.rmdir
125          chdir = os.chdir
126          # Disable functionalities that can make destructive changes to the test.
127          reliability_guard(max_as_limit, max_data_limit, max_stack_limit)
128          module_name = "__test__"
129          new_module = types.ModuleType(module_name)
130          # Set necessary attributes for the module
131          new_module.__dict__.update({
132              '__builtins__': builtins,
133              '__file__': f"{module_name}.py",
134              '__package__': None,
135              '__doc__': None,
136              'sys': sys,
137              'os': os,
138              'environ': os.environ,
139          })
140  
141          try:
142              full_code = code + "\n" + test_code
143  
144              with swallow_io():
145                  exec(compile(full_code, f"{module_name}.py", 'exec'), new_module.__dict__)
146                  sys.modules[module_name] = new_module
147                  TestCases = getattr(new_module, 'TestCases')
148                  loader = unittest.TestLoader()
149                  suite = loader.loadTestsFromTestCase(TestCases)
150                  test_result = unittest.TestResult()
151                  start_time = time.time()
152                  with time_limit(timeout):
153                      suite.run(test_result)
154  
155              issues = test_result.failures + test_result.errors
156              for test, trace in issues:
157                  details[test.id().split(".")[-1]] = trace
158              stat.value = _SUCCESS
159          except BaseException as e:
160              details["ALL"] = str(e)
161              stat.value = _FAILED
162          # Needed for cleaning up.
163          shutil.rmtree = rmtree
164          os.rmdir = rmdir
165          os.chdir = chdir
166  
167  
168  def untrusted_check(code: str, test_code: str, entry_point: str, max_as_limit: float, max_data_limit: float, max_stack_limit: float, min_time_limit: float = 10, gt_time_limit: float = 60) -> Tuple[str, np.ndarray]:
169      time_limit = max(min_time_limit, gt_time_limit)
170      timeout = max(os.getenv("BIGCODEBENCH_TIMEOUT_PER_TASK", TIMEOUT_LIMIT), time_limit) + 1
171      # shared memory objects
172      stat = Value("i", _UNKNOWN)
173      manager = Manager()
174      details = manager.dict()
175  
176      p = multiprocessing.Process(
177          target=unsafe_execute,
178          args=(
179              entry_point,
180              code,
181              test_code,
182              timeout,
183              max_as_limit,
184              max_data_limit,
185              max_stack_limit,
186              stat,
187              details,
188          ),
189      )
190      p.start()
191      p.join(timeout=timeout + 1)
192      if p.is_alive():
193          p.terminate()
194          time.sleep(0.1)
195      if p.is_alive():
196          p.kill()
197          time.sleep(0.1)
198  
199      stat = _mapping[stat.value]
200      # convert details to a dict
201      details = dict(details)
202  
203      if not stat:
204          stat = TIMEOUT
205      if stat == PASS:
206          if details:
207              stat = FAIL
208  
209      return stat, details
210  
211  
212  def evaluate_files(
213      files: List[str],
214      inputs: List,
215      entry_point: str,
216      min_time_limit: float = 0.1,
217      gt_time_limit_factor: float = 2.0,
218  ) -> List[Tuple[str, List[bool]]]:
219      ret = []
220      # sort files by the id in name (i.e., "../n.py")
221      files = sorted(files, key=lambda x: int(x.split("/")[-1].split(".")[0]))
222      for file in files:
223          code = open(file, "r").read()
224          stat, det = untrusted_check(
225              code,
226              inputs,
227              entry_point,
228          )
229          ret.append((stat, det.tolist()))
230      return ret