model.py
  1  import json
  2  import os
  3  from abc import ABC, abstractmethod
  4  from typing import List
  5  from warnings import warn
  6  import torch
  7  from transformers import AutoModelForCausalLM, AutoTokenizer
  8  from vllm import LLM, SamplingParams
  9  
 10  EOS = [
 11      "<|endoftext|>",
 12      "<|endofmask|>",
 13      "</s>",
 14      "\nif __name__",
 15      "\ndef main(",
 16      "\nprint(",
 17  ]
 18  
 19  
 20  def extra_eos_for_direct_completion(dataset) -> List[str]:
 21      if dataset.lower() == "bigcodebench":
 22          return ["\ndef ", "\nclass ", "\nimport ", "\nfrom ", "\nassert "]
 23      raise ValueError(f"Unknown dataset: {dataset}")
 24  
 25  
 26  # some random words which serves as the splitter
 27  _MAGIC_SPLITTER_ = "-[[]]-this-is-really-our-highest-priority-[[]]-"
 28  
 29  
 30  def make_chat_prompt(
 31      prompt: str,
 32      tokenizer: AutoTokenizer,
 33      chat_mode,
 34  ) -> str:
 35      if not chat_mode:  # complete tasks
 36          return prompt
 37  
 38      prompt = f"""\
 39  Please provide a self-contained Python script that solves the following problem in a markdown code block:
 40  ```
 41  {prompt.strip()}
 42  ```
 43  """
 44      response = f"""\
 45  Below is a Python script with a self-contained function that solves the problem and passes corresponding tests:
 46  ```python
 47  {_MAGIC_SPLITTER_}
 48  ```
 49  """
 50      prompt = tokenizer.apply_chat_template(
 51          [
 52              {"role": "user", "content": prompt},
 53              {"role": "assistant", "content": response},
 54          ],
 55          tokenize=False,
 56      ).split(
 57          _MAGIC_SPLITTER_
 58      )[0]
 59      return prompt
 60  
 61  
 62  class DecoderBase(ABC):
 63  
 64      def __init__(
 65          self,
 66          name: str,
 67          batch_size: int = 1,
 68          temperature: float = 0.8,
 69          max_new_tokens: int = 1280,
 70          dtype: str = "bfloat16",  # default
 71          trust_remote_code: bool = True,
 72          tokenizer_name: str = None,
 73          tokenizer_legacy: bool = False,
 74          chat_mode: bool = False,
 75      ) -> None:
 76          print("Initializing a decoder model: {} ...".format(name))
 77          self.name = name
 78          self.batch_size = batch_size
 79          self.temperature = temperature
 80          self.eos = EOS
 81          self.skip_special_tokens = False
 82          self.max_new_tokens = max_new_tokens
 83          self.dtype = dtype
 84          self.trust_remote_code = trust_remote_code
 85          self.tokenizer_name = tokenizer_name
 86          self.tokenizer_legacy = tokenizer_legacy
 87          self.chat_mode = chat_mode
 88  
 89      @abstractmethod
 90      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
 91          pass
 92  
 93      @abstractmethod
 94      def is_direct_completion(self) -> bool:
 95          pass
 96  
 97      def __repr__(self) -> str:
 98          return self.name
 99  
