/ src / python / txtai / data / questions.py
questions.py
  1  """
  2  Questions module
  3  """
  4  
  5  from .base import Data
  6  
  7  
  8  class Questions(Data):
  9      """
 10      Tokenizes question-answering datasets as input for training question-answering models.
 11      """
 12  
 13      def __init__(self, tokenizer, columns, maxlength, stride):
 14          """
 15          Creates a new instance for tokenizing Questions training data.
 16  
 17          Args:
 18              tokenizer: model tokenizer
 19              columns: tuple of columns to use for question/context/answer
 20              maxlength: maximum sequence length
 21              stride: chunk size for splitting data for QA tasks
 22          """
 23  
 24          super().__init__(tokenizer, columns, maxlength)
 25  
 26          if not self.columns:
 27              self.columns = ("question", "context", "answers")
 28  
 29          self.question, self.context, self.answer = self.columns
 30          self.stride = stride
 31          self.rpad = tokenizer.padding_side == "right"
 32  
 33      def process(self, data):
 34          # Tokenize data
 35          tokenized = self.tokenize(data)
 36  
 37          # Get mapping of overflowing tokens and answer offsets
 38          samples = tokenized.pop("overflow_to_sample_mapping")
 39          offsets = tokenized.pop("offset_mapping")
 40  
 41          # Start/end positions
 42          tokenized["start_positions"] = []
 43          tokenized["end_positions"] = []
 44  
 45          for x, offset in enumerate(offsets):
 46              # Label NO ANSWER with CLS token
 47              inputids = tokenized["input_ids"][x]
 48              clstoken = inputids.index(self.tokenizer.cls_token_id)
 49  
 50              # Sequence ids
 51              sequences = tokenized.sequence_ids(x)
 52  
 53              # Get and format answer
 54              answers = self.answers(data, samples[x])
 55  
 56              # If no answers are given, set cls token as answer.
 57              if len(answers["answer_start"]) == 0:
 58                  tokenized["start_positions"].append(clstoken)
 59                  tokenized["end_positions"].append(clstoken)
 60              else:
 61                  # Start/end character index of the answer in the text.
 62                  startchar = answers["answer_start"][0]
 63                  endchar = startchar + len(answers["text"][0])
 64  
 65                  # Start token index of the current span in the text.
 66                  start = 0
 67                  while sequences[start] != (1 if self.rpad else 0):
 68                      start += 1
 69  
 70                  # End token index of the current span in the text.
 71                  end = len(inputids) - 1
 72                  while sequences[end] != (1 if self.rpad else 0):
 73                      end -= 1
 74  
 75                  # Map start character and end character to matching token index
 76                  while start < len(offset) and offset[start][0] <= startchar:
 77                      start += 1
 78                  tokenized["start_positions"].append(start - 1)
 79  
 80                  while offset[end][1] >= endchar:
 81                      end -= 1
 82                  tokenized["end_positions"].append(end + 1)
 83  
 84          return tokenized
 85  
 86      def tokenize(self, data):
 87          """
 88          Tokenizes batch of data
 89  
 90          Args:
 91              data: input data batch
 92  
 93          Returns:
 94              tokenized data
 95          """
 96  
 97          # Trim question whitespace
 98          data[self.question] = [x.lstrip() for x in data[self.question]]
 99  
100          # Tokenize records
101          return self.tokenizer(
102              data[self.question if self.rpad else self.context],
103              data[self.context if self.rpad else self.question],
104              truncation="only_second" if self.rpad else "only_first",
105              max_length=self.maxlength,
106              stride=self.stride,
107              return_overflowing_tokens=True,
108              return_offsets_mapping=True,
109              padding=True,
110          )
111  
112      def answers(self, data, index):
113          """
114          Gets and formats an answer.
115  
116          Args:
117              data: input examples
118              index: answer index to retrieve
119  
120          Returns:
121              answers dict
122          """
123  
124          # Answer mappings
125          answers = data[self.answer][index]
126          context = data[self.context][index]
127  
128          # Handle mapping string answers to dict
129          if not isinstance(answers, dict):
130              if not answers:
131                  answers = {"text": [], "answer_start": []}
132              else:
133                  answers = {"text": [answers], "answer_start": [context.index(answers)]}
134  
135          return answers