/ dataset / dataset.py
dataset.py
  1  import json
  2  import os
  3  import random
  4  import torch
  5  from torch.utils.data import Dataset
  6  from utils.path_utils import _resolve_path
  7  
  8  
  9  class TrainFbankPtDataset(torch.utils.data.Dataset):
 10      def __init__(self, list_path: str, crop_frames: int = 200):
 11          self.list_path = os.path.abspath(list_path)
 12          base_dir = os.path.dirname(self.list_path)
 13  
 14          raw_items = []
 15          raw_labels = []
 16  
 17          with open(self.list_path, "r", encoding="utf-8") as f:
 18              for line in f:
 19                  line = line.strip()
 20                  if not line:
 21                      continue
 22  
 23                  lab_str, p = line.split(maxsplit=1)
 24                  lab = int(lab_str)
 25  
 26                  p = _resolve_path(p, base_dir)
 27  
 28                  raw_items.append((lab, p))
 29                  raw_labels.append(lab)
 30  
 31          # label 连续化
 32          uniq = sorted(set(raw_labels))
 33          self.label_map = {old: new for new, old in enumerate(uniq)}
 34          self.num_classes = len(uniq)
 35  
 36          self.items = [(self.label_map[lab], p) for lab, p in raw_items]
 37  
 38          self.crop_frames = int(crop_frames)
 39  
 40      def __len__(self):
 41          return len(self.items)
 42  
 43      def __getitem__(self, idx):
 44          y, feat_path = self.items[idx]
 45  
 46          feat = torch.load(feat_path, map_location="cpu")  # [T,80]
 47          if not torch.is_tensor(feat):
 48              feat = torch.tensor(feat)
 49  
 50          T = feat.size(0)
 51  
 52          # 固定长度裁剪(关键防 OOM)
 53          if T > self.crop_frames:
 54              s = random.randint(0, T - self.crop_frames)
 55              feat = feat[s:s + self.crop_frames]
 56          else:
 57              reps = (self.crop_frames + T - 1) // T
 58              feat = feat.repeat(reps, 1)[:self.crop_frames]
 59  
 60          return feat, int(y)
 61  
 62  class ValMetaDataset(torch.utils.data.Dataset):
 63      """
 64      读取 val_meta.jsonl:
 65        {"spk":"id0001","feat":"...pt"}
 66      """
 67      def __init__(self, meta_path: str, crop_frames: int = 200):
 68          self.meta_path = os.path.abspath(meta_path)
 69          base_dir = os.path.dirname(self.meta_path)
 70  
 71          self.items = []
 72          with open(self.meta_path, "r", encoding="utf-8") as f:
 73              for line in f:
 74                  if not line.strip():
 75                      continue
 76                  j = json.loads(line)
 77                  spk = str(j["spk"])
 78                  feat = _resolve_path(j["feat"], base_dir)
 79                  self.items.append((spk, feat))
 80  
 81          self.crop_frames = crop_frames
 82  
 83      def __len__(self):
 84          return len(self.items)
 85  
 86      def __getitem__(self, idx):
 87          spk, feat_path = self.items[idx]
 88  
 89          feat = torch.load(feat_path, map_location="cpu")
 90          if not torch.is_tensor(feat):
 91              feat = torch.tensor(feat)
 92  
 93          T = feat.size(0)
 94  
 95          # 固定 crop(验证也必须 crop)
 96          if T > self.crop_frames:
 97              s = random.randint(0, T - self.crop_frames)
 98              feat = feat[s:s + self.crop_frames]
 99          else:
100              reps = (self.crop_frames + T - 1) // T
101              feat = feat.repeat(reps, 1)[:self.crop_frames]
102  
103          return feat, spk
104  
105  
106  
107  def spec_augment(feat, time_mask=20, freq_mask=8, p=0.5):
108      # feat: [T,80]
109      if random.random() > p:
110          return feat
111      T, F = feat.size(0), feat.size(1)
112  
113      # time mask
114      t = random.randint(0, time_mask)
115      t0 = random.randint(0, max(0, T - t))
116      feat[t0:t0+t, :] = 0
117  
118      # freq mask
119      f = random.randint(0, freq_mask)
120      f0 = random.randint(0, max(0, F - f))
121      feat[:, f0:f0+f] = 0
122      return feat
123  
124  def collate_val(batch):
125      feats, spks = zip(*batch)
126      x = torch.stack(feats, dim=0)
127      return x, list(spks)
128  
129  
130  def collate_fixed(batch):
131      feats, ys = zip(*batch)
132      x = torch.stack(feats, dim=0)  # [B, crop_frames, 80]
133      y = torch.tensor(ys, dtype=torch.long)
134      return x, y
135