/ verify.py
verify.py
  1  """
  2  Speaker Verification Evaluation Script
  3  使用方式:
  4      python verify.py --val_meta processed/cn_celeb2/val_meta.jsonl --ckpt outputs/best.pt
  5  """
  6  
  7  import os
  8  import json
  9  import random
 10  import argparse
 11  from collections import defaultdict
 12  
 13  import torch
 14  import numpy as np
 15  import matplotlib.pyplot as plt
 16  from tqdm import tqdm
 17  
 18  from utils.meters import compute_eer, roc_points, det_points, recall_at_k, _l2norm
 19  from utils.path_utils import _resolve_path
 20  from speaker_verification.models.ecapa import ECAPA_TDNN
 21  
 22  try:
 23      from sklearn.manifold import TSNE
 24      _HAS_SKLEARN = True
 25  except Exception:
 26      _HAS_SKLEARN = False
 27  
 28  
 29  def read_meta_jsonl(meta_path: str):
 30      meta_path = os.path.abspath(meta_path)
 31      base_dir = os.path.dirname(meta_path)
 32      items = []
 33  
 34      with open(meta_path, "r", encoding="utf-8") as f:
 35          for line in f:
 36              line = line.strip()
 37              if not line:
 38                  continue
 39              j = json.loads(line)
 40              spk = str(j["spk"])
 41              feat = _resolve_path(j["feat"], base_dir)
 42              items.append((spk, feat))
 43      return items
 44  
 45  
 46  # =========================
 47  # Pair building
 48  # =========================
 49  def build_pairs(items, num_pos=3000, num_neg=3000, seed=1234):
 50      random.seed(seed)
 51      spk2paths = defaultdict(list)
 52      for spk, p in items:
 53          spk2paths[spk].append(p)
 54  
 55      spks_with2 = [s for s in spk2paths if len(spk2paths[s]) >= 2]
 56      all_spks = list(spk2paths.keys())
 57  
 58      if len(spks_with2) == 0:
 59          raise RuntimeError("Not enough speakers to generate positive pairs.")
 60  
 61      pairs = []
 62  
 63      # Positive pairs
 64      for _ in range(num_pos):
 65          spk = random.choice(spks_with2)
 66          p1, p2 = random.sample(spk2paths[spk], 2)
 67          pairs.append((1, p1, p2))
 68  
 69      # Negative pairs
 70      for _ in range(num_neg):
 71          s1, s2 = random.sample(all_spks, 2)
 72          p1 = random.choice(spk2paths[s1])
 73          p2 = random.choice(spk2paths[s2])
 74          pairs.append((0, p1, p2))
 75  
 76      random.shuffle(pairs)
 77      return pairs
 78  
 79  
 80  # =========================
 81  # Embedding Extraction
 82  # =========================
 83  @torch.no_grad()
 84  def load_feat_pt(feat_path: str):
 85      if not os.path.exists(feat_path):
 86          return None
 87      try:
 88          feat = torch.load(feat_path, map_location="cpu", weights_only=True)
 89      except:
 90          feat = torch.load(feat_path, map_location="cpu")
 91      if not torch.is_tensor(feat) or feat.dim() != 2:
 92          return None
 93      return feat
 94  
 95  
 96  @torch.no_grad()
 97  def embed_from_feat(model, feat: torch.Tensor, device, crop_frames=400, num_crops=6, seed=1234):
 98      rng = random.Random(seed)
 99      T = feat.size(0)
