model.py
1 import json 2 import os 3 from abc import ABC, abstractmethod 4 from typing import List 5 from warnings import warn 6 import torch 7 from transformers import AutoModelForCausalLM, AutoTokenizer 8 from vllm import LLM, SamplingParams 9 10 EOS = [ 11 "<|endoftext|>", 12 "<|endofmask|>", 13 "</s>", 14 "\nif __name__", 15 "\ndef main(", 16 "\nprint(", 17 ] 18 19 20 def extra_eos_for_direct_completion(dataset) -> List[str]: 21 if dataset.lower() == "bigcodebench": 22 return ["\ndef ", "\nclass ", "\nimport ", "\nfrom ", "\nassert "] 23 raise ValueError(f"Unknown dataset: {dataset}") 24 25 26 # some random words which serves as the splitter 27 _MAGIC_SPLITTER_ = "-[[]]-this-is-really-our-highest-priority-[[]]-" 28 29 30 def make_chat_prompt( 31 prompt: str, 32 tokenizer: AutoTokenizer, 33 chat_mode, 34 ) -> str: 35 if not chat_mode: # complete tasks 36 return prompt 37 38 prompt = f"""\ 39 Please provide a self-contained Python script that solves the following problem in a markdown code block: 40 ``` 41 {prompt.strip()} 42 ``` 43 """ 44 response = f"""\ 45 Below is a Python script with a self-contained function that solves the problem and passes corresponding tests: 46 ```python 47 {_MAGIC_SPLITTER_} 48 ``` 49 """ 50 prompt = tokenizer.apply_chat_template( 51 [ 52 {"role": "user", "content": prompt}, 53 {"role": "assistant", "content": response}, 54 ], 55 tokenize=False, 56 ).split( 57 _MAGIC_SPLITTER_ 58 )[0] 59 return prompt 60 61 62 class DecoderBase(ABC): 63 64 def __init__( 65 self, 66 name: str, 67 batch_size: int = 1, 68 temperature: float = 0.8, 69 max_new_tokens: int = 1280, 70 dtype: str = "bfloat16", # default 71 trust_remote_code: bool = True, 72 tokenizer_name: str = None, 73 tokenizer_legacy: bool = False, 74 chat_mode: bool = False, 75 ) -> None: 76 print("Initializing a decoder model: {} ...".format(name)) 77 self.name = name 78 self.batch_size = batch_size 79 self.temperature = temperature 80 self.eos = EOS 81 self.skip_special_tokens = False 82 self.max_new_tokens = max_new_tokens 83 self.dtype = dtype 84 self.trust_remote_code = trust_remote_code 85 self.tokenizer_name = tokenizer_name 86 self.tokenizer_legacy = tokenizer_legacy 87 self.chat_mode = chat_mode 88 89 @abstractmethod 90 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 91 pass 92 93 @abstractmethod 94 def is_direct_completion(self) -> bool: 95 pass 96 97 def __repr__(self) -> str: 98 return self.name 99 100 def __str__(self) -> str: 101 return self.name 102 103 104 class VllmDecoder(DecoderBase): 105 106 def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None: 107 super().__init__(name, **kwargs) 108 109 kwargs = { 110 "tensor_parallel_size": int(os.getenv("VLLM_N_GPUS", tp)), 111 "dtype": self.dtype, 112 "trust_remote_code": self.trust_remote_code, 113 "gpu_memory_utilization": 0.95, 114 "enforce_eager": True, 115 "distributed_executor_backend": "ray", 116 } 117 if self.tokenizer_name is None: 118 self.tokenizer_name = self.name 119 120 self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs, legacy=self.tokenizer_legacy) 121 if not self.chat_mode: 122 self.eos += extra_eos_for_direct_completion(dataset) 123 self.llm = LLM(model=name, max_model_len=2048, **kwargs) 124 self.llm.set_tokenizer(tokenizer=self.tokenizer) 125 126 def is_direct_completion(self) -> bool: 127 return not self.chat_mode 128 129 def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200) -> List[str]: 130 if do_sample: 131 assert self.temperature > 0, "Temperature must be greater than 0!" 132 133 vllm_outputs = self.llm.generate( 134 prompts, 135 SamplingParams( 136 temperature=self.temperature, 137 max_tokens=self.max_new_tokens, 138 top_p=0.95 if do_sample else 1.0, 139 stop=self.eos, 140 ), 141 use_tqdm=True, 142 ) 143 144 gen_strs = [x.outputs[0].text.replace("\t", " ") for x in vllm_outputs] 145 return gen_strs 146 147 148 class GeneralVllmDecoder(VllmDecoder): 149 150 def __init__(self, name: str, **kwargs) -> None: 151 super().__init__(name, **kwargs) 152 self.eos += ["\n```\n", "```"] 153 print(f"EOS strings: {self.eos}") 154 155 def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200) -> List[str]: 156 chat_prompts = [ 157 make_chat_prompt( 158 prompt, 159 self.tokenizer, 160 self.chat_mode, 161 ) 162 for prompt in prompts 163 ] 164 return VllmDecoder.codegen(self, chat_prompts, do_sample, num_samples) 165 166 167 class HfTorchDecoder(DecoderBase): 168 169 def __init__(self, name: str, dataset: str, **kwargs): 170 super().__init__(name=name, **kwargs) 171 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 172 173 kwargs = {} 174 kwargs["device_map"] = "auto" 175 kwargs["trust_remote_code"] = self.trust_remote_code 176 # string to torch dtype 177 kwargs["torch_dtype"] = getattr(torch, self.dtype) 178 self.skip_special_tokens = True 179 180 print(f"{kwargs = }", self.tokenizer_name) 181 if self.tokenizer_name is None: 182 self.tokenizer_name = self.name 183 184 self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs, legacy=self.tokenizer_legacy) 185 186 if not self.chat_mode: 187 self.eos += extra_eos_for_direct_completion(dataset) 188 189 self.model = AutoModelForCausalLM.from_pretrained(name, **kwargs) 190 self.model = self.model.to(self.device) 191 192 def is_direct_completion(self) -> bool: 193 return not self.chat_mode 194 195 @torch.inference_mode() 196 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 197 if self.temperature == 0: 198 assert not do_sample 199 assert num_samples == 1 200 201 input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) 202 kwargs = {} 203 if do_sample: 204 kwargs["top_p"] = 0.95 205 kwargs["temperature"] = self.temperature 206 207 outputs = self.model.generate( 208 input_tokens, 209 max_new_tokens=self.max_new_tokens, 210 do_sample=do_sample, 211 num_return_sequences=min(self.batch_size, num_samples), 212 pad_token_id=self.tokenizer.eos_token_id, 213 stop_strings=self.eos, 214 tokenizer=self.tokenizer, 215 **kwargs, 216 ) 217 218 gen_strs = self.tokenizer.batch_decode( 219 outputs[:, input_tokens.size(-1) :], 220 skip_special_tokens=self.skip_special_tokens, 221 ) 222 outputs = [] 223 # removes eos tokens. 224 for output in gen_strs: 225 min_index = 10000 226 for eos in self.eos: 227 if eos in output: 228 min_index = min(min_index, output.index(eos)) 229 outputs.append(output[:min_index].replace("\t", " ")) 230 return outputs 231 232 233 class GenenralHfTorchDecoder(HfTorchDecoder): 234 235 def __init__(self, name: str, **kwargs): 236 super().__init__(name=name, **kwargs) 237 self.eos += ["\n```\n", "```"] 238 print(f"EOS strings: {self.eos}") 239 self.tokenizer = AutoTokenizer.from_pretrained( 240 self.tokenizer_name if self.tokenizer_name else self.name, **kwargs, legacy=self.tokenizer_legacy 241 ) 242 243 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 244 prompt = make_chat_prompt(prompt, self.tokenizer, self.chat_mode) 245 return HfTorchDecoder.codegen(self, prompt, do_sample, num_samples) 246 247 248 def make_model( 249 model: str, 250 backend: str, 251 dataset: str = "bigcodebench", 252 batch_size: int = 1, 253 temperature: float = 0.0, 254 tp=1, 255 base_url=None, 256 trust_remote_code=True, 257 tokenizer_name=None, 258 tokenizer_legacy=True, 259 chat_mode=False, 260 ): 261 print(f"{chat_mode = }") 262 if backend == "vllm": 263 return GeneralVllmDecoder( 264 name=model, 265 batch_size=batch_size, 266 temperature=temperature, 267 dataset=dataset, 268 tp=tp, 269 trust_remote_code=trust_remote_code, 270 tokenizer_name=tokenizer_name, 271 tokenizer_legacy=tokenizer_legacy, 272 chat_mode=chat_mode, 273 ) 274 elif backend == "hf": 275 return GenenralHfTorchDecoder( 276 name=model, 277 batch_size=batch_size, 278 temperature=temperature, 279 dataset=dataset, 280 trust_remote_code=trust_remote_code, 281 tokenizer_name=tokenizer_name, 282 tokenizer_legacy=tokenizer_legacy, 283 chat_mode=chat_mode, 284 )