/ train.py
train.py
  1  import os
  2  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
  3  
  4  import json
  5  import random
  6  from collections import defaultdict
  7  
  8  import numpy as np
  9  import torch
 10  from torch.utils.data import DataLoader
 11  from tqdm import tqdm
 12  from omegaconf import DictConfig, OmegaConf
 13  import hydra
 14  
 15  from speaker_verification.models.ecapa import ECAPA_TDNN
 16  from speaker_verification.head.aamsoftmax import AAMSoftmax
 17  from dataset.pk_sampler import PKBatchSampler
 18  from dataset.dataset import TrainFbankPtDataset, collate_fixed
 19  
 20  from speaker_verification.checkpointing import ModelCfg, build_ckpt, save_ckpt
 21  from utils.seed import set_seed
 22  from utils.meters import AverageMeter, top1_accuracy, compute_eer
 23  from utils.path_utils import _resolve_path
 24  
 25  try:
 26      from utils.plot import plot_curves
 27      _HAS_PLOT = True
 28  except Exception:
 29      _HAS_PLOT = False
 30  
 31  
 32  # =========================
 33  # Validation (verification EER) - 采样验证
 34  # =========================
 35  @torch.no_grad()
 36  def validate_eer_sampled(
 37      model: torch.nn.Module,
 38      val_meta_path: str,
 39      device: torch.device,
 40      crop_frames: int = 400,
 41      num_crops: int = 6,
 42      max_spk: int = 120,
 43      per_spk: int = 3,
 44      num_pos: int = 3000,
 45      num_neg: int = 3000,
 46      seed: int = 1234,
 47  ) -> dict:
 48      """采样验证 EER,支持多 crop 平均,提升稳定性"""
 49      model.eval()
 50      rng = random.Random(seed)
 51  
 52      # 读取 meta
 53      items = []
 54      base_dir = os.path.dirname(os.path.abspath(val_meta_path))
 55      with open(val_meta_path, "r", encoding="utf-8") as f:
 56          for line in f:
 57              if not line.strip():
 58                  continue
 59              j = json.loads(line)
 60              spk = str(j["spk"])
 61              p = _resolve_path(j["feat"], base_dir)
 62              items.append((spk, p))
 63  
 64      spk2paths = defaultdict(list)
 65      for spk, p in items:
 66          spk2paths[spk].append(p)
 67  
 68      # 采样说话人
 69      spks = [s for s in spk2paths if len(spk2paths[s]) >= 2]
 70      rng.shuffle(spks)
 71      spks = spks[:max_spk]
 72  
 73      # 采样 utterances
 74      sample_paths = []
 75      sample_spk = []
 76      for s in spks:
 77          ps = spk2paths[s][:]
 78          rng.shuffle(ps)
 79          ps = ps[:per_spk]
 80          for p in ps:
 81              sample_paths.append(p)
 82              sample_spk.append(s)
 83  
 84      # 提取 embedding(逐条 + 多 crop 平均)
 85      emb_cache = {}
 86      for p in sample_paths:
 87          feat = torch.load(p, map_location="cpu")  # [T,80]
 88          if not torch.is_tensor(feat):
 89              feat = torch.tensor(feat)
 90          T = feat.size(0)
 91  
 92          embs = []
 93          if T <= crop_frames:
 94              x = feat.unsqueeze(0).to(device)
 95              e = model(x).squeeze(0).detach().cpu()
 96              embs.append(e)
 97          else:
 98              for _ in range(num_crops):
 99                  s0 = rng.randint(0, T - crop_frames)
