/ src / python / txtai / data / texts.py
texts.py
  1  """
  2  Texts module
  3  """
  4  
  5  from itertools import chain
  6  
  7  from .base import Data
  8  
  9  
 10  class Texts(Data):
 11      """
 12      Tokenizes text datasets as input for training language models.
 13      """
 14  
 15      def __init__(self, tokenizer, columns, maxlength, merge):
 16          """
 17          Creates a new instance for tokenizing Texts training data.
 18  
 19          Args:
 20              tokenizer: model tokenizer
 21              columns: tuple of columns to use for text
 22              maxlength: maximum sequence length
 23              merge: determines how chunks are combined for language modeling tasks - "concat" (default), "pack" or None
 24          """
 25  
 26          super().__init__(tokenizer, columns, maxlength)
 27  
 28          # Standardize columns
 29          if not self.columns:
 30              self.columns = ("text", None)
 31  
 32          # Method to combine chunks
 33          self.merge = merge
 34  
 35      def process(self, data):
 36          # Column keys
 37          text1, text2 = self.columns
 38  
 39          # Tokenizer inputs can be single string or string pair, depending on task
 40          text = (data[text1], data[text2]) if text2 else (data[text1],)
 41  
 42          # Tokenize text and add label
 43          inputs = self.tokenizer(*text, return_special_tokens_mask=True)
 44  
 45          # Combine inputs based on parameters
 46          return self.concat(inputs) if self.merge == "concat" else self.pack(inputs) if self.merge == "pack" else inputs
 47  
 48      def concat(self, inputs):
 49          """
 50          Concatenates tokenized text into chunks of maxlength. This method guarantees that each chunk is maxlength
 51          size and splits data across multiple chunks if needed.
 52  
 53          This is best with general language modeling tasks like masked language modeling that are streams of text.
 54  
 55          Args:
 56              inputs: tokenized input
 57  
 58          Returns:
 59              Chunks of tokenized text each with a size of maxlength
 60          """
 61  
 62          # Concatenate tokenized text
 63          concat = {k: list(chain(*inputs[k])) for k in inputs.keys()}
 64  
 65          # Calculate total length
 66          length = len(concat[list(inputs.keys())[0]])
 67  
 68          # Ensure total is multiple of maxlength, drop last incomplete chunk
 69          if length >= self.maxlength:
 70              length = (length // self.maxlength) * self.maxlength
 71  
 72          # Split into chunks of maxlength
 73          result = {k: [v[x : x + self.maxlength] for x in range(0, length, self.maxlength)] for k, v in concat.items()}
 74  
 75          return result
 76  
 77      def pack(self, inputs):
 78          """
 79          Packs tokenized text into chunks up to maxlength. This method guarantees that data is never split across
 80          multiple chunks.
 81  
 82          This is best with instruction/prompt learning where it's crucial to ensure entire records are preserved.
 83  
 84          Args:
 85              inputs: tokenized input
 86  
 87          Returns:
 88              Chunks of tokenized text each with a size of maxlength
 89          """
 90  
 91          # Sort lists by length descending
 92          inputs = {k: sorted(v, key=len, reverse=True) for k, v in inputs.items()}
 93  
 94          # Inputs has lists of equal length per column
 95          columns = list(inputs.keys())
 96  
 97          # Create empty results dict
 98          results = {column: [] for column in columns}
 99  
100          # Iterate over values in first column since all column lengths per row are equal
101          length, index, rows = 0, 0, inputs[columns[0]]
102          for x, row in enumerate(rows):
103              length += len(row)
104              nextlength = len(rows[x + 1]) if x < len(rows) - 1 else 0
105  
106              # New row
107              if (length + nextlength) >= self.maxlength:
108                  for column in columns:
109                      results[column].append(list(chain(*inputs[column][index : x + 1])))
110  
111                  # Reset length and index
112                  length, index = 0, x + 1
113  
114          # Last row
115          if length:
116              for column in columns:
117                  results[column].append(list(chain(*inputs[column][index : len(rows)])))
118  
119          return results