dataset.py
1 import json 2 import os 3 import random 4 import torch 5 from torch.utils.data import Dataset 6 from utils.path_utils import _resolve_path 7 8 9 class TrainFbankPtDataset(torch.utils.data.Dataset): 10 def __init__(self, list_path: str, crop_frames: int = 200): 11 self.list_path = os.path.abspath(list_path) 12 base_dir = os.path.dirname(self.list_path) 13 14 raw_items = [] 15 raw_labels = [] 16 17 with open(self.list_path, "r", encoding="utf-8") as f: 18 for line in f: 19 line = line.strip() 20 if not line: 21 continue 22 23 lab_str, p = line.split(maxsplit=1) 24 lab = int(lab_str) 25 26 p = _resolve_path(p, base_dir) 27 28 raw_items.append((lab, p)) 29 raw_labels.append(lab) 30 31 # label 连续化 32 uniq = sorted(set(raw_labels)) 33 self.label_map = {old: new for new, old in enumerate(uniq)} 34 self.num_classes = len(uniq) 35 36 self.items = [(self.label_map[lab], p) for lab, p in raw_items] 37 38 self.crop_frames = int(crop_frames) 39 40 def __len__(self): 41 return len(self.items) 42 43 def __getitem__(self, idx): 44 y, feat_path = self.items[idx] 45 46 feat = torch.load(feat_path, map_location="cpu") # [T,80] 47 if not torch.is_tensor(feat): 48 feat = torch.tensor(feat) 49 50 T = feat.size(0) 51 52 # 固定长度裁剪(关键防 OOM) 53 if T > self.crop_frames: 54 s = random.randint(0, T - self.crop_frames) 55 feat = feat[s:s + self.crop_frames] 56 else: 57 reps = (self.crop_frames + T - 1) // T 58 feat = feat.repeat(reps, 1)[:self.crop_frames] 59 60 return feat, int(y) 61 62 class ValMetaDataset(torch.utils.data.Dataset): 63 """ 64 读取 val_meta.jsonl: 65 {"spk":"id0001","feat":"...pt"} 66 """ 67 def __init__(self, meta_path: str, crop_frames: int = 200): 68 self.meta_path = os.path.abspath(meta_path) 69 base_dir = os.path.dirname(self.meta_path) 70 71 self.items = [] 72 with open(self.meta_path, "r", encoding="utf-8") as f: 73 for line in f: 74 if not line.strip(): 75 continue 76 j = json.loads(line) 77 spk = str(j["spk"]) 78 feat = _resolve_path(j["feat"], base_dir) 79 self.items.append((spk, feat)) 80 81 self.crop_frames = crop_frames 82 83 def __len__(self): 84 return len(self.items) 85 86 def __getitem__(self, idx): 87 spk, feat_path = self.items[idx] 88 89 feat = torch.load(feat_path, map_location="cpu") 90 if not torch.is_tensor(feat): 91 feat = torch.tensor(feat) 92 93 T = feat.size(0) 94 95 # 固定 crop(验证也必须 crop) 96 if T > self.crop_frames: 97 s = random.randint(0, T - self.crop_frames) 98 feat = feat[s:s + self.crop_frames] 99 else: 100 reps = (self.crop_frames + T - 1) // T 101 feat = feat.repeat(reps, 1)[:self.crop_frames] 102 103 return feat, spk 104 105 106 107 def spec_augment(feat, time_mask=20, freq_mask=8, p=0.5): 108 # feat: [T,80] 109 if random.random() > p: 110 return feat 111 T, F = feat.size(0), feat.size(1) 112 113 # time mask 114 t = random.randint(0, time_mask) 115 t0 = random.randint(0, max(0, T - t)) 116 feat[t0:t0+t, :] = 0 117 118 # freq mask 119 f = random.randint(0, freq_mask) 120 f0 = random.randint(0, max(0, F - f)) 121 feat[:, f0:f0+f] = 0 122 return feat 123 124 def collate_val(batch): 125 feats, spks = zip(*batch) 126 x = torch.stack(feats, dim=0) 127 return x, list(spks) 128 129 130 def collate_fixed(batch): 131 feats, ys = zip(*batch) 132 x = torch.stack(feats, dim=0) # [B, crop_frames, 80] 133 y = torch.tensor(ys, dtype=torch.long) 134 return x, y 135