model.py
1 import json 2 import os 3 from abc import ABC, abstractmethod 4 from typing import List 5 from warnings import warn 6 7 import openai 8 9 try: 10 import anthropic 11 12 from bigcodebench.gen.util import anthropic_request 13 except ImportError: 14 warn("Anthropic decoder will not work. Fix by `pip install anthropic`") 15 16 # mistral.ai 17 try: 18 from mistralai.client import MistralClient 19 from mistralai.models.chat_completion import ChatMessage 20 except ImportError: 21 warn("MistralAI decoder will not work. Fix by `pip install mistralai`") 22 try: 23 import google.generativeai as genai 24 except ImportError: 25 warn("GoogleGenAI decoder will not work. Fix by `pip install google-generativeai`") 26 27 import torch 28 from transformers import AutoModelForCausalLM, AutoTokenizer 29 30 try: 31 from vllm import LLM, SamplingParams 32 except ImportError: 33 warn("VLLM decoder will not work. Fix by `pip install vllm`") 34 35 from gen.util import openai_request 36 37 EOS = [ 38 "<|endoftext|>", 39 "<|endofmask|>", 40 "</s>", 41 "\nif __name__", 42 "\ndef main(", 43 "\nprint(", 44 ] 45 46 47 def extra_eos_for_direct_completion(dataset) -> List[str]: 48 if dataset.lower() == "bigcodebench": 49 return ["\ndef ", "\nclass ", "\nimport ", "\nfrom ", "\nassert "] 50 raise ValueError(f"Unknown dataset: {dataset}") 51 52 53 # some random words which serves as the splitter 54 _MAGIC_SPLITTER_ = "-[[]]-this-is-really-our-highest-priority-[[]]-" 55 56 57 def make_chat_prompt(prompt: str, tokenizer: AutoTokenizer) -> str: 58 # directly return prompt if it does not have a tokenizer.chat_template 59 if tokenizer.chat_template is None: 60 return prompt 61 62 prompt = f"""\ 63 Please provide a self-contained Python script that solves the following problem in a markdown code block: 64 ``` 65 {prompt.strip()} 66 ``` 67 """ 68 response = f"""\ 69 Below is a Python script with a self-contained function that solves the problem and passes corresponding tests: 70 ```python 71 {_MAGIC_SPLITTER_} 72 ``` 73 """ 74 prompt = tokenizer.apply_chat_template( 75 [ 76 { 77 "role": "user", 78 "content": prompt 79 }, 80 { 81 "role": "assistant", 82 "content": response 83 }, 84 ], 85 tokenize=False, 86 ).split(_MAGIC_SPLITTER_)[0] 87 return prompt 88 89 90 class DecoderBase(ABC): 91 92 def __init__( 93 self, 94 name: str, 95 batch_size: int = 1, 96 temperature: float = 0.8, 97 max_new_tokens: int = 1280, 98 dtype: str = "bfloat16", # default 99 trust_remote_code: bool = False, 100 tokenizer_name: str = None, 101 tokenizer_legacy: bool = False, 102 ) -> None: 103 print("Initializing a decoder model: {} ...".format(name)) 104 self.name = name 105 self.batch_size = batch_size 106 self.temperature = temperature 107 self.eos = EOS 108 self.skip_special_tokens = False 109 self.max_new_tokens = max_new_tokens 110 self.dtype = dtype 111 self.trust_remote_code = trust_remote_code 112 self.tokenizer_name = tokenizer_name 113 self.tokenizer_legacy = tokenizer_legacy 114 115 @abstractmethod 116 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 117 pass 118 119 @abstractmethod 120 def is_direct_completion(self) -> bool: 121 pass 122 123 def __repr__(self) -> str: 124 return self.name 125 126 def __str__(self) -> str: 127 return self.name 128 129 130 class VllmDecoder(DecoderBase): 131 132 def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None: 133 super().__init__(name, **kwargs) 134 135 kwargs = { 136 "tensor_parallel_size": int(os.getenv("VLLM_N_GPUS", tp)), 137 "dtype": self.dtype, 138 "trust_remote_code": True, 139 "enforce_eager": True, 140 "gpu_memory_utilization": 0.95, 141 "worker_use_ray": True 142 } 143 if self.tokenizer_name is None: 144 self.tokenizer_name = self.name 145 146 self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs, legacy=self.tokenizer_legacy) 147 if self.tokenizer.chat_template is None: 148 self.eos += extra_eos_for_direct_completion(dataset) 149 self.llm = LLM(model=name, max_model_len=2048, **kwargs) 150 self.llm.set_tokenizer(tokenizer=self.tokenizer) 151 152 def is_direct_completion(self) -> bool: 153 return self.tokenizer.chat_template is None 154 155 def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200) -> List[str]: 156 if do_sample: 157 assert self.temperature > 0, "Temperature must be greater than 0!" 158 159 vllm_outputs = self.llm.generate( 160 prompts, 161 SamplingParams( 162 temperature=self.temperature, 163 max_tokens=self.max_new_tokens, 164 top_p=0.95 if do_sample else 1.0, 165 stop=self.eos, 166 ), 167 use_tqdm=True, 168 ) 169 170 gen_strs = [x.outputs[0].text.replace("\t", " ") for x in vllm_outputs] 171 return gen_strs 172 173 174 class GeneralVllmDecoder(VllmDecoder): 175 176 def __init__(self, name: str, **kwargs) -> None: 177 super().__init__(name, **kwargs) 178 self.eos += ["\n```\n"] 179 print(f"EOS strings: {self.eos}") 180 181 def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200) -> List[str]: 182 chat_prompts = [make_chat_prompt(prompt, self.tokenizer) for prompt in prompts] 183 return VllmDecoder.codegen(self, chat_prompts, do_sample, num_samples) 184 185 186 class HfTorchDecoder(DecoderBase): 187 188 def __init__(self, name: str, dataset: str, **kwargs): 189 super().__init__(name=name, **kwargs) 190 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 191 192 kwargs = {} 193 kwargs["device_map"] = "auto" 194 kwargs["trust_remote_code"] = self.trust_remote_code 195 # string to torch dtype 196 kwargs["torch_dtype"] = getattr(torch, self.dtype) 197 self.skip_special_tokens = True 198 199 print(f"{kwargs = }", self.tokenizer_name) 200 if self.tokenizer_name is None: 201 self.tokenizer_name = self.name 202 203 self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs, legacy=self.tokenizer_legacy) 204 205 if self.tokenizer.chat_template is None: 206 self.eos += extra_eos_for_direct_completion(dataset) 207 208 self.model = AutoModelForCausalLM.from_pretrained(name, **kwargs) 209 self.model = self.model.to(self.device) 210 211 def is_direct_completion(self) -> bool: 212 return self.tokenizer.chat_template is None 213 214 @torch.inference_mode() 215 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 216 if self.temperature == 0: 217 assert not do_sample 218 assert num_samples == 1 219 220 input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) 221 kwargs = {} 222 if do_sample: 223 kwargs["top_p"] = 0.95 224 kwargs["temperature"] = self.temperature 225 226 outputs = self.model.generate( 227 input_tokens, 228 max_new_tokens=self.max_new_tokens, 229 do_sample=do_sample, 230 num_return_sequences=min(self.batch_size, num_samples), 231 pad_token_id=self.tokenizer.eos_token_id, 232 stop_strings=self.eos, 233 tokenizer=self.tokenizer, 234 **kwargs, 235 ) 236 237 gen_strs = self.tokenizer.batch_decode( 238 outputs[:, input_tokens.size(-1):], 239 skip_special_tokens=self.skip_special_tokens, 240 ) 241 outputs = [] 242 # removes eos tokens. 243 for output in gen_strs: 244 min_index = 10000 245 for eos in self.eos: 246 if eos in output: 247 min_index = min(min_index, output.index(eos)) 248 outputs.append(output[:min_index].replace("\t", " ")) 249 return outputs 250 251 252 class GenenralHfTorchDecoder(HfTorchDecoder): 253 254 def __init__(self, name: str, **kwargs): 255 super().__init__(name=name, **kwargs) 256 self.eos += ["\n```\n"] 257 print(f"EOS strings: {self.eos}") 258 self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name if self.tokenizer_name else self.name, **kwargs, legacy=self.tokenizer_legacy) 259 260 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 261 prompt = make_chat_prompt(prompt, self.tokenizer) 262 return HfTorchDecoder.codegen(self, prompt, do_sample, num_samples) 263 264 265 class OpenAIChatDecoder(DecoderBase): 266 267 def __init__(self, name: str, base_url=None, **kwargs) -> None: 268 super().__init__(name, **kwargs) 269 self.client = openai.OpenAI(base_url=base_url) 270 271 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 272 if do_sample: 273 assert self.temperature > 0, "Temperature must be positive for sampling" 274 batch_size = min(self.batch_size, num_samples) 275 276 # construct prompt 277 fmt = "json_object" if self.name == "gpt-4-1106-preview" else "text" 278 if fmt == "json_object": 279 message = r'Please complete the following code snippet by generating JSON like {"code": ""}' 280 else: 281 message = r"Please generate self-contained code to complete the following problem:" 282 283 message += f"\n```python\n{prompt.strip()}\n```" 284 285 ret = openai_request.make_auto_request( 286 self.client, 287 message=message, 288 model=self.name, 289 max_tokens=self.max_new_tokens, 290 temperature=self.temperature, 291 n=batch_size, 292 response_format={"type": fmt}, 293 ) 294 295 outputs = [] 296 for item in ret.choices: 297 content = item.message.content 298 # if json serializable 299 if fmt == "json_object": 300 try: 301 json_data = json.loads(content) 302 if json_data.get("code", None) is not None: 303 outputs.append(prompt + "\n" + json_data["code"]) 304 continue 305 306 print(f"'code' field not found in: {json_data}") 307 except Exception as e: 308 print(e) 309 outputs.append(content) 310 311 return outputs 312 313 def is_direct_completion(self) -> bool: 314 return False 315 316 317 class MistralChatDecoder(DecoderBase): 318 319 def __init__(self, name: str, **kwargs) -> None: 320 super().__init__(name, **kwargs) 321 self.client = MistralClient(api_key=os.getenv("MISTRAL_API_KEY")) 322 323 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 324 kwargs = {} 325 if do_sample: 326 assert self.temperature > 0, "Temperature must be positive for sampling" 327 kwargs["top_p"] = 0.95 328 kwargs["temperature"] = self.temperature 329 else: 330 self.temperature = 0 331 332 batch_size = min(self.batch_size, num_samples) 333 334 outputs = [] 335 for _ in range(batch_size): 336 ret = self.client.chat( 337 model=self.name, 338 messages=[ChatMessage( 339 role="user", 340 content="Please generate self-contained code to solve the following problem in a Python markdown block:" + f"\n```python\n{prompt.strip()}\n```", 341 )], 342 max_tokens=self.max_new_tokens, 343 **kwargs, 344 ) 345 346 outputs.append(ret.choices[0].message.content) 347 348 return outputs 349 350 def is_direct_completion(self) -> bool: 351 return False 352 353 354 class AnthropicDecoder(DecoderBase, ABC): 355 def __init__(self, name: str, **kwargs) -> None: 356 super().__init__(name, **kwargs) 357 self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_KEY")) 358 359 def is_direct_completion(self) -> bool: 360 return False 361 362 363 class AnthropicMessageDecoder(AnthropicDecoder): 364 365 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 366 kwargs = {} 367 if do_sample: 368 assert self.temperature > 0, "Temperature must be positive for sampling" 369 kwargs["top_p"] = 0.95 370 kwargs["temperature"] = self.temperature 371 else: 372 self.temperature = 0 373 374 batch_size = min(self.batch_size, num_samples) 375 if not do_sample: 376 assert batch_size == 1, "Sampling only supports batch size of 1" 377 378 outputs = [] 379 for _ in range(batch_size): 380 message = anthropic_request.make_auto_request( 381 client=self.client, 382 model=self.name, 383 messages=[{ 384 "role": "user", 385 "content": "Please generate self-contained code to complete the following problem wrapped in a Python markdown block:" + f"\n```python\n{prompt.strip()}\n```\n", 386 }], 387 max_tokens=self.max_new_tokens, 388 stop_sequences=["\n```\n", "\nif "], 389 **kwargs, 390 ) 391 outputs.append(message.content[0].text) 392 393 return outputs 394 395 396 class GoogleGenAIDecoder(DecoderBase, ABC): 397 398 def __init__(self, name: str, **kwargs) -> None: 399 super().__init__(name, **kwargs) 400 genai.configure(api_key=os.environ['GOOGLE_API_KEY']) 401 402 def is_direct_completion(self) -> bool: 403 return False 404 405 406 class GeminiDecoder(GoogleGenAIDecoder): 407 408 def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]: 409 kwargs = {} 410 if do_sample: 411 assert self.temperature > 0, "Temperature must be positive for sampling" 412 kwargs["top_p"] = 0.95 413 kwargs["temperature"] = self.temperature 414 else: 415 self.temperature = 0 416 417 batch_size = min(self.batch_size, num_samples) 418 if not do_sample: 419 assert batch_size == 1, "Sampling only supports batch size of 1" 420 421 genai_config = genai.GenerationConfig( 422 max_output_tokens=self.max_new_tokens, 423 **kwargs, 424 ) 425 426 safety_settings = [ 427 { 428 "category": "HARM_CATEGORY_DANGEROUS", 429 "threshold": "BLOCK_NONE", 430 }, 431 { 432 "category": "HARM_CATEGORY_HARASSMENT", 433 "threshold": "BLOCK_NONE", 434 }, 435 { 436 "category": "HARM_CATEGORY_HATE_SPEECH", 437 "threshold": "BLOCK_NONE", 438 }, 439 { 440 "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", 441 "threshold": "BLOCK_NONE", 442 }, 443 { 444 "category": "HARM_CATEGORY_DANGEROUS_CONTENT", 445 "threshold": "BLOCK_NONE", 446 }, 447 ] 448 449 model = genai.GenerativeModel(model_name=self.name, generation_config=genai_config, safety_settings=safety_settings) 450 451 outputs = [] 452 for _ in range(batch_size): 453 while True: 454 try: 455 response = model.generate_content("Please generate self-contained code to complete the following problem wrapped in a Python markdown block:" + f"\n```python\n{prompt.strip()}\n```", generation_config=genai_config) 456 output = response.candidates[0].content.parts[0].text 457 outputs.append(output) 458 break 459 except Exception as e: 460 if "list index out of range" in str(e): 461 # append dummy response 462 outputs.append("NO_RESPONSE") 463 break 464 else: 465 print(e) 466 continue 467 468 return outputs 469 470 471 def make_model( 472 model: str, 473 backend: str, 474 dataset: str = "bigcodebench", 475 batch_size: int = 1, 476 temperature: float = 0.0, 477 tp=1, 478 base_url=None, 479 trust_remote_code=False, 480 tokenizer_name=None, 481 tokenizer_legacy=True, 482 ): 483 if backend == "vllm": 484 return GeneralVllmDecoder( 485 name=model, 486 batch_size=batch_size, 487 temperature=temperature, 488 dataset=dataset, 489 tp=tp, 490 trust_remote_code=trust_remote_code, 491 tokenizer_name=tokenizer_name, 492 tokenizer_legacy=tokenizer_legacy, 493 ) 494 elif backend == "hf": 495 return GenenralHfTorchDecoder( 496 name=model, 497 batch_size=batch_size, 498 temperature=temperature, 499 dataset=dataset, 500 trust_remote_code=trust_remote_code, 501 tokenizer_name=tokenizer_name, 502 tokenizer_legacy=tokenizer_legacy, 503 ) 504 elif backend == "openai": 505 return OpenAIChatDecoder( 506 name=model, 507 batch_size=batch_size, 508 temperature=temperature, 509 base_url=base_url, 510 ) 511 elif backend == "mistral": 512 return MistralChatDecoder( 513 name=model, 514 batch_size=batch_size, 515 temperature=temperature, 516 ) 517 elif backend == "anthropic": 518 return AnthropicMessageDecoder( 519 name=model, 520 batch_size=batch_size, 521 temperature=temperature, 522 ) 523 elif backend == "google": 524 return GeminiDecoder( 525 name=model, 526 batch_size=batch_size, 527 temperature=temperature, 528 )