last.py
1 """ 2 Last module 3 """ 4 5 import torch 6 7 from .base import Pooling 8 9 10 class LastPooling(Pooling): 11 """ 12 Builds last token pooled vectors usings outputs from a transformers model. 13 """ 14 15 def forward(self, **inputs): 16 """ 17 Runs last pooling on token embeddings taking the input mask into account. 18 19 Args: 20 inputs: model inputs 21 22 Returns: 23 last pooled embeddings using output token embeddings (i.e. last hidden state) 24 """ 25 26 # Run through transformers model 27 tokens = super().forward(**inputs) 28 mask = inputs["attention_mask"] 29 30 # Last pooling logic from Sentence Transformers 31 _, sequence, dimensions = tokens.shape 32 33 # Avoid tracing the argmax with int64 input that can not be handled by ONNX Runtime 34 mask = mask.to(torch.int32) if torch.jit.is_tracing() else mask 35 36 # Use flip and max() to get the last index of 1 in the attention mask 37 values, indices = mask.flip(1).max(1) 38 indices = torch.where(values == 0, sequence - 1, indices) 39 gather = sequence - indices - 1 40 41 # Turn indices from shape [bs] --> [bs, 1, hidden_dim] 42 gather = gather.unsqueeze(-1).repeat(1, dimensions) 43 gather = gather.unsqueeze(1) 44 45 # Expand mask to ignore 0 index attention masks 46 mask = mask.unsqueeze(-1).expand(tokens.size()).to(tokens.dtype) 47 48 # Return last pooled embeddings 49 return torch.gather(tokens * mask, 1, gather).squeeze(dim=1)