/ src / python / txtai / models / tokendetection.py
tokendetection.py
  1  """
  2  Token Detection module
  3  """
  4  
  5  import inspect
  6  import os
  7  
  8  import torch
  9  
 10  from transformers import PreTrainedModel
 11  
 12  
 13  class TokenDetection(PreTrainedModel):
 14      """
 15      Runs the replaced token detection training objective. This method was first proposed by the ELECTRA model.
 16      The method consists of a masked language model generator feeding data to a discriminator that determines
 17      which of the tokens are incorrect. More on this training objective can be found in the ELECTRA paper.
 18      """
 19  
 20      def __init__(self, generator, discriminator, tokenizer, weight=50.0):
 21          """
 22          Creates a new TokenDetection class.
 23  
 24          Args:
 25              generator: Generator model, must be a masked language model
 26              discriminator: Discriminator model, must be a model that can detect replaced tokens. Any model can
 27                             can be customized for this task. See ElectraForPretraining for more.
 28          """
 29  
 30          # Initialize model with discriminator config
 31          super().__init__(discriminator.config)
 32  
 33          self.generator = generator
 34          self.discriminator = discriminator
 35  
 36          # Tokenizer to save with generator and discriminator
 37          self.tokenizer = tokenizer
 38  
 39          # Discriminator weight
 40          self.weight = weight
 41  
 42          # Share embeddings if both models are the same type
 43          # Embeddings must be same size
 44          if self.generator.config.model_type == self.discriminator.config.model_type:
 45              self.discriminator.set_input_embeddings(self.generator.get_input_embeddings())
 46  
 47          # Set attention mask present flags
 48          self.gattention = "attention_mask" in inspect.signature(self.generator.forward).parameters
 49          self.dattention = "attention_mask" in inspect.signature(self.discriminator.forward).parameters
 50  
 51      # pylint: disable=E1101
 52      def forward(self, input_ids=None, labels=None, attention_mask=None, token_type_ids=None):
 53          """
 54          Runs a forward pass through the model. This method runs the masked language model then randomly samples
 55          the generated tokens and builds a binary classification problem for the discriminator (detecting if each token is correct).
 56  
 57          Args:
 58              input_ids: token ids
 59              labels: token labels
 60              attention_mask: attention mask
 61              token_type_ids: segment token indices
 62  
 63          Returns:
 64              (loss, generator outputs, discriminator outputs, discriminator labels)
 65          """
 66  
 67          # Copy input ids
 68          dinputs = input_ids.clone()
 69  
 70          # Run inputs through masked language model
 71          inputs = {"attention_mask": attention_mask} if self.gattention else {}
 72          goutputs = self.generator(input_ids, labels=labels, token_type_ids=token_type_ids, **inputs)
 73  
 74          # Get predictions
 75          preds = torch.softmax(goutputs[1], dim=-1)
 76          preds = preds.view(-1, self.config.vocab_size)
 77  
 78          tokens = torch.multinomial(preds, 1).view(-1)
 79          tokens = tokens.view(dinputs.shape[0], -1)
 80  
 81          # Labels have a -100 value to ignore loss from unchanged tokens
 82          mask = labels.ne(-100)
 83  
 84          # Replace the masked out tokens of the input with the generator predictions
 85          dinputs[mask] = tokens[mask]
 86  
 87          # Turn mask into new target labels - 1 (True) for corrupted, 0 otherwise.
 88          # If the prediction was correct, mark it as uncorrupted.
 89          correct = tokens == labels
 90          dlabels = mask.long()
 91          dlabels[correct] = 0
 92  
 93          # Run token classification, predict whether each token was corrupted
 94          inputs = {"attention_mask": attention_mask} if self.dattention else {}
 95          doutputs = self.discriminator(dinputs, labels=dlabels, token_type_ids=token_type_ids, **inputs)
 96  
 97          # Compute combined loss
 98          loss = goutputs[0] + self.weight * doutputs[0]
 99          return loss, goutputs[1], doutputs[1], dlabels
100  
101      def save_pretrained(self, output, state_dict=None, **kwargs):
102          """
103          Saves current model to output directory.
104  
105          Args:
106              output: output directory
107              state_dict: model state
108              kwargs: additional keyword arguments
109          """
110  
111          # Save combined model to support training from checkpoints
112          super().save_pretrained(output, state_dict, **kwargs)
113  
114          # Save generator tokenizer and model
115          gpath = os.path.join(output, "generator")
116          self.tokenizer.save_pretrained(gpath)
117          self.generator.save_pretrained(gpath)
118  
119          # Save discriminator tokenizer and model
120          dpath = os.path.join(output, "discriminator")
121          self.tokenizer.save_pretrained(dpath)
122          self.discriminator.save_pretrained(dpath)