questions.py
1 """ 2 Questions module 3 """ 4 5 from .base import Data 6 7 8 class Questions(Data): 9 """ 10 Tokenizes question-answering datasets as input for training question-answering models. 11 """ 12 13 def __init__(self, tokenizer, columns, maxlength, stride): 14 """ 15 Creates a new instance for tokenizing Questions training data. 16 17 Args: 18 tokenizer: model tokenizer 19 columns: tuple of columns to use for question/context/answer 20 maxlength: maximum sequence length 21 stride: chunk size for splitting data for QA tasks 22 """ 23 24 super().__init__(tokenizer, columns, maxlength) 25 26 if not self.columns: 27 self.columns = ("question", "context", "answers") 28 29 self.question, self.context, self.answer = self.columns 30 self.stride = stride 31 self.rpad = tokenizer.padding_side == "right" 32 33 def process(self, data): 34 # Tokenize data 35 tokenized = self.tokenize(data) 36 37 # Get mapping of overflowing tokens and answer offsets 38 samples = tokenized.pop("overflow_to_sample_mapping") 39 offsets = tokenized.pop("offset_mapping") 40 41 # Start/end positions 42 tokenized["start_positions"] = [] 43 tokenized["end_positions"] = [] 44 45 for x, offset in enumerate(offsets): 46 # Label NO ANSWER with CLS token 47 inputids = tokenized["input_ids"][x] 48 clstoken = inputids.index(self.tokenizer.cls_token_id) 49 50 # Sequence ids 51 sequences = tokenized.sequence_ids(x) 52 53 # Get and format answer 54 answers = self.answers(data, samples[x]) 55 56 # If no answers are given, set cls token as answer. 57 if len(answers["answer_start"]) == 0: 58 tokenized["start_positions"].append(clstoken) 59 tokenized["end_positions"].append(clstoken) 60 else: 61 # Start/end character index of the answer in the text. 62 startchar = answers["answer_start"][0] 63 endchar = startchar + len(answers["text"][0]) 64 65 # Start token index of the current span in the text. 66 start = 0 67 while sequences[start] != (1 if self.rpad else 0): 68 start += 1 69 70 # End token index of the current span in the text. 71 end = len(inputids) - 1 72 while sequences[end] != (1 if self.rpad else 0): 73 end -= 1 74 75 # Map start character and end character to matching token index 76 while start < len(offset) and offset[start][0] <= startchar: 77 start += 1 78 tokenized["start_positions"].append(start - 1) 79 80 while offset[end][1] >= endchar: 81 end -= 1 82 tokenized["end_positions"].append(end + 1) 83 84 return tokenized 85 86 def tokenize(self, data): 87 """ 88 Tokenizes batch of data 89 90 Args: 91 data: input data batch 92 93 Returns: 94 tokenized data 95 """ 96 97 # Trim question whitespace 98 data[self.question] = [x.lstrip() for x in data[self.question]] 99 100 # Tokenize records 101 return self.tokenizer( 102 data[self.question if self.rpad else self.context], 103 data[self.context if self.rpad else self.question], 104 truncation="only_second" if self.rpad else "only_first", 105 max_length=self.maxlength, 106 stride=self.stride, 107 return_overflowing_tokens=True, 108 return_offsets_mapping=True, 109 padding=True, 110 ) 111 112 def answers(self, data, index): 113 """ 114 Gets and formats an answer. 115 116 Args: 117 data: input examples 118 index: answer index to retrieve 119 120 Returns: 121 answers dict 122 """ 123 124 # Answer mappings 125 answers = data[self.answer][index] 126 context = data[self.context][index] 127 128 # Handle mapping string answers to dict 129 if not isinstance(answers, dict): 130 if not answers: 131 answers = {"text": [], "answer_start": []} 132 else: 133 answers = {"text": [answers], "answer_start": [context.index(answers)]} 134 135 return answers