/ compare_two_wavs.py
compare_two_wavs.py
1 import argparse 2 from pathlib import Path 3 4 import torch 5 6 from speaker_verification.inference import load_sv, cosine_score 7 8 try: 9 import onnxruntime as ort # noqa: F401 10 _HAS_ONNX = True 11 except ImportError: 12 _HAS_ONNX = False 13 print("⚠️ onnxruntime 未安装,ONNX 模式不可用。") 14 print(" pip install onnxruntime 或 onnxruntime-gpu") 15 16 17 def main(): 18 parser = argparse.ArgumentParser(description="两个音频说话人对比(PyTorch / ONNX)") 19 20 parser.add_argument("--wav1", type=str, required=True, help="第一个音频路径") 21 parser.add_argument("--wav2", type=str, required=True, help="第二个音频路径") 22 23 parser.add_argument( 24 "--ckpt", 25 type=str, 26 default="scripts/outputs/export/model.pt", 27 help="模型路径:.pt (PyTorch) 或 .onnx (ONNX)", 28 ) 29 30 parser.add_argument( 31 "--onnx", 32 action="store_true", 33 default=False, 34 help="使用 ONNX 推理(默认自动根据文件后缀判断)", 35 ) 36 37 parser.add_argument( 38 "--threshold", 39 type=float, 40 default=0.55, 41 help="判断同一人的余弦相似度阈值(建议通过 verify.py 得到最佳阈值)", 42 ) 43 44 parser.add_argument( 45 "--device", 46 type=str, 47 default="cuda", 48 choices=["cuda", "cpu"], 49 help="PyTorch 模式使用的设备(ONNX 默认按 onnxruntime providers)", 50 ) 51 52 parser.add_argument("--num_crops", type=int, default=5, help="多 crop 平均的 crop 数") 53 parser.add_argument("--crop_sec", type=float, default=3.0, help="每个 crop 的时长(秒)") 54 55 args = parser.parse_args() 56 57 ckpt_path = Path(args.ckpt) 58 use_onnx = args.onnx or ckpt_path.suffix.lower() == ".onnx" 59 60 print("=" * 70) 61 print("🎙️ Speaker Verification - Two Wavs Comparison") 62 print("=" * 70) 63 print(f"Audio 1 : {args.wav1}") 64 print(f"Audio 2 : {args.wav2}") 65 print(f"Model : {ckpt_path} ({'ONNX' if use_onnx else 'PyTorch'})") 66 print(f"Threshold: {args.threshold}") 67 print(f"Crops : {args.num_crops} | Crop_sec: {args.crop_sec}") 68 print("=" * 70) 69 70 if use_onnx and not _HAS_ONNX: 71 raise ImportError("请先安装 onnxruntime: pip install onnxruntime 或 onnxruntime-gpu") 72 73 backend_device = "cpu" if use_onnx else (args.device if torch.cuda.is_available() else "cpu") 74 sv, meta = load_sv(str(ckpt_path), device=backend_device, use_onnx=use_onnx) 75 76 if use_onnx: 77 print(f"使用 ONNX Runtime 推理... providers={meta.get('providers')}") 78 else: 79 print(f"使用 PyTorch 推理... device={backend_device}") 80 81 score = cosine_score( 82 sv, 83 args.wav1, 84 args.wav2, 85 num_crops=args.num_crops, 86 crop_sec=args.crop_sec, 87 ) 88 89 same = score >= args.threshold 90 91 print(f"\n🔍 Cosine Similarity = {score:.4f}") 92 print(f" Threshold = {args.threshold}") 93 print(f" → {'同一说话人' if same else '不同说话人'}") 94 95 color = "\033[92m" if same else "\033[91m" 96 print(f"\n{color}【最终判定】{'✅ 同一人' if same else '❌ 不同人'}\033[0m") 97 print("\n" + "=" * 70) 98 99 100 if __name__ == "__main__": 101 main()