__init__.py
1 # The MIT License 2 # 3 # Copyright (c) OpenAI (https://openai.com) 4 # 5 # Permission is hereby granted, free of charge, to any person obtaining a copy 6 # of this software and associated documentation files (the "Software"), to deal 7 # in the Software without restriction, including without limitation the rights 8 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 # copies of the Software, and to permit persons to whom the Software is 10 # furnished to do so, subject to the following conditions: 11 # 12 # The above copyright notice and this permission notice shall be included in 13 # all copies or substantial portions of the Software. 14 # 15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 # THE SOFTWARE. 22 23 import itertools 24 import multiprocessing 25 import os 26 import sys 27 import time 28 import types 29 import unittest 30 from multiprocessing import Array, Value, Manager 31 from typing import Any, Dict, List, Tuple, Union 32 33 import numpy as np 34 35 from eval._special_oracle import ( 36 _poly,) 37 from eval.utils import ( 38 create_tempdir, 39 reliability_guard, 40 swallow_io, 41 time_limit, 42 safe_environment, 43 TIMEOUT_LIMIT, 44 ) 45 46 47 def compatible_eval_result(results: Dict) -> Dict: 48 # compatibility 49 for task_results in results["eval"].values(): 50 # update the "files" field to "nfiles" 51 if "files" in task_results and "nfiles" not in task_results: 52 task_results["nfiles"] = len(task_results.pop("files")) 53 return results 54 55 56 # unbiased estimator from https://github.com/openai/human-eval 57 def estimate_pass_at_k( 58 num_samples: Union[int, List[int], np.ndarray], 59 num_correct: Union[List[int], np.ndarray], 60 k: int, 61 ) -> np.ndarray: 62 """ 63 Estimates pass@k of each problem and returns them in an array. 64 """ 65 66 def estimator(n: int, c: int, k: int) -> float: 67 """ 68 Calculates 1 - comb(n - c, k) / comb(n, k). 69 """ 70 if n - c < k: 71 return 1.0 72 return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 73 74 if isinstance(num_samples, int): 75 num_samples_it = itertools.repeat(num_samples, len(num_correct)) 76 else: 77 assert len(num_samples) == len(num_correct) 78 num_samples_it = iter(num_samples) 79 80 return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]) 81 82 83 PASS = "pass" 84 FAIL = "fail" 85 TIMEOUT = "timeout" 86 87 _SUCCESS = 0 88 _FAILED = 1 89 _TIMEOUT = 2 90 _UNKNOWN = 3 91 92 _mapping = {_SUCCESS: PASS, _FAILED: FAIL, _TIMEOUT: TIMEOUT, _UNKNOWN: None} 93 94 95 def is_floats(x) -> bool: 96 # check if it is float; List[float]; Tuple[float] 97 if isinstance(x, float): 98 return True 99 if isinstance(x, (list, tuple)): 100 return all(isinstance(i, float) for i in x) 101 if isinstance(x, np.ndarray): 102 return x.dtype == np.float64 or x.dtype == np.float32 103 return False 104 105 106 def unsafe_execute( 107 entry_point: str, 108 code: str, 109 test_code: str, 110 timeout: float, 111 max_as_limit: float, 112 max_data_limit: float, 113 max_stack_limit: float, 114 stat, # Value 115 details, # Array 116 ): 117 with safe_environment(), create_tempdir(): 118 # These system calls are needed when cleaning up tempdir. 119 import os 120 import shutil 121 import builtins 122 123 rmtree = shutil.rmtree 124 rmdir = os.rmdir 125 chdir = os.chdir 126 # Disable functionalities that can make destructive changes to the test. 127 reliability_guard(max_as_limit, max_data_limit, max_stack_limit) 128 module_name = "__test__" 129 new_module = types.ModuleType(module_name) 130 # Set necessary attributes for the module 131 new_module.__dict__.update({ 132 '__builtins__': builtins, 133 '__file__': f"{module_name}.py", 134 '__package__': None, 135 '__doc__': None, 136 'sys': sys, 137 'os': os, 138 'environ': os.environ, 139 }) 140 141 try: 142 full_code = code + "\n" + test_code 143 144 with swallow_io(): 145 exec(compile(full_code, f"{module_name}.py", 'exec'), new_module.__dict__) 146 sys.modules[module_name] = new_module 147 TestCases = getattr(new_module, 'TestCases') 148 loader = unittest.TestLoader() 149 suite = loader.loadTestsFromTestCase(TestCases) 150 test_result = unittest.TestResult() 151 start_time = time.time() 152 with time_limit(timeout): 153 suite.run(test_result) 154 155 issues = test_result.failures + test_result.errors 156 for test, trace in issues: 157 details[test.id().split(".")[-1]] = trace 158 stat.value = _SUCCESS 159 except BaseException as e: 160 details["ALL"] = str(e) 161 stat.value = _FAILED 162 # Needed for cleaning up. 163 shutil.rmtree = rmtree 164 os.rmdir = rmdir 165 os.chdir = chdir 166 167 168 def untrusted_check(code: str, test_code: str, entry_point: str, max_as_limit: float, max_data_limit: float, max_stack_limit: float, min_time_limit: float = 10, gt_time_limit: float = 60) -> Tuple[str, np.ndarray]: 169 time_limit = max(min_time_limit, gt_time_limit) 170 timeout = max(os.getenv("BIGCODEBENCH_TIMEOUT_PER_TASK", TIMEOUT_LIMIT), time_limit) + 1 171 # shared memory objects 172 stat = Value("i", _UNKNOWN) 173 manager = Manager() 174 details = manager.dict() 175 176 p = multiprocessing.Process( 177 target=unsafe_execute, 178 args=( 179 entry_point, 180 code, 181 test_code, 182 timeout, 183 max_as_limit, 184 max_data_limit, 185 max_stack_limit, 186 stat, 187 details, 188 ), 189 ) 190 p.start() 191 p.join(timeout=timeout + 1) 192 if p.is_alive(): 193 p.terminate() 194 time.sleep(0.1) 195 if p.is_alive(): 196 p.kill() 197 time.sleep(0.1) 198 199 stat = _mapping[stat.value] 200 # convert details to a dict 201 details = dict(details) 202 203 if not stat: 204 stat = TIMEOUT 205 if stat == PASS: 206 if details: 207 stat = FAIL 208 209 return stat, details 210 211 212 def evaluate_files( 213 files: List[str], 214 inputs: List, 215 entry_point: str, 216 min_time_limit: float = 0.1, 217 gt_time_limit_factor: float = 2.0, 218 ) -> List[Tuple[str, List[bool]]]: 219 ret = [] 220 # sort files by the id in name (i.e., "../n.py") 221 files = sorted(files, key=lambda x: int(x.split("/")[-1].split(".")[0])) 222 for file in files: 223 code = open(file, "r").read() 224 stat, det = untrusted_check( 225 code, 226 inputs, 227 entry_point, 228 ) 229 ret.append((stat, det.tolist())) 230 return ret