/ train.py
train.py
1 import os 2 os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 3 4 import json 5 import random 6 from collections import defaultdict 7 8 import numpy as np 9 import torch 10 from torch.utils.data import DataLoader 11 from tqdm import tqdm 12 from omegaconf import DictConfig, OmegaConf 13 import hydra 14 15 from speaker_verification.models.ecapa import ECAPA_TDNN 16 from speaker_verification.head.aamsoftmax import AAMSoftmax 17 from dataset.pk_sampler import PKBatchSampler 18 from dataset.dataset import TrainFbankPtDataset, collate_fixed 19 20 from speaker_verification.checkpointing import ModelCfg, build_ckpt, save_ckpt 21 from utils.seed import set_seed 22 from utils.meters import AverageMeter, top1_accuracy, compute_eer 23 from utils.path_utils import _resolve_path 24 25 try: 26 from utils.plot import plot_curves 27 _HAS_PLOT = True 28 except Exception: 29 _HAS_PLOT = False 30 31 32 # ========================= 33 # Validation (verification EER) - 采样验证 34 # ========================= 35 @torch.no_grad() 36 def validate_eer_sampled( 37 model: torch.nn.Module, 38 val_meta_path: str, 39 device: torch.device, 40 crop_frames: int = 400, 41 num_crops: int = 6, 42 max_spk: int = 120, 43 per_spk: int = 3, 44 num_pos: int = 3000, 45 num_neg: int = 3000, 46 seed: int = 1234, 47 ) -> dict: 48 """采样验证 EER,支持多 crop 平均,提升稳定性""" 49 model.eval() 50 rng = random.Random(seed) 51 52 # 读取 meta 53 items = [] 54 base_dir = os.path.dirname(os.path.abspath(val_meta_path)) 55 with open(val_meta_path, "r", encoding="utf-8") as f: 56 for line in f: 57 if not line.strip(): 58 continue 59 j = json.loads(line) 60 spk = str(j["spk"]) 61 p = _resolve_path(j["feat"], base_dir) 62 items.append((spk, p)) 63 64 spk2paths = defaultdict(list) 65 for spk, p in items: 66 spk2paths[spk].append(p) 67 68 # 采样说话人 69 spks = [s for s in spk2paths if len(spk2paths[s]) >= 2] 70 rng.shuffle(spks) 71 spks = spks[:max_spk] 72 73 # 采样 utterances 74 sample_paths = [] 75 sample_spk = [] 76 for s in spks: 77 ps = spk2paths[s][:] 78 rng.shuffle(ps) 79 ps = ps[:per_spk] 80 for p in ps: 81 sample_paths.append(p) 82 sample_spk.append(s) 83 84 # 提取 embedding(逐条 + 多 crop 平均) 85 emb_cache = {} 86 for p in sample_paths: 87 feat = torch.load(p, map_location="cpu") # [T,80] 88 if not torch.is_tensor(feat): 89 feat = torch.tensor(feat) 90 T = feat.size(0) 91 92 embs = [] 93 if T <= crop_frames: 94 x = feat.unsqueeze(0).to(device) 95 e = model(x).squeeze(0).detach().cpu() 96 embs.append(e) 97 else: 98 for _ in range(num_crops): 99 s0 = rng.randint(0, T - crop_frames) 100 chunk = feat[s0 : s0 + crop_frames] 101 x = chunk.unsqueeze(0).to(device) 102 e = model(x).squeeze(0).detach().cpu() 103 embs.append(e) 104 105 e = torch.stack(embs, 0).mean(0) 106 e = e / (e.norm() + 1e-12) 107 emb_cache[p] = e 108 109 # 构造 spk -> idx 110 spk2idx = defaultdict(list) 111 for i, (s, p) in enumerate(zip(sample_spk, sample_paths)): 112 spk2idx[s].append(p) 113 114 spks_with2 = [s for s in spk2idx if len(spk2idx[s]) >= 2] 115 all_spks = list(spk2idx.keys()) 116 117 if len(all_spks) < 2 or len(spks_with2) == 0: 118 return {"eer": 1.0, "pos_mean": 0.0, "neg_mean": 0.0} 119 120 # 采样正负样本对 121 labels, scores = [], [] 122 for _ in range(num_pos): 123 s = rng.choice(spks_with2) 124 p1, p2 = rng.sample(spk2idx[s], 2) 125 sc = float((emb_cache[p1] * emb_cache[p2]).sum().item()) 126 labels.append(1) 127 scores.append(sc) 128 129 for _ in range(num_neg): 130 s1, s2 = rng.sample(all_spks, 2) 131 p1 = rng.choice(spk2idx[s1]) 132 p2 = rng.choice(spk2idx[s2]) 133 sc = float((emb_cache[p1] * emb_cache[p2]).sum().item()) 134 labels.append(0) 135 scores.append(sc) 136 137 eer, _ = compute_eer(labels, scores) 138 139 pos = [s for s, l in zip(scores, labels) if l == 1] 140 neg = [s for s, l in zip(scores, labels) if l == 0] 141 142 return { 143 "eer": eer, 144 "pos_mean": float(np.mean(pos)), 145 "neg_mean": float(np.mean(neg)), 146 } 147 148 149 # ========================= 150 # Train one epoch 151 # ========================= 152 def train_one_epoch(model, head, loader, device, num_classes, optim, scaler, use_amp, params, grad_clip): 153 model.train() 154 head.train() 155 156 loss_meter = AverageMeter() 157 acc_meter = AverageMeter() 158 159 pbar = tqdm(loader, desc="TRAIN", ncols=110) 160 for x, y in pbar: 161 x = x.to(device, non_blocking=True) 162 y = y.to(device, non_blocking=True) 163 164 if y.min().item() < 0 or y.max().item() >= num_classes: 165 raise RuntimeError(f"[TRAIN] label out of range: min={y.min().item()}, max={y.max().item()}, C={num_classes}") 166 167 optim.zero_grad(set_to_none=True) 168 169 with torch.amp.autocast(device_type=device.type, enabled=use_amp): 170 emb = model(x) 171 if not torch.isfinite(emb).all(): 172 raise RuntimeError("[TRAIN] Non-finite embedding detected (NaN/Inf).") 173 loss, logits = head(emb, y) 174 175 if not torch.isfinite(loss).all() or not torch.isfinite(logits).all(): 176 raise RuntimeError("[TRAIN] Non-finite loss/logits detected (NaN/Inf).") 177 178 if use_amp: 179 scaler.scale(loss).backward() 180 scaler.unscale_(optim) 181 torch.nn.utils.clip_grad_norm_(params, grad_clip) 182 scaler.step(optim) 183 scaler.update() 184 else: 185 loss.backward() 186 torch.nn.utils.clip_grad_norm_(params, grad_clip) 187 optim.step() 188 189 acc = top1_accuracy(logits, y) 190 191 bs = y.size(0) 192 loss_meter.update(float(loss.item()), bs) 193 acc_meter.update(float(acc), bs) 194 195 pbar.set_postfix( 196 loss=f"{loss_meter.avg:.4f}", 197 acc=f"{acc_meter.avg:.4f}", 198 lr=f"{optim.param_groups[0]['lr']:.2e}", 199 ) 200 201 return loss_meter.avg, acc_meter.avg 202 203 204 # ========================= 205 # Main 206 # ========================= 207 @hydra.main(version_base=None, config_path="configs", config_name="train") 208 def main(cfg: DictConfig): 209 set_seed(cfg.seed if hasattr(cfg, "seed") else 1234) 210 211 os.makedirs(cfg.out_dir, exist_ok=True) 212 213 cfg_dict = OmegaConf.to_container(cfg, resolve=True) 214 with open(os.path.join(cfg.out_dir, "config.json"), "w", encoding="utf-8") as f: 215 json.dump(cfg_dict, f, ensure_ascii=False, indent=2) 216 217 device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") 218 print(f"Using device: {device}") 219 220 train_ds = TrainFbankPtDataset( 221 cfg.train_list, 222 crop_frames=cfg.get("crop_frames", 200), 223 ) 224 num_classes = train_ds.num_classes 225 print(f"num_classes = {num_classes}") 226 227 train_labels = [y for (y, _) in train_ds.items] 228 229 pk_sampler = PKBatchSampler( 230 train_labels, 231 P=cfg.get("p", 32), 232 K=cfg.get("k", 4), 233 drop_last=True, 234 seed=cfg.get("seed", 1234), 235 ) 236 237 train_loader = DataLoader( 238 train_ds, 239 batch_sampler=pk_sampler, 240 num_workers=cfg.num_workers, 241 collate_fn=collate_fixed, 242 pin_memory=(device.type == "cuda"), 243 ) 244 245 model = ECAPA_TDNN( 246 in_channels=cfg.feat_dim, 247 channels=cfg.channels, 248 embd_dim=cfg.emb_dim, 249 ).to(device) 250 251 head = AAMSoftmax( 252 cfg.emb_dim, 253 num_classes, 254 s=cfg.scale, 255 m=cfg.margin, 256 ).to(device) 257 258 params = list(model.parameters()) + list(head.parameters()) 259 optim = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay) 260 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=cfg.epochs) 261 262 # AMP 263 use_amp = bool(cfg.get("amp", True)) and device.type == "cuda" 264 scaler = torch.amp.GradScaler("cuda", enabled=use_amp) 265 266 history = {"train_loss": [], "train_acc": [], "val_eer": [], "val_pos_mean": [], "val_neg_mean": []} 267 best_val_eer = 1e9 268 269 for epoch in range(1, cfg.epochs + 1): 270 print(f"\n===== Epoch {epoch}/{cfg.epochs} =====") 271 272 # Train 273 train_loss, train_acc = train_one_epoch( 274 model, head, train_loader, device, num_classes, 275 optim, scaler, use_amp, params, cfg.grad_clip, 276 ) 277 scheduler.step() 278 279 torch.cuda.empty_cache() 280 281 # Validation (EER) 282 val_info = validate_eer_sampled( 283 model=model, 284 val_meta_path=cfg.val_list, 285 device=device, 286 crop_frames=cfg.get("crop_frames_val", 400), 287 num_crops=cfg.get("num_crops", 6), 288 max_spk=cfg.get("max_spk", 120), 289 per_spk=cfg.get("per_spk", 3), 290 num_pos=cfg.get("num_pos", 3000), 291 num_neg=cfg.get("num_neg", 3000), 292 seed=cfg.get("seed", 1234), 293 ) 294 val_eer = val_info["eer"] 295 296 history["train_loss"].append(train_loss) 297 history["train_acc"].append(train_acc) 298 history["val_eer"].append(val_eer) 299 history["val_pos_mean"].append(val_info["pos_mean"]) 300 history["val_neg_mean"].append(val_info["neg_mean"]) 301 302 print( 303 f"[Epoch {epoch}] train_loss={train_loss:.4f}, train_acc={train_acc:.4f} | " 304 f"val_EER={val_eer*100:.2f}% (pos={val_info['pos_mean']:.3f}, neg={val_info['neg_mean']:.3f})" 305 ) 306 307 # torch.cuda.empty_cache() 308 309 model_cfg = ModelCfg( 310 channels=cfg.model.channels, 311 emb_dim=cfg.model.emb_dim, 312 feat_dim=cfg.feature.n_mels, 313 sample_rate=cfg.audio.sample_rate, 314 ) 315 316 ckpt = build_ckpt( 317 model=model, 318 head=head, 319 optim=optim, 320 scheduler=scheduler, 321 epoch=epoch, 322 best_eer=val_eer, 323 label_map=train_ds.label_map, 324 model_cfg=model_cfg, 325 extra={"cfg_text": cfg.to_yaml_string() if hasattr(cfg, "to_yaml_string") else None}, 326 ) 327 save_ckpt(os.path.join(cfg.out_dir, "last.pt"), ckpt) 328 329 if _HAS_PLOT: 330 try: 331 plot_curves(cfg.out_dir, history) 332 except Exception as e: 333 print(f"[WARN] plot_curves failed: {e}") 334 335 with open(os.path.join(cfg.out_dir, "history.json"), "w", encoding="utf-8") as f: 336 json.dump(history, f, ensure_ascii=False, indent=2) 337 338 print(f"训练完成!输出目录:{cfg.out_dir}") 339 340 341 if __name__ == "__main__": 342 main()