/ examples / huggingface / training.py
training.py
 1  from dataclasses import dataclass
 2  
 3  import datasets
 4  import torch
 5  import transformers
 6  from callback import EfficiencyCallback
 7  from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
 8  
 9  from liger_kernel.transformers import AutoLigerKernelForCausalLM
10  
11  
12  @dataclass
13  class CustomArguments:
14      model_name: str = "meta-llama/Meta-Llama-3-8B"
15      dataset: str = "tatsu-lab/alpaca"
16      max_seq_length: int = 512
17      use_liger: bool = False
18  
19  
20  def formatting_prompts_func(example):
21      return example["text"]
22  
23  
24  def train():
25      parser = transformers.HfArgumentParser(
26          (transformers.TrainingArguments, CustomArguments)
27      )
28      training_args, custom_args = parser.parse_args_into_dataclasses()
29      tokenizer = transformers.AutoTokenizer.from_pretrained(
30          custom_args.model_name,
31          padding_side="left",
32          truncation_side="left",
33      )
34      tokenizer.pad_token = tokenizer.eos_token
35  
36      dataset = datasets.load_dataset(custom_args.dataset)["train"].train_test_split(
37          test_size=0.1
38      )
39      train_dataset = dataset["train"]
40      eval_dataset = dataset["test"]
41      response_prompt = tokenizer.encode("### Response:\n", add_special_tokens=False)
42      collator = DataCollatorForCompletionOnlyLM(
43          tokenizer=tokenizer,
44          response_template=response_prompt,
45          pad_to_multiple_of=16,
46      )
47  
48      if custom_args.use_liger:
49          model = AutoLigerKernelForCausalLM.from_pretrained(
50              custom_args.model_name,
51              trust_remote_code=True,
52              use_cache=False,
53              torch_dtype=torch.bfloat16,
54              # These args will get passed to the appropriate apply_liger_kernel_to_* function
55              # to override the default settings
56              # cross_entropy=True,
57              # fused_linear_cross_entropy=False,
58          )
59      else:
60          model = transformers.AutoModelForCausalLM.from_pretrained(
61              custom_args.model_name,
62              trust_remote_code=True,
63              use_cache=False,
64              torch_dtype=torch.bfloat16,
65          )
66  
67      trainer = SFTTrainer(
68          model=model,
69          args=training_args,
70          data_collator=collator,
71          max_seq_length=custom_args.max_seq_length,
72          train_dataset=train_dataset,
73          eval_dataset=eval_dataset,
74          formatting_func=formatting_prompts_func,
75          callbacks=[EfficiencyCallback()],
76      )
77      trainer.train()
78  
79  
80  if __name__ == "__main__":
81      train()