100      def __str__(self) -> str:
101          return self.name
102  
103  
104  class VllmDecoder(DecoderBase):
105  
106      def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
107          super().__init__(name, **kwargs)
108  
109          kwargs = {
110              "tensor_parallel_size": int(os.getenv("VLLM_N_GPUS", tp)),
111              "dtype": self.dtype,
112              "trust_remote_code": self.trust_remote_code,
113              "gpu_memory_utilization": 0.95,
114              "enforce_eager": True,
115              "distributed_executor_backend": "ray",
116          }
117          if self.tokenizer_name is None:
118              self.tokenizer_name = self.name
119  
120          self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs, legacy=self.tokenizer_legacy)
121          if not self.chat_mode:
122              self.eos += extra_eos_for_direct_completion(dataset)
123          self.llm = LLM(model=name, max_model_len=2048, **kwargs)
124          self.llm.set_tokenizer(tokenizer=self.tokenizer)
125  
126      def is_direct_completion(self) -> bool:
127          return not self.chat_mode
128  
129      def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200) -> List[str]:
130          if do_sample:
131              assert self.temperature > 0, "Temperature must be greater than 0!"
132  
133          vllm_outputs = self.llm.generate(
134              prompts,
135              SamplingParams(
136                  temperature=self.temperature,
137                  max_tokens=self.max_new_tokens,
138                  top_p=0.95 if do_sample else 1.0,
139                  stop=self.eos,
140              ),
141              use_tqdm=True,
142          )
143  
144          gen_strs = [x.outputs[0].text.replace("\t", "    ") for x in vllm_outputs]
145          return gen_strs
146  
147  
148  class GeneralVllmDecoder(VllmDecoder):
149  
150      def __init__(self, name: str, **kwargs) -> None:
151          super().__init__(name, **kwargs)
152          self.eos += ["\n```\n", "```"]
153          print(f"EOS strings: {self.eos}")
154  
155      def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200) -> List[str]:
156          chat_prompts = [
157              make_chat_prompt(
158                  prompt,
159                  self.tokenizer,
160                  self.chat_mode,
161              )
162              for prompt in prompts
163          ]
164          return VllmDecoder.codegen(self, chat_prompts, do_sample, num_samples)
165  
166  
167  class HfTorchDecoder(DecoderBase):
168  
169      def __init__(self, name: str, dataset: str, **kwargs):
170          super().__init__(name=name, **kwargs)
171          self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172  
173          kwargs = {}
174          kwargs["device_map"] = "auto"
175          kwargs["trust_remote_code"] = self.trust_remote_code
176          # string to torch dtype
177          kwargs["torch_dtype"] = getattr(torch, self.dtype)
178          self.skip_special_tokens = True
179  
180          print(f"{kwargs = }", self.tokenizer_name)
181          if self.tokenizer_name is None:
182              self.tokenizer_name = self.name
183  
184          self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs, legacy=self.tokenizer_legacy)
185  
186          if not self.chat_mode:
187              self.eos += extra_eos_for_direct_completion(dataset)
188  
189          self.model = AutoModelForCausalLM.from_pretrained(name, **kwargs)
190          self.model = self.model.to(self.device)
191  
192      def is_direct_completion(self) -> bool:
193          return not self.chat_mode
194  
195      @torch.inference_mode()
196      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
197          if self.temperature == 0:
198              assert not do_sample
199              assert num_samples == 1
200  
201          input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
202          kwargs = {}
203          if do_sample:
204              kwargs["top_p"] = 0.95
205              kwargs["temperature"] = self.temperature
206  
207          outputs = self.model.generate(
208              input_tokens,
209              max_new_tokens=self.max_new_tokens,
210              do_sample=do_sample,
211              num_return_sequences=min(self.batch_size, num_samples),
212              pad_token_id=self.tokenizer.eos_token_id,
213              stop_strings=self.eos,
214              tokenizer=self.tokenizer,
215              **kwargs,
216          )
217  
218          gen_strs = self.tokenizer.batch_decode(
219              outputs[:, input_tokens.size(-1) :],
220              skip_special_tokens=self.skip_special_tokens,
221          )
222          outputs = []
223          # removes eos tokens.
224          for output in gen_strs:
225              min_index = 10000
226              for eos in self.eos:
227                  if eos in output:
228                      min_index = min(min_index, output.index(eos))
229              outputs.append(output[:min_index].replace("\t", "    "))
230          return outputs
231  
232  
233  class GenenralHfTorchDecoder(HfTorchDecoder):
234  
235      def __init__(self, name: str, **kwargs):
236          super().__init__(name=name, **kwargs)
237          self.eos += ["\n```\n", "```"]
238          print(f"EOS strings: {self.eos}")
239          self.tokenizer = AutoTokenizer.from_pretrained(
240              self.tokenizer_name if self.tokenizer_name else self.name, **kwargs, legacy=self.tokenizer_legacy
241          )
242  
243      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
244          prompt = make_chat_prompt(prompt, self.tokenizer, self.chat_mode)
245          return HfTorchDecoder.codegen(self, prompt, do_sample, num_samples)
246  
247  
248  def make_model(
249      model: str,
250      backend: str,
251      dataset: str = "bigcodebench",
252      batch_size: int = 1,
253      temperature: float = 0.0,
254      tp=1,
255      base_url=None,
256      trust_remote_code=True,
257      tokenizer_name=None,
258      tokenizer_legacy=True,
259      chat_mode=False,
260  ):
261      print(f"{chat_mode = }")
262      if backend == "vllm":
263          return GeneralVllmDecoder(
264              name=model,
265              batch_size=batch_size,
266              temperature=temperature,
267              dataset=dataset,
268              tp=tp,
269              trust_remote_code=trust_remote_code,
270              tokenizer_name=tokenizer_name,
271              tokenizer_legacy=tokenizer_legacy,
272              chat_mode=chat_mode,
273          )
274      elif backend == "hf":
275          return GenenralHfTorchDecoder(
276              name=model,
277              batch_size=batch_size,
278              temperature=temperature,
279              dataset=dataset,
280              trust_remote_code=trust_remote_code,
281              tokenizer_name=tokenizer_name,
282              tokenizer_legacy=tokenizer_legacy,
283              chat_mode=chat_mode,
284          )