/ repo3-fine-tuning-template / dpo_train.py
dpo_train.py
1 #!/usr/bin/env python3 2 """ 3 DPO (Direct Preference Optimization) 训练脚本 4 使用TRL库对Qwen3-4B模型进行fine-tuning 5 """ 6 7 import os 8 # 设置环境变量抑制TensorFlow AVX警告 9 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 10 os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' 11 12 import os 13 import sys 14 import logging 15 import yaml 16 import torch 17 from pathlib import Path 18 from datasets import Dataset 19 from transformers import AutoTokenizer 20 from trl import DPOTrainer 21 22 from data_utils import load_dpo_dataset, create_sample_dpo_data, format_dpo_data_for_training 23 from model_utils import ( 24 load_model_and_tokenizer, 25 create_peft_config, 26 apply_peft_to_model, 27 create_training_arguments, 28 save_model_and_tokenizer 29 ) 30 31 # 设置日志 32 logging.basicConfig( 33 level=logging.INFO, 34 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 35 handlers=[ 36 logging.FileHandler('dpo_training.log'), 37 logging.StreamHandler(sys.stdout) 38 ] 39 ) 40 logger = logging.getLogger(__name__) 41 42 def load_config(config_path: str = "train_config.yaml"): 43 """加载配置文件""" 44 with open(config_path, 'r', encoding='utf-8') as f: 45 config = yaml.safe_load(f) 46 return config 47 48 def prepare_dataset(config: dict): 49 """准备训练数据集""" 50 data_config = config['data'] 51 train_file = data_config['train_file'] 52 53 # 如果数据文件不存在,创建示例数据 54 if not os.path.exists(train_file): 55 logger.info("数据文件不存在,创建示例数据...") 56 create_sample_dpo_data(train_file) 57 58 # 加载数据集 59 dataset = load_dpo_dataset(train_file) 60 61 # 分割训练集和验证集 62 test_size = data_config.get('test_size', 0.1) 63 if test_size > 0: 64 dataset = dataset.train_test_split(test_size=test_size) 65 train_dataset = dataset['train'] 66 eval_dataset = dataset['test'] 67 logger.info(f"数据集分割完成: 训练集 {len(train_dataset)} 条, 验证集 {len(eval_dataset)} 条") 68 else: 69 train_dataset = dataset 70 eval_dataset = dataset 71 logger.info(f"使用全部数据作为训练集: {len(train_dataset)} 条") 72 73 return train_dataset, eval_dataset 74 75 def create_dpo_trainer( 76 model, 77 tokenizer, 78 train_dataset: Dataset, 79 eval_dataset: Dataset, 80 config: dict 81 ): 82 """创建DPO训练器""" 83 84 # 创建训练参数 85 training_config = config['training'] 86 output_config = config['output'] 87 dpo_config = config['dpo'] 88 89 training_args = create_training_arguments( 90 output_dir=output_config['output_dir'], 91 num_train_epochs=training_config['num_train_epochs'], 92 per_device_train_batch_size=training_config['per_device_train_batch_size'], 93 per_device_eval_batch_size=training_config['per_device_eval_batch_size'], 94 gradient_accumulation_steps=training_config['gradient_accumulation_steps'], 95 learning_rate=float(training_config['learning_rate']), 96 warmup_steps=training_config['warmup_steps'], 97 logging_steps=training_config['logging_steps'], 98 save_steps=training_config['save_steps'], 99 eval_steps=training_config['eval_steps'], 100 save_total_limit=training_config['save_total_limit'], 101 load_best_model_at_end=training_config['load_best_model_at_end'], 102 metric_for_best_model=training_config['metric_for_best_model'], 103 greater_is_better=training_config['greater_is_better'], 104 logging_dir=output_config['logging_dir'], 105 beta=float(dpo_config['beta']), 106 max_prompt_length=int(dpo_config['max_prompt_length']), 107 max_length=int(dpo_config['max_length']) 108 ) 109 110 # 创建DPO训练器 111 dpo_config = config['dpo'] 112 from trl import DPOConfig 113 114 # 创建DPO配置 115 dpo_training_args = DPOConfig( 116 **training_args.to_dict(), 117 beta=dpo_config['beta'], 118 max_prompt_length=dpo_config['max_prompt_length'], 119 max_length=dpo_config['max_length'], 120 padding_value=0, 121 truncation_mode="keep_end", 122 ) 123 124 dpo_trainer = DPOTrainer( 125 model=model, 126 args=dpo_training_args, 127 train_dataset=train_dataset, 128 eval_dataset=eval_dataset, 129 peft_config=None, # 已经在模型上应用了PEFT 130 ) 131 132 logger.info("DPO训练器创建完成") 133 return dpo_trainer 134 135 def main(): 136 """主函数""" 137 logger.info("开始DPO训练...") 138 139 # 加载配置 140 config = load_config() 141 logger.info("配置文件加载完成") 142 143 # 检查CUDA可用性 144 if torch.cuda.is_available(): 145 logger.info(f"CUDA可用,使用GPU: {torch.cuda.get_device_name()}") 146 logger.info(f"GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") 147 else: 148 logger.warning("CUDA不可用,将使用CPU训练(速度会很慢)") 149 150 # 准备数据集 151 train_dataset, eval_dataset = prepare_dataset(config) 152 153 # 加载模型和分词器 154 model_config = config['model'] 155 hardware_config = config['hardware'] 156 157 model, tokenizer = load_model_and_tokenizer( 158 model_name=model_config['base_model'], 159 use_4bit=hardware_config['use_4bit'], 160 use_8bit=hardware_config['use_8bit'], 161 bf16=hardware_config['bf16'], 162 device_map=hardware_config['device_map'] 163 ) 164 165 # 格式化数据集用于DPO训练 166 logger.info("格式化数据集...") 167 train_dataset = format_dpo_data_for_training(train_dataset, tokenizer, config['dpo']['max_length']) 168 eval_dataset = format_dpo_data_for_training(eval_dataset, tokenizer, config['dpo']['max_length']) 169 logger.info("数据集格式化完成") 170 171 # 应用PEFT配置 172 if model_config['use_peft']: 173 peft_config = create_peft_config( 174 lora_r=model_config['lora_r'], 175 lora_alpha=model_config['lora_alpha'], 176 lora_dropout=model_config['lora_dropout'], 177 target_modules=model_config['target_modules'] 178 ) 179 model = apply_peft_to_model(model, peft_config) 180 181 # 创建DPO训练器 182 dpo_trainer = create_dpo_trainer( 183 model, tokenizer, train_dataset, eval_dataset, config 184 ) 185 186 # 开始训练 187 logger.info("开始训练...") 188 train_result = dpo_trainer.train() 189 190 # 保存训练结果 191 output_dir = config['output']['output_dir'] 192 os.makedirs(output_dir, exist_ok=True) 193 194 # 保存模型和分词器 195 save_model_and_tokenizer(model, tokenizer, output_dir) 196 197 # 保存训练结果 198 dpo_trainer.save_model() 199 dpo_trainer.save_state() 200 201 # 保存训练指标 202 metrics_file = os.path.join(output_dir, "training_metrics.json") 203 with open(metrics_file, 'w') as f: 204 import json 205 json.dump(train_result.metrics, f, indent=2) 206 207 logger.info(f"训练完成!模型已保存到: {output_dir}") 208 logger.info(f"训练指标已保存到: {metrics_file}") 209 210 # 打印最终指标 211 logger.info("最终训练指标:") 212 for key, value in train_result.metrics.items(): 213 logger.info(f" {key}: {value}") 214 215 if __name__ == "__main__": 216 try: 217 main() 218 except Exception as e: 219 logger.error(f"训练过程中出现错误: {e}", exc_info=True) 220 sys.exit(1)