model.py
  1  import json
  2  import os
  3  from abc import ABC, abstractmethod
  4  from typing import List
  5  from warnings import warn
  6  
  7  import openai
  8  
  9  try:
 10      import anthropic
 11  
 12      from bigcodebench.gen.util import anthropic_request
 13  except ImportError:
 14      warn("Anthropic decoder will not work. Fix by `pip install anthropic`")
 15  
 16  # mistral.ai
 17  try:
 18      from mistralai.client import MistralClient
 19      from mistralai.models.chat_completion import ChatMessage
 20  except ImportError:
 21      warn("MistralAI decoder will not work. Fix by `pip install mistralai`")
 22  try:
 23      import google.generativeai as genai
 24  except ImportError:
 25      warn("GoogleGenAI decoder will not work. Fix by `pip install google-generativeai`")
 26  
 27  import torch
 28  from transformers import AutoModelForCausalLM, AutoTokenizer
 29  
 30  try:
 31      from vllm import LLM, SamplingParams
 32  except ImportError:
 33      warn("VLLM decoder will not work. Fix by `pip install vllm`")
 34  
 35  from gen.util import openai_request
 36  
 37  EOS = [
 38      "<|endoftext|>",
 39      "<|endofmask|>",
 40      "</s>",
 41      "\nif __name__",
 42      "\ndef main(",
 43      "\nprint(",
 44  ]
 45  
 46  
 47  def extra_eos_for_direct_completion(dataset) -> List[str]:
 48      if dataset.lower() == "bigcodebench":
 49          return ["\ndef ", "\nclass ", "\nimport ", "\nfrom ", "\nassert "]
 50      raise ValueError(f"Unknown dataset: {dataset}")
 51  
 52  
 53  # some random words which serves as the splitter
 54  _MAGIC_SPLITTER_ = "-[[]]-this-is-really-our-highest-priority-[[]]-"
 55  
 56  
 57  def make_chat_prompt(prompt: str, tokenizer: AutoTokenizer) -> str:
 58      # directly return prompt if it does not have a tokenizer.chat_template
 59      if tokenizer.chat_template is None:
 60          return prompt
 61  
 62      prompt = f"""\
 63  Please provide a self-contained Python script that solves the following problem in a markdown code block:
 64  ```
 65  {prompt.strip()}
 66  ```
 67  """
 68      response = f"""\
 69  Below is a Python script with a self-contained function that solves the problem and passes corresponding tests:
 70  ```python
 71  {_MAGIC_SPLITTER_}
 72  ```
 73  """
 74      prompt = tokenizer.apply_chat_template(
 75          [
 76              {
 77                  "role": "user",
 78                  "content": prompt
 79              },
 80              {
 81                  "role": "assistant",
 82                  "content": response
 83              },
 84          ],
 85          tokenize=False,
 86      ).split(_MAGIC_SPLITTER_)[0]
 87      return prompt
 88  
 89  
 90  class DecoderBase(ABC):
 91  
 92      def __init__(
 93          self,
 94          name: str,
 95          batch_size: int = 1,
 96          temperature: float = 0.8,
 97          max_new_tokens: int = 1280,
 98          dtype: str = "bfloat16",  # default
 99          trust_remote_code: bool = False,
100          tokenizer_name: str = None,
101          tokenizer_legacy: bool = False,
102      ) -> None:
103          print("Initializing a decoder model: {} ...".format(name))
104          self.name = name
105          self.batch_size = batch_size
106          self.temperature = temperature
107          self.eos = EOS
108          self.skip_special_tokens = False
109          self.max_new_tokens = max_new_tokens
110          self.dtype = dtype
111          self.trust_remote_code = trust_remote_code
112          self.tokenizer_name = tokenizer_name
113          self.tokenizer_legacy = tokenizer_legacy
114  
115      @abstractmethod
116      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
117          pass
118  
119      @abstractmethod
120      def is_direct_completion(self) -> bool:
121          pass
122  
123      def __repr__(self) -> str:
124          return self.name
125  
126      def __str__(self) -> str:
127          return self.name
128  
129  
130  class VllmDecoder(DecoderBase):
131  
132      def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
133          super().__init__(name, **kwargs)
134  
135          kwargs = {
136              "tensor_parallel_size": int(os.getenv("VLLM_N_GPUS", tp)),
137              "dtype": self.dtype,
138              "trust_remote_code": True,
139              "enforce_eager": True,
140              "gpu_memory_utilization": 0.95,
141              "worker_use_ray": True
142          }
143          if self.tokenizer_name is None:
144              self.tokenizer_name = self.name
145  
146          self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs, legacy=self.tokenizer_legacy)
147          if self.tokenizer.chat_template is None:
148              self.eos += extra_eos_for_direct_completion(dataset)
149          self.llm = LLM(model=name, max_model_len=2048, **kwargs)
150          self.llm.set_tokenizer(tokenizer=self.tokenizer)
151  
152      def is_direct_completion(self) -> bool:
153          return self.tokenizer.chat_template is None
154  
155      def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200) -> List[str]:
156          if do_sample:
157              assert self.temperature > 0, "Temperature must be greater than 0!"
158  
159          vllm_outputs = self.llm.generate(
160              prompts,
161              SamplingParams(
162                  temperature=self.temperature,
163                  max_tokens=self.max_new_tokens,
164                  top_p=0.95 if do_sample else 1.0,
165                  stop=self.eos,
166              ),
167              use_tqdm=True,
168          )
169  
170          gen_strs = [x.outputs[0].text.replace("\t", "    ") for x in vllm_outputs]
171          return gen_strs
172  
173  
174  class GeneralVllmDecoder(VllmDecoder):
175  
176      def __init__(self, name: str, **kwargs) -> None:
177          super().__init__(name, **kwargs)
178          self.eos += ["\n```\n"]
179          print(f"EOS strings: {self.eos}")
180  
181      def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200) -> List[str]:
182          chat_prompts = [make_chat_prompt(prompt, self.tokenizer) for prompt in prompts]
183          return VllmDecoder.codegen(self, chat_prompts, do_sample, num_samples)
184  
185  
186  class HfTorchDecoder(DecoderBase):
187  
188      def __init__(self, name: str, dataset: str, **kwargs):
189          super().__init__(name=name, **kwargs)
190          self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191  
192          kwargs = {}
193          kwargs["device_map"] = "auto"
194          kwargs["trust_remote_code"] = self.trust_remote_code
195          # string to torch dtype
196          kwargs["torch_dtype"] = getattr(torch, self.dtype)
197          self.skip_special_tokens = True
198  
199          print(f"{kwargs = }", self.tokenizer_name)
200          if self.tokenizer_name is None:
201              self.tokenizer_name = self.name
202  
203          self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs, legacy=self.tokenizer_legacy)
204  
205          if self.tokenizer.chat_template is None:
206              self.eos += extra_eos_for_direct_completion(dataset)
207  
208          self.model = AutoModelForCausalLM.from_pretrained(name, **kwargs)
209          self.model = self.model.to(self.device)
210  
211      def is_direct_completion(self) -> bool:
212          return self.tokenizer.chat_template is None
213  
214      @torch.inference_mode()
215      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
216          if self.temperature == 0:
217              assert not do_sample
218              assert num_samples == 1
219  
220          input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
221          kwargs = {}
222          if do_sample:
223              kwargs["top_p"] = 0.95
224              kwargs["temperature"] = self.temperature
225  
226          outputs = self.model.generate(
227              input_tokens,
228              max_new_tokens=self.max_new_tokens,
229              do_sample=do_sample,
230              num_return_sequences=min(self.batch_size, num_samples),
231              pad_token_id=self.tokenizer.eos_token_id,
232              stop_strings=self.eos,
233              tokenizer=self.tokenizer,
234              **kwargs,
235          )
236  
237          gen_strs = self.tokenizer.batch_decode(
238              outputs[:, input_tokens.size(-1):],
239              skip_special_tokens=self.skip_special_tokens,
240          )
241          outputs = []
242          # removes eos tokens.
243          for output in gen_strs:
244              min_index = 10000
245              for eos in self.eos:
246                  if eos in output:
247                      min_index = min(min_index, output.index(eos))
248              outputs.append(output[:min_index].replace("\t", "    "))
249          return outputs
250  
251  
252  class GenenralHfTorchDecoder(HfTorchDecoder):
253  
254      def __init__(self, name: str, **kwargs):
255          super().__init__(name=name, **kwargs)
256          self.eos += ["\n```\n"]
257          print(f"EOS strings: {self.eos}")
258          self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name if self.tokenizer_name else self.name, **kwargs, legacy=self.tokenizer_legacy)
259  
260      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
261          prompt = make_chat_prompt(prompt, self.tokenizer)
262          return HfTorchDecoder.codegen(self, prompt, do_sample, num_samples)
263  
264  
265  class OpenAIChatDecoder(DecoderBase):
266  
267      def __init__(self, name: str, base_url=None, **kwargs) -> None:
268          super().__init__(name, **kwargs)
269          self.client = openai.OpenAI(base_url=base_url)
270  
271      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
272          if do_sample:
273              assert self.temperature > 0, "Temperature must be positive for sampling"
274          batch_size = min(self.batch_size, num_samples)
275  
276          # construct prompt
277          fmt = "json_object" if self.name == "gpt-4-1106-preview" else "text"
278          if fmt == "json_object":
279              message = r'Please complete the following code snippet by generating JSON like {"code": ""}'
280          else:
281              message = r"Please generate self-contained code to complete the following problem:"
282  
283          message += f"\n```python\n{prompt.strip()}\n```"
284  
285          ret = openai_request.make_auto_request(
286              self.client,
287              message=message,
288              model=self.name,
289              max_tokens=self.max_new_tokens,
290              temperature=self.temperature,
291              n=batch_size,
292              response_format={"type": fmt},
293          )
294  
295          outputs = []
296          for item in ret.choices:
297              content = item.message.content
298              # if json serializable
299              if fmt == "json_object":
300                  try:
301                      json_data = json.loads(content)
302                      if json_data.get("code", None) is not None:
303                          outputs.append(prompt + "\n" + json_data["code"])
304                          continue
305  
306                      print(f"'code' field not found in: {json_data}")
307                  except Exception as e:
308                      print(e)
309              outputs.append(content)
310  
311          return outputs
312  
313      def is_direct_completion(self) -> bool:
314          return False
315  
316  
317  class MistralChatDecoder(DecoderBase):
318  
319      def __init__(self, name: str, **kwargs) -> None:
320          super().__init__(name, **kwargs)
321          self.client = MistralClient(api_key=os.getenv("MISTRAL_API_KEY"))
322  
323      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
324          kwargs = {}
325          if do_sample:
326              assert self.temperature > 0, "Temperature must be positive for sampling"
327              kwargs["top_p"] = 0.95
328              kwargs["temperature"] = self.temperature
329          else:
330              self.temperature = 0
331  
332          batch_size = min(self.batch_size, num_samples)
333  
334          outputs = []
335          for _ in range(batch_size):
336              ret = self.client.chat(
337                  model=self.name,
338                  messages=[ChatMessage(
339                      role="user",
340                      content="Please generate self-contained code to solve the following problem in a Python markdown block:" + f"\n```python\n{prompt.strip()}\n```",
341                  )],
342                  max_tokens=self.max_new_tokens,
343                  **kwargs,
344              )
345  
346              outputs.append(ret.choices[0].message.content)
347  
348          return outputs
349  
350      def is_direct_completion(self) -> bool:
351          return False
352  
353  
354  class AnthropicDecoder(DecoderBase, ABC):
355      def __init__(self, name: str, **kwargs) -> None:
356          super().__init__(name, **kwargs)
357          self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_KEY"))
358  
359      def is_direct_completion(self) -> bool:
360          return False
361  
362  
363  class AnthropicMessageDecoder(AnthropicDecoder):
364  
365      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
366          kwargs = {}
367          if do_sample:
368              assert self.temperature > 0, "Temperature must be positive for sampling"
369              kwargs["top_p"] = 0.95
370              kwargs["temperature"] = self.temperature
371          else:
372              self.temperature = 0
373  
374          batch_size = min(self.batch_size, num_samples)
375          if not do_sample:
376              assert batch_size == 1, "Sampling only supports batch size of 1"
377  
378          outputs = []
379          for _ in range(batch_size):
380              message = anthropic_request.make_auto_request(
381                  client=self.client,
382                  model=self.name,
383                  messages=[{
384                      "role": "user",
385                      "content": "Please generate self-contained code to complete the following problem wrapped in a Python markdown block:" + f"\n```python\n{prompt.strip()}\n```\n",
386                  }],
387                  max_tokens=self.max_new_tokens,
388                  stop_sequences=["\n```\n", "\nif "],
389                  **kwargs,
390              )
391              outputs.append(message.content[0].text)
392  
393          return outputs
394  
395  
396  class GoogleGenAIDecoder(DecoderBase, ABC):
397  
398      def __init__(self, name: str, **kwargs) -> None:
399          super().__init__(name, **kwargs)
400          genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
401  
402      def is_direct_completion(self) -> bool:
403          return False
404  
405  
406  class GeminiDecoder(GoogleGenAIDecoder):
407  
408      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
409          kwargs = {}
410          if do_sample:
411              assert self.temperature > 0, "Temperature must be positive for sampling"
412              kwargs["top_p"] = 0.95
413              kwargs["temperature"] = self.temperature
414          else:
415              self.temperature = 0
416  
417          batch_size = min(self.batch_size, num_samples)
418          if not do_sample:
419              assert batch_size == 1, "Sampling only supports batch size of 1"
420  
421          genai_config = genai.GenerationConfig(
422              max_output_tokens=self.max_new_tokens,
423              **kwargs,
424          )
425  
426          safety_settings = [
427              {
428                  "category": "HARM_CATEGORY_DANGEROUS",
429                  "threshold": "BLOCK_NONE",
430              },
431              {
432                  "category": "HARM_CATEGORY_HARASSMENT",
433                  "threshold": "BLOCK_NONE",
434              },
435              {
436                  "category": "HARM_CATEGORY_HATE_SPEECH",
437                  "threshold": "BLOCK_NONE",
438              },
439              {
440                  "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
441                  "threshold": "BLOCK_NONE",
442              },
443              {
444                  "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
445                  "threshold": "BLOCK_NONE",
446              },
447          ]
448  
449          model = genai.GenerativeModel(model_name=self.name, generation_config=genai_config, safety_settings=safety_settings)
450  
451          outputs = []
452          for _ in range(batch_size):
453              while True:
454                  try:
455                      response = model.generate_content("Please generate self-contained code to complete the following problem wrapped in a Python markdown block:" + f"\n```python\n{prompt.strip()}\n```", generation_config=genai_config)
456                      output = response.candidates[0].content.parts[0].text
457                      outputs.append(output)
458                      break
459                  except Exception as e:
460                      if "list index out of range" in str(e):
461                          # append dummy response
462                          outputs.append("NO_RESPONSE")
463                          break
464                      else:
465                          print(e)
466                          continue
467  
468          return outputs
469  
470  
471  def make_model(
472      model: str,
473      backend: str,
474      dataset: str = "bigcodebench",
475      batch_size: int = 1,
476      temperature: float = 0.0,
477      tp=1,
478      base_url=None,
479      trust_remote_code=False,
480      tokenizer_name=None,
481      tokenizer_legacy=True,
482  ):
483      if backend == "vllm":
484          return GeneralVllmDecoder(
485              name=model,
486              batch_size=batch_size,
487              temperature=temperature,
488              dataset=dataset,
489              tp=tp,
490              trust_remote_code=trust_remote_code,
491              tokenizer_name=tokenizer_name,
492              tokenizer_legacy=tokenizer_legacy,
493          )
494      elif backend == "hf":
495          return GenenralHfTorchDecoder(
496              name=model,
497              batch_size=batch_size,
498              temperature=temperature,
499              dataset=dataset,
500              trust_remote_code=trust_remote_code,
501              tokenizer_name=tokenizer_name,
502              tokenizer_legacy=tokenizer_legacy,
503          )
504      elif backend == "openai":
505          return OpenAIChatDecoder(
506              name=model,
507              batch_size=batch_size,
508              temperature=temperature,
509              base_url=base_url,
510          )
511      elif backend == "mistral":
512          return MistralChatDecoder(
513              name=model,
514              batch_size=batch_size,
515              temperature=temperature,
516          )
517      elif backend == "anthropic":
518          return AnthropicMessageDecoder(
519              name=model,
520              batch_size=batch_size,
521              temperature=temperature,
522          )
523      elif backend == "google":
524          return GeminiDecoder(
525              name=model,
526              batch_size=batch_size,
527              temperature=temperature,
528          )