/ src / python / txtai / models / pooling / last.py
last.py
 1  """
 2  Last module
 3  """
 4  
 5  import torch
 6  
 7  from .base import Pooling
 8  
 9  
10  class LastPooling(Pooling):
11      """
12      Builds last token pooled vectors usings outputs from a transformers model.
13      """
14  
15      def forward(self, **inputs):
16          """
17          Runs last pooling on token embeddings taking the input mask into account.
18  
19          Args:
20              inputs: model inputs
21  
22          Returns:
23              last pooled embeddings using output token embeddings (i.e. last hidden state)
24          """
25  
26          # Run through transformers model
27          tokens = super().forward(**inputs)
28          mask = inputs["attention_mask"]
29  
30          # Last pooling logic from Sentence Transformers
31          _, sequence, dimensions = tokens.shape
32  
33          # Avoid tracing the argmax with int64 input that can not be handled by ONNX Runtime
34          mask = mask.to(torch.int32) if torch.jit.is_tracing() else mask
35  
36          # Use flip and max() to get the last index of 1 in the attention mask
37          values, indices = mask.flip(1).max(1)
38          indices = torch.where(values == 0, sequence - 1, indices)
39          gather = sequence - indices - 1
40  
41          # Turn indices from shape [bs] --> [bs, 1, hidden_dim]
42          gather = gather.unsqueeze(-1).repeat(1, dimensions)
43          gather = gather.unsqueeze(1)
44  
45          # Expand mask to ignore 0 index attention masks
46          mask = mask.unsqueeze(-1).expand(tokens.size()).to(tokens.dtype)
47  
48          # Return last pooled embeddings
49          return torch.gather(tokens * mask, 1, gather).squeeze(dim=1)