base.py
1 """ 2 Data module 3 """ 4 5 from .tokens import Tokens 6 7 8 class Data: 9 """ 10 Base data tokenization class. 11 """ 12 13 def __init__(self, tokenizer, columns, maxlength): 14 """ 15 Creates new base instance for tokenizing data. 16 17 Args: 18 tokenizer: model tokenizer 19 columns: column names 20 maxlength: maximum sequence length 21 """ 22 23 self.tokenizer = tokenizer 24 self.columns = columns 25 self.maxlength = maxlength 26 27 def __call__(self, train, validation, workers): 28 """ 29 Tokenizes training and validation data and returns processed datasets. 30 31 Args: 32 train: training data 33 validation: validation data 34 workers: number of concurrent tokenizers when processing datasets, only main process used when set to None 35 36 Returns: 37 (train, validation) 38 """ 39 40 return (self.prepare(train, self.process, workers), self.prepare(validation, self.process, workers) if validation else None) 41 42 def prepare(self, data, fn, workers): 43 """ 44 Prepares and tokenizes data for training. 45 46 Args: 47 data: input data 48 fn: tokenize processing function to apply 49 workers: number of concurrent tokenizers when processing datasets, only main process used when set to None 50 51 Returns: 52 tokens 53 """ 54 55 if hasattr(data, "map"): 56 # Hugging Face dataset 57 tokens = data.map(fn, batched=True, num_proc=workers, remove_columns=data.column_names) 58 else: 59 # Re-orient data into columns for efficient batch tokenization 60 columns = {} 61 if hasattr(data, "columns"): 62 # Polars/pandas DataFrame 63 for column in data.columns: 64 columns[column] = list(data[column]) 65 else: 66 # Iterable dicts 67 for row in data: 68 for column in row.keys(): 69 if column not in columns: 70 columns[column] = [] 71 72 columns[column].append(row[column]) 73 74 # Process column-oriented data 75 tokens = Tokens(fn(columns)) 76 77 return tokens 78 79 def labels(self, data): 80 """ 81 Extracts a list of unique labels from data. 82 83 Args: 84 data: input data 85 86 Returns: 87 list of unique labels 88 """ 89 90 # Last column is label 91 column = self.columns[-1] 92 93 # Return length of labels if it's an array 94 length = self.length(data[column][0] if hasattr(data, "columns") else data[0][column]) 95 if length: 96 return length 97 98 if hasattr(data, "map"): 99 # Hugging Face dataset 100 labels = sorted(data.unique(self.columns[-1])) 101 elif hasattr(data, "columns"): 102 # Polars/pandas DataFrame 103 labels = sorted(data[self.columns[-1]].unique()) 104 else: 105 # Iterable dicts 106 labels = sorted({row[self.columns[-1]] for row in data}) 107 108 # Labels are single numeric values per entry 109 # - Consider a regression task if at least one label isn't an integer 110 # - Otherwise use number of labels for a classification task 111 return 1 if [x for x in labels if float(x) != int(x)] else len(labels) 112 113 def process(self, data): 114 """ 115 Tokenizes batch of input data 116 117 Args: 118 data: input data batch 119 120 Returns: 121 tokenized data 122 """ 123 124 return data 125 126 def length(self, value): 127 """ 128 Returns the length of value if value has a len function defined. Otherwise, 129 None is returned. 130 131 Args: 132 value: value to check 133 134 Returns: 135 length of value if available, otherwise returns None 136 """ 137 138 return len(value) if hasattr(value, "__len__") else None