/ finetune.py
finetune.py
  1  # finetune.py
  2  import os
  3  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
  4  
  5  import json
  6  import random
  7  from collections import defaultdict
  8  
  9  import numpy as np
 10  import torch
 11  from torch.utils.data import DataLoader
 12  from tqdm import tqdm
 13  
 14  import hydra
 15  from omegaconf import DictConfig, OmegaConf
 16  
 17  from speaker_verification.models.ecapa import ECAPA_TDNN
 18  from speaker_verification.head.aamsoftmax import AAMSoftmax
 19  from speaker_verification.checkpointing import ModelCfg, build_ckpt, save_ckpt, load_ckpt
 20  
 21  from dataset.pk_sampler import PKBatchSampler
 22  from dataset.dataset import TrainFbankPtDataset, collate_fixed
 23  
 24  from utils.seed import set_seed
 25  from utils.meters import AverageMeter, top1_accuracy, compute_eer
 26  from utils.path_utils import _resolve_path
 27  
 28  try:
 29      from utils.plot import plot_curves
 30      _HAS_PLOT = True
 31  except Exception:
 32      _HAS_PLOT = False
 33  
 34  
 35  # =========================
 36  # Validation (verification EER) - 采样验证(和 train.py 一致)
 37  # =========================
 38  @torch.no_grad()
 39  def validate_eer_sampled(
 40      model: torch.nn.Module,
 41      val_meta_path: str,
 42      device: torch.device,
 43      crop_frames: int = 400,
 44      num_crops: int = 6,
 45      max_spk: int = 120,
 46      per_spk: int = 3,
 47      num_pos: int = 3000,
 48      num_neg: int = 3000,
 49      seed: int = 1234,
 50  ) -> dict:
 51      model.eval()
 52      rng = random.Random(seed)
 53  
 54      items = []
 55      base_dir = os.path.dirname(os.path.abspath(val_meta_path))
 56      with open(val_meta_path, "r", encoding="utf-8") as f:
 57          for line in f:
 58              if not line.strip():
 59                  continue
 60              j = json.loads(line)
 61              spk = str(j["spk"])
 62              p = _resolve_path(j["feat"], base_dir)
 63              items.append((spk, p))
 64  
 65      spk2paths = defaultdict(list)
 66      for spk, p in items:
 67          spk2paths[spk].append(p)
 68  
 69      spks = [s for s in spk2paths if len(spk2paths[s]) >= 2]
 70      rng.shuffle(spks)
 71      spks = spks[:max_spk]
 72  
 73      sample_paths, sample_spk = [], []
 74      for s in spks:
 75          ps = spk2paths[s][:]
 76          rng.shuffle(ps)
 77          ps = ps[:per_spk]
 78          for p in ps:
 79              sample_paths.append(p)
 80              sample_spk.append(s)
 81  
 82      emb_cache = {}
 83      for p in sample_paths:
 84          feat = torch.load(p, map_location="cpu")  # [T,80]
 85          if not torch.is_tensor(feat):
 86              feat = torch.tensor(feat)
 87  
 88          T = feat.size(0)
 89          embs = []
 90          if T <= crop_frames:
 91              x = feat.unsqueeze(0).to(device)
 92              e = model(x).squeeze(0).detach().cpu()
 93              embs.append(e)
 94          else:
 95              for _ in range(num_crops):
 96                  s0 = rng.randint(0, T - crop_frames)
 97                  chunk = feat[s0 : s0 + crop_frames]
 98                  x = chunk.unsqueeze(0).to(device)
 99                  e = model(x).squeeze(0).detach().cpu()
100                  embs.append(e)
101  
102          e = torch.stack(embs, 0).mean(0)
103          e = e / (e.norm() + 1e-12)
104          emb_cache[p] = e
105  
106      spk2idx = defaultdict(list)
107      for s, p in zip(sample_spk, sample_paths):
108          spk2idx[s].append(p)
109  
110      spks_with2 = [s for s in spk2idx if len(spk2idx[s]) >= 2]
111      all_spks = list(spk2idx.keys())
112      if len(all_spks) < 2 or len(spks_with2) == 0:
113          return {"eer": 1.0, "pos_mean": 0.0, "neg_mean": 0.0}
114  
115      labels, scores = [], []
116      for _ in range(num_pos):
117          s = rng.choice(spks_with2)
118          p1, p2 = rng.sample(spk2idx[s], 2)
119          sc = float((emb_cache[p1] * emb_cache[p2]).sum().item())
120          labels.append(1)
121          scores.append(sc)
122  
123      for _ in range(num_neg):
124          s1, s2 = rng.sample(all_spks, 2)
125          p1 = rng.choice(spk2idx[s1])
126          p2 = rng.choice(spk2idx[s2])
127          sc = float((emb_cache[p1] * emb_cache[p2]).sum().item())
128          labels.append(0)
129          scores.append(sc)
130  
131      eer, _ = compute_eer(labels, scores)
132      pos = [s for s, l in zip(scores, labels) if l == 1]
133      neg = [s for s, l in zip(scores, labels) if l == 0]
134      return {
135          "eer": eer,
136          "pos_mean": float(np.mean(pos)),
137          "neg_mean": float(np.mean(neg)),
138      }
139  
140  
141  def set_requires_grad(module: torch.nn.Module, flag: bool) -> None:
142      for p in module.parameters():
143          p.requires_grad = flag
144  
145  
146  def build_optimizer(cfg: DictConfig, model: torch.nn.Module, head: torch.nn.Module, phase: str):
147      """
148      phase:
149        - "head_only": 只训练 head(backbone 冻结)
150        - "full": 训练 backbone + head(支持不同 lr 倍率)
151      """
152      lr = float(cfg.lr)
153      wd = float(cfg.weight_decay)
154  
155      fin = cfg.get("finetune", {})
156      backbone_mult = float(fin.get("backbone_lr_mult", 0.5))
157      head_mult = float(fin.get("head_lr_mult", 1.0))
158  
159      if phase == "head_only":
160          params = [{"params": head.parameters(), "lr": lr * head_mult}]
161      else:
162          params = [
163              {"params": model.parameters(), "lr": lr * backbone_mult},
164              {"params": head.parameters(), "lr": lr * head_mult},
165          ]
166  
167      optim = torch.optim.AdamW(params, lr=lr, weight_decay=wd)
168      return optim
169  
170  
171  def train_one_epoch(model, head, loader, device, num_classes, optim, scaler, use_amp, grad_clip):
172      model.train()
173      head.train()
174  
175      loss_meter = AverageMeter()
176      acc_meter = AverageMeter()
177      pbar = tqdm(loader, desc="TRAIN", ncols=110)
178  
179      for x, y in pbar:
180          x = x.to(device, non_blocking=True)
181          y = y.to(device, non_blocking=True)
182  
183          if y.min().item() < 0 or y.max().item() >= num_classes:
184              raise RuntimeError(f"[TRAIN] label out of range: min={y.min().item()}, max={y.max().item()}, C={num_classes}")
185  
186          optim.zero_grad(set_to_none=True)
187  
188          with torch.amp.autocast(device_type=device.type, enabled=use_amp):
189              emb = model(x)
190              if not torch.isfinite(emb).all():
191                  raise RuntimeError("[TRAIN] Non-finite embedding detected (NaN/Inf).")
192  
193              loss, logits = head(emb, y)
194              if not torch.isfinite(loss).all() or not torch.isfinite(logits).all():
195                  raise RuntimeError("[TRAIN] Non-finite loss/logits detected (NaN/Inf).")
196  
197          if use_amp:
198              scaler.scale(loss).backward()
199              scaler.unscale_(optim)
200              torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), grad_clip)
201              scaler.step(optim)
202              scaler.update()
203          else:
204              loss.backward()
205              torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(head.parameters()), grad_clip)
206              optim.step()
207  
208          acc = top1_accuracy(logits, y)
209          bs = y.size(0)
210          loss_meter.update(float(loss.item()), bs)
211          acc_meter.update(float(acc), bs)
212  
213          lr0 = optim.param_groups[0]["lr"]
214          pbar.set_postfix(loss=f"{loss_meter.avg:.4f}", acc=f"{acc_meter.avg:.4f}", lr=f"{lr0:.2e}")
215  
216      return loss_meter.avg, acc_meter.avg
217  
218  
219  def load_backbone_weights(cfg: DictConfig, model: torch.nn.Module, device: torch.device):
220      fin = cfg.get("finetune", {})
221      model_only = fin.get("pretrained_model_only", "")
222      full_ckpt = fin.get("pretrained_full_ckpt", "")
223  
224      if (not model_only) and (not full_ckpt):
225          print("[FINETUNE] No pretrained weights provided. Training from scratch.")
226          return None
227  
228      if model_only and full_ckpt:
229          print("[FINETUNE] Both pretrained_model_only and pretrained_full_ckpt are set. "
230                "Prefer pretrained_full_ckpt (has model_cfg).")
231          model_only = ""
232  
233      if full_ckpt:
234          ckpt = load_ckpt(full_ckpt, map_location=str(device))
235          model.load_state_dict(ckpt["model_state"], strict=True)
236          print(f"[FINETUNE] Loaded backbone from full ckpt: {full_ckpt}")
237          return ckpt
238      else:
239          sd = torch.load(model_only, map_location=str(device))
240          model.load_state_dict(sd, strict=True)
241          print(f"[FINETUNE] Loaded backbone from model-only state_dict: {model_only}")
242          return None
243  
244  
245  @hydra.main(version_base=None, config_path="configs", config_name="finetune")
246  def main(cfg: DictConfig):
247      set_seed(cfg.seed if hasattr(cfg, "seed") else 1234)
248  
249      os.makedirs(cfg.out_dir, exist_ok=True)
250      cfg_dict = OmegaConf.to_container(cfg, resolve=True)
251      with open(os.path.join(cfg.out_dir, "config.json"), "w", encoding="utf-8") as f:
252          json.dump(cfg_dict, f, ensure_ascii=False, indent=2)
253  
254      device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
255      print(f"Using device: {device}")
256  
257      # --------
258      # Dataset
259      # --------
260      train_ds = TrainFbankPtDataset(cfg.train_list, crop_frames=cfg.get("crop_frames", 200))
261      num_classes = train_ds.num_classes
262      print(f"num_classes (NEW DATASET) = {num_classes}")
263  
264      train_labels = [y for (y, _) in train_ds.items]
265      pk_sampler = PKBatchSampler(
266          train_labels,
267          P=cfg.get("p", 32),
268          K=cfg.get("k", 4),
269          drop_last=True,
270          seed=cfg.get("seed", 1234),
271      )
272      train_loader = DataLoader(
273          train_ds,
274          batch_sampler=pk_sampler,
275          num_workers=cfg.num_workers,
276          collate_fn=collate_fixed,
277          pin_memory=(device.type == "cuda"),
278      )
279  
280      # --------
281      # Model
282      # --------
283      model = ECAPA_TDNN(
284          in_channels=int(cfg.feat_dim),
285          channels=int(cfg.channels),
286          embd_dim=int(cfg.emb_dim),
287      ).to(device)
288  
289      ckpt = load_backbone_weights(cfg, model, device)
290  
291      head = AAMSoftmax(
292          int(cfg.emb_dim),
293          int(num_classes),
294          s=float(cfg.scale),
295          m=float(cfg.margin),
296      ).to(device)
297  
298      # --------
299      # Freeze schedule
300      # --------
301      fin = cfg.get("finetune", {})
302      freeze_epochs = int(fin.get("freeze_backbone_epochs", 3))
303      if freeze_epochs > 0:
304          set_requires_grad(model, False)
305          set_requires_grad(head, True)
306          print(f"[FINETUNE] Freeze backbone for {freeze_epochs} epochs (train head only).")
307          optim = build_optimizer(cfg, model, head, phase="head_only")
308      else:
309          set_requires_grad(model, True)
310          set_requires_grad(head, True)
311          optim = build_optimizer(cfg, model, head, phase="full")
312  
313      scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=int(cfg.epochs))
314  
315      use_amp = bool(cfg.get("amp", True)) and device.type == "cuda"
316      scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
317  
318      history = {"train_loss": [], "train_acc": [], "val_eer": [], "val_pos_mean": [], "val_neg_mean": []}
319      best_val_eer = 1e9
320  
321      for epoch in range(1, int(cfg.epochs) + 1):
322          print(f"\n===== Epoch {epoch}/{cfg.epochs} =====")
323  
324          if freeze_epochs > 0 and epoch == freeze_epochs + 1:
325              print("[FINETUNE] Unfreeze backbone and switch to full training.")
326              set_requires_grad(model, True)
327              # 重新建优化器(更干净)
328              optim = build_optimizer(cfg, model, head, phase="full")
329              scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=int(cfg.epochs) - epoch + 1)
330              scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
331  
332          train_loss, train_acc = train_one_epoch(
333              model=model,
334              head=head,
335              loader=train_loader,
336              device=device,
337              num_classes=num_classes,
338              optim=optim,
339              scaler=scaler,
340              use_amp=use_amp,
341              grad_clip=float(cfg.grad_clip),
342          )
343  
344          scheduler.step()
345          torch.cuda.empty_cache()
346  
347          val_info = validate_eer_sampled(
348              model=model,
349              val_meta_path=cfg.val_list,
350              device=device,
351              crop_frames=int(cfg.get("crop_frames_val", 400)),
352              num_crops=int(cfg.get("num_crops", 6)),
353              max_spk=int(cfg.get("max_spk", 120)),
354              per_spk=int(cfg.get("per_spk", 3)),
355              num_pos=int(cfg.get("num_pos", 3000)),
356              num_neg=int(cfg.get("num_neg", 3000)),
357              seed=int(cfg.get("seed", 1234)),
358          )
359          val_eer = float(val_info["eer"])
360  
361          history["train_loss"].append(float(train_loss))
362          history["train_acc"].append(float(train_acc))
363          history["val_eer"].append(val_eer)
364          history["val_pos_mean"].append(float(val_info["pos_mean"]))
365          history["val_neg_mean"].append(float(val_info["neg_mean"]))
366  
367          print(
368              f"[Epoch {epoch}] train_loss={train_loss:.4f}, train_acc={train_acc:.4f} | "
369              f"val_EER={val_eer*100:.2f}% (pos={val_info['pos_mean']:.3f}, neg={val_info['neg_mean']:.3f})"
370          )
371  
372          model_cfg = ModelCfg(
373              channels=int(cfg.channels),
374              emb_dim=int(cfg.emb_dim),
375              feat_dim=int(cfg.feat_dim),
376              sample_rate=int(cfg.get("sample_rate", 16000)),
377          )
378          ckpt_out = build_ckpt(
379              model=model,
380              head=head,
381              optim=optim,
382              scheduler=scheduler,
383              epoch=epoch,
384              best_eer=val_eer,
385              label_map=train_ds.label_map,
386              model_cfg=model_cfg,
387              extra={"cfg_text": cfg.to_yaml_string() if hasattr(cfg, "to_yaml_string") else None},
388          )
389          save_ckpt(os.path.join(cfg.out_dir, "last.pt"), ckpt_out)
390  
391          if val_eer < best_val_eer:
392              best_val_eer = val_eer
393              save_ckpt(os.path.join(cfg.out_dir, "best.pt"), ckpt_out)
394              print(f"[FINETUNE] New best EER: {best_val_eer*100:.2f}% -> saved best.pt")
395  
396          if _HAS_PLOT:
397              try:
398                  plot_curves(cfg.out_dir, history)
399              except Exception as e:
400                  print(f"[WARN] plot_curves failed: {e}")
401  
402          with open(os.path.join(cfg.out_dir, "history.json"), "w", encoding="utf-8") as f:
403              json.dump(history, f, ensure_ascii=False, indent=2)
404  
405      print(f"Finetune done! out_dir = {cfg.out_dir}")
406  
407  
408  if __name__ == "__main__":
409      main()