/ examples / huggingface / training_multimodal.py
training_multimodal.py
  1  import os
  2  from dataclasses import dataclass
  3  
  4  import datasets
  5  import torch
  6  import transformers
  7  from callback import EfficiencyCallback
  8  from datasets import Image as ImageFeature
  9  from trl import SFTTrainer
 10  
 11  from liger_kernel.transformers import monkey_patch
 12  
 13  
 14  @dataclass
 15  class CustomArguments:
 16      model_name: str = "Qwen/Qwen2-VL-2B-Instruct"
 17      dataset: str = "HuggingFaceM4/the_cauldron"
 18      dataset_subset: str = "ai2d"
 19      dataset_split: str = "train"
 20      max_seq_length: int = 512
 21      dataset_text_field: str = "texts"
 22      use_liger: bool = False
 23  
 24  
 25  def construct_model_and_processor(model_name: str, use_liger: bool) -> torch.nn.Module:
 26      if "Qwen2-VL" in model_name:
 27          from transformers import Qwen2VLForConditionalGeneration
 28  
 29          # These settings are used to reduce the memory footprint of the Qwen2-VL model,
 30          # which supports training/inferences on images in their native resolution. Large
 31          # images -> many visual tokens (a max of 16384) -> large memory consumption.
 32          # If fine-tuning for a real-world application, consider these values carefully.
 33          min_visual_tokens_per_image = 256
 34          max_visual_tokens_per_image = 256
 35  
 36          processor = transformers.AutoProcessor.from_pretrained(
 37              model_name,
 38              padding_side="left",
 39              truncation_side="left",
 40              min_pixels=min_visual_tokens_per_image * 28 * 28,  # patch size is 14x14
 41              max_pixels=max_visual_tokens_per_image * 28 * 28,  # 4 patches / token
 42          )
 43          processor.tokenizer.pad_token = processor.tokenizer.eos_token
 44          image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
 45  
 46          if use_liger:
 47              print("Applying Liger Kernel to Qwen2-VL model")
 48              monkey_patch.apply_liger_kernel_to_qwen2_vl(
 49                  # These args can be used to override the default Liger settings
 50                  # cross_entropy=True,
 51                  # fused_linear_cross_entropy=False,
 52              )
 53  
 54          model = Qwen2VLForConditionalGeneration.from_pretrained(
 55              pretrained_model_name_or_path=model_name,
 56              use_cache=False,
 57              torch_dtype=torch.bfloat16,
 58              low_cpu_mem_usage=True,
 59              attn_implementation="sdpa",
 60          )
 61          return model, processor, image_token_id
 62  
 63      raise NotImplementedError(f"Model {model_name} not supported")
 64  
 65  
 66  def _validate_and_extract_the_cauldron(examples) -> dict[str, list]:
 67      batch_texts = []
 68      batch_images = []
 69      for images, texts in zip(examples["images"], examples["texts"]):
 70          if not images:
 71              raise ValueError("No image found in example from the_cauldron dataset")
 72          if len(images) > 1:
 73              raise ValueError("Only one image per example is supported")
 74          batch_texts.extend(texts)
 75          batch_images.extend([images[0]] * len(texts))
 76      return {"texts": batch_texts, "images": batch_images}
 77  
 78  
 79  def _format_for_convo(example, tokenizer):
 80      # cauldron data is already in message format {"user": ..., "assistant": ...}
 81      text = example["texts"]
 82      messages = [
 83          {
 84              "role": "user",
 85              "content": [{"type": "image"}, {"type": "text", "text": text["user"]}],
 86          },
 87          {"role": "assistant", "content": [{"type": "text", "text": text["assistant"]}]},
 88      ]
 89      text = tokenizer.apply_chat_template(messages, tokenize=False)
 90      return {"texts": text}
 91  
 92  
 93  def train():
 94      parser = transformers.HfArgumentParser(
 95          (transformers.TrainingArguments, CustomArguments)
 96      )
 97      training_args, custom_args = parser.parse_args_into_dataclasses()
 98      training_args.remove_unused_columns = False  # required to not drop the image column
 99      training_args.dataset_kwargs = {"skip_prepare_dataset": True}
100  
101      model, processor, image_token_id = construct_model_and_processor(
102          custom_args.model_name, custom_args.use_liger
103      )
104  
105      dataset = (
106          datasets.load_dataset(
107              custom_args.dataset,
108              custom_args.dataset_subset,
109              split=custom_args.dataset_split,
110          )
111          .map(
112              _validate_and_extract_the_cauldron,
113              batched=True,
114              num_proc=min(os.cpu_count(), 16),
115              desc="Extracting text and images",
116          )
117          .map(
118              _format_for_convo,
119              fn_kwargs={"tokenizer": processor.tokenizer},
120              desc="Formatting for convo",
121          )
122          .cast_column("images", ImageFeature())
123          .train_test_split(test_size=0.1)
124      )
125  
126      train_dataset = dataset["train"]
127      eval_dataset = dataset["test"]
128  
129      def collate_fn(examples):
130          """
131          Taken directly from the TRL documentation with minor modifications:
132          https://huggingface.co/docs/trl/en/sft_trainer#a-custom-collator-for-processing-multi-modal-data
133  
134          Modifications:
135          1. `apply_chat_template` is used to preprocess the texts before training begins (see above)
136          2. `example["messages"]` -> `example["texts"]` to conform with the_cauldron dataset schema
137          3. Ignoring image tokens in the loss computation
138          """
139          # Get the texts and images
140          texts = [example["texts"] for example in examples]
141          images = [example["images"] for example in examples]
142  
143          # Tokenize the texts and process the images
144          batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
145  
146          # The labels are the input_ids, and we mask the padding tokens in the loss computation
147          labels = batch["input_ids"].clone()
148          labels[labels == processor.tokenizer.pad_token_id] = -100
149  
150          # Ignore the image token index in the loss computation
151          labels[labels == image_token_id] = -100
152          batch["labels"] = labels
153  
154          return batch
155  
156      trainer = SFTTrainer(
157          model=model,
158          args=training_args,
159          data_collator=collate_fn,
160          max_seq_length=custom_args.max_seq_length,
161          dataset_text_field=custom_args.dataset_text_field,
162          train_dataset=train_dataset,
163          eval_dataset=eval_dataset,
164          tokenizer=processor.tokenizer,
165          callbacks=[EfficiencyCallback()],
166      )
167      trainer.train()
168  
169  
170  if __name__ == "__main__":
171      train()