tokendetection.py
1 """ 2 Token Detection module 3 """ 4 5 import inspect 6 import os 7 8 import torch 9 10 from transformers import PreTrainedModel 11 12 13 class TokenDetection(PreTrainedModel): 14 """ 15 Runs the replaced token detection training objective. This method was first proposed by the ELECTRA model. 16 The method consists of a masked language model generator feeding data to a discriminator that determines 17 which of the tokens are incorrect. More on this training objective can be found in the ELECTRA paper. 18 """ 19 20 def __init__(self, generator, discriminator, tokenizer, weight=50.0): 21 """ 22 Creates a new TokenDetection class. 23 24 Args: 25 generator: Generator model, must be a masked language model 26 discriminator: Discriminator model, must be a model that can detect replaced tokens. Any model can 27 can be customized for this task. See ElectraForPretraining for more. 28 """ 29 30 # Initialize model with discriminator config 31 super().__init__(discriminator.config) 32 33 self.generator = generator 34 self.discriminator = discriminator 35 36 # Tokenizer to save with generator and discriminator 37 self.tokenizer = tokenizer 38 39 # Discriminator weight 40 self.weight = weight 41 42 # Share embeddings if both models are the same type 43 # Embeddings must be same size 44 if self.generator.config.model_type == self.discriminator.config.model_type: 45 self.discriminator.set_input_embeddings(self.generator.get_input_embeddings()) 46 47 # Set attention mask present flags 48 self.gattention = "attention_mask" in inspect.signature(self.generator.forward).parameters 49 self.dattention = "attention_mask" in inspect.signature(self.discriminator.forward).parameters 50 51 # pylint: disable=E1101 52 def forward(self, input_ids=None, labels=None, attention_mask=None, token_type_ids=None): 53 """ 54 Runs a forward pass through the model. This method runs the masked language model then randomly samples 55 the generated tokens and builds a binary classification problem for the discriminator (detecting if each token is correct). 56 57 Args: 58 input_ids: token ids 59 labels: token labels 60 attention_mask: attention mask 61 token_type_ids: segment token indices 62 63 Returns: 64 (loss, generator outputs, discriminator outputs, discriminator labels) 65 """ 66 67 # Copy input ids 68 dinputs = input_ids.clone() 69 70 # Run inputs through masked language model 71 inputs = {"attention_mask": attention_mask} if self.gattention else {} 72 goutputs = self.generator(input_ids, labels=labels, token_type_ids=token_type_ids, **inputs) 73 74 # Get predictions 75 preds = torch.softmax(goutputs[1], dim=-1) 76 preds = preds.view(-1, self.config.vocab_size) 77 78 tokens = torch.multinomial(preds, 1).view(-1) 79 tokens = tokens.view(dinputs.shape[0], -1) 80 81 # Labels have a -100 value to ignore loss from unchanged tokens 82 mask = labels.ne(-100) 83 84 # Replace the masked out tokens of the input with the generator predictions 85 dinputs[mask] = tokens[mask] 86 87 # Turn mask into new target labels - 1 (True) for corrupted, 0 otherwise. 88 # If the prediction was correct, mark it as uncorrupted. 89 correct = tokens == labels 90 dlabels = mask.long() 91 dlabels[correct] = 0 92 93 # Run token classification, predict whether each token was corrupted 94 inputs = {"attention_mask": attention_mask} if self.dattention else {} 95 doutputs = self.discriminator(dinputs, labels=dlabels, token_type_ids=token_type_ids, **inputs) 96 97 # Compute combined loss 98 loss = goutputs[0] + self.weight * doutputs[0] 99 return loss, goutputs[1], doutputs[1], dlabels 100 101 def save_pretrained(self, output, state_dict=None, **kwargs): 102 """ 103 Saves current model to output directory. 104 105 Args: 106 output: output directory 107 state_dict: model state 108 kwargs: additional keyword arguments 109 """ 110 111 # Save combined model to support training from checkpoints 112 super().save_pretrained(output, state_dict, **kwargs) 113 114 # Save generator tokenizer and model 115 gpath = os.path.join(output, "generator") 116 self.tokenizer.save_pretrained(gpath) 117 self.generator.save_pretrained(gpath) 118 119 # Save discriminator tokenizer and model 120 dpath = os.path.join(output, "discriminator") 121 self.tokenizer.save_pretrained(dpath) 122 self.discriminator.save_pretrained(dpath)