generate.py
1 import os 2 import json 3 import argparse 4 5 from model import DecoderBase, make_model 6 from data import get_bigcodebench, write_jsonl 7 import gc 8 import torch 9 from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel 10 11 12 def codegen( 13 model: DecoderBase, 14 save_path: str, 15 split: str, 16 subset="full", 17 greedy=False, 18 strip_newlines=False, 19 n_samples=1, 20 id_range=None, 21 resume=True, 22 ): 23 24 dataset = get_bigcodebench(subset=subset) 25 26 if model.is_direct_completion() and split == "instruct": 27 raise Exception("Base model does not support direct completion for instruct tasks") 28 29 # create save_path if it doesn't exist, e.g., a/b.jsonl 30 dirname = os.path.dirname(save_path) 31 if not os.path.exists(dirname) and dirname != "": 32 os.makedirs(dirname) 33 34 task_ids, prompts, complete_prompts = [], [], [] 35 for task_id, task in dataset.items(): 36 37 task_ids.append(task_id) 38 complete_prompts.append(task["complete_prompt"]) 39 prompt = task[f"{split}_prompt"] 40 prompt = prompt.strip("\n") if strip_newlines else prompt 41 prompts.append(prompt) 42 43 outputs = model.codegen(prompts, do_sample=not greedy, num_samples=n_samples) 44 assert outputs, "No outputs from model!" 45 46 samples = [] 47 for task_id, complete_prompt, completion in zip(task_ids, complete_prompts, outputs): 48 if model.is_direct_completion(): 49 generated = complete_prompt + completion 50 else: 51 generated = completion 52 samples.append(dict(task_id=task_id, solution=generated)) 53 # print(f"[{generated}]") 54 55 print(f"Generated {len(samples)} samples") 56 write_jsonl(save_path, samples) 57 58 59 def main(): 60 parser = argparse.ArgumentParser() 61 parser.add_argument("--model", required=True, type=str) 62 parser.add_argument("--split", required=True, type=str, choices=["complete", "instruct"]) 63 parser.add_argument("--subset", default="full", type=str, choices=["full", "hard"]) 64 parser.add_argument("--save_path", default=None, type=str) 65 parser.add_argument("--bs", default=1, type=int) 66 parser.add_argument("--n_samples", default=1, type=int) 67 parser.add_argument("--temperature", default=0.0, type=float) 68 parser.add_argument("--greedy", action="store_true") 69 parser.add_argument("--strip_newlines", action="store_true") 70 parser.add_argument("--resume", action="store_true") 71 parser.add_argument("--chat_mode", action="store_true", default=False) 72 parser.add_argument("--id_range", nargs=2, type=int) 73 parser.add_argument("--backend", default="vllm", type=str, choices=["vllm", "hf", "openai", "mistral", "anthropic", "google"]) 74 parser.add_argument("--base_url", default=None, type=str) 75 parser.add_argument("--tp", default=1, type=int) 76 parser.add_argument("--tokenizer_legacy", action="store_true") 77 parser.add_argument("--tokenizer_name", default=None, type=str) 78 79 args = parser.parse_args() 80 81 if args.greedy or (args.temperature == 0 and args.n_samples == 1): 82 args.temperature = 0 83 args.bs = 1 84 args.n_samples = 1 85 args.greedy = True 86 print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0") 87 88 if args.id_range is not None: 89 assert len(args.id_range) == 2, "id_range must be a list of length 2" 90 assert args.id_range[0] < args.id_range[1], "id_range must be increasing" 91 args.id_range = tuple(args.id_range) 92 93 # Make dir for codes generated by each model 94 model_runner = make_model( 95 model=args.model, 96 backend=args.backend, 97 batch_size=args.bs, 98 temperature=args.temperature, 99 base_url=args.base_url, 100 tp=args.tp, 101 tokenizer_name=args.tokenizer_name, 102 tokenizer_legacy=args.tokenizer_legacy, 103 chat_mode=args.chat_mode, 104 ) 105 106 extra = "-" + args.subset if args.subset != "full" else "" 107 if not args.save_path: 108 save_path = args.model.replace("/", "--") + f"--bigcodebench{extra}-{args.split}--{args.backend}-{args.temperature}-{args.n_samples}.jsonl" 109 else: 110 save_path = args.save_path 111 112 codegen( 113 model=model_runner, 114 save_path=save_path, 115 split=args.split, 116 subset=args.subset, 117 greedy=args.greedy, 118 strip_newlines=args.strip_newlines, 119 n_samples=args.n_samples, 120 resume=args.resume, 121 id_range=args.id_range, 122 ) 123 124 if args.backend == "vllm": 125 print(f"Try cleanup...") 126 destroy_model_parallel() 127 destroy_distributed_environment() 128 del model_runner.llm.llm_engine.model_executor 129 del model_runner.llm 130 gc.collect() 131 torch.cuda.empty_cache() 132 133 print(f"===============================") 134 print(f"----- END OF BigCodeBench -----") 135 136 137 if __name__ == "__main__": 138 main()