generate.py
  1  import os
  2  import json
  3  import argparse
  4  
  5  from model import DecoderBase, make_model
  6  from data import get_bigcodebench, write_jsonl
  7  from rich.progress import (
  8      BarColumn,
  9      MofNCompleteColumn,
 10      Progress,
 11      TextColumn,
 12      TimeElapsedColumn,
 13  )
 14  
 15  
 16  def codegen(
 17      model: DecoderBase,
 18      save_path: str,
 19      split: str,
 20      subset="full",
 21      greedy=False,
 22      strip_newlines=False,
 23      n_samples=1,
 24      id_range=None,
 25      resume=True,
 26  ):
 27  
 28      dataset = get_bigcodebench(subset=subset)
 29  
 30      if model.is_direct_completion() and split == "instruct":
 31          raise Exception("Base model does not support direct completion for instruct tasks")
 32  
 33      # create save_path if it doesn't exist, e.g., a/b.jsonl
 34      dirname = os.path.dirname(save_path)
 35      if not os.path.exists(dirname) and dirname != "":
 36          os.makedirs(dirname)
 37  
 38      task_ids, prompts, complete_prompts = [], [], []
 39      for task_id, task in dataset.items():
 40          task_ids.append(task_id)
 41          complete_prompts.append(task["complete_prompt"])
 42          prompt = task[f"{split}_prompt"]
 43          prompt = prompt.strip("\n") if strip_newlines else prompt
 44          prompts.append(prompt)
 45  
 46      outputs = model.codegen(prompts, do_sample=not greedy, num_samples=n_samples)
 47      assert outputs, "No outputs from model!"
 48  
 49      samples = []
 50      for task_id, complete_prompt, completion in zip(task_ids, complete_prompts, outputs):
 51          if model.is_direct_completion():
 52              samples.append(dict(task_id=task_id, solution=complete_prompt + completion))
 53          else:
 54              samples.append(dict(task_id=task_id, solution=completion))
 55  
 56      print(f"Generated {len(samples)} samples")
 57      write_jsonl(save_path, samples)
 58  
 59  
 60  def main():
 61      parser = argparse.ArgumentParser()
 62      parser.add_argument("--model", required=True, type=str)
 63      parser.add_argument("--split", required=True, type=str, choices=["complete", "instruct"])
 64      parser.add_argument("--subset", default="full", type=str, choices=["full", "hard"])
 65      parser.add_argument("--save_path", default=None, type=str)
 66      parser.add_argument("--bs", default=1, type=int)
 67      parser.add_argument("--n_samples", default=1, type=int)
 68      parser.add_argument("--temperature", default=0.0, type=float)
 69      parser.add_argument("--greedy", action="store_true")
 70      parser.add_argument("--strip_newlines", action="store_true")
 71      parser.add_argument("--resume", action="store_true")
 72      parser.add_argument("--id_range", nargs=2, type=int)
 73      parser.add_argument("--backend", default="vllm", type=str, choices=["vllm", "hf", "openai", "mistral", "anthropic", "google"])
 74      parser.add_argument("--base_url", default=None, type=str)
 75      parser.add_argument("--tp", default=1, type=int)
 76      parser.add_argument("--trust_remote_code", action="store_true")
 77      parser.add_argument("--tokenizer_legacy", action="store_true")
 78      parser.add_argument("--tokenizer_name", default=None, type=str)
 79  
 80      args = parser.parse_args()
 81  
 82      if args.greedy or (args.temperature == 0 and args.n_samples == 1):
 83          args.temperature = 0
 84          args.bs = 1
 85          args.n_samples = 1
 86          args.greedy = True
 87          print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0")
 88  
 89      if args.id_range is not None:
 90          assert len(args.id_range) == 2, "id_range must be a list of length 2"
 91          assert args.id_range[0] < args.id_range[1], "id_range must be increasing"
 92          args.id_range = tuple(args.id_range)
 93  
 94      # Make dir for codes generated by each model
 95      model_runner = make_model(
 96          model=args.model,
 97          backend=args.backend,
 98          batch_size=args.bs,
 99          temperature=args.temperature,
100          base_url=args.base_url,
101          tp=args.tp,
102          trust_remote_code=args.trust_remote_code,
103          tokenizer_name=args.tokenizer_name,
104          tokenizer_legacy=args.tokenizer_legacy,
105      )
106  
107      extra = "-" + args.subset if args.subset != "full" else ""
108      if not args.save_path:
109          save_path = args.model.replace("/", "--") + f"--bigcodebench{extra}-{args.split}--{args.backend}-{args.temperature}-{args.n_samples}.jsonl"
110      else:
111          save_path = args.save_path
112  
113      codegen(
114          model=model_runner,
115          save_path=save_path,
116          split=args.split,
117          subset=args.subset,
118          greedy=args.greedy,
119          strip_newlines=args.strip_newlines,
120          n_samples=args.n_samples,
121          resume=args.resume,
122          id_range=args.id_range,
123      )
124  
125  
126  if __name__ == "__main__":
127      main()