tokens.py
1 """ 2 Tokens module 3 """ 4 5 import torch 6 7 8 class Tokens(torch.utils.data.Dataset): 9 """ 10 Default dataset used to hold tokenized data. 11 """ 12 13 def __init__(self, columns): 14 self.data = [] 15 16 # Map column-oriented data to rows 17 for column in columns: 18 for x, value in enumerate(columns[column]): 19 if len(self.data) <= x: 20 self.data.append({}) 21 22 self.data[x][column] = value 23 24 def __len__(self): 25 return len(self.data) 26 27 def __getitem__(self, index): 28 return self.data[index]