model.py
  1  import json
  2  import os
  3  from abc import ABC, abstractmethod
  4  from typing import List
  5  from warnings import warn
  6  
  7  import torch
  8  from stop_sequencer import StopSequencer
  9  from tqdm import tqdm
 10  from transformers import AutoModelForCausalLM, AutoTokenizer
 11  from vllm import LLM, SamplingParams
 12  
 13  EOS = [
 14      "<|endoftext|>",
 15      "<|endofmask|>",
 16      "</s>",
 17      "\nif __name__",
 18      "\ndef main(",
 19      "\nprint(",
 20  ]
 21  
 22  
 23  def extra_eos_for_direct_completion(dataset) -> List[str]:
 24      print(f"eos: {dataset = }")
 25      if dataset.lower() == "humaneval":
 26          return ["\ndef ", "\nclass ", "\nimport ", "\nfrom ", "\nassert "]
 27      elif dataset.lower() == "mbpp":
 28          return ['\n"""', "\nassert"]
 29      raise ValueError(f"Unknown dataset: {dataset}")
 30  
 31  
 32  # some random words which serves as the splitter
 33  _MAGIC_SPLITTER_ = "-[[]]-this-is-really-our-highest-priority-[[]]-"
 34  
 35  
 36  def make_chat_prompt(
 37      task_prompt: str,
 38      instruction_prefix: str,
 39      response_prefix: str,
 40      tokenizer: AutoTokenizer,
 41      chat_mode=False,
 42  ) -> str:
 43  
 44      # if tokenizer.chat_template is None:
 45      if not chat_mode:
 46          return task_prompt
 47  
 48      assert instruction_prefix is not None, "Instruction prefix is required!"
 49      assert response_prefix is not None, "Response prefix is required!"
 50  
 51      task_prompt = f"""\
 52  {instruction_prefix}
 53  ```
 54  {task_prompt.strip()}
 55  ```
 56  """
 57      response = f"""\
 58  {response_prefix}
 59  ```python
 60  {_MAGIC_SPLITTER_}
 61  ```
 62  """
 63      task_prompt = tokenizer.apply_chat_template(
 64          [
 65              {"role": "user", "content": task_prompt},
 66              {"role": "assistant", "content": response},
 67          ],
 68          tokenize=False,
 69      ).split(_MAGIC_SPLITTER_)[0]
 70      return task_prompt
 71  
 72  
 73  class DecoderBase(ABC):
 74  
 75      def __init__(
 76          self,
 77          name: str,
 78          batch_size: int = 1,
 79          temperature: float = 0.8,
 80          max_new_tokens: int = 768,
 81          dtype: str = "bfloat16",  # default
 82          trust_remote_code: bool = True,
 83          instruction_prefix: str = None,
 84          response_prefix: str = None,
 85          chat_mode=False,
 86      ) -> None:
 87          print("Initializing a decoder model: {} ...".format(name))
 88          self.name = name
 89          self.batch_size = batch_size
 90          self.temperature = temperature
 91          self.eos = EOS
 92          self.skip_special_tokens = False
 93          self.max_new_tokens = max_new_tokens
 94          self.dtype = dtype
 95          self.trust_remote_code = trust_remote_code
 96          self.instruction_prefix = instruction_prefix
 97          self.response_prefix = response_prefix
 98          self.chat_mode = chat_mode
 99  
100          print(f"{self.chat_mode = }")
101  
102      @abstractmethod
103      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
104          pass
105  
106      @abstractmethod
107      def is_direct_completion(self) -> bool:
108          pass
109  
110      def __repr__(self) -> str:
111          return self.name
112  
113      def __str__(self) -> str:
114          return self.name
115  
116  
117  class VllmDecoder(DecoderBase):
118  
119      def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
120          super().__init__(name, **kwargs)
121  
122          kwargs = {
123              "tensor_parallel_size": int(os.getenv("VLLM_N_GPUS", tp)),
124              "dtype": self.dtype,
125              "trust_remote_code": self.trust_remote_code,
126              "gpu_memory_utilization": 0.95,
127              "enforce_eager": True,
128              "distributed_executor_backend": "ray",
129          }
130  
131          self.tokenizer = AutoTokenizer.from_pretrained(self.name)
132          if not self.chat_mode:
133              self.eos += extra_eos_for_direct_completion(dataset)
134          self.llm = LLM(model=name, max_model_len=2048, **kwargs)
135  
136      def is_direct_completion(self) -> bool:
137          return not self.chat_mode
138  
139      def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200, use_tqdm=True) -> List[str]:
140          if do_sample:
141              assert self.temperature > 0, "Temperature must be greater than 0!"
142  
143          vllm_outputs = self.llm.generate(
144              prompts,
145              SamplingParams(
146                  temperature=self.temperature,
147                  max_tokens=self.max_new_tokens,
148                  top_p=0.95 if do_sample else 1.0,
149                  stop=self.eos,
150              ),
151              use_tqdm=use_tqdm,
152          )
153  
154          gen_strs = [x.outputs[0].text.replace("\t", "    ") for x in vllm_outputs]
155          return gen_strs
156  
157  
158  class GeneralVllmDecoder(VllmDecoder):
159  
160      def __init__(self, name: str, no_batching: bool, **kwargs) -> None:
161          super().__init__(name, **kwargs)
162  
163          self.no_batching = no_batching
164  
165          self.eos += ["\n```\n", "```"]
166          print(f"EOS strings: {self.eos}")
167  
168      def codegen(self, prompts: List[str], do_sample: bool = True, num_samples: int = 200) -> List[str]:
169          if self.no_batching:
170              print(f"{self.no_batching = }, disable continous batching. This may be slow.")
171              generated = []
172              for prompt in tqdm(prompts):
173                  chat_prompt = make_chat_prompt(
174                      prompt,
175                      instruction_prefix=self.instruction_prefix,
176                      response_prefix=self.response_prefix,
177                      tokenizer=self.tokenizer,
178                      chat_mode=self.chat_mode,
179                  )
180                  generated.append(VllmDecoder.codegen(self, chat_prompt, do_sample, num_samples, use_tqdm=False)[0])
181          else:
182              print(f"{self.no_batching = }, enable continous batching.")
183              chat_prompts = [
184                  make_chat_prompt(
185                      prompt,
186                      instruction_prefix=self.instruction_prefix,
187                      response_prefix=self.response_prefix,
188                      tokenizer=self.tokenizer,
189                      chat_mode=self.chat_mode,
190                  )
191                  for prompt in prompts
192              ]
193              generated = VllmDecoder.codegen(self, chat_prompts, do_sample, num_samples)
194  
195          return generated
196  
197  
198  class HfTorchDecoder(DecoderBase):
199  
200      def __init__(self, name: str, dataset: str, **kwargs):
201          super().__init__(name=name, **kwargs)
202          self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
203  
204          kwargs = {}
205          kwargs["device_map"] = "auto"
206          kwargs["trust_remote_code"] = self.trust_remote_code
207          # string to torch dtype
208          kwargs["torch_dtype"] = getattr(torch, self.dtype)
209          self.skip_special_tokens = True
210  
211          print(f"{kwargs = }")
212  
213          self.tokenizer = AutoTokenizer.from_pretrained(name)
214          if not self.chat_mode:
215              self.eos += extra_eos_for_direct_completion(dataset)
216  
217          self.model = AutoModelForCausalLM.from_pretrained(name, **kwargs)
218          self.model = self.model.to(self.device)
219  
220      def is_direct_completion(self) -> bool:
221          return not self.chat_mode
222  
223      @torch.inference_mode()
224      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
225          if self.temperature == 0:
226              assert not do_sample
227              assert num_samples == 1
228  
229          input_tokens = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
230          kwargs = {}
231          if do_sample:
232              kwargs["top_p"] = 0.95
233              kwargs["temperature"] = self.temperature
234  
235          stop_sequencer = StopSequencer(
236              self.model,
237              model_type="causal",  # or seq2seq
238              tokenizer=self.tokenizer,
239          )
240  
241          model = stop_sequencer.register_stop_texts(
242              stop_texts=self.eos,
243              input_length=input_tokens.size(-1),
244          )
245  
246          outputs = model.generate(
247              input_tokens,
248              max_new_tokens=self.max_new_tokens,
249              do_sample=do_sample,
250              num_return_sequences=min(self.batch_size, num_samples),
251              pad_token_id=self.tokenizer.eos_token_id,
252              **kwargs,
253          )
254  
255          gen_strs = self.tokenizer.batch_decode(
256              outputs[:, input_tokens.size(-1) :],
257              skip_special_tokens=self.skip_special_tokens,
258          )
259          outputs = []
260          # removes eos tokens.
261          for output in gen_strs:
262              min_index = 10000
263              for eos in self.eos:
264                  if eos in output:
265                      min_index = min(min_index, output.index(eos))
266              outputs.append(output[:min_index].replace("\t", "    "))
267          return outputs
268  
269  
270  class GenenralHfTorchDecoder(HfTorchDecoder):
271  
272      def __init__(self, name: str, **kwargs):
273          super().__init__(name=name, **kwargs)
274          self.eos += ["\n```\n", "```"]
275          print(f"EOS strings: {self.eos}")
276          self.tokenizer = AutoTokenizer.from_pretrained(self.name)
277  
278      def codegen(self, prompt: str, do_sample: bool = True, num_samples: int = 200) -> List[str]:
279          prompt = make_chat_prompt(
280              prompt,
281              self.instruction_prefix,
282              self.response_prefix,
283              self.tokenizer,
284              chat_mode=self.chat_mode,
285          )
286          return HfTorchDecoder.codegen(self, prompt, do_sample, num_samples)
287  
288  
289  def make_model(
290      model: str,
291      dataset: str,
292      backend: str = "vllm",
293      batch_size: int = 1,
294      temperature: float = 0.0,
295      tp=1,
296      instruction_prefix=None,
297      response_prefix=None,
298      no_batching=False,  # no_batching -> disable continous batching in vLLM
299      chat_mode=False,
300  ):
301      if backend == "vllm":
302          return GeneralVllmDecoder(
303              name=model,
304              batch_size=batch_size,
305              temperature=temperature,
306              dataset=dataset,
307              tp=tp,
308              instruction_prefix=instruction_prefix,
309              response_prefix=response_prefix,
310              no_batching=no_batching,
311              chat_mode=chat_mode,
312          )
313      elif backend == "hf":
314          return GenenralHfTorchDecoder(
315              name=model,
316              batch_size=batch_size,
317              temperature=temperature,
318              dataset=dataset,
319              instruction_prefix=instruction_prefix,
320              response_prefix=response_prefix,
321              chat_mode=chat_mode,
322          )