dpo_train.py
  1  #!/usr/bin/env python3
  2  """
  3  DPO (Direct Preference Optimization) 训练脚本
  4  使用TRL库对Qwen3-4B模型进行fine-tuning
  5  """
  6  
  7  import os
  8  # 设置环境变量抑制TensorFlow AVX警告
  9  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 10  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
 11  
 12  import os
 13  import sys
 14  import logging
 15  import yaml
 16  import torch
 17  from pathlib import Path
 18  from datasets import Dataset
 19  from transformers import AutoTokenizer
 20  from trl import DPOTrainer
 21  
 22  from data_utils import load_dpo_dataset, create_sample_dpo_data, format_dpo_data_for_training
 23  from model_utils import (
 24      load_model_and_tokenizer,
 25      create_peft_config,
 26      apply_peft_to_model,
 27      create_training_arguments,
 28      save_model_and_tokenizer
 29  )
 30  
 31  # 设置日志
 32  logging.basicConfig(
 33      level=logging.INFO,
 34      format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
 35      handlers=[
 36          logging.FileHandler('dpo_training.log'),
 37          logging.StreamHandler(sys.stdout)
 38      ]
 39  )
 40  logger = logging.getLogger(__name__)
 41  
 42  def load_config(config_path: str = "train_config.yaml"):
 43      """加载配置文件"""
 44      with open(config_path, 'r', encoding='utf-8') as f:
 45          config = yaml.safe_load(f)
 46      return config
 47  
 48  def prepare_dataset(config: dict):
 49      """准备训练数据集"""
 50      data_config = config['data']
 51      train_file = data_config['train_file']
 52      
 53      # 如果数据文件不存在,创建示例数据
 54      if not os.path.exists(train_file):
 55          logger.info("数据文件不存在,创建示例数据...")
 56          create_sample_dpo_data(train_file)
 57      
 58      # 加载数据集
 59      dataset = load_dpo_dataset(train_file)
 60      
 61      # 分割训练集和验证集
 62      test_size = data_config.get('test_size', 0.1)
 63      if test_size > 0:
 64          dataset = dataset.train_test_split(test_size=test_size)
 65          train_dataset = dataset['train']
 66          eval_dataset = dataset['test']
 67          logger.info(f"数据集分割完成: 训练集 {len(train_dataset)} 条, 验证集 {len(eval_dataset)} 条")
 68      else:
 69          train_dataset = dataset
 70          eval_dataset = dataset
 71          logger.info(f"使用全部数据作为训练集: {len(train_dataset)} 条")
 72      
 73      return train_dataset, eval_dataset
 74  
 75  def create_dpo_trainer(
 76      model,
 77      tokenizer,
 78      train_dataset: Dataset,
 79      eval_dataset: Dataset,
 80      config: dict
 81  ):
 82      """创建DPO训练器"""
 83      
 84      # 创建训练参数
 85      training_config = config['training']
 86      output_config = config['output']
 87      dpo_config = config['dpo']
 88      
 89      training_args = create_training_arguments(
 90          output_dir=output_config['output_dir'],
 91          num_train_epochs=training_config['num_train_epochs'],
 92          per_device_train_batch_size=training_config['per_device_train_batch_size'],
 93          per_device_eval_batch_size=training_config['per_device_eval_batch_size'],
 94          gradient_accumulation_steps=training_config['gradient_accumulation_steps'],
 95          learning_rate=float(training_config['learning_rate']),
 96          warmup_steps=training_config['warmup_steps'],
 97          logging_steps=training_config['logging_steps'],
 98          save_steps=training_config['save_steps'],
 99          eval_steps=training_config['eval_steps'],
100          save_total_limit=training_config['save_total_limit'],
101          load_best_model_at_end=training_config['load_best_model_at_end'],
102          metric_for_best_model=training_config['metric_for_best_model'],
103          greater_is_better=training_config['greater_is_better'],
104          logging_dir=output_config['logging_dir'],
105          beta=float(dpo_config['beta']),
106          max_prompt_length=int(dpo_config['max_prompt_length']),
107          max_length=int(dpo_config['max_length'])
108      )
109      
110      # 创建DPO训练器
111      dpo_config = config['dpo']
112      from trl import DPOConfig
113      
114      # 创建DPO配置
115      dpo_training_args = DPOConfig(
116          **training_args.to_dict(),
117          beta=dpo_config['beta'],
118          max_prompt_length=dpo_config['max_prompt_length'],
119          max_length=dpo_config['max_length'],
120          padding_value=0,
121          truncation_mode="keep_end",
122      )
123      
124      dpo_trainer = DPOTrainer(
125          model=model,
126          args=dpo_training_args,
127          train_dataset=train_dataset,
128          eval_dataset=eval_dataset,
129          peft_config=None,  # 已经在模型上应用了PEFT
130      )
131      
132      logger.info("DPO训练器创建完成")
133      return dpo_trainer
134  
135  def main():
136      """主函数"""
137      logger.info("开始DPO训练...")
138      
139      # 加载配置
140      config = load_config()
141      logger.info("配置文件加载完成")
142      
143      # 检查CUDA可用性
144      if torch.cuda.is_available():
145          logger.info(f"CUDA可用,使用GPU: {torch.cuda.get_device_name()}")
146          logger.info(f"GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
147      else:
148          logger.warning("CUDA不可用,将使用CPU训练(速度会很慢)")
149      
150      # 准备数据集
151      train_dataset, eval_dataset = prepare_dataset(config)
152      
153      # 加载模型和分词器
154      model_config = config['model']
155      hardware_config = config['hardware']
156      
157      model, tokenizer = load_model_and_tokenizer(
158          model_name=model_config['base_model'],
159          use_4bit=hardware_config['use_4bit'],
160          use_8bit=hardware_config['use_8bit'],
161          bf16=hardware_config['bf16'],
162          device_map=hardware_config['device_map']
163      )
164      
165      # 格式化数据集用于DPO训练
166      logger.info("格式化数据集...")
167      train_dataset = format_dpo_data_for_training(train_dataset, tokenizer, config['dpo']['max_length'])
168      eval_dataset = format_dpo_data_for_training(eval_dataset, tokenizer, config['dpo']['max_length'])
169      logger.info("数据集格式化完成")
170      
171      # 应用PEFT配置
172      if model_config['use_peft']:
173          peft_config = create_peft_config(
174              lora_r=model_config['lora_r'],
175              lora_alpha=model_config['lora_alpha'],
176              lora_dropout=model_config['lora_dropout'],
177              target_modules=model_config['target_modules']
178          )
179          model = apply_peft_to_model(model, peft_config)
180      
181      # 创建DPO训练器
182      dpo_trainer = create_dpo_trainer(
183          model, tokenizer, train_dataset, eval_dataset, config
184      )
185      
186      # 开始训练
187      logger.info("开始训练...")
188      train_result = dpo_trainer.train()
189      
190      # 保存训练结果
191      output_dir = config['output']['output_dir']
192      os.makedirs(output_dir, exist_ok=True)
193      
194      # 保存模型和分词器
195      save_model_and_tokenizer(model, tokenizer, output_dir)
196      
197      # 保存训练结果
198      dpo_trainer.save_model()
199      dpo_trainer.save_state()
200      
201      # 保存训练指标
202      metrics_file = os.path.join(output_dir, "training_metrics.json")
203      with open(metrics_file, 'w') as f:
204          import json
205          json.dump(train_result.metrics, f, indent=2)
206      
207      logger.info(f"训练完成!模型已保存到: {output_dir}")
208      logger.info(f"训练指标已保存到: {metrics_file}")
209      
210      # 打印最终指标
211      logger.info("最终训练指标:")
212      for key, value in train_result.metrics.items():
213          logger.info(f"  {key}: {value}")
214  
215  if __name__ == "__main__":
216      try:
217          main()
218      except Exception as e:
219          logger.error(f"训练过程中出现错误: {e}", exc_info=True)
220          sys.exit(1)