training_multimodal.py
1 import os 2 from dataclasses import dataclass 3 4 import datasets 5 import torch 6 import transformers 7 from callback import EfficiencyCallback 8 from datasets import Image as ImageFeature 9 from trl import SFTTrainer 10 11 from liger_kernel.transformers import monkey_patch 12 13 14 @dataclass 15 class CustomArguments: 16 model_name: str = "Qwen/Qwen2-VL-2B-Instruct" 17 dataset: str = "HuggingFaceM4/the_cauldron" 18 dataset_subset: str = "ai2d" 19 dataset_split: str = "train" 20 max_seq_length: int = 512 21 dataset_text_field: str = "texts" 22 use_liger: bool = False 23 24 25 def construct_model_and_processor(model_name: str, use_liger: bool) -> torch.nn.Module: 26 if "Qwen2-VL" in model_name: 27 from transformers import Qwen2VLForConditionalGeneration 28 29 # These settings are used to reduce the memory footprint of the Qwen2-VL model, 30 # which supports training/inferences on images in their native resolution. Large 31 # images -> many visual tokens (a max of 16384) -> large memory consumption. 32 # If fine-tuning for a real-world application, consider these values carefully. 33 min_visual_tokens_per_image = 256 34 max_visual_tokens_per_image = 256 35 36 processor = transformers.AutoProcessor.from_pretrained( 37 model_name, 38 padding_side="left", 39 truncation_side="left", 40 min_pixels=min_visual_tokens_per_image * 28 * 28, # patch size is 14x14 41 max_pixels=max_visual_tokens_per_image * 28 * 28, # 4 patches / token 42 ) 43 processor.tokenizer.pad_token = processor.tokenizer.eos_token 44 image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") 45 46 if use_liger: 47 print("Applying Liger Kernel to Qwen2-VL model") 48 monkey_patch.apply_liger_kernel_to_qwen2_vl( 49 # These args can be used to override the default Liger settings 50 # cross_entropy=True, 51 # fused_linear_cross_entropy=False, 52 ) 53 54 model = Qwen2VLForConditionalGeneration.from_pretrained( 55 pretrained_model_name_or_path=model_name, 56 use_cache=False, 57 torch_dtype=torch.bfloat16, 58 low_cpu_mem_usage=True, 59 attn_implementation="sdpa", 60 ) 61 return model, processor, image_token_id 62 63 raise NotImplementedError(f"Model {model_name} not supported") 64 65 66 def _validate_and_extract_the_cauldron(examples) -> dict[str, list]: 67 batch_texts = [] 68 batch_images = [] 69 for images, texts in zip(examples["images"], examples["texts"]): 70 if not images: 71 raise ValueError("No image found in example from the_cauldron dataset") 72 if len(images) > 1: 73 raise ValueError("Only one image per example is supported") 74 batch_texts.extend(texts) 75 batch_images.extend([images[0]] * len(texts)) 76 return {"texts": batch_texts, "images": batch_images} 77 78 79 def _format_for_convo(example, tokenizer): 80 # cauldron data is already in message format {"user": ..., "assistant": ...} 81 text = example["texts"] 82 messages = [ 83 { 84 "role": "user", 85 "content": [{"type": "image"}, {"type": "text", "text": text["user"]}], 86 }, 87 {"role": "assistant", "content": [{"type": "text", "text": text["assistant"]}]}, 88 ] 89 text = tokenizer.apply_chat_template(messages, tokenize=False) 90 return {"texts": text} 91 92 93 def train(): 94 parser = transformers.HfArgumentParser( 95 (transformers.TrainingArguments, CustomArguments) 96 ) 97 training_args, custom_args = parser.parse_args_into_dataclasses() 98 training_args.remove_unused_columns = False # required to not drop the image column 99 training_args.dataset_kwargs = {"skip_prepare_dataset": True} 100 101 model, processor, image_token_id = construct_model_and_processor( 102 custom_args.model_name, custom_args.use_liger 103 ) 104 105 dataset = ( 106 datasets.load_dataset( 107 custom_args.dataset, 108 custom_args.dataset_subset, 109 split=custom_args.dataset_split, 110 ) 111 .map( 112 _validate_and_extract_the_cauldron, 113 batched=True, 114 num_proc=min(os.cpu_count(), 16), 115 desc="Extracting text and images", 116 ) 117 .map( 118 _format_for_convo, 119 fn_kwargs={"tokenizer": processor.tokenizer}, 120 desc="Formatting for convo", 121 ) 122 .cast_column("images", ImageFeature()) 123 .train_test_split(test_size=0.1) 124 ) 125 126 train_dataset = dataset["train"] 127 eval_dataset = dataset["test"] 128 129 def collate_fn(examples): 130 """ 131 Taken directly from the TRL documentation with minor modifications: 132 https://huggingface.co/docs/trl/en/sft_trainer#a-custom-collator-for-processing-multi-modal-data 133 134 Modifications: 135 1. `apply_chat_template` is used to preprocess the texts before training begins (see above) 136 2. `example["messages"]` -> `example["texts"]` to conform with the_cauldron dataset schema 137 3. Ignoring image tokens in the loss computation 138 """ 139 # Get the texts and images 140 texts = [example["texts"] for example in examples] 141 images = [example["images"] for example in examples] 142 143 # Tokenize the texts and process the images 144 batch = processor(text=texts, images=images, return_tensors="pt", padding=True) 145 146 # The labels are the input_ids, and we mask the padding tokens in the loss computation 147 labels = batch["input_ids"].clone() 148 labels[labels == processor.tokenizer.pad_token_id] = -100 149 150 # Ignore the image token index in the loss computation 151 labels[labels == image_token_id] = -100 152 batch["labels"] = labels 153 154 return batch 155 156 trainer = SFTTrainer( 157 model=model, 158 args=training_args, 159 data_collator=collate_fn, 160 max_seq_length=custom_args.max_seq_length, 161 dataset_text_field=custom_args.dataset_text_field, 162 train_dataset=train_dataset, 163 eval_dataset=eval_dataset, 164 tokenizer=processor.tokenizer, 165 callbacks=[EfficiencyCallback()], 166 ) 167 trainer.train() 168 169 170 if __name__ == "__main__": 171 train()