/ src / python / txtai / models / pooling / base.py
base.py
  1  """
  2  Pooling module
  3  """
  4  
  5  import json
  6  
  7  import numpy as np
  8  import torch
  9  
 10  from huggingface_hub.errors import HFValidationError
 11  from torch import nn
 12  from transformers.utils import cached_file
 13  
 14  from ..models import Models
 15  
 16  
 17  class Pooling(nn.Module):
 18      """
 19      Builds pooled vectors usings outputs from a transformers model.
 20      """
 21  
 22      def __init__(self, path, device, tokenizer=None, maxlength=None, loadprompts=None, modelargs=None):
 23          """
 24          Creates a new Pooling model.
 25  
 26          Args:
 27              path: path to model, accepts Hugging Face model hub id or local path
 28              device: tensor device id
 29              tokenizer: optional path to tokenizer
 30              maxlength: max sequence length
 31              loadprompts: whether instruction prompts should be loaded
 32              modelargs: additional model arguments
 33          """
 34  
 35          super().__init__()
 36  
 37          self.model = Models.load(path, modelargs=modelargs)
 38          self.tokenizer = Models.tokenizer(tokenizer if tokenizer else path)
 39          self.device = Models.device(device)
 40  
 41          # Detect unbounded tokenizer typically found in older models
 42          Models.checklength(self.model, self.tokenizer)
 43  
 44          # Set max length
 45          self.maxlength = maxlength if maxlength else self.tokenizer.model_max_length if self.tokenizer.model_max_length != int(1e30) else None
 46  
 47          # Load stored prompts
 48          self.prompts = self.loadprompts(path) if loadprompts else None
 49  
 50          # Move to device
 51          self.to(self.device)
 52  
 53      def encode(self, documents, batch=32, category=None):
 54          """
 55          Builds an array of pooled embeddings for documents.
 56  
 57          Args:
 58              documents: list of documents used to build embeddings
 59              batch: model batch size
 60              category: embeddings category (query or data)
 61  
 62          Returns:
 63              pooled embeddings
 64          """
 65  
 66          # Split documents into batches and process
 67          results = []
 68  
 69          # Apply pre encoding transformation logic
 70          documents = self.preencode(documents, category)
 71  
 72          # Sort document indices from largest to smallest to enable efficient batching
 73          # This performance tweak matches logic in sentence-transformers
 74          lengths = np.argsort([-len(x) if x else 0 for x in documents])
 75          documents = [documents[x] for x in lengths]
 76  
 77          for chunk in self.chunk(documents, batch):
 78              # Tokenize input
 79              inputs = self.tokenizer(chunk, padding=True, truncation="longest_first", return_tensors="pt", max_length=self.maxlength)
 80  
 81              # Move inputs to device
 82              inputs = inputs.to(self.device)
 83  
 84              # Run inputs through model
 85              with torch.no_grad():
 86                  outputs = self.forward(**inputs)
 87  
 88              # Add batch result
 89              results.extend(outputs.cpu().to(torch.float32).numpy())
 90  
 91          # Apply post encoding transformation logic
 92          results = self.postencode(results, category)
 93  
 94          # Restore original order and return array
 95          return np.asarray([results[x] for x in np.argsort(lengths)])
 96  
 97      def chunk(self, texts, size):
 98          """
 99          Splits texts into separate batch sizes specified by size.
100  
101          Args:
102              texts: text elements
103              size: batch size
104  
105          Returns:
106              list of evenly sized batches with the last batch having the remaining elements
107          """
108  
109          return [texts[x : x + size] for x in range(0, len(texts), size)]
110  
111      def forward(self, **inputs):
112          """
113          Runs inputs through transformers model and returns outputs.
114  
115          Args:
116              inputs: model inputs
117  
118          Returns:
119              model outputs
120          """
121  
122          return self.model(**inputs)[0]
123  
124      # pylint: disable=W0613
125      def preencode(self, documents, category):
126          """
127          Applies pre encoding transformation logic.
128  
129          Args:
130              documents: list of documents used to build embeddings
131              category: embeddings category (query or data)
132          """
133  
134          # Prepend prompt
135          prompt = self.prompts.get(category) if self.prompts else None
136          if prompt:
137              documents = [f"{prompt}{x}" if isinstance(x, str) else x for x in documents]
138  
139          return documents
140  
141      # pylint: disable=W0613
142      def postencode(self, results, category):
143          """
144          Applies post encoding transformation logic.
145  
146          Args:
147              results: list of results
148              category: embeddings category (query or data)
149  
150          Returns:
151              results with transformation logic applied
152          """
153  
154          return results
155  
156      def load(self, path, name):
157          """
158          Loads a JSON config file from the Hugging Face Hub.
159  
160          Args:
161              path: model path
162              name: file to load
163  
164          Returns:
165              config
166          """
167  
168          # Download file and parse JSON
169          config = None
170          try:
171              path = cached_file(path_or_repo_id=path, filename=name)
172              if path:
173                  with open(path, encoding="utf-8") as f:
174                      config = json.load(f)
175  
176          # Ignore this error - invalid repo or directory
177          except (HFValidationError, OSError):
178              pass
179  
180          return config
181  
182      def loadprompts(self, path):
183          """
184          Loads prompts from a sentence transformers configuration file.
185  
186          Args:
187              path: model path
188  
189          Returns:
190              prompts dictionary, if available
191          """
192  
193          prompts = None
194          config = self.load(path, "config_sentence_transformers.json")
195          if config:
196              # Copy document prompt to data
197              prompts = config.get("prompts")
198              if prompts and "document" in prompts:
199                  prompts["data"] = prompts["document"]
200  
201          return prompts