100                  chunk = feat[s0 : s0 + crop_frames]
101                  x = chunk.unsqueeze(0).to(device)
102                  e = model(x).squeeze(0).detach().cpu()
103                  embs.append(e)
104  
105          e = torch.stack(embs, 0).mean(0)
106          e = e / (e.norm() + 1e-12)
107          emb_cache[p] = e
108  
109      # 构造 spk -> idx
110      spk2idx = defaultdict(list)
111      for i, (s, p) in enumerate(zip(sample_spk, sample_paths)):
112          spk2idx[s].append(p)
113  
114      spks_with2 = [s for s in spk2idx if len(spk2idx[s]) >= 2]
115      all_spks = list(spk2idx.keys())
116  
117      if len(all_spks) < 2 or len(spks_with2) == 0:
118          return {"eer": 1.0, "pos_mean": 0.0, "neg_mean": 0.0}
119  
120      # 采样正负样本对
121      labels, scores = [], []
122      for _ in range(num_pos):
123          s = rng.choice(spks_with2)
124          p1, p2 = rng.sample(spk2idx[s], 2)
125          sc = float((emb_cache[p1] * emb_cache[p2]).sum().item())
126          labels.append(1)
127          scores.append(sc)
128  
129      for _ in range(num_neg):
130          s1, s2 = rng.sample(all_spks, 2)
131          p1 = rng.choice(spk2idx[s1])
132          p2 = rng.choice(spk2idx[s2])
133          sc = float((emb_cache[p1] * emb_cache[p2]).sum().item())
134          labels.append(0)
135          scores.append(sc)
136  
137      eer, _ = compute_eer(labels, scores)
138  
139      pos = [s for s, l in zip(scores, labels) if l == 1]
140      neg = [s for s, l in zip(scores, labels) if l == 0]
141  
142      return {
143          "eer": eer,
144          "pos_mean": float(np.mean(pos)),
145          "neg_mean": float(np.mean(neg)),
146      }
147  
148  
149  # =========================
150  # Train one epoch
151  # =========================
152  def train_one_epoch(model, head, loader, device, num_classes, optim, scaler, use_amp, params, grad_clip):
153      model.train()
154      head.train()
155  
156      loss_meter = AverageMeter()
157      acc_meter = AverageMeter()
158  
159      pbar = tqdm(loader, desc="TRAIN", ncols=110)
160      for x, y in pbar:
161          x = x.to(device, non_blocking=True)
162          y = y.to(device, non_blocking=True)
163  
164          if y.min().item() < 0 or y.max().item() >= num_classes:
165              raise RuntimeError(f"[TRAIN] label out of range: min={y.min().item()}, max={y.max().item()}, C={num_classes}")
166  
167          optim.zero_grad(set_to_none=True)
168  
169          with torch.amp.autocast(device_type=device.type, enabled=use_amp):
170              emb = model(x)
171              if not torch.isfinite(emb).all():
172                  raise RuntimeError("[TRAIN] Non-finite embedding detected (NaN/Inf).")
173              loss, logits = head(emb, y)
174  
175          if not torch.isfinite(loss).all() or not torch.isfinite(logits).all():
176              raise RuntimeError("[TRAIN] Non-finite loss/logits detected (NaN/Inf).")
177  
178          if use_amp:
179              scaler.scale(loss).backward()
180              scaler.unscale_(optim)
181              torch.nn.utils.clip_grad_norm_(params, grad_clip)
182              scaler.step(optim)
183              scaler.update()
184          else:
185              loss.backward()
186              torch.nn.utils.clip_grad_norm_(params, grad_clip)
187              optim.step()
188  
189          acc = top1_accuracy(logits, y)
190  
191          bs = y.size(0)
192          loss_meter.update(float(loss.item()), bs)
193          acc_meter.update(float(acc), bs)
194  
195          pbar.set_postfix(
196              loss=f"{loss_meter.avg:.4f}",
197              acc=f"{acc_meter.avg:.4f}",
198              lr=f"{optim.param_groups[0]['lr']:.2e}",
199          )
200  
201      return loss_meter.avg, acc_meter.avg
202  
203  
204  # =========================
205  # Main
206  # =========================
207  @hydra.main(version_base=None, config_path="configs", config_name="train")
208  def main(cfg: DictConfig):
209      set_seed(cfg.seed if hasattr(cfg, "seed") else 1234)
210  
211      os.makedirs(cfg.out_dir, exist_ok=True)
212  
213      cfg_dict = OmegaConf.to_container(cfg, resolve=True)
214      with open(os.path.join(cfg.out_dir, "config.json"), "w", encoding="utf-8") as f:
215          json.dump(cfg_dict, f, ensure_ascii=False, indent=2)
216  
217      device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
218      print(f"Using device: {device}")
219  
220      train_ds = TrainFbankPtDataset(
221          cfg.train_list,
222          crop_frames=cfg.get("crop_frames", 200),
223      )
224      num_classes = train_ds.num_classes
225      print(f"num_classes = {num_classes}")
226  
227      train_labels = [y for (y, _) in train_ds.items]
228  
229      pk_sampler = PKBatchSampler(
230          train_labels,
231          P=cfg.get("p", 32),
232          K=cfg.get("k", 4),
233          drop_last=True,
234          seed=cfg.get("seed", 1234),
235      )
236  
237      train_loader = DataLoader(
238          train_ds,
239          batch_sampler=pk_sampler,
240          num_workers=cfg.num_workers,
241          collate_fn=collate_fixed,
242          pin_memory=(device.type == "cuda"),
243      )
244  
245      model = ECAPA_TDNN(
246          in_channels=cfg.feat_dim,
247          channels=cfg.channels,
248          embd_dim=cfg.emb_dim,
249      ).to(device)
250  
251      head = AAMSoftmax(
252          cfg.emb_dim,
253          num_classes,
254          s=cfg.scale,
255          m=cfg.margin,
256      ).to(device)
257  
258      params = list(model.parameters()) + list(head.parameters())
259      optim = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
260      scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=cfg.epochs)
261  
262      # AMP
263      use_amp = bool(cfg.get("amp", True)) and device.type == "cuda"
264      scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
265  
266      history = {"train_loss": [], "train_acc": [], "val_eer": [], "val_pos_mean": [], "val_neg_mean": []}
267      best_val_eer = 1e9
268  
269      for epoch in range(1, cfg.epochs + 1):
270          print(f"\n===== Epoch {epoch}/{cfg.epochs} =====")
271  
272          # Train
273          train_loss, train_acc = train_one_epoch(
274              model, head, train_loader, device, num_classes,
275              optim, scaler, use_amp, params, cfg.grad_clip,
276          )
277          scheduler.step()
278  
279          torch.cuda.empty_cache()
280  
281          # Validation (EER)
282          val_info = validate_eer_sampled(
283              model=model,
284              val_meta_path=cfg.val_list,
285              device=device,
286              crop_frames=cfg.get("crop_frames_val", 400),
287              num_crops=cfg.get("num_crops", 6),
288              max_spk=cfg.get("max_spk", 120),
289              per_spk=cfg.get("per_spk", 3),
290              num_pos=cfg.get("num_pos", 3000),
291              num_neg=cfg.get("num_neg", 3000),
292              seed=cfg.get("seed", 1234),
293          )
294          val_eer = val_info["eer"]
295  
296          history["train_loss"].append(train_loss)
297          history["train_acc"].append(train_acc)
298          history["val_eer"].append(val_eer)
299          history["val_pos_mean"].append(val_info["pos_mean"])
300          history["val_neg_mean"].append(val_info["neg_mean"])
301  
302          print(
303              f"[Epoch {epoch}] train_loss={train_loss:.4f}, train_acc={train_acc:.4f} | "
304              f"val_EER={val_eer*100:.2f}% (pos={val_info['pos_mean']:.3f}, neg={val_info['neg_mean']:.3f})"
305          )
306  
307          # torch.cuda.empty_cache()
308  
309          model_cfg = ModelCfg(
310              channels=cfg.model.channels,
311              emb_dim=cfg.model.emb_dim,
312              feat_dim=cfg.feature.n_mels,
313              sample_rate=cfg.audio.sample_rate,
314          )
315  
316          ckpt = build_ckpt(
317              model=model,
318              head=head,
319              optim=optim,
320              scheduler=scheduler,
321              epoch=epoch,
322              best_eer=val_eer,
323              label_map=train_ds.label_map,
324              model_cfg=model_cfg,
325              extra={"cfg_text": cfg.to_yaml_string() if hasattr(cfg, "to_yaml_string") else None},
326          )
327          save_ckpt(os.path.join(cfg.out_dir, "last.pt"), ckpt)
328          
329          if _HAS_PLOT:
330              try:
331                  plot_curves(cfg.out_dir, history)
332              except Exception as e:
333                  print(f"[WARN] plot_curves failed: {e}")
334  
335      with open(os.path.join(cfg.out_dir, "history.json"), "w", encoding="utf-8") as f:
336          json.dump(history, f, ensure_ascii=False, indent=2)
337  
338      print(f"训练完成!输出目录:{cfg.out_dir}")
339  
340  
341  if __name__ == "__main__":
342      main()