/ src / python / txtai / models / pooling / mean.py
mean.py
 1  """
 2  Mean module
 3  """
 4  
 5  import torch
 6  
 7  from .base import Pooling
 8  
 9  
10  class MeanPooling(Pooling):
11      """
12      Builds mean pooled vectors usings outputs from a transformers model.
13      """
14  
15      def forward(self, **inputs):
16          """
17          Runs mean pooling on token embeddings taking the input mask into account.
18  
19          Args:
20              inputs: model inputs
21  
22          Returns:
23              mean pooled embeddings using output token embeddings (i.e. last hidden state)
24          """
25  
26          # Run through transformers model
27          tokens = super().forward(**inputs)
28          mask = inputs["attention_mask"]
29  
30          # Mean pooling
31          # pylint: disable=E1101
32          mask = mask.unsqueeze(-1).expand(tokens.size()).float()
33          return torch.sum(tokens * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)