/ verify.py
verify.py
1 """ 2 Speaker Verification Evaluation Script 3 使用方式: 4 python verify.py --val_meta processed/cn_celeb2/val_meta.jsonl --ckpt outputs/best.pt 5 """ 6 7 import os 8 import json 9 import random 10 import argparse 11 from collections import defaultdict 12 13 import torch 14 import numpy as np 15 import matplotlib.pyplot as plt 16 from tqdm import tqdm 17 18 from utils.meters import compute_eer, roc_points, det_points, recall_at_k, _l2norm 19 from utils.path_utils import _resolve_path 20 from speaker_verification.models.ecapa import ECAPA_TDNN 21 22 try: 23 from sklearn.manifold import TSNE 24 _HAS_SKLEARN = True 25 except Exception: 26 _HAS_SKLEARN = False 27 28 29 def read_meta_jsonl(meta_path: str): 30 meta_path = os.path.abspath(meta_path) 31 base_dir = os.path.dirname(meta_path) 32 items = [] 33 34 with open(meta_path, "r", encoding="utf-8") as f: 35 for line in f: 36 line = line.strip() 37 if not line: 38 continue 39 j = json.loads(line) 40 spk = str(j["spk"]) 41 feat = _resolve_path(j["feat"], base_dir) 42 items.append((spk, feat)) 43 return items 44 45 46 # ========================= 47 # Pair building 48 # ========================= 49 def build_pairs(items, num_pos=3000, num_neg=3000, seed=1234): 50 random.seed(seed) 51 spk2paths = defaultdict(list) 52 for spk, p in items: 53 spk2paths[spk].append(p) 54 55 spks_with2 = [s for s in spk2paths if len(spk2paths[s]) >= 2] 56 all_spks = list(spk2paths.keys()) 57 58 if len(spks_with2) == 0: 59 raise RuntimeError("Not enough speakers to generate positive pairs.") 60 61 pairs = [] 62 63 # Positive pairs 64 for _ in range(num_pos): 65 spk = random.choice(spks_with2) 66 p1, p2 = random.sample(spk2paths[spk], 2) 67 pairs.append((1, p1, p2)) 68 69 # Negative pairs 70 for _ in range(num_neg): 71 s1, s2 = random.sample(all_spks, 2) 72 p1 = random.choice(spk2paths[s1]) 73 p2 = random.choice(spk2paths[s2]) 74 pairs.append((0, p1, p2)) 75 76 random.shuffle(pairs) 77 return pairs 78 79 80 # ========================= 81 # Embedding Extraction 82 # ========================= 83 @torch.no_grad() 84 def load_feat_pt(feat_path: str): 85 if not os.path.exists(feat_path): 86 return None 87 try: 88 feat = torch.load(feat_path, map_location="cpu", weights_only=True) 89 except: 90 feat = torch.load(feat_path, map_location="cpu") 91 if not torch.is_tensor(feat) or feat.dim() != 2: 92 return None 93 return feat 94 95 96 @torch.no_grad() 97 def embed_from_feat(model, feat: torch.Tensor, device, crop_frames=400, num_crops=6, seed=1234): 98 rng = random.Random(seed) 99 T = feat.size(0) 100 101 if T <= crop_frames: 102 x = feat.unsqueeze(0).to(device) 103 emb = model(x).squeeze(0).cpu() 104 return _l2norm(emb) 105 106 embs = [] 107 for _ in range(num_crops): 108 s = rng.randint(0, T - crop_frames) 109 chunk = feat[s:s + crop_frames] 110 x = chunk.unsqueeze(0).to(device) 111 embs.append(model(x).squeeze(0).cpu()) 112 113 emb = torch.stack(embs, 0).mean(0) 114 return _l2norm(emb) 115 116 117 @torch.no_grad() 118 def embed_from_fbank_pt(model, feat_path: str, device, crop_frames=400, num_crops=6): 119 feat = load_feat_pt(feat_path) 120 if feat is None: 121 return None 122 return embed_from_feat(model, feat, device, crop_frames, num_crops) 123 124 125 def cosine_sim(a: torch.Tensor, b: torch.Tensor) -> float: 126 return float((a * b).sum().item()) 127 128 129 # ========================= 130 # t-SNE + Recall@K 131 # ========================= 132 @torch.no_grad() 133 def collect_embeddings_for_tsne(model, items, device, max_spk=20, per_spk=25, 134 crop_frames=400, num_crops=6, seed=1234): 135 random.seed(seed) 136 spk2paths = defaultdict(list) 137 for spk, p in items: 138 spk2paths[spk].append(p) 139 140 spks = [s for s in spk2paths if len(spk2paths[s]) >= 2] 141 random.shuffle(spks) 142 spks = spks[:max_spk] 143 144 X_list, y_list = [], [] 145 for spk in spks: 146 paths = random.sample(spk2paths[spk], min(per_spk, len(spk2paths[spk]))) 147 for p in paths: 148 emb = embed_from_fbank_pt(model, p, device, crop_frames, num_crops) 149 if emb is not None: 150 X_list.append(emb.numpy()) 151 y_list.append(spk) 152 153 if len(X_list) == 0: 154 return None, None 155 156 uniq = sorted(set(y_list)) 157 spk2id = {s: i for i, s in enumerate(uniq)} 158 y = np.array([spk2id[s] for s in y_list], dtype=np.int64) 159 160 return np.stack(X_list), y 161 162 163 # ========================= 164 # Main Function 165 # ========================= 166 def main(args): 167 os.makedirs(args.out_dir, exist_ok=True) 168 random.seed(args.seed) 169 torch.manual_seed(args.seed) 170 171 print("=" * 70) 172 print("Speaker Verification Evaluation") 173 print("=" * 70) 174 print(f"VAL_META : {args.val_meta}") 175 print(f"CHECKPOINT : {args.ckpt}") 176 print(f"OUTPUT DIR : {args.out_dir}") 177 print(f"Crop Frames: {args.crop_frames} | Num Crops: {args.num_crops}") 178 print(f"Pairs : {args.num_pos} pos + {args.num_neg} neg") 179 print(f"Device : {args.device}") 180 print("=" * 70) 181 182 device = torch.device(args.device if torch.cuda.is_available() else "cpu") 183 184 # 1. Load meta 185 items = read_meta_jsonl(args.val_meta) 186 print(f"Loaded {len(items)} utterances from {len(set(spk for spk, _ in items))} speakers\n") 187 188 # 2. Build pairs 189 pairs = build_pairs(items, num_pos=args.num_pos, num_neg=args.num_neg, seed=args.seed) 190 print(f"Generated {len(pairs)} pairs\n") 191 192 # 3. Load model 193 ckpt = torch.load(args.ckpt, map_location="cpu") 194 model = ECAPA_TDNN( 195 in_channels=80, 196 channels=args.channels, 197 embd_dim=args.emb_dim 198 ).to(device) 199 200 model.load_state_dict(ckpt["model"], strict=True) 201 model.eval() 202 print(f"Model loaded: ECAPA-TDNN (channels={args.channels}, emb_dim={args.emb_dim})\n") 203 204 # 4. Scoring 205 emb_cache = {} 206 labels, scores = [], [] 207 missing = 0 208 209 for is_same, p1, p2 in tqdm(pairs, desc="Scoring"): 210 if p1 not in emb_cache: 211 emb_cache[p1] = embed_from_fbank_pt(model, p1, device, args.crop_frames, args.num_crops) 212 if p2 not in emb_cache: 213 emb_cache[p2] = embed_from_fbank_pt(model, p2, device, args.crop_frames, args.num_crops) 214 215 e1 = emb_cache[p1] 216 e2 = emb_cache[p2] 217 218 if e1 is None or e2 is None: 219 missing += 1 220 continue 221 222 scores.append(cosine_sim(e1, e2)) 223 labels.append(is_same) 224 225 print(f"\nScoring completed! Used pairs: {len(scores)}, Skipped: {missing}") 226 227 if len(scores) == 0: 228 print("[ERROR] No valid pairs! Check your feature paths.") 229 return 230 231 # 5. EER & Metrics 232 eer, th = compute_eer(labels, scores) 233 print(f"\n>>> EER = {eer*100:.3f}% (threshold ≈ {th:.4f})") 234 235 pos = [s for s, l in zip(scores, labels) if l == 1] 236 neg = [s for s, l in zip(scores, labels) if l == 0] 237 print(f"Pos mean: {np.mean(pos):.4f} | Neg mean: {np.mean(neg):.4f}") 238 239 # 6. Save Plots 240 fpr, tpr = roc_points(labels, scores, num_th=200) 241 plt.figure(figsize=(8, 6)) 242 plt.plot(fpr, tpr) 243 plt.title(f"ROC Curve (EER = {eer*100:.2f}%)") 244 plt.xlabel("False Positive Rate") 245 plt.ylabel("True Positive Rate") 246 plt.grid(True) 247 plt.savefig(os.path.join(args.out_dir, "roc.png"), dpi=300) 248 plt.close() 249 250 fars, frrs = det_points(labels, scores, num_th=400) 251 plt.figure(figsize=(8, 6)) 252 plt.plot(fars, frrs) 253 plt.title(f"DET Curve (EER = {eer*100:.2f}%)") 254 plt.xlabel("False Acceptance Rate") 255 plt.ylabel("False Rejection Rate") 256 plt.grid(True) 257 plt.savefig(os.path.join(args.out_dir, "det.png"), dpi=300) 258 plt.close() 259 260 plt.figure(figsize=(9, 6)) 261 plt.hist(pos, bins=80, alpha=0.7, label="Same Speaker") 262 plt.hist(neg, bins=80, alpha=0.7, label="Different Speaker") 263 plt.title("Score Distribution") 264 plt.xlabel("Cosine Similarity") 265 plt.ylabel("Count") 266 plt.legend() 267 plt.grid(True) 268 plt.savefig(os.path.join(args.out_dir, "score_hist.png"), dpi=300) 269 plt.close() 270 271 # 7. t-SNE + Recall@K 272 X, y_tsne = collect_embeddings_for_tsne( 273 model, items, device, 274 max_spk=20, per_spk=25, 275 crop_frames=args.crop_frames, 276 num_crops=args.num_crops, 277 seed=args.seed 278 ) 279 280 if X is not None and _HAS_SKLEARN: 281 tsne = TSNE(n_components=2, perplexity=30, random_state=args.seed, init="pca") 282 Z = tsne.fit_transform(X) 283 284 plt.figure(figsize=(10, 8)) 285 for spk_id in np.unique(y_tsne): 286 mask = (y_tsne == spk_id) 287 plt.scatter(Z[mask, 0], Z[mask, 1], s=12, alpha=0.8) 288 plt.title("t-SNE Visualization of Speaker Embeddings") 289 plt.grid(True) 290 plt.savefig(os.path.join(args.out_dir, "tsne.png"), dpi=300) 291 plt.close() 292 293 # Recall@K 294 emb_t = torch.from_numpy(X).float() 295 emb_t = emb_t / (emb_t.norm(dim=1, keepdim=True) + 1e-12) 296 lab_t = torch.from_numpy(y_tsne).long() 297 recall = recall_at_k(emb_t, lab_t, ks=(1, 5, 10)) 298 299 print("\nRecall@K (sampled):") 300 for k, v in recall.items(): 301 print(f" R@{k}: {v*100:.2f}%") 302 303 with open(os.path.join(args.out_dir, "metrics.txt"), "w") as f: 304 f.write(f"EER: {eer*100:.3f}%\n") 305 f.write(f"Threshold: {th:.4f}\n") 306 for k, v in recall.items(): 307 f.write(f"Recall@{k}: {v*100:.2f}%\n") 308 309 print(f"\n所有结果已保存至: {args.out_dir}") 310 print("文件: roc.png, det.png, score_hist.png, tsne.png, metrics.txt") 311 312 313 if __name__ == "__main__": 314 parser = argparse.ArgumentParser(description="Speaker Verification Evaluation Tool") 315 316 parser.add_argument("--val_meta", type=str, required=True, 317 help="Path to validation meta.jsonl file") 318 319 parser.add_argument("--ckpt", type=str, required=True, 320 help="Path to model checkpoint (.pt)") 321 322 parser.add_argument("--out_dir", type=str, default="outputs_eval", 323 help="Directory to save evaluation results (default: outputs_eval)") 324 325 parser.add_argument("--crop_frames", type=int, default=400, 326 help="Number of frames per crop (default: 400 ≈ 4秒)") 327 328 parser.add_argument("--num_crops", type=int, default=6, 329 help="Number of crops to average (default: 6)") 330 331 parser.add_argument("--num_pos", type=int, default=3000, 332 help="Number of positive pairs") 333 334 parser.add_argument("--num_neg", type=int, default=3000, 335 help="Number of negative pairs") 336 337 parser.add_argument("--emb_dim", type=int, default=256, 338 help="Embedding dimension (must match checkpoint)") 339 340 parser.add_argument("--channels", type=int, default=512, 341 help="ECAPA channels (default: 512)") 342 343 parser.add_argument("--seed", type=int, default=1234, 344 help="Random seed") 345 346 parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], 347 help="Device to use") 348 349 args = parser.parse_args() 350 351 main(args)