/ src / python / txtai / data / base.py
base.py
  1  """
  2  Data module
  3  """
  4  
  5  from .tokens import Tokens
  6  
  7  
  8  class Data:
  9      """
 10      Base data tokenization class.
 11      """
 12  
 13      def __init__(self, tokenizer, columns, maxlength):
 14          """
 15          Creates new base instance for tokenizing data.
 16  
 17          Args:
 18              tokenizer: model tokenizer
 19              columns: column names
 20              maxlength: maximum sequence length
 21          """
 22  
 23          self.tokenizer = tokenizer
 24          self.columns = columns
 25          self.maxlength = maxlength
 26  
 27      def __call__(self, train, validation, workers):
 28          """
 29          Tokenizes training and validation data and returns processed datasets.
 30  
 31          Args:
 32              train: training data
 33              validation: validation data
 34              workers: number of concurrent tokenizers when processing datasets, only main process used when set to None
 35  
 36          Returns:
 37              (train, validation)
 38          """
 39  
 40          return (self.prepare(train, self.process, workers), self.prepare(validation, self.process, workers) if validation else None)
 41  
 42      def prepare(self, data, fn, workers):
 43          """
 44          Prepares and tokenizes data for training.
 45  
 46          Args:
 47              data: input data
 48              fn: tokenize processing function to apply
 49              workers: number of concurrent tokenizers when processing datasets, only main process used when set to None
 50  
 51          Returns:
 52              tokens
 53          """
 54  
 55          if hasattr(data, "map"):
 56              # Hugging Face dataset
 57              tokens = data.map(fn, batched=True, num_proc=workers, remove_columns=data.column_names)
 58          else:
 59              # Re-orient data into columns for efficient batch tokenization
 60              columns = {}
 61              if hasattr(data, "columns"):
 62                  # Polars/pandas DataFrame
 63                  for column in data.columns:
 64                      columns[column] = list(data[column])
 65              else:
 66                  # Iterable dicts
 67                  for row in data:
 68                      for column in row.keys():
 69                          if column not in columns:
 70                              columns[column] = []
 71  
 72                          columns[column].append(row[column])
 73  
 74              # Process column-oriented data
 75              tokens = Tokens(fn(columns))
 76  
 77          return tokens
 78  
 79      def labels(self, data):
 80          """
 81          Extracts a list of unique labels from data.
 82  
 83          Args:
 84              data: input data
 85  
 86          Returns:
 87              list of unique labels
 88          """
 89  
 90          # Last column is label
 91          column = self.columns[-1]
 92  
 93          # Return length of labels if it's an array
 94          length = self.length(data[column][0] if hasattr(data, "columns") else data[0][column])
 95          if length:
 96              return length
 97  
 98          if hasattr(data, "map"):
 99              # Hugging Face dataset
100              labels = sorted(data.unique(self.columns[-1]))
101          elif hasattr(data, "columns"):
102              # Polars/pandas DataFrame
103              labels = sorted(data[self.columns[-1]].unique())
104          else:
105              # Iterable dicts
106              labels = sorted({row[self.columns[-1]] for row in data})
107  
108          # Labels are single numeric values per entry
109          #   - Consider a regression task if at least one label isn't an integer
110          #   - Otherwise use number of labels for a classification task
111          return 1 if [x for x in labels if float(x) != int(x)] else len(labels)
112  
113      def process(self, data):
114          """
115          Tokenizes batch of input data
116  
117          Args:
118              data: input data batch
119  
120          Returns:
121              tokenized data
122          """
123  
124          return data
125  
126      def length(self, value):
127          """
128          Returns the length of value if value has a len function defined. Otherwise,
129          None is returned.
130  
131          Args:
132              value: value to check
133  
134          Returns:
135              length of value if available, otherwise returns None
136          """
137  
138          return len(value) if hasattr(value, "__len__") else None