generate.py
  1  import os
  2  import json
  3  import argparse
  4  
  5  from model import DecoderBase, make_model
  6  from data import get_bigcodebench, write_jsonl
  7  import gc
  8  import torch
  9  from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
 10  
 11  
 12  def codegen(
 13      model: DecoderBase,
 14      save_path: str,
 15      split: str,
 16      subset="full",
 17      greedy=False,
 18      strip_newlines=False,
 19      n_samples=1,
 20      id_range=None,
 21      resume=True,
 22  ):
 23  
 24      dataset = get_bigcodebench(subset=subset)
 25  
 26      if model.is_direct_completion() and split == "instruct":
 27          raise Exception("Base model does not support direct completion for instruct tasks")
 28  
 29      # create save_path if it doesn't exist, e.g., a/b.jsonl
 30      dirname = os.path.dirname(save_path)
 31      if not os.path.exists(dirname) and dirname != "":
 32          os.makedirs(dirname)
 33  
 34      task_ids, prompts, complete_prompts = [], [], []
 35      for task_id, task in dataset.items():
 36  
 37          task_ids.append(task_id)
 38          complete_prompts.append(task["complete_prompt"])
 39          prompt = task[f"{split}_prompt"]
 40          prompt = prompt.strip("\n") if strip_newlines else prompt
 41          prompts.append(prompt)
 42  
 43      outputs = model.codegen(prompts, do_sample=not greedy, num_samples=n_samples)
 44      assert outputs, "No outputs from model!"
 45  
 46      samples = []
 47      for task_id, complete_prompt, completion in zip(task_ids, complete_prompts, outputs):
 48          if model.is_direct_completion():
 49              generated = complete_prompt + completion
 50          else:
 51              generated = completion
 52          samples.append(dict(task_id=task_id, solution=generated))
 53          # print(f"[{generated}]")
 54  
 55      print(f"Generated {len(samples)} samples")
 56      write_jsonl(save_path, samples)
 57  
 58  
 59  def main():
 60      parser = argparse.ArgumentParser()
 61      parser.add_argument("--model", required=True, type=str)
 62      parser.add_argument("--split", required=True, type=str, choices=["complete", "instruct"])
 63      parser.add_argument("--subset", default="full", type=str, choices=["full", "hard"])
 64      parser.add_argument("--save_path", default=None, type=str)
 65      parser.add_argument("--bs", default=1, type=int)
 66      parser.add_argument("--n_samples", default=1, type=int)
 67      parser.add_argument("--temperature", default=0.0, type=float)
 68      parser.add_argument("--greedy", action="store_true")
 69      parser.add_argument("--strip_newlines", action="store_true")
 70      parser.add_argument("--resume", action="store_true")
 71      parser.add_argument("--chat_mode", action="store_true", default=False)
 72      parser.add_argument("--id_range", nargs=2, type=int)
 73      parser.add_argument("--backend", default="vllm", type=str, choices=["vllm", "hf", "openai", "mistral", "anthropic", "google"])
 74      parser.add_argument("--base_url", default=None, type=str)
 75      parser.add_argument("--tp", default=1, type=int)
 76      parser.add_argument("--tokenizer_legacy", action="store_true")
 77      parser.add_argument("--tokenizer_name", default=None, type=str)
 78  
 79      args = parser.parse_args()
 80  
 81      if args.greedy or (args.temperature == 0 and args.n_samples == 1):
 82          args.temperature = 0
 83          args.bs = 1
 84          args.n_samples = 1
 85          args.greedy = True
 86          print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0")
 87  
 88      if args.id_range is not None:
 89          assert len(args.id_range) == 2, "id_range must be a list of length 2"
 90          assert args.id_range[0] < args.id_range[1], "id_range must be increasing"
 91          args.id_range = tuple(args.id_range)
 92  
 93      # Make dir for codes generated by each model
 94      model_runner = make_model(
 95          model=args.model,
 96          backend=args.backend,
 97          batch_size=args.bs,
 98          temperature=args.temperature,
 99          base_url=args.base_url,
100          tp=args.tp,
101          tokenizer_name=args.tokenizer_name,
102          tokenizer_legacy=args.tokenizer_legacy,
103          chat_mode=args.chat_mode,
104      )
105  
106      extra = "-" + args.subset if args.subset != "full" else ""
107      if not args.save_path:
108          save_path = args.model.replace("/", "--") + f"--bigcodebench{extra}-{args.split}--{args.backend}-{args.temperature}-{args.n_samples}.jsonl"
109      else:
110          save_path = args.save_path
111  
112      codegen(
113          model=model_runner,
114          save_path=save_path,
115          split=args.split,
116          subset=args.subset,
117          greedy=args.greedy,
118          strip_newlines=args.strip_newlines,
119          n_samples=args.n_samples,
120          resume=args.resume,
121          id_range=args.id_range,
122      )
123  
124      if args.backend == "vllm":
125          print(f"Try cleanup...")
126          destroy_model_parallel()
127          destroy_distributed_environment()
128          del model_runner.llm.llm_engine.model_executor
129          del model_runner.llm
130          gc.collect()
131          torch.cuda.empty_cache()
132  
133      print(f"===============================")
134      print(f"----- END OF BigCodeBench -----")
135  
136  
137  if __name__ == "__main__":
138      main()