/ src / python / txtai / models / pooling / late.py
late.py
  1  """
  2  Late module
  3  """
  4  
  5  import numpy as np
  6  import torch
  7  
  8  from safetensors import safe_open
  9  from torch import nn
 10  from transformers.utils import cached_file
 11  
 12  from .base import Pooling
 13  from .muvera import Muvera
 14  
 15  
 16  class LatePooling(Pooling):
 17      """
 18      Builds late pooled vectors using outputs from a transformers model.
 19      """
 20  
 21      def __init__(self, path, device, tokenizer=None, maxlength=None, loadprompts=None, modelargs=None):
 22          # Check if fixed dimensional encoder is enabled
 23          modelargs = modelargs.copy() if modelargs else {}
 24          muvera = modelargs.pop("muvera", {})
 25          self.encoder = Muvera(**muvera) if muvera is not None else None
 26  
 27          # Call parent initialization
 28          super().__init__(path, device, tokenizer, maxlength, loadprompts, modelargs)
 29  
 30          # Get linear weights path
 31          config = self.load(path, "1_Dense/config.json")
 32          if config:
 33              # PyLate weights format
 34              name = "1_Dense/model.safetensors"
 35          else:
 36              # Stanford weights format
 37              name = "model.safetensors"
 38  
 39          # Read model settings
 40          self.qprefix, self.qlength, self.dprefix, self.dlength = self.settings(path, config)
 41  
 42          # Load linear layer
 43          path = cached_file(path_or_repo_id=path, filename=name)
 44          with safe_open(filename=path, framework="pt") as f:
 45              weights = f.get_tensor("linear.weight")
 46  
 47              # Load weights into linear layer
 48              self.linear = nn.Linear(weights.shape[1], weights.shape[0], bias=False, device=self.device, dtype=weights.dtype)
 49              with torch.no_grad():
 50                  self.linear.weight.copy_(weights)
 51  
 52      def forward(self, **inputs):
 53          """
 54          Runs late pooling on token embeddings.
 55  
 56          Args:
 57              inputs: model inputs
 58  
 59          Returns:
 60              Late pooled embeddings using output token embeddings (i.e. last hidden state)
 61          """
 62  
 63          # Run through transformers model
 64          tokens = super().forward(**inputs)
 65  
 66          # Run through final linear layer and return
 67          return self.linear(tokens)
 68  
 69      def preencode(self, documents, category):
 70          """
 71          Apply prefixes and lengths to data.
 72  
 73          Args:
 74              documents: list of documents used to build embeddings
 75              category: embeddings category (query or data)
 76          """
 77  
 78          results = []
 79  
 80          # Apply prefix
 81          for text in documents:
 82              prefix = self.qprefix if category == "query" else self.dprefix
 83              if prefix:
 84                  text = f"{prefix}{text}"
 85  
 86              results.append(text)
 87  
 88          # Set maxlength
 89          maxlength = self.qlength if category == "query" else self.dlength
 90          if maxlength:
 91              self.maxlength = maxlength
 92  
 93          return results
 94  
 95      def postencode(self, results, category):
 96          """
 97          Normalizes and pads results.
 98  
 99          Args:
100              results: input results
101  
102          Returns:
103              normalized results with padding
104          """
105  
106          length = 0
107          for vectors in results:
108              # Get max length
109              if vectors.shape[0] > length:
110                  length = vectors.shape[0]
111  
112              # Normalize vectors
113              vectors /= np.linalg.norm(vectors, axis=1)[:, np.newaxis]
114  
115          # Pad values
116          data = []
117          for vectors in results:
118              data.append(np.pad(vectors, [(0, length - vectors.shape[0]), (0, 0)]))
119  
120          # Build NumPy array
121          data = np.asarray(data)
122  
123          # Apply fixed dimesional encoder, if necessary
124          return self.encoder(data, category) if self.encoder else data
125  
126      def settings(self, path, config):
127          """
128          Reads model settings.
129  
130          Args:
131              path: model path
132              config: PyLate model format if provided, otherwise read from Stanford format
133          """
134  
135          if config:
136              # PyLate format
137              config = self.load(path, "config_sentence_transformers.json")
138              params = ["query_prefix", "query_length", "document_prefix", "document_length"]
139          else:
140              # Stanford format
141              config = self.load(path, "artifact.metadata")
142              params = ["query_token_id", "query_maxlen", "doc_token_id", "doc_maxlen"]
143  
144          return [config.get(p) for p in params]