evaluation.py
1 import os 2 import sys 3 import fire 4 import json 5 import gzip 6 import regex 7 import numpy as np 8 import itertools 9 10 from typing import * 11 from tqdm.auto import tqdm 12 from collections import defaultdict 13 from concurrent.futures import ThreadPoolExecutor, as_completed 14 from .data import stream_jsonl 15 from .execution import check_correctness 16 17 IMPORT_HELPER = { 18 "python": [ 19 "import math", 20 "import re", 21 "import sys", 22 "import copy", 23 "import datetime", 24 "import itertools", 25 "import collections", 26 "import heapq", 27 "import functools", 28 "import hashlib", 29 "import numpy", 30 "import numpy as np", 31 "import string", 32 "from typing import *", 33 "from collections import *", 34 ], 35 "go": [ 36 "math", 37 "strings", 38 "fmt", 39 "strconv", 40 "time", 41 "bytes", 42 "regexp", 43 "sort", 44 "math/rand", 45 "crypto/md5", 46 ], 47 "cpp": [ 48 "#include<stdlib.h>", 49 "#include<algorithm>", 50 "#include<math.h>", 51 "#include<stdio.h>", 52 "#include<vector>", 53 "#include<string>", 54 "#include<climits>", 55 "#include<cstring>", 56 "#include<iostream>", 57 "#include<cassert>", 58 ], 59 "cs": [ 60 "using System.Numerics;", 61 "using System.Diagnostics;", 62 "using System.Collections.Generic;", 63 "using System.Linq;", 64 "using System.Text;", 65 "using System.Security.Cryptography;", 66 "using System.Collections.Generic;", 67 ], 68 } 69 70 71 LANGUAGE_NAME = { 72 "cpp": "CPP", 73 "go": "Go", 74 "java": "Java", 75 "js": "JavaScript", 76 "python": "Python", 77 } 78 79 80 def read_dataset( 81 data_file: str = None, 82 dataset_type: str = "humaneval", 83 num_shot=None, 84 ) -> Dict: 85 """ 86 Reads a dataset and returns a dictionary of tasks. 87 """ 88 if num_shot is not None: 89 print(f"{num_shot}-shot setting...") 90 if "humaneval" in dataset_type.lower(): 91 if data_file is None: 92 current_path = os.path.dirname(os.path.abspath(__file__)) 93 data_file = os.path.join(current_path, "..", "humaneval-x", "python", "data", "humaneval_python.jsonl.gz") 94 dataset = {task["task_id"]: task for task in stream_jsonl(data_file)} 95 else: 96 raise f"Dataset: {dataset_type} not supported." 97 98 return dataset 99 100 101 def estimate_pass_at_k(num_samples: Union[int, List[int], np.ndarray], num_correct: Union[List[int], np.ndarray], k: int) -> np.ndarray: 102 """ 103 Estimates pass@k of each problem and returns them in an array. 104 """ 105 106 def estimator(n: int, c: int, k: int) -> float: 107 """ 108 Calculates 1 - comb(n - c, k) / comb(n, k). 109 """ 110 if n - c < k: 111 return 1.0 112 return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 113 114 if isinstance(num_samples, int): 115 num_samples_it = itertools.repeat(num_samples, len(num_correct)) 116 else: 117 assert len(num_samples) == len(num_correct) 118 num_samples_it = iter(num_samples) 119 120 return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]) 121 122 123 def process_humaneval_test(sample, problems, example_test=False, is_mbpp=False, language="python"): 124 """ 125 Processes a sample for evaluation. 126 """ 127 task_id = sample["task_id"] 128 if is_mbpp: 129 return sample["generation"] + "\n" + "\n".join(problems[task_id]["test"]) 130 131 prompt = sample["prompt"] 132 if example_test and "example_test" in problems[task_id] and problems[task_id]["example_test"] != "": 133 test = problems[task_id]["example_test"] 134 else: 135 test = problems[task_id]["test"] 136 code = sample["generation"] 137 138 # Pre-process for different languages 139 if language == "python": 140 test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n" 141 test_string = test_setup + code + "\n" + test + "\n" 142 elif language == "cpp": 143 test_set_up = "" 144 for s in IMPORT_HELPER["cpp"]: 145 if s not in prompt: 146 test_set_up += s + "\n" 147 test_string = test_set_up + "\n" + code + "\n" + test 148 elif language == "java": 149 test_string = code + "\n" + test 150 elif language == "cs": 151 test_set_up = "" 152 for s in IMPORT_HELPER["cs"]: 153 test_set_up += s + "\n" 154 test_string = test_set_up + "\n" + code + "\n" + test 155 elif language in ["js", "javascript", "ts", "sh", "go"]: 156 test_string = code + "\n" + test 157 elif language == "go232": 158 import_string = problems[task_id]["import"] 159 prompt = prompt.replace(import_string, "") 160 if example_test and "example_test" in problems[task_id]: 161 test = problems[task_id]["example_test"] 162 else: 163 test = problems[task_id]["test"] 164 test_setup = problems[task_id]["test_setup"] 165 other_pkgs = [] 166 for pkg in IMPORT_HELPER["go"]: 167 if pkg not in test_setup: 168 p = pkg.split("/")[-1] 169 if p + "." in code: 170 other_pkgs.append(f'"{pkg}"') 171 if other_pkgs: 172 import_other_pkgs = "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")" 173 test_string = test_setup + "\n" + import_other_pkgs + "\n" + prompt + code + "\n" + test 174 else: 175 test_string = test_setup + "\n" + prompt + code + "\n" + test 176 elif language == "rs": 177 main = "\nfn main(){ \n } \n" 178 declaration = problems[task_id]["declaration"] 179 test_string = main + declaration + prompt + code + test 180 elif language == "php": 181 if code[:5] != "<?php": 182 code = "<?php\n" + code 183 test_string = code + "\n" + test + "?>" 184 return test_string 185 186 187 def stream_jsonl_all(filename: str) -> Iterable[Dict]: 188 """ 189 Streams a JSONL file. 190 """ 191 results = [] 192 print(f"fp: openfile {filename}") 193 fp = open(filename, "r") 194 195 for line in fp: 196 if any(not x.isspace() for x in line): 197 results.append(json.loads(line)) 198 fp.close() 199 200 return results 201 202 203 def evaluate_functional_correctness( 204 input_file: str = None, 205 tmp_dir: str = "./", 206 n_workers: int = 32, 207 timeout: float = 10.0, 208 problem_file: str = "../data/humaneval_python.jsonl.gz", 209 out_dir: str = None, 210 k: List[int] = [1, 10, 100], 211 test_groundtruth: bool = False, 212 example_test: bool = False, 213 is_mbpp: bool = False, 214 language: str = "python", 215 ): 216 """ 217 Evaluates the functional correctness of a model. 218 """ 219 if example_test: 220 print("Example test...") 221 222 problems = read_dataset(problem_file, dataset_type="humaneval") 223 sample_jsonl = stream_jsonl_all(input_file) 224 225 with ThreadPoolExecutor(max_workers=n_workers) as executor: 226 227 futures = [] 228 completion_id = Counter() 229 n_samples = 0 230 results = defaultdict(list) 231 232 if test_groundtruth: 233 print("Testing ground truth...") 234 for sample in tqdm(problems.values()): 235 task_id = sample["task_id"] 236 lang = task_id.split("/")[0].lower() 237 if lang == "javascript": 238 lang = "js" 239 tmp_dir_ = os.path.join(tmp_dir, "evaluation", lang) 240 sample["generation"] = sample["canonical_solution"] 241 sample["test_code"] = process_humaneval_test(sample, problems, example_test, language) 242 if sample["test_code"] is None: 243 continue 244 args = (task_id, sample, lang, timeout, tmp_dir_, completion_id[task_id]) 245 future = executor.submit(check_correctness, *args) 246 futures.append(future) 247 completion_id[task_id] += 1 248 n_samples += 1 249 else: 250 print("Reading samples...") 251 for sample in sample_jsonl: 252 task_id = sample["task_id"] 253 if not is_mbpp: 254 lang = language 255 if not is_mbpp and lang == "javascript": 256 lang = "js" 257 if is_mbpp: 258 lang = "python" 259 tmp_dir_ = os.path.join(tmp_dir, "evaluation", lang) 260 sample["task_id"] = task_id 261 sample["test_code"] = process_humaneval_test(sample, problems, example_test, is_mbpp, language) 262 if sample["test_code"] is None: 263 continue 264 if "completion_id" in sample: 265 completion_id_ = sample["completion_id"] 266 else: 267 completion_id_ = completion_id[task_id] 268 args = (task_id, sample, lang, timeout, tmp_dir_, completion_id_) 269 future = executor.submit(check_correctness, *args) 270 futures.append(future) 271 completion_id[task_id] += 1 272 n_samples += 1 273 274 if len(completion_id) == len(problems): 275 evaluate_pass_at_k = True 276 else: 277 evaluate_pass_at_k = False 278 279 print("Running test suites...") 280 for future in tqdm(as_completed(futures), total=len(futures)): 281 result = future.result() 282 results[result["task_id"]].append((result["completion_id"], result)) 283 284 # Calculate pass@k. 285 total, correct = [], [] 286 for result in results.values(): 287 passed = [r[1]["passed"] for r in result] 288 total.append(len(passed)) 289 correct.append(sum(passed)) 290 total = np.array(total) 291 correct = np.array(correct) 292 if evaluate_pass_at_k: 293 ks = k 294 pass_at_k = {f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in ks if (total >= k).all()} 295 print(pass_at_k) 296 else: 297 print("Total:", np.sum(total)) 298 print("Correct:", np.sum(correct)) 299 return pass_at_k, results