mean.py
1 """ 2 Mean module 3 """ 4 5 import torch 6 7 from .base import Pooling 8 9 10 class MeanPooling(Pooling): 11 """ 12 Builds mean pooled vectors usings outputs from a transformers model. 13 """ 14 15 def forward(self, **inputs): 16 """ 17 Runs mean pooling on token embeddings taking the input mask into account. 18 19 Args: 20 inputs: model inputs 21 22 Returns: 23 mean pooled embeddings using output token embeddings (i.e. last hidden state) 24 """ 25 26 # Run through transformers model 27 tokens = super().forward(**inputs) 28 mask = inputs["attention_mask"] 29 30 # Mean pooling 31 # pylint: disable=E1101 32 mask = mask.unsqueeze(-1).expand(tokens.size()).float() 33 return torch.sum(tokens * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)