/ qwencoder-eval / base / benchmarks / multiple-eval / humaneval_qwen2_base.py
humaneval_qwen2_base.py
  1  import json
  2  import os
  3  from pathlib import Path
  4  
  5  from human_eval.evaluation import evaluate_functional_correctness
  6  from tqdm import tqdm
  7  from utils.dataset import HumanEvalDataset
  8  from utils.utils import cleanup_code
  9  from vllm import LLM, SamplingParams
 10  
 11  COMMON_EOS = [
 12      "<|endoftext|>",
 13      "<|endofmask|>",
 14      "</s>",
 15  ]
 16  
 17  
 18  class HumanEval:
 19      """
 20      HumanEval evaluation class.
 21      """
 22  
 23      def __init__(
 24          self,
 25          data_root,
 26          max_seq_len=2048,
 27          language="python",
 28          max_gen_len=200,
 29          log_dir=None,
 30          temperature=0,
 31          top_p=0.95,
 32          n_sample=40,
 33          k_sample=1,
 34          no_batching=False,
 35      ):
 36          self.data_root = data_root
 37          self.max_seq_len = max_seq_len
 38          self.max_gen_len = max_gen_len
 39          self.k = k_sample
 40          self.n_sample = n_sample
 41          self.language = language
 42          self.log_dir = log_dir
 43          self.temperature = temperature
 44          self.top_p = top_p
 45          self.sft = False
 46          self.no_batching = no_batching
 47  
 48          self.eos = COMMON_EOS
 49          print(f"EOS: {self.eos}")
 50  
 51          os.makedirs(self.log_dir, exist_ok=True)
 52  
 53      def eval_model(self, llm: LLM):
 54          assert self.log_dir is not None, "log_dir should not be None when evaluating humaneval"
 55          dataset = HumanEvalDataset(self.data_root, sample_num=self.n_sample, language=self.language, issft=self.sft)
 56          if self.k > 1:
 57              assert self.n_sample >= 100, "HumanEval PASS@100 needs n_sample >= 100"
 58  
 59          sampling_params = SamplingParams(max_tokens=self.max_gen_len, temperature=self.temperature, top_p=self.top_p, stop=self.eos)
 60  
 61          # Generate.
 62          with Path(self.log_file_path).open("w") as f_log:
 63              prompts = [data["prompt"] + "\n" for data in dataset]
 64  
 65              if self.no_batching:
 66                  print(f"Disable: continuous batching. This will be slow.")
 67                  generated = [llm.generate(prompt, sampling_params, use_tqdm=False) for prompt in tqdm(prompts)]
 68              else:
 69                  print(f"Enable: continuous batching.")
 70                  generated = llm.generate(prompts, sampling_params, use_tqdm=True)
 71  
 72              for idx, output in enumerate(generated):
 73                  data = dataset[idx]
 74                  # suffixprediction = output.outputs[0].text.replace("\t", "    ")
 75                  suffixprediction = output.outputs[0].text
 76                  prediction = output.prompt + suffixprediction
 77                  suffixprediction = cleanup_code(suffixprediction, self.language, "humaneval", self.sft, dataset.stopwords)
 78                  original_prompt = data["original_prompt"]
 79                  if not self.sft:
 80                      suffixprediction = original_prompt + "\n" + suffixprediction
 81                  res = {
 82                      "task_id": data["task_id"],
 83                      "generation": suffixprediction,
 84                      "prompt": original_prompt,
 85                      "wholecode": prediction,
 86                  }
 87                  f_log.write(json.dumps(res, ensure_ascii=False) + "\n")
 88  
 89          # Aggregate scores.
 90          self._calculate_final_score()
 91  
 92      @property
 93      def log_file_path(self) -> str:
 94          return os.path.join(self.log_dir, f"pred_{self.language}_output.jsonl")
 95  
 96      def _calculate_final_score(self):
 97          timeout = 10
 98          res, details = evaluate_functional_correctness(
 99              input_file=self.log_file_path,
100              problem_file=os.path.join(self.data_root, f"humaneval-{self.language}.jsonl"),
101              tmp_dir=self.log_dir,
102              timeout=timeout,
103              language=self.language,
104          )
105          print(f"{self.language} score is", res["pass@%d" % self.k])
106  
107          details_file = os.path.join(self.log_dir, f"humaneval-{self.language}-details.json")
108          with Path(details_file).open("w") as f:
109              json.dump(details, f, ensure_ascii=False, indent=2)
110  
111          print(f"Details => {details_file}")