/ finetune.py
finetune.py
1 # finetune.py 2 import os 3 os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 4 5 import json 6 import random 7 from collections import defaultdict 8 9 import numpy as np 10 import torch 11 from torch.utils.data import DataLoader 12 from tqdm import tqdm 13 14 import hydra 15 from omegaconf import DictConfig, OmegaConf 16 17 from speaker_verification.models.ecapa import ECAPA_TDNN 18 from speaker_verification.head.aamsoftmax import AAMSoftmax 19 from speaker_verification.checkpointing import ModelCfg, build_ckpt, save_ckpt, load_ckpt 20 21 from dataset.pk_sampler import PKBatchSampler 22 from dataset.dataset import TrainFbankPtDataset, collate_fixed 23 24 from utils.seed import set_seed 25 from utils.meters import AverageMeter, top1_accuracy, compute_eer 26 from utils.path_utils import _resolve_path 27 28 try: 29 from utils.plot import plot_curves 30 _HAS_PLOT = True 31 except Exception: 32 _HAS_PLOT = False 33 34 35 # ========================= 36 # Validation (verification EER) - 采样验证(和 train.py 一致) 37 # ========================= 38 @torch.no_grad() 39 def validate_eer_sampled( 40 model: torch.nn.Module, 41 val_meta_path: str, 42 device: torch.device, 43 crop_frames: int = 400, 44 num_crops: int = 6, 45 max_spk: int = 120, 46 per_spk: int = 3, 47 num_pos: int = 3000, 48 num_neg: int = 3000, 49 seed: int = 1234, 50 ) -> dict: 51 model.eval() 52 rng = random.Random(seed) 53 54 items = [] 55 base_dir = os.path.dirname(os.path.abspath(val_meta_path)) 56 with open(val_meta_path, "r", encoding="utf-8") as f: 57 for line in f: 58 if not line.strip(): 59 continue 60 j = json.loads(line) 61 spk = str(j["spk"]) 62 p = _resolve_path(j["feat"], base_dir) 63 items.append((spk, p)) 64 65 spk2paths = defaultdict(list) 66 for spk, p in items: 67 spk2paths[spk].append(p) 68 69 spks = [s for s in spk2paths if len(spk2paths[s]) >= 2] 70 rng.shuffle(spks) 71 spks = spks[:max_spk] 72 73 sample_paths, sample_spk = [], [] 74 for s in spks: 75 ps = spk2paths[s][:] 76 rng.shuffle(ps) 77 ps = ps[:per_spk] 78 for p in ps: 79 sample_paths.append(p) 80 sample_spk.append(s) 81 82 emb_cache = {} 83 for p in sample_paths: 84 feat = torch.load(p, map_location="cpu") # [T,80] 85 if not torch.is_tensor(feat): 86 feat = torch.tensor(feat) 87 88 T = feat.size(0) 89 embs = [] 90 if T <= crop_frames: 91 x = feat.unsqueeze(0).to(device) 92 e = model(x).squeeze(0).detach().cpu() 93 embs.append(e) 94 else: 95 for _ in range(num_crops): 96 s0 = rng.randint(0, T - crop_frames) 97 chunk = feat[s0 : s0 + crop_frames] 98 x = chunk.unsqueeze(0).to(device) 99 e = model(x).squeeze(0).detach().cpu() 100 embs.append(e) 101 102 e = torch.stack(embs, 0).mean(0) 103 e = e / (e.norm() + 1e-12) 104 emb_cache[p] = e 105 106 spk2idx = defaultdict(list) 107 for s, p in zip(sample_spk, sample_paths): 108 spk2idx[s].append(p) 109 110 spks_with2 = [s for s in spk2idx if len(spk2idx[s]) >= 2] 111 all_spks = list(spk2idx.keys()) 112 if len(all_spks) < 2 or len(spks_with2) == 0: 113 return {"eer": 1.0, "pos_mean": 0.0, "neg_mean": 0.0} 114 115 labels, scores = [], [] 116 for _ in range(num_pos): 117 s = rng.choice(spks_with2) 118 p1, p2 = rng.sample(spk2idx[s], 2) 119 sc = float((emb_cache[p1] * emb_cache[p2]).sum().item()) 120 labels.append(1) 121 scores.append(sc) 122 123 for _ in range(num_neg): 124 s1, s2 = rng.sample(all_spks, 2) 125 p1 = rng.choice(spk2idx[s1]) 126 p2 = rng.choice(spk2idx[s2]) 127 sc = float((emb_cache[p1] * emb_cache[p2]).sum().item()) 128 labels.append(0) 129 scores.append(sc) 130 131 eer, _ = compute_eer(labels, scores) 132 pos = [s for s, l in zip(scores, labels) if l == 1] 133 neg = [s for s, l in zip(scores, labels) if l == 0] 134 return { 135 "eer": eer, 136 "pos_mean": float(np.mean(pos)), 137 "neg_mean": float(np.mean(neg)), 138 } 139 140 141 def set_requires_grad(module: torch.nn.Module, flag: bool) -> None: 142 for p in module.parameters(): 143 p.requires_grad = flag 144 145 146 def build_optimizer(cfg: DictConfig, model: torch.nn.Module, head: torch.nn.Module, phase: str): 147 """ 148 phase: 149 - "head_only": 只训练 head(backbone 冻结) 150 - "full": 训练 backbone + head(支持不同 lr 倍率) 151 """ 152 lr = float(cfg.lr) 153 wd = float(cfg.weight_decay) 154 155 fin = cfg.get("finetune", {}) 156 backbone_mult = float(fin.get("backbone_lr_mult", 0.5)) 157 head_mult = float(fin.get("head_lr_mult", 1.0)) 158 159 if phase == "head_only": 160 params = [{"params": head.parameters(), "lr": lr * head_mult}] 161 else: 162 params = [ 163 {"params": model.parameters(), "lr": lr * backbone_mult}, 164 {"params": head.parameters(), "lr": lr * head_mult}, 165 ] 166 167 optim = torch.optim.AdamW(params, lr=lr, weight_decay=wd) 168 return optim 169 170 171 def train_one_epoch(model, head, loader, device, num_classes, optim, scaler, use_amp, grad_clip): 172 model.train() 173 head.train() 174 175 loss_meter = AverageMeter() 176 acc_meter = AverageMeter() 177 pbar = tqdm(loader, desc="TRAIN", ncols=110) 178 179 for x, y in pbar: 180 x = x.to(device, non_blocking=True) 181 y = y.to(device, non_blocking=True) 182 183 if y.min().item() < 0 or y.max().item() >= num_classes: 184 raise RuntimeError(f"[TRAIN] label out of range: min={y.min().item()}, max={y.max().item()}, C={num_classes}") 185 186 optim.zero_grad(set_to_none=True) 187 188 with torch.amp.autocast(device_type=device.type, enabled=use_amp): 189 emb = model(x) 190 if not torch.isfinite(emb).all(): 191 raise RuntimeError("[TRAIN] Non-finite embedding detected (NaN/Inf).") 192 193 loss, logits = head(emb, y) 194 if not torch.isfinite(loss).all() or not torch.isfinite(logits).all(): 195 raise RuntimeError("[TRAIN] Non-finite loss/logits detected (NaN/Inf).") 196 197 if use_amp: 198 scaler.scale(loss).backward() 199 scaler.unscale_(optim) 200 torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), grad_clip) 201 scaler.step(optim) 202 scaler.update() 203 else: 204 loss.backward() 205 torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), grad_clip) 206 optim.step() 207 208 acc = top1_accuracy(logits, y) 209 bs = y.size(0) 210 loss_meter.update(float(loss.item()), bs) 211 acc_meter.update(float(acc), bs) 212 213 lr0 = optim.param_groups[0]["lr"] 214 pbar.set_postfix(loss=f"{loss_meter.avg:.4f}", acc=f"{acc_meter.avg:.4f}", lr=f"{lr0:.2e}") 215 216 return loss_meter.avg, acc_meter.avg 217 218 219 def load_backbone_weights(cfg: DictConfig, model: torch.nn.Module, device: torch.device): 220 fin = cfg.get("finetune", {}) 221 model_only = fin.get("pretrained_model_only", "") 222 full_ckpt = fin.get("pretrained_full_ckpt", "") 223 224 if (not model_only) and (not full_ckpt): 225 print("[FINETUNE] No pretrained weights provided. Training from scratch.") 226 return None 227 228 if model_only and full_ckpt: 229 print("[FINETUNE] Both pretrained_model_only and pretrained_full_ckpt are set. " 230 "Prefer pretrained_full_ckpt (has model_cfg).") 231 model_only = "" 232 233 if full_ckpt: 234 ckpt = load_ckpt(full_ckpt, map_location=str(device)) 235 model.load_state_dict(ckpt["model_state"], strict=True) 236 print(f"[FINETUNE] Loaded backbone from full ckpt: {full_ckpt}") 237 return ckpt 238 else: 239 sd = torch.load(model_only, map_location=str(device)) 240 model.load_state_dict(sd, strict=True) 241 print(f"[FINETUNE] Loaded backbone from model-only state_dict: {model_only}") 242 return None 243 244 245 @hydra.main(version_base=None, config_path="configs", config_name="finetune") 246 def main(cfg: DictConfig): 247 set_seed(cfg.seed if hasattr(cfg, "seed") else 1234) 248 249 os.makedirs(cfg.out_dir, exist_ok=True) 250 cfg_dict = OmegaConf.to_container(cfg, resolve=True) 251 with open(os.path.join(cfg.out_dir, "config.json"), "w", encoding="utf-8") as f: 252 json.dump(cfg_dict, f, ensure_ascii=False, indent=2) 253 254 device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") 255 print(f"Using device: {device}") 256 257 # -------- 258 # Dataset 259 # -------- 260 train_ds = TrainFbankPtDataset(cfg.train_list, crop_frames=cfg.get("crop_frames", 200)) 261 num_classes = train_ds.num_classes 262 print(f"num_classes (NEW DATASET) = {num_classes}") 263 264 train_labels = [y for (y, _) in train_ds.items] 265 pk_sampler = PKBatchSampler( 266 train_labels, 267 P=cfg.get("p", 32), 268 K=cfg.get("k", 4), 269 drop_last=True, 270 seed=cfg.get("seed", 1234), 271 ) 272 train_loader = DataLoader( 273 train_ds, 274 batch_sampler=pk_sampler, 275 num_workers=cfg.num_workers, 276 collate_fn=collate_fixed, 277 pin_memory=(device.type == "cuda"), 278 ) 279 280 # -------- 281 # Model 282 # -------- 283 model = ECAPA_TDNN( 284 in_channels=int(cfg.feat_dim), 285 channels=int(cfg.channels), 286 embd_dim=int(cfg.emb_dim), 287 ).to(device) 288 289 ckpt = load_backbone_weights(cfg, model, device) 290 291 head = AAMSoftmax( 292 int(cfg.emb_dim), 293 int(num_classes), 294 s=float(cfg.scale), 295 m=float(cfg.margin), 296 ).to(device) 297 298 # -------- 299 # Freeze schedule 300 # -------- 301 fin = cfg.get("finetune", {}) 302 freeze_epochs = int(fin.get("freeze_backbone_epochs", 3)) 303 if freeze_epochs > 0: 304 set_requires_grad(model, False) 305 set_requires_grad(head, True) 306 print(f"[FINETUNE] Freeze backbone for {freeze_epochs} epochs (train head only).") 307 optim = build_optimizer(cfg, model, head, phase="head_only") 308 else: 309 set_requires_grad(model, True) 310 set_requires_grad(head, True) 311 optim = build_optimizer(cfg, model, head, phase="full") 312 313 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=int(cfg.epochs)) 314 315 use_amp = bool(cfg.get("amp", True)) and device.type == "cuda" 316 scaler = torch.amp.GradScaler("cuda", enabled=use_amp) 317 318 history = {"train_loss": [], "train_acc": [], "val_eer": [], "val_pos_mean": [], "val_neg_mean": []} 319 best_val_eer = 1e9 320 321 for epoch in range(1, int(cfg.epochs) + 1): 322 print(f"\n===== Epoch {epoch}/{cfg.epochs} =====") 323 324 if freeze_epochs > 0 and epoch == freeze_epochs + 1: 325 print("[FINETUNE] Unfreeze backbone and switch to full training.") 326 set_requires_grad(model, True) 327 # 重新建优化器(更干净) 328 optim = build_optimizer(cfg, model, head, phase="full") 329 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=int(cfg.epochs) - epoch + 1) 330 scaler = torch.amp.GradScaler("cuda", enabled=use_amp) 331 332 train_loss, train_acc = train_one_epoch( 333 model=model, 334 head=head, 335 loader=train_loader, 336 device=device, 337 num_classes=num_classes, 338 optim=optim, 339 scaler=scaler, 340 use_amp=use_amp, 341 grad_clip=float(cfg.grad_clip), 342 ) 343 344 scheduler.step() 345 torch.cuda.empty_cache() 346 347 val_info = validate_eer_sampled( 348 model=model, 349 val_meta_path=cfg.val_list, 350 device=device, 351 crop_frames=int(cfg.get("crop_frames_val", 400)), 352 num_crops=int(cfg.get("num_crops", 6)), 353 max_spk=int(cfg.get("max_spk", 120)), 354 per_spk=int(cfg.get("per_spk", 3)), 355 num_pos=int(cfg.get("num_pos", 3000)), 356 num_neg=int(cfg.get("num_neg", 3000)), 357 seed=int(cfg.get("seed", 1234)), 358 ) 359 val_eer = float(val_info["eer"]) 360 361 history["train_loss"].append(float(train_loss)) 362 history["train_acc"].append(float(train_acc)) 363 history["val_eer"].append(val_eer) 364 history["val_pos_mean"].append(float(val_info["pos_mean"])) 365 history["val_neg_mean"].append(float(val_info["neg_mean"])) 366 367 print( 368 f"[Epoch {epoch}] train_loss={train_loss:.4f}, train_acc={train_acc:.4f} | " 369 f"val_EER={val_eer*100:.2f}% (pos={val_info['pos_mean']:.3f}, neg={val_info['neg_mean']:.3f})" 370 ) 371 372 model_cfg = ModelCfg( 373 channels=int(cfg.channels), 374 emb_dim=int(cfg.emb_dim), 375 feat_dim=int(cfg.feat_dim), 376 sample_rate=int(cfg.get("sample_rate", 16000)), 377 ) 378 ckpt_out = build_ckpt( 379 model=model, 380 head=head, 381 optim=optim, 382 scheduler=scheduler, 383 epoch=epoch, 384 best_eer=val_eer, 385 label_map=train_ds.label_map, 386 model_cfg=model_cfg, 387 extra={"cfg_text": cfg.to_yaml_string() if hasattr(cfg, "to_yaml_string") else None}, 388 ) 389 save_ckpt(os.path.join(cfg.out_dir, "last.pt"), ckpt_out) 390 391 if val_eer < best_val_eer: 392 best_val_eer = val_eer 393 save_ckpt(os.path.join(cfg.out_dir, "best.pt"), ckpt_out) 394 print(f"[FINETUNE] New best EER: {best_val_eer*100:.2f}% -> saved best.pt") 395 396 if _HAS_PLOT: 397 try: 398 plot_curves(cfg.out_dir, history) 399 except Exception as e: 400 print(f"[WARN] plot_curves failed: {e}") 401 402 with open(os.path.join(cfg.out_dir, "history.json"), "w", encoding="utf-8") as f: 403 json.dump(history, f, ensure_ascii=False, indent=2) 404 405 print(f"Finetune done! out_dir = {cfg.out_dir}") 406 407 408 if __name__ == "__main__": 409 main()