generate.py
1 import os 2 import json 3 import argparse 4 5 from model import DecoderBase, make_model 6 from data import get_bigcodebench, write_jsonl 7 from rich.progress import ( 8 BarColumn, 9 MofNCompleteColumn, 10 Progress, 11 TextColumn, 12 TimeElapsedColumn, 13 ) 14 15 16 def codegen( 17 model: DecoderBase, 18 save_path: str, 19 split: str, 20 subset="full", 21 greedy=False, 22 strip_newlines=False, 23 n_samples=1, 24 id_range=None, 25 resume=True, 26 ): 27 28 dataset = get_bigcodebench(subset=subset) 29 30 if model.is_direct_completion() and split == "instruct": 31 raise Exception("Base model does not support direct completion for instruct tasks") 32 33 # create save_path if it doesn't exist, e.g., a/b.jsonl 34 dirname = os.path.dirname(save_path) 35 if not os.path.exists(dirname) and dirname != "": 36 os.makedirs(dirname) 37 38 task_ids, prompts, complete_prompts = [], [], [] 39 for task_id, task in dataset.items(): 40 task_ids.append(task_id) 41 complete_prompts.append(task["complete_prompt"]) 42 prompt = task[f"{split}_prompt"] 43 prompt = prompt.strip("\n") if strip_newlines else prompt 44 prompts.append(prompt) 45 46 outputs = model.codegen(prompts, do_sample=not greedy, num_samples=n_samples) 47 assert outputs, "No outputs from model!" 48 49 samples = [] 50 for task_id, complete_prompt, completion in zip(task_ids, complete_prompts, outputs): 51 if model.is_direct_completion(): 52 samples.append(dict(task_id=task_id, solution=complete_prompt + completion)) 53 else: 54 samples.append(dict(task_id=task_id, solution=completion)) 55 56 print(f"Generated {len(samples)} samples") 57 write_jsonl(save_path, samples) 58 59 60 def main(): 61 parser = argparse.ArgumentParser() 62 parser.add_argument("--model", required=True, type=str) 63 parser.add_argument("--split", required=True, type=str, choices=["complete", "instruct"]) 64 parser.add_argument("--subset", default="full", type=str, choices=["full", "hard"]) 65 parser.add_argument("--save_path", default=None, type=str) 66 parser.add_argument("--bs", default=1, type=int) 67 parser.add_argument("--n_samples", default=1, type=int) 68 parser.add_argument("--temperature", default=0.0, type=float) 69 parser.add_argument("--greedy", action="store_true") 70 parser.add_argument("--strip_newlines", action="store_true") 71 parser.add_argument("--resume", action="store_true") 72 parser.add_argument("--id_range", nargs=2, type=int) 73 parser.add_argument("--backend", default="vllm", type=str, choices=["vllm", "hf", "openai", "mistral", "anthropic", "google"]) 74 parser.add_argument("--base_url", default=None, type=str) 75 parser.add_argument("--tp", default=1, type=int) 76 parser.add_argument("--trust_remote_code", action="store_true") 77 parser.add_argument("--tokenizer_legacy", action="store_true") 78 parser.add_argument("--tokenizer_name", default=None, type=str) 79 80 args = parser.parse_args() 81 82 if args.greedy or (args.temperature == 0 and args.n_samples == 1): 83 args.temperature = 0 84 args.bs = 1 85 args.n_samples = 1 86 args.greedy = True 87 print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0") 88 89 if args.id_range is not None: 90 assert len(args.id_range) == 2, "id_range must be a list of length 2" 91 assert args.id_range[0] < args.id_range[1], "id_range must be increasing" 92 args.id_range = tuple(args.id_range) 93 94 # Make dir for codes generated by each model 95 model_runner = make_model( 96 model=args.model, 97 backend=args.backend, 98 batch_size=args.bs, 99 temperature=args.temperature, 100 base_url=args.base_url, 101 tp=args.tp, 102 trust_remote_code=args.trust_remote_code, 103 tokenizer_name=args.tokenizer_name, 104 tokenizer_legacy=args.tokenizer_legacy, 105 ) 106 107 extra = "-" + args.subset if args.subset != "full" else "" 108 if not args.save_path: 109 save_path = args.model.replace("/", "--") + f"--bigcodebench{extra}-{args.split}--{args.backend}-{args.temperature}-{args.n_samples}.jsonl" 110 else: 111 save_path = args.save_path 112 113 codegen( 114 model=model_runner, 115 save_path=save_path, 116 split=args.split, 117 subset=args.subset, 118 greedy=args.greedy, 119 strip_newlines=args.strip_newlines, 120 n_samples=args.n_samples, 121 resume=args.resume, 122 id_range=args.id_range, 123 ) 124 125 126 if __name__ == "__main__": 127 main()