100  
101      if T <= crop_frames:
102          x = feat.unsqueeze(0).to(device)
103          emb = model(x).squeeze(0).cpu()
104          return _l2norm(emb)
105  
106      embs = []
107      for _ in range(num_crops):
108          s = rng.randint(0, T - crop_frames)
109          chunk = feat[s:s + crop_frames]
110          x = chunk.unsqueeze(0).to(device)
111          embs.append(model(x).squeeze(0).cpu())
112  
113      emb = torch.stack(embs, 0).mean(0)
114      return _l2norm(emb)
115  
116  
117  @torch.no_grad()
118  def embed_from_fbank_pt(model, feat_path: str, device, crop_frames=400, num_crops=6):
119      feat = load_feat_pt(feat_path)
120      if feat is None:
121          return None
122      return embed_from_feat(model, feat, device, crop_frames, num_crops)
123  
124  
125  def cosine_sim(a: torch.Tensor, b: torch.Tensor) -> float:
126      return float((a * b).sum().item())
127  
128  
129  # =========================
130  # t-SNE + Recall@K
131  # =========================
132  @torch.no_grad()
133  def collect_embeddings_for_tsne(model, items, device, max_spk=20, per_spk=25,
134                                  crop_frames=400, num_crops=6, seed=1234):
135      random.seed(seed)
136      spk2paths = defaultdict(list)
137      for spk, p in items:
138          spk2paths[spk].append(p)
139  
140      spks = [s for s in spk2paths if len(spk2paths[s]) >= 2]
141      random.shuffle(spks)
142      spks = spks[:max_spk]
143  
144      X_list, y_list = [], []
145      for spk in spks:
146          paths = random.sample(spk2paths[spk], min(per_spk, len(spk2paths[spk])))
147          for p in paths:
148              emb = embed_from_fbank_pt(model, p, device, crop_frames, num_crops)
149              if emb is not None:
150                  X_list.append(emb.numpy())
151                  y_list.append(spk)
152  
153      if len(X_list) == 0:
154          return None, None
155  
156      uniq = sorted(set(y_list))
157      spk2id = {s: i for i, s in enumerate(uniq)}
158      y = np.array([spk2id[s] for s in y_list], dtype=np.int64)
159  
160      return np.stack(X_list), y
161  
162  
163  # =========================
164  # Main Function
165  # =========================
166  def main(args):
167      os.makedirs(args.out_dir, exist_ok=True)
168      random.seed(args.seed)
169      torch.manual_seed(args.seed)
170  
171      print("=" * 70)
172      print("Speaker Verification Evaluation")
173      print("=" * 70)
174      print(f"VAL_META   : {args.val_meta}")
175      print(f"CHECKPOINT : {args.ckpt}")
176      print(f"OUTPUT DIR : {args.out_dir}")
177      print(f"Crop Frames: {args.crop_frames} | Num Crops: {args.num_crops}")
178      print(f"Pairs      : {args.num_pos} pos + {args.num_neg} neg")
179      print(f"Device     : {args.device}")
180      print("=" * 70)
181  
182      device = torch.device(args.device if torch.cuda.is_available() else "cpu")
183  
184      # 1. Load meta
185      items = read_meta_jsonl(args.val_meta)
186      print(f"Loaded {len(items)} utterances from {len(set(spk for spk, _ in items))} speakers\n")
187  
188      # 2. Build pairs
189      pairs = build_pairs(items, num_pos=args.num_pos, num_neg=args.num_neg, seed=args.seed)
190      print(f"Generated {len(pairs)} pairs\n")
191  
192      # 3. Load model
193      ckpt = torch.load(args.ckpt, map_location="cpu")
194      model = ECAPA_TDNN(
195          in_channels=80,
196          channels=args.channels,
197          embd_dim=args.emb_dim
198      ).to(device)
199  
200      model.load_state_dict(ckpt["model"], strict=True)
201      model.eval()
202      print(f"Model loaded: ECAPA-TDNN (channels={args.channels}, emb_dim={args.emb_dim})\n")
203  
204      # 4. Scoring
205      emb_cache = {}
206      labels, scores = [], []
207      missing = 0
208  
209      for is_same, p1, p2 in tqdm(pairs, desc="Scoring"):
210          if p1 not in emb_cache:
211              emb_cache[p1] = embed_from_fbank_pt(model, p1, device, args.crop_frames, args.num_crops)
212          if p2 not in emb_cache:
213              emb_cache[p2] = embed_from_fbank_pt(model, p2, device, args.crop_frames, args.num_crops)
214  
215          e1 = emb_cache[p1]
216          e2 = emb_cache[p2]
217  
218          if e1 is None or e2 is None:
219              missing += 1
220              continue
221  
222          scores.append(cosine_sim(e1, e2))
223          labels.append(is_same)
224  
225      print(f"\nScoring completed! Used pairs: {len(scores)}, Skipped: {missing}")
226  
227      if len(scores) == 0:
228          print("[ERROR] No valid pairs! Check your feature paths.")
229          return
230  
231      # 5. EER & Metrics
232      eer, th = compute_eer(labels, scores)
233      print(f"\n>>> EER = {eer*100:.3f}%   (threshold ≈ {th:.4f})")
234  
235      pos = [s for s, l in zip(scores, labels) if l == 1]
236      neg = [s for s, l in zip(scores, labels) if l == 0]
237      print(f"Pos mean: {np.mean(pos):.4f} | Neg mean: {np.mean(neg):.4f}")
238  
239      # 6. Save Plots
240      fpr, tpr = roc_points(labels, scores, num_th=200)
241      plt.figure(figsize=(8, 6))
242      plt.plot(fpr, tpr)
243      plt.title(f"ROC Curve (EER = {eer*100:.2f}%)")
244      plt.xlabel("False Positive Rate")
245      plt.ylabel("True Positive Rate")
246      plt.grid(True)
247      plt.savefig(os.path.join(args.out_dir, "roc.png"), dpi=300)
248      plt.close()
249  
250      fars, frrs = det_points(labels, scores, num_th=400)
251      plt.figure(figsize=(8, 6))
252      plt.plot(fars, frrs)
253      plt.title(f"DET Curve (EER = {eer*100:.2f}%)")
254      plt.xlabel("False Acceptance Rate")
255      plt.ylabel("False Rejection Rate")
256      plt.grid(True)
257      plt.savefig(os.path.join(args.out_dir, "det.png"), dpi=300)
258      plt.close()
259  
260      plt.figure(figsize=(9, 6))
261      plt.hist(pos, bins=80, alpha=0.7, label="Same Speaker")
262      plt.hist(neg, bins=80, alpha=0.7, label="Different Speaker")
263      plt.title("Score Distribution")
264      plt.xlabel("Cosine Similarity")
265      plt.ylabel("Count")
266      plt.legend()
267      plt.grid(True)
268      plt.savefig(os.path.join(args.out_dir, "score_hist.png"), dpi=300)
269      plt.close()
270  
271      # 7. t-SNE + Recall@K
272      X, y_tsne = collect_embeddings_for_tsne(
273          model, items, device,
274          max_spk=20, per_spk=25,
275          crop_frames=args.crop_frames,
276          num_crops=args.num_crops,
277          seed=args.seed
278      )
279  
280      if X is not None and _HAS_SKLEARN:
281          tsne = TSNE(n_components=2, perplexity=30, random_state=args.seed, init="pca")
282          Z = tsne.fit_transform(X)
283  
284          plt.figure(figsize=(10, 8))
285          for spk_id in np.unique(y_tsne):
286              mask = (y_tsne == spk_id)
287              plt.scatter(Z[mask, 0], Z[mask, 1], s=12, alpha=0.8)
288          plt.title("t-SNE Visualization of Speaker Embeddings")
289          plt.grid(True)
290          plt.savefig(os.path.join(args.out_dir, "tsne.png"), dpi=300)
291          plt.close()
292  
293          # Recall@K
294          emb_t = torch.from_numpy(X).float()
295          emb_t = emb_t / (emb_t.norm(dim=1, keepdim=True) + 1e-12)
296          lab_t = torch.from_numpy(y_tsne).long()
297          recall = recall_at_k(emb_t, lab_t, ks=(1, 5, 10))
298  
299          print("\nRecall@K (sampled):")
300          for k, v in recall.items():
301              print(f"  R@{k}: {v*100:.2f}%")
302  
303          with open(os.path.join(args.out_dir, "metrics.txt"), "w") as f:
304              f.write(f"EER: {eer*100:.3f}%\n")
305              f.write(f"Threshold: {th:.4f}\n")
306              for k, v in recall.items():
307                  f.write(f"Recall@{k}: {v*100:.2f}%\n")
308  
309      print(f"\n所有结果已保存至: {args.out_dir}")
310      print("文件: roc.png, det.png, score_hist.png, tsne.png, metrics.txt")
311  
312  
313  if __name__ == "__main__":
314      parser = argparse.ArgumentParser(description="Speaker Verification Evaluation Tool")
315  
316      parser.add_argument("--val_meta", type=str, required=True,
317                          help="Path to validation meta.jsonl file")
318  
319      parser.add_argument("--ckpt", type=str, required=True,
320                          help="Path to model checkpoint (.pt)")
321  
322      parser.add_argument("--out_dir", type=str, default="outputs_eval",
323                          help="Directory to save evaluation results (default: outputs_eval)")
324  
325      parser.add_argument("--crop_frames", type=int, default=400,
326                          help="Number of frames per crop (default: 400 ≈ 4秒)")
327  
328      parser.add_argument("--num_crops", type=int, default=6,
329                          help="Number of crops to average (default: 6)")
330  
331      parser.add_argument("--num_pos", type=int, default=3000,
332                          help="Number of positive pairs")
333  
334      parser.add_argument("--num_neg", type=int, default=3000,
335                          help="Number of negative pairs")
336  
337      parser.add_argument("--emb_dim", type=int, default=256,
338                          help="Embedding dimension (must match checkpoint)")
339  
340      parser.add_argument("--channels", type=int, default=512,
341                          help="ECAPA channels (default: 512)")
342  
343      parser.add_argument("--seed", type=int, default=1234,
344                          help="Random seed")
345  
346      parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
347                          help="Device to use")
348  
349      args = parser.parse_args()
350  
351      main(args)