/ examples / lightning / training.py
training.py
  1  import argparse
  2  import math
  3  import os
  4  from dataclasses import _MISSING_TYPE, dataclass
  5  
  6  import datasets
  7  import lightning.pytorch as pl
  8  import torch
  9  import transformers
 10  from lightning.pytorch.strategies import DeepSpeedStrategy, FSDPStrategy
 11  from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision
 12  from torch.utils.data import DataLoader
 13  from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 14  from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
 15  from trl import DataCollatorForCompletionOnlyLM
 16  
 17  from liger_kernel.transformers import AutoLigerKernelForCausalLM
 18  from liger_kernel.utils import infer_device
 19  
 20  _RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"}
 21  QUESTION = "<Question>"
 22  CHOICES = "<Choices>"
 23  
 24  
 25  @dataclass
 26  class Args:
 27      model: str = "Qwen/Qwen2-0.5B-Instruct"
 28      data: str = "cais/mmlu"
 29      output_dir: str = "mmlu_finetuning"
 30      max_length: int = 2048
 31      # for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G
 32      batch_size: int = 4
 33      lr: float = 6e-6
 34      weight_decay: float = 0.05
 35      warmup_ratio: float = 0.1
 36      seed: int = 42
 37      strategy: str = "auto"
 38      num_gpu: int = None
 39  
 40  
 41  def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0):
 42      def lr_lambda(current_step):
 43          if current_step < warmup_steps:
 44              # Linear warmup
 45              return float(current_step) / float(max(1, warmup_steps))
 46          else:
 47              # Cosine annealing
 48              progress = float(current_step - warmup_steps) / float(
 49                  max(1, total_steps - warmup_steps)
 50              )
 51              return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress)))
 52  
 53      return lr_lambda
 54  
 55  
 56  def parse_args() -> Args:
 57      parser = argparse.ArgumentParser()
 58      for k, v in Args.__dataclass_fields__.items():
 59          parser.add_argument(f"--{k}", type=v.type, default=v.default)
 60      parsed = parser.parse_args()
 61      return Args(
 62          **{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)}
 63      )
 64  
 65  
 66  class LanguageModel(pl.LightningModule):
 67      def __init__(self, args: Args, tokenizer):
 68          super().__init__()
 69          self.args = args
 70          self.tokenizer = tokenizer
 71          self.model = None
 72  
 73      def configure_model(self):
 74          # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization
 75          if self.model is not None:
 76              return
 77          self.model = AutoLigerKernelForCausalLM.from_pretrained(
 78              self.args.model, use_cache=False, ignore_mismatched_sizes=True
 79          )
 80          if self.args.strategy == "deepspeed":
 81              self.model.train()
 82              self.model.gradient_checkpointing_enable()
 83  
 84      def forward(self, input_ids, attention_mask, labels=None, **kwargs):
 85          return self.model(
 86              input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs
 87          )
 88  
 89      def training_step(self, batch):
 90          outputs = self.model(
 91              input_ids=batch["input_ids"],
 92              attention_mask=batch["attention_mask"],
 93              labels=batch["labels"],
 94          )
 95          loss = outputs.loss
 96          self.log_dict(
 97              {"train_loss": loss},
 98              on_step=True,
 99              on_epoch=True,
100              prog_bar=True,
101              logger=True,
102              rank_zero_only=True,
103              sync_dist=False,
104          )
105          return loss
106  
107      def validation_step(self, batch):
108          outputs = self.model(
109              input_ids=batch["input_ids"],
110              attention_mask=batch["attention_mask"],
111              labels=batch["labels"],
112          )
113          loss = outputs.loss
114          self.log_dict(
115              {"val_loss": outputs.loss},
116              on_step=True,
117              on_epoch=True,
118              prog_bar=True,
119              logger=True,
120              rank_zero_only=True,
121              sync_dist=True,
122          )
123          return loss
124  
125      def configure_optimizers(self):
126          optimizer = torch.optim.AdamW(
127              self.parameters(),
128              lr=self.args.lr,
129              weight_decay=self.args.weight_decay,
130              fused=True,
131          )
132          lr_lambda = warmup_cosine_schedule(
133              warmup_steps=self.trainer.estimated_stepping_batches
134              * self.args.warmup_ratio,
135              total_steps=self.trainer.estimated_stepping_batches,
136              min_lr=0,
137          )
138          lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
139          return {
140              "optimizer": optimizer,
141              "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
142          }
143  
144  
145  class DataModule(pl.LightningDataModule):
146      def __init__(self, tokenizer, args: Args):
147          super().__init__()
148          self.args = args
149          self.tokenizer = tokenizer
150          self.response_template_str = " <Answer>"
151          response_prompt = tokenizer.encode(
152              f"{self.response_template_str}", add_special_tokens=False
153          )
154          self.collator = DataCollatorForCompletionOnlyLM(
155              tokenizer=tokenizer,
156              response_template=response_prompt,
157              pad_to_multiple_of=16,
158          )
159  
160      def formatting_func(self, example):
161          output_texts = []
162          for i in range(len(example["question"])):
163              choices = ""
164              for j in range(len(example["choices"][i])):
165                  choices += f"{j+1}. {example['choices'][i][j]}; "
166              s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. "
167              s += f"{QUESTION}{example['question'][i]} "
168              s += f"{CHOICES}{choices} "
169              s += f"{self.response_template_str}{example['answer'][i]}"
170              output_texts.append(s)
171          return output_texts
172  
173      def tokenize(self, example):
174          outputs = self.tokenizer(
175              self.formatting_func(example),
176              truncation=True,
177              padding=False,
178              max_length=self.args.max_length,
179          )
180          return {
181              "input_ids": outputs["input_ids"],
182              "attention_mask": outputs["attention_mask"],
183          }
184  
185      def setup(self, stage) -> None:
186          dataset = datasets.load_dataset(self.args.data, "auxiliary_train")
187          flattened_data = [
188              {
189                  "answer": x["train"]["answer"],
190                  "choices": x["train"]["choices"],
191                  "question": x["train"]["question"],
192                  "subject": x["train"]["subject"],
193              }
194              for x in dataset["train"]
195          ]
196          dataset = datasets.Dataset.from_list(flattened_data)
197          dataset = dataset.train_test_split(test_size=4096, seed=self.args.seed)
198          train_dataset, val_dataset = dataset["train"], dataset["test"]
199          self.train_dataset = train_dataset.map(
200              self.tokenize,
201              remove_columns=list(set(train_dataset.column_names) - _RETAIN_COLUMNS),
202              batched=True,
203              batch_size=1,
204              num_proc=4,
205          )
206          self.val_dataset = val_dataset.map(
207              self.tokenize,
208              remove_columns=list(set(val_dataset.column_names) - _RETAIN_COLUMNS),
209              batched=True,
210              batch_size=1,
211              num_proc=4,
212          )
213  
214      def train_dataloader(self):
215          return DataLoader(
216              self.train_dataset,
217              batch_size=self.args.batch_size,
218              collate_fn=self.collator,
219          )
220  
221      def val_dataloader(self):
222          return DataLoader(
223              self.val_dataset,
224              batch_size=self.args.batch_size,
225              collate_fn=self.collator,
226          )
227  
228  
229  def train():
230      args = parse_args()
231      pl.seed_everything(args.seed)
232      os.makedirs(args.output_dir, exist_ok=True)
233  
234      if "Meta-Llama-3-8B" in args.model:
235          layers = {LlamaDecoderLayer}
236      elif "Qwen2" in args.model:
237          layers = {Qwen2DecoderLayer}
238      else:
239          layers = {}
240          raise Warning(
241              f"Unimplemented layer wrap policy for {args.model} in this example"
242          )
243  
244      if args.strategy == "fsdp":
245          strategy = FSDPStrategy(
246              auto_wrap_policy=layers,
247              sharding_strategy="FULL_SHARD",
248              backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
249              sync_module_states=True,
250              activation_checkpointing_policy=layers,
251              mixed_precision=MixedPrecision(
252                  param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16
253              ),
254              forward_prefetch=True,
255          )
256          precision = None
257      elif args.strategy == "deepspeed":
258          strategy = DeepSpeedStrategy(stage=3)
259          precision = "bf16-mixed"
260      elif args.strategy == "ddp":
261          strategy = "ddp"
262          precision = "bf16-true"
263      else:
264          strategy = "auto"
265          precision = "bf16-true"
266  
267      device = infer_device()
268      trainer = pl.Trainer(
269          accelerator=device,
270          strategy=strategy,
271          devices=(
272              getattr(torch, device).device_count()
273              if args.num_gpu is None
274              else args.num_gpu
275          ),
276          default_root_dir=args.output_dir,
277          log_every_n_steps=1,
278          max_epochs=1,
279          precision=precision,
280      )
281  
282      tokenizer = transformers.AutoTokenizer.from_pretrained(
283          args.model, padding_side="left", truncation_side="left"
284      )
285      tokenizer.pad_token = tokenizer.eos_token
286      data_module = DataModule(
287          tokenizer=tokenizer,
288          args=args,
289      )
290      model = LanguageModel(args=args, tokenizer=tokenizer)
291      trainer.fit(model, datamodule=data_module)
292  
293  
294  if __name__ == "__main__":
295      train()