/ compare_two_wavs.py
compare_two_wavs.py
  1  import argparse
  2  from pathlib import Path
  3  
  4  import torch
  5  
  6  from speaker_verification.inference import load_sv, cosine_score
  7  
  8  try:
  9      import onnxruntime as ort  # noqa: F401
 10      _HAS_ONNX = True
 11  except ImportError:
 12      _HAS_ONNX = False
 13      print("⚠️ onnxruntime 未安装,ONNX 模式不可用。")
 14      print("   pip install onnxruntime  或  onnxruntime-gpu")
 15  
 16  
 17  def main():
 18      parser = argparse.ArgumentParser(description="两个音频说话人对比(PyTorch / ONNX)")
 19  
 20      parser.add_argument("--wav1", type=str, required=True, help="第一个音频路径")
 21      parser.add_argument("--wav2", type=str, required=True, help="第二个音频路径")
 22  
 23      parser.add_argument(
 24          "--ckpt",
 25          type=str,
 26          default="scripts/outputs/export/model.pt",
 27          help="模型路径:.pt (PyTorch) 或 .onnx (ONNX)",
 28      )
 29  
 30      parser.add_argument(
 31          "--onnx",
 32          action="store_true",
 33          default=False,
 34          help="使用 ONNX 推理(默认自动根据文件后缀判断)",
 35      )
 36  
 37      parser.add_argument(
 38          "--threshold",
 39          type=float,
 40          default=0.55,
 41          help="判断同一人的余弦相似度阈值(建议通过 verify.py 得到最佳阈值)",
 42      )
 43  
 44      parser.add_argument(
 45          "--device",
 46          type=str,
 47          default="cuda",
 48          choices=["cuda", "cpu"],
 49          help="PyTorch 模式使用的设备(ONNX 默认按 onnxruntime providers)",
 50      )
 51  
 52      parser.add_argument("--num_crops", type=int, default=5, help="多 crop 平均的 crop 数")
 53      parser.add_argument("--crop_sec", type=float, default=3.0, help="每个 crop 的时长(秒)")
 54  
 55      args = parser.parse_args()
 56  
 57      ckpt_path = Path(args.ckpt)
 58      use_onnx = args.onnx or ckpt_path.suffix.lower() == ".onnx"
 59  
 60      print("=" * 70)
 61      print("🎙️  Speaker Verification - Two Wavs Comparison")
 62      print("=" * 70)
 63      print(f"Audio 1  : {args.wav1}")
 64      print(f"Audio 2  : {args.wav2}")
 65      print(f"Model    : {ckpt_path}  ({'ONNX' if use_onnx else 'PyTorch'})")
 66      print(f"Threshold: {args.threshold}")
 67      print(f"Crops    : {args.num_crops}  |  Crop_sec: {args.crop_sec}")
 68      print("=" * 70)
 69  
 70      if use_onnx and not _HAS_ONNX:
 71          raise ImportError("请先安装 onnxruntime: pip install onnxruntime 或 onnxruntime-gpu")
 72  
 73      backend_device = "cpu" if use_onnx else (args.device if torch.cuda.is_available() else "cpu")
 74      sv, meta = load_sv(str(ckpt_path), device=backend_device, use_onnx=use_onnx)
 75  
 76      if use_onnx:
 77          print(f"使用 ONNX Runtime 推理... providers={meta.get('providers')}")
 78      else:
 79          print(f"使用 PyTorch 推理... device={backend_device}")
 80  
 81      score = cosine_score(
 82          sv,
 83          args.wav1,
 84          args.wav2,
 85          num_crops=args.num_crops,
 86          crop_sec=args.crop_sec,
 87      )
 88  
 89      same = score >= args.threshold
 90  
 91      print(f"\n🔍 Cosine Similarity = {score:.4f}")
 92      print(f"   Threshold        = {args.threshold}")
 93      print(f"   → {'同一说话人' if same else '不同说话人'}")
 94  
 95      color = "\033[92m" if same else "\033[91m"
 96      print(f"\n{color}【最终判定】{'✅ 同一人' if same else '❌ 不同人'}\033[0m")
 97      print("\n" + "=" * 70)
 98  
 99  
100  if __name__ == "__main__":
101      main()