training.py
1 import argparse 2 import math 3 import os 4 from dataclasses import _MISSING_TYPE, dataclass 5 6 import datasets 7 import lightning.pytorch as pl 8 import torch 9 import transformers 10 from lightning.pytorch.strategies import DeepSpeedStrategy, FSDPStrategy 11 from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision 12 from torch.utils.data import DataLoader 13 from transformers.models.llama.modeling_llama import LlamaDecoderLayer 14 from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer 15 from trl import DataCollatorForCompletionOnlyLM 16 17 from liger_kernel.transformers import AutoLigerKernelForCausalLM 18 from liger_kernel.utils import infer_device 19 20 _RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"} 21 QUESTION = "<Question>" 22 CHOICES = "<Choices>" 23 24 25 @dataclass 26 class Args: 27 model: str = "Qwen/Qwen2-0.5B-Instruct" 28 data: str = "cais/mmlu" 29 output_dir: str = "mmlu_finetuning" 30 max_length: int = 2048 31 # for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G 32 batch_size: int = 4 33 lr: float = 6e-6 34 weight_decay: float = 0.05 35 warmup_ratio: float = 0.1 36 seed: int = 42 37 strategy: str = "auto" 38 num_gpu: int = None 39 40 41 def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0): 42 def lr_lambda(current_step): 43 if current_step < warmup_steps: 44 # Linear warmup 45 return float(current_step) / float(max(1, warmup_steps)) 46 else: 47 # Cosine annealing 48 progress = float(current_step - warmup_steps) / float( 49 max(1, total_steps - warmup_steps) 50 ) 51 return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress))) 52 53 return lr_lambda 54 55 56 def parse_args() -> Args: 57 parser = argparse.ArgumentParser() 58 for k, v in Args.__dataclass_fields__.items(): 59 parser.add_argument(f"--{k}", type=v.type, default=v.default) 60 parsed = parser.parse_args() 61 return Args( 62 **{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)} 63 ) 64 65 66 class LanguageModel(pl.LightningModule): 67 def __init__(self, args: Args, tokenizer): 68 super().__init__() 69 self.args = args 70 self.tokenizer = tokenizer 71 self.model = None 72 73 def configure_model(self): 74 # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization 75 if self.model is not None: 76 return 77 self.model = AutoLigerKernelForCausalLM.from_pretrained( 78 self.args.model, use_cache=False, ignore_mismatched_sizes=True 79 ) 80 if self.args.strategy == "deepspeed": 81 self.model.train() 82 self.model.gradient_checkpointing_enable() 83 84 def forward(self, input_ids, attention_mask, labels=None, **kwargs): 85 return self.model( 86 input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs 87 ) 88 89 def training_step(self, batch): 90 outputs = self.model( 91 input_ids=batch["input_ids"], 92 attention_mask=batch["attention_mask"], 93 labels=batch["labels"], 94 ) 95 loss = outputs.loss 96 self.log_dict( 97 {"train_loss": loss}, 98 on_step=True, 99 on_epoch=True, 100 prog_bar=True, 101 logger=True, 102 rank_zero_only=True, 103 sync_dist=False, 104 ) 105 return loss 106 107 def validation_step(self, batch): 108 outputs = self.model( 109 input_ids=batch["input_ids"], 110 attention_mask=batch["attention_mask"], 111 labels=batch["labels"], 112 ) 113 loss = outputs.loss 114 self.log_dict( 115 {"val_loss": outputs.loss}, 116 on_step=True, 117 on_epoch=True, 118 prog_bar=True, 119 logger=True, 120 rank_zero_only=True, 121 sync_dist=True, 122 ) 123 return loss 124 125 def configure_optimizers(self): 126 optimizer = torch.optim.AdamW( 127 self.parameters(), 128 lr=self.args.lr, 129 weight_decay=self.args.weight_decay, 130 fused=True, 131 ) 132 lr_lambda = warmup_cosine_schedule( 133 warmup_steps=self.trainer.estimated_stepping_batches 134 * self.args.warmup_ratio, 135 total_steps=self.trainer.estimated_stepping_batches, 136 min_lr=0, 137 ) 138 lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 139 return { 140 "optimizer": optimizer, 141 "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"}, 142 } 143 144 145 class DataModule(pl.LightningDataModule): 146 def __init__(self, tokenizer, args: Args): 147 super().__init__() 148 self.args = args 149 self.tokenizer = tokenizer 150 self.response_template_str = " <Answer>" 151 response_prompt = tokenizer.encode( 152 f"{self.response_template_str}", add_special_tokens=False 153 ) 154 self.collator = DataCollatorForCompletionOnlyLM( 155 tokenizer=tokenizer, 156 response_template=response_prompt, 157 pad_to_multiple_of=16, 158 ) 159 160 def formatting_func(self, example): 161 output_texts = [] 162 for i in range(len(example["question"])): 163 choices = "" 164 for j in range(len(example["choices"][i])): 165 choices += f"{j+1}. {example['choices'][i][j]}; " 166 s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. " 167 s += f"{QUESTION}{example['question'][i]} " 168 s += f"{CHOICES}{choices} " 169 s += f"{self.response_template_str}{example['answer'][i]}" 170 output_texts.append(s) 171 return output_texts 172 173 def tokenize(self, example): 174 outputs = self.tokenizer( 175 self.formatting_func(example), 176 truncation=True, 177 padding=False, 178 max_length=self.args.max_length, 179 ) 180 return { 181 "input_ids": outputs["input_ids"], 182 "attention_mask": outputs["attention_mask"], 183 } 184 185 def setup(self, stage) -> None: 186 dataset = datasets.load_dataset(self.args.data, "auxiliary_train") 187 flattened_data = [ 188 { 189 "answer": x["train"]["answer"], 190 "choices": x["train"]["choices"], 191 "question": x["train"]["question"], 192 "subject": x["train"]["subject"], 193 } 194 for x in dataset["train"] 195 ] 196 dataset = datasets.Dataset.from_list(flattened_data) 197 dataset = dataset.train_test_split(test_size=4096, seed=self.args.seed) 198 train_dataset, val_dataset = dataset["train"], dataset["test"] 199 self.train_dataset = train_dataset.map( 200 self.tokenize, 201 remove_columns=list(set(train_dataset.column_names) - _RETAIN_COLUMNS), 202 batched=True, 203 batch_size=1, 204 num_proc=4, 205 ) 206 self.val_dataset = val_dataset.map( 207 self.tokenize, 208 remove_columns=list(set(val_dataset.column_names) - _RETAIN_COLUMNS), 209 batched=True, 210 batch_size=1, 211 num_proc=4, 212 ) 213 214 def train_dataloader(self): 215 return DataLoader( 216 self.train_dataset, 217 batch_size=self.args.batch_size, 218 collate_fn=self.collator, 219 ) 220 221 def val_dataloader(self): 222 return DataLoader( 223 self.val_dataset, 224 batch_size=self.args.batch_size, 225 collate_fn=self.collator, 226 ) 227 228 229 def train(): 230 args = parse_args() 231 pl.seed_everything(args.seed) 232 os.makedirs(args.output_dir, exist_ok=True) 233 234 if "Meta-Llama-3-8B" in args.model: 235 layers = {LlamaDecoderLayer} 236 elif "Qwen2" in args.model: 237 layers = {Qwen2DecoderLayer} 238 else: 239 layers = {} 240 raise Warning( 241 f"Unimplemented layer wrap policy for {args.model} in this example" 242 ) 243 244 if args.strategy == "fsdp": 245 strategy = FSDPStrategy( 246 auto_wrap_policy=layers, 247 sharding_strategy="FULL_SHARD", 248 backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 249 sync_module_states=True, 250 activation_checkpointing_policy=layers, 251 mixed_precision=MixedPrecision( 252 param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16 253 ), 254 forward_prefetch=True, 255 ) 256 precision = None 257 elif args.strategy == "deepspeed": 258 strategy = DeepSpeedStrategy(stage=3) 259 precision = "bf16-mixed" 260 elif args.strategy == "ddp": 261 strategy = "ddp" 262 precision = "bf16-true" 263 else: 264 strategy = "auto" 265 precision = "bf16-true" 266 267 device = infer_device() 268 trainer = pl.Trainer( 269 accelerator=device, 270 strategy=strategy, 271 devices=( 272 getattr(torch, device).device_count() 273 if args.num_gpu is None 274 else args.num_gpu 275 ), 276 default_root_dir=args.output_dir, 277 log_every_n_steps=1, 278 max_epochs=1, 279 precision=precision, 280 ) 281 282 tokenizer = transformers.AutoTokenizer.from_pretrained( 283 args.model, padding_side="left", truncation_side="left" 284 ) 285 tokenizer.pad_token = tokenizer.eos_token 286 data_module = DataModule( 287 tokenizer=tokenizer, 288 args=args, 289 ) 290 model = LanguageModel(args=args, tokenizer=tokenizer) 291 trainer.fit(model, datamodule=data_module) 292 293 294 if __name__ == "__main__": 295 train()