model.py
1 import json 2 import os 3 from abc import ABC, abstractmethod 4 from typing import List 5 from warnings import warn 6 7 import torch 8 from stop_sequencer import StopSequencer 9 from tqdm import tqdm 10 from transformers import AutoModelForCausalLM, AutoTokenizer 11 from vllm import LLM, SamplingParams 12 13 EOS = [ 14 "<|endoftext|>", 15 "<|endofmask|>", 16 "</s>", 17 "\nif __name__", 18 "\ndef main(", 19 "\nprint(", 20 ] 21 22 23 def extra_eos_for_direct_completion(dataset) -> List[str]: 24 print(f"eos: {dataset = }") 25 if dataset.lower() == "humaneval": 26 return ["\ndef ", "\nclass ", "\nimport ", "\nfrom ", "\nassert "] 27 elif dataset.lower() == "mbpp": 28 return ['\n"""', "\nassert"] 29 raise ValueError(f"Unknown dataset: {dataset}") 30 31 32 # some random words which serves as the splitter 33 _MAGIC_SPLITTER_ = "-[[]]-this-is-really-our-highest-priority-[[]]-" 34 35 36 def make_chat_prompt( 37 task_prompt: str, 38 instruction_prefix: str, 39 response_prefix: str, 40 tokenizer: AutoTokenizer, 41 chat_mode=False, 42 ) -> str: 43 44 # if tokenizer.chat_template is None: 45 if not chat_mode: 46 return task_prompt 47 48 assert instruction_prefix is not None, "Instruction prefix is required!" 49 assert response_prefix is not None, "Response prefix is required!" 50 51 task_prompt = f"""\ 52 {instruction_prefix} 53 ``` 54 {task_prompt.strip()} 55 ``` 56 """ 57 response = f"""\ 58 {response_prefix} 59 ```python 60 {_MAGIC_SPLITTER_} 61 ``` 62 """ 63 task_prompt = tokenizer.apply_chat_template( 64 [ 65 {"role": "user", "content": task_prompt}, 66 {"role": "assistant", "content": response}, 67 ], 68 tokenize=False, 69 ).split(_MAGIC_SPLITTER_)[0] 70 return task_prompt 71 72 73 class DecoderBase(ABC): 74 75 def __init__( 76 self, 77 name: str, 78 batch_size: int = 1, 79 temperature: float = 0.8, 80 max_new_tokens: int = 768, 81 dtype: str = "bfloat16", # default 82 trust_remote_code: bool = True, 83 instruction_prefix: str = None, 84 response_prefix: str = None, 85 chat_mode=False, 86 ) -> None: 87 print("Initializing a decoder model: {} ...".format(name)) 88 self.name = name 89 self.batch_size = batch_size 90 self.temperature = temperature 91 self.eos = EOS 92 self.skip_special_tokens = False 93 self.max_new_tokens = max_new_tokens 94 self.dtype = dtype 95 self.trust_remote_code = trust_remote_code 96 self.instruction_prefix = instruction_prefix 97 self.response_prefix = response_prefix 98 self.chat_mode = chat_mode 99 100 print(f"{self.chat_mode = }") 101 102 @abstractmethod 103 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 104 pass 105 106 @abstractmethod 107 def is_direct_completion(self) -> bool: 108 pass 109 110 def __repr__(self) -> str: 111 return self.name 112 113 def __str__(self) -> str: 114 return self.name 115 116 117 class VllmDecoder(DecoderBase): 118 119 def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None: 120 super().__init__(name, **kwargs) 121 122 kwargs = { 123 "tensor_parallel_size": int(os.getenv("VLLM_N_GPUS", tp)), 124 "dtype": self.dtype, 125 "trust_remote_code": self.trust_remote_code, 126 "gpu_memory_utilization": 0.95, 127 "enforce_eager": True, 128 "distributed_executor_backend": "ray", 129 } 130 131 self.tokenizer = AutoTokenizer.from_pretrained(self.name) 132 if not self.chat_mode: 133 self.eos += extra_eos_for_direct_completion(dataset) 134 self.llm = LLM(model=name, max_model_len=2048, **kwargs) 135 136 def is_direct_completion(self) -> bool: 137 return not self.chat_mode 138 139 def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200, use_tqdm=True) -> List[str]: 140 if do_sample: 141 assert self.temperature > 0, "Temperature must be greater than 0!" 142 143 vllm_outputs = self.llm.generate( 144 prompts, 145 SamplingParams( 146 temperature=self.temperature, 147 max_tokens=self.max_new_tokens, 148 top_p=0.95 if do_sample else 1.0, 149 stop=self.eos, 150 ), 151 use_tqdm=use_tqdm, 152 ) 153 154 gen_strs = [x.outputs[0].text.replace("\t", " ") for x in vllm_outputs] 155 return gen_strs 156 157 158 class GeneralVllmDecoder(VllmDecoder): 159 160 def __init__(self, name: str, no_batching: bool, **kwargs) -> None: 161 super().__init__(name, **kwargs) 162 163 self.no_batching = no_batching 164 165 self.eos += ["\n```\n", "```"] 166 print(f"EOS strings: {self.eos}") 167 168 def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200) -> List[str]: 169 if self.no_batching: 170 print(f"{self.no_batching = }, disable continous batching. This may be slow.") 171 generated = [] 172 for prompt in tqdm(prompts): 173 chat_prompt = make_chat_prompt( 174 prompt, 175 instruction_prefix=self.instruction_prefix, 176 response_prefix=self.response_prefix, 177 tokenizer=self.tokenizer, 178 chat_mode=self.chat_mode, 179 ) 180 generated.append(VllmDecoder.codegen(self, chat_prompt, do_sample, num_samples, use_tqdm=False)[0]) 181 else: 182 print(f"{self.no_batching = }, enable continous batching.") 183 chat_prompts = [ 184 make_chat_prompt( 185 prompt, 186 instruction_prefix=self.instruction_prefix, 187 response_prefix=self.response_prefix, 188 tokenizer=self.tokenizer, 189 chat_mode=self.chat_mode, 190 ) 191 for prompt in prompts 192 ] 193 generated = VllmDecoder.codegen(self, chat_prompts, do_sample, num_samples) 194 195 return generated 196 197 198 class HfTorchDecoder(DecoderBase): 199 200 def __init__(self, name: str, dataset: str, **kwargs): 201 super().__init__(name=name, **kwargs) 202 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 203 204 kwargs = {} 205 kwargs["device_map"] = "auto" 206 kwargs["trust_remote_code"] = self.trust_remote_code 207 # string to torch dtype 208 kwargs["torch_dtype"] = getattr(torch, self.dtype) 209 self.skip_special_tokens = True 210 211 print(f"{kwargs = }") 212 213 self.tokenizer = AutoTokenizer.from_pretrained(name) 214 if not self.chat_mode: 215 self.eos += extra_eos_for_direct_completion(dataset) 216 217 self.model = AutoModelForCausalLM.from_pretrained(name, **kwargs) 218 self.model = self.model.to(self.device) 219 220 def is_direct_completion(self) -> bool: 221 return not self.chat_mode 222 223 @torch.inference_mode() 224 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 225 if self.temperature == 0: 226 assert not do_sample 227 assert num_samples == 1 228 229 input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) 230 kwargs = {} 231 if do_sample: 232 kwargs["top_p"] = 0.95 233 kwargs["temperature"] = self.temperature 234 235 stop_sequencer = StopSequencer( 236 self.model, 237 model_type="causal", # or seq2seq 238 tokenizer=self.tokenizer, 239 ) 240 241 model = stop_sequencer.register_stop_texts( 242 stop_texts=self.eos, 243 input_length=input_tokens.size(-1), 244 ) 245 246 outputs = model.generate( 247 input_tokens, 248 max_new_tokens=self.max_new_tokens, 249 do_sample=do_sample, 250 num_return_sequences=min(self.batch_size, num_samples), 251 pad_token_id=self.tokenizer.eos_token_id, 252 **kwargs, 253 ) 254 255 gen_strs = self.tokenizer.batch_decode( 256 outputs[:, input_tokens.size(-1) :], 257 skip_special_tokens=self.skip_special_tokens, 258 ) 259 outputs = [] 260 # removes eos tokens. 261 for output in gen_strs: 262 min_index = 10000 263 for eos in self.eos: 264 if eos in output: 265 min_index = min(min_index, output.index(eos)) 266 outputs.append(output[:min_index].replace("\t", " ")) 267 return outputs 268 269 270 class GenenralHfTorchDecoder(HfTorchDecoder): 271 272 def __init__(self, name: str, **kwargs): 273 super().__init__(name=name, **kwargs) 274 self.eos += ["\n```\n", "```"] 275 print(f"EOS strings: {self.eos}") 276 self.tokenizer = AutoTokenizer.from_pretrained(self.name) 277 278 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 279 prompt = make_chat_prompt( 280 prompt, 281 self.instruction_prefix, 282 self.response_prefix, 283 self.tokenizer, 284 chat_mode=self.chat_mode, 285 ) 286 return HfTorchDecoder.codegen(self, prompt, do_sample, num_samples) 287 288 289 def make_model( 290 model: str, 291 dataset: str, 292 backend: str = "vllm", 293 batch_size: int = 1, 294 temperature: float = 0.0, 295 tp=1, 296 instruction_prefix=None, 297 response_prefix=None, 298 no_batching=False, # no_batching -> disable continous batching in vLLM 299 chat_mode=False, 300 ): 301 if backend == "vllm": 302 return GeneralVllmDecoder( 303 name=model, 304 batch_size=batch_size, 305 temperature=temperature, 306 dataset=dataset, 307 tp=tp, 308 instruction_prefix=instruction_prefix, 309 response_prefix=response_prefix, 310 no_batching=no_batching, 311 chat_mode=chat_mode, 312 ) 313 elif backend == "hf": 314 return GenenralHfTorchDecoder( 315 name=model, 316 batch_size=batch_size, 317 temperature=temperature, 318 dataset=dataset, 319 instruction_prefix=instruction_prefix, 320 response_prefix=response_prefix, 321 chat_mode=chat_mode, 322 )