training.py
1 from dataclasses import dataclass 2 3 import datasets 4 import torch 5 import transformers 6 from callback import EfficiencyCallback 7 from trl import DataCollatorForCompletionOnlyLM, SFTTrainer 8 9 from liger_kernel.transformers import AutoLigerKernelForCausalLM 10 11 12 @dataclass 13 class CustomArguments: 14 model_name: str = "meta-llama/Meta-Llama-3-8B" 15 dataset: str = "tatsu-lab/alpaca" 16 max_seq_length: int = 512 17 use_liger: bool = False 18 19 20 def formatting_prompts_func(example): 21 return example["text"] 22 23 24 def train(): 25 parser = transformers.HfArgumentParser( 26 (transformers.TrainingArguments, CustomArguments) 27 ) 28 training_args, custom_args = parser.parse_args_into_dataclasses() 29 tokenizer = transformers.AutoTokenizer.from_pretrained( 30 custom_args.model_name, 31 padding_side="left", 32 truncation_side="left", 33 ) 34 tokenizer.pad_token = tokenizer.eos_token 35 36 dataset = datasets.load_dataset(custom_args.dataset)["train"].train_test_split( 37 test_size=0.1 38 ) 39 train_dataset = dataset["train"] 40 eval_dataset = dataset["test"] 41 response_prompt = tokenizer.encode("### Response:\n", add_special_tokens=False) 42 collator = DataCollatorForCompletionOnlyLM( 43 tokenizer=tokenizer, 44 response_template=response_prompt, 45 pad_to_multiple_of=16, 46 ) 47 48 if custom_args.use_liger: 49 model = AutoLigerKernelForCausalLM.from_pretrained( 50 custom_args.model_name, 51 trust_remote_code=True, 52 use_cache=False, 53 torch_dtype=torch.bfloat16, 54 # These args will get passed to the appropriate apply_liger_kernel_to_* function 55 # to override the default settings 56 # cross_entropy=True, 57 # fused_linear_cross_entropy=False, 58 ) 59 else: 60 model = transformers.AutoModelForCausalLM.from_pretrained( 61 custom_args.model_name, 62 trust_remote_code=True, 63 use_cache=False, 64 torch_dtype=torch.bfloat16, 65 ) 66 67 trainer = SFTTrainer( 68 model=model, 69 args=training_args, 70 data_collator=collator, 71 max_seq_length=custom_args.max_seq_length, 72 train_dataset=train_dataset, 73 eval_dataset=eval_dataset, 74 formatting_func=formatting_prompts_func, 75 callbacks=[EfficiencyCallback()], 76 ) 77 trainer.train() 78 79 80 if __name__ == "__main__": 81 train()