humaneval_qwen2_base.py
1 import json 2 import os 3 from pathlib import Path 4 5 from human_eval.evaluation import evaluate_functional_correctness 6 from tqdm import tqdm 7 from utils.dataset import HumanEvalDataset 8 from utils.utils import cleanup_code 9 from vllm import LLM, SamplingParams 10 11 COMMON_EOS = [ 12 "<|endoftext|>", 13 "<|endofmask|>", 14 "</s>", 15 ] 16 17 18 class HumanEval: 19 """ 20 HumanEval evaluation class. 21 """ 22 23 def __init__( 24 self, 25 data_root, 26 max_seq_len=2048, 27 language="python", 28 max_gen_len=200, 29 log_dir=None, 30 temperature=0, 31 top_p=0.95, 32 n_sample=40, 33 k_sample=1, 34 no_batching=False, 35 ): 36 self.data_root = data_root 37 self.max_seq_len = max_seq_len 38 self.max_gen_len = max_gen_len 39 self.k = k_sample 40 self.n_sample = n_sample 41 self.language = language 42 self.log_dir = log_dir 43 self.temperature = temperature 44 self.top_p = top_p 45 self.sft = False 46 self.no_batching = no_batching 47 48 self.eos = COMMON_EOS 49 print(f"EOS: {self.eos}") 50 51 os.makedirs(self.log_dir, exist_ok=True) 52 53 def eval_model(self, llm: LLM): 54 assert self.log_dir is not None, "log_dir should not be None when evaluating humaneval" 55 dataset = HumanEvalDataset(self.data_root, sample_num=self.n_sample, language=self.language, issft=self.sft) 56 if self.k > 1: 57 assert self.n_sample >= 100, "HumanEval PASS@100 needs n_sample >= 100" 58 59 sampling_params = SamplingParams(max_tokens=self.max_gen_len, temperature=self.temperature, top_p=self.top_p, stop=self.eos) 60 61 # Generate. 62 with Path(self.log_file_path).open("w") as f_log: 63 prompts = [data["prompt"] + "\n" for data in dataset] 64 65 if self.no_batching: 66 print(f"Disable: continuous batching. This will be slow.") 67 generated = [llm.generate(prompt, sampling_params, use_tqdm=False) for prompt in tqdm(prompts)] 68 else: 69 print(f"Enable: continuous batching.") 70 generated = llm.generate(prompts, sampling_params, use_tqdm=True) 71 72 for idx, output in enumerate(generated): 73 data = dataset[idx] 74 # suffixprediction = output.outputs[0].text.replace("\t", " ") 75 suffixprediction = output.outputs[0].text 76 prediction = output.prompt + suffixprediction 77 suffixprediction = cleanup_code(suffixprediction, self.language, "humaneval", self.sft, dataset.stopwords) 78 original_prompt = data["original_prompt"] 79 if not self.sft: 80 suffixprediction = original_prompt + "\n" + suffixprediction 81 res = { 82 "task_id": data["task_id"], 83 "generation": suffixprediction, 84 "prompt": original_prompt, 85 "wholecode": prediction, 86 } 87 f_log.write(json.dumps(res, ensure_ascii=False) + "\n") 88 89 # Aggregate scores. 90 self._calculate_final_score() 91 92 @property 93 def log_file_path(self) -> str: 94 return os.path.join(self.log_dir, f"pred_{self.language}_output.jsonl") 95 96 def _calculate_final_score(self): 97 timeout = 10 98 res, details = evaluate_functional_correctness( 99 input_file=self.log_file_path, 100 problem_file=os.path.join(self.data_root, f"humaneval-{self.language}.jsonl"), 101 tmp_dir=self.log_dir, 102 timeout=timeout, 103 language=self.language, 104 ) 105 print(f"{self.language} score is", res["pass@%d" % self.k]) 106 107 details_file = os.path.join(self.log_dir, f"humaneval-{self.language}-details.json") 108 with Path(details_file).open("w") as f: 109 json.dump(details, f, ensure_ascii=False, indent=2) 110 111 print(f"Details => {details_file}")