texts.py
1 """ 2 Texts module 3 """ 4 5 from itertools import chain 6 7 from .base import Data 8 9 10 class Texts(Data): 11 """ 12 Tokenizes text datasets as input for training language models. 13 """ 14 15 def __init__(self, tokenizer, columns, maxlength, merge): 16 """ 17 Creates a new instance for tokenizing Texts training data. 18 19 Args: 20 tokenizer: model tokenizer 21 columns: tuple of columns to use for text 22 maxlength: maximum sequence length 23 merge: determines how chunks are combined for language modeling tasks - "concat" (default), "pack" or None 24 """ 25 26 super().__init__(tokenizer, columns, maxlength) 27 28 # Standardize columns 29 if not self.columns: 30 self.columns = ("text", None) 31 32 # Method to combine chunks 33 self.merge = merge 34 35 def process(self, data): 36 # Column keys 37 text1, text2 = self.columns 38 39 # Tokenizer inputs can be single string or string pair, depending on task 40 text = (data[text1], data[text2]) if text2 else (data[text1],) 41 42 # Tokenize text and add label 43 inputs = self.tokenizer(*text, return_special_tokens_mask=True) 44 45 # Combine inputs based on parameters 46 return self.concat(inputs) if self.merge == "concat" else self.pack(inputs) if self.merge == "pack" else inputs 47 48 def concat(self, inputs): 49 """ 50 Concatenates tokenized text into chunks of maxlength. This method guarantees that each chunk is maxlength 51 size and splits data across multiple chunks if needed. 52 53 This is best with general language modeling tasks like masked language modeling that are streams of text. 54 55 Args: 56 inputs: tokenized input 57 58 Returns: 59 Chunks of tokenized text each with a size of maxlength 60 """ 61 62 # Concatenate tokenized text 63 concat = {k: list(chain(*inputs[k])) for k in inputs.keys()} 64 65 # Calculate total length 66 length = len(concat[list(inputs.keys())[0]]) 67 68 # Ensure total is multiple of maxlength, drop last incomplete chunk 69 if length >= self.maxlength: 70 length = (length // self.maxlength) * self.maxlength 71 72 # Split into chunks of maxlength 73 result = {k: [v[x : x + self.maxlength] for x in range(0, length, self.maxlength)] for k, v in concat.items()} 74 75 return result 76 77 def pack(self, inputs): 78 """ 79 Packs tokenized text into chunks up to maxlength. This method guarantees that data is never split across 80 multiple chunks. 81 82 This is best with instruction/prompt learning where it's crucial to ensure entire records are preserved. 83 84 Args: 85 inputs: tokenized input 86 87 Returns: 88 Chunks of tokenized text each with a size of maxlength 89 """ 90 91 # Sort lists by length descending 92 inputs = {k: sorted(v, key=len, reverse=True) for k, v in inputs.items()} 93 94 # Inputs has lists of equal length per column 95 columns = list(inputs.keys()) 96 97 # Create empty results dict 98 results = {column: [] for column in columns} 99 100 # Iterate over values in first column since all column lengths per row are equal 101 length, index, rows = 0, 0, inputs[columns[0]] 102 for x, row in enumerate(rows): 103 length += len(row) 104 nextlength = len(rows[x + 1]) if x < len(rows) - 1 else 0 105 106 # New row 107 if (length + nextlength) >= self.maxlength: 108 for column in columns: 109 results[column].append(list(chain(*inputs[column][index : x + 1]))) 110 111 # Reset length and index 112 length, index = 0, x + 1 113 114 # Last row 115 if length: 116 for column in columns: 117 results[column].append(list(chain(*inputs[column][index : len(rows)]))) 118 119 return results