base.py
1 """ 2 Pooling module 3 """ 4 5 import json 6 7 import numpy as np 8 import torch 9 10 from huggingface_hub.errors import HFValidationError 11 from torch import nn 12 from transformers.utils import cached_file 13 14 from ..models import Models 15 16 17 class Pooling(nn.Module): 18 """ 19 Builds pooled vectors usings outputs from a transformers model. 20 """ 21 22 def __init__(self, path, device, tokenizer=None, maxlength=None, loadprompts=None, modelargs=None): 23 """ 24 Creates a new Pooling model. 25 26 Args: 27 path: path to model, accepts Hugging Face model hub id or local path 28 device: tensor device id 29 tokenizer: optional path to tokenizer 30 maxlength: max sequence length 31 loadprompts: whether instruction prompts should be loaded 32 modelargs: additional model arguments 33 """ 34 35 super().__init__() 36 37 self.model = Models.load(path, modelargs=modelargs) 38 self.tokenizer = Models.tokenizer(tokenizer if tokenizer else path) 39 self.device = Models.device(device) 40 41 # Detect unbounded tokenizer typically found in older models 42 Models.checklength(self.model, self.tokenizer) 43 44 # Set max length 45 self.maxlength = maxlength if maxlength else self.tokenizer.model_max_length if self.tokenizer.model_max_length != int(1e30) else None 46 47 # Load stored prompts 48 self.prompts = self.loadprompts(path) if loadprompts else None 49 50 # Move to device 51 self.to(self.device) 52 53 def encode(self, documents, batch=32, category=None): 54 """ 55 Builds an array of pooled embeddings for documents. 56 57 Args: 58 documents: list of documents used to build embeddings 59 batch: model batch size 60 category: embeddings category (query or data) 61 62 Returns: 63 pooled embeddings 64 """ 65 66 # Split documents into batches and process 67 results = [] 68 69 # Apply pre encoding transformation logic 70 documents = self.preencode(documents, category) 71 72 # Sort document indices from largest to smallest to enable efficient batching 73 # This performance tweak matches logic in sentence-transformers 74 lengths = np.argsort([-len(x) if x else 0 for x in documents]) 75 documents = [documents[x] for x in lengths] 76 77 for chunk in self.chunk(documents, batch): 78 # Tokenize input 79 inputs = self.tokenizer(chunk, padding=True, truncation="longest_first", return_tensors="pt", max_length=self.maxlength) 80 81 # Move inputs to device 82 inputs = inputs.to(self.device) 83 84 # Run inputs through model 85 with torch.no_grad(): 86 outputs = self.forward(**inputs) 87 88 # Add batch result 89 results.extend(outputs.cpu().to(torch.float32).numpy()) 90 91 # Apply post encoding transformation logic 92 results = self.postencode(results, category) 93 94 # Restore original order and return array 95 return np.asarray([results[x] for x in np.argsort(lengths)]) 96 97 def chunk(self, texts, size): 98 """ 99 Splits texts into separate batch sizes specified by size. 100 101 Args: 102 texts: text elements 103 size: batch size 104 105 Returns: 106 list of evenly sized batches with the last batch having the remaining elements 107 """ 108 109 return [texts[x : x + size] for x in range(0, len(texts), size)] 110 111 def forward(self, **inputs): 112 """ 113 Runs inputs through transformers model and returns outputs. 114 115 Args: 116 inputs: model inputs 117 118 Returns: 119 model outputs 120 """ 121 122 return self.model(**inputs)[0] 123 124 # pylint: disable=W0613 125 def preencode(self, documents, category): 126 """ 127 Applies pre encoding transformation logic. 128 129 Args: 130 documents: list of documents used to build embeddings 131 category: embeddings category (query or data) 132 """ 133 134 # Prepend prompt 135 prompt = self.prompts.get(category) if self.prompts else None 136 if prompt: 137 documents = [f"{prompt}{x}" if isinstance(x, str) else x for x in documents] 138 139 return documents 140 141 # pylint: disable=W0613 142 def postencode(self, results, category): 143 """ 144 Applies post encoding transformation logic. 145 146 Args: 147 results: list of results 148 category: embeddings category (query or data) 149 150 Returns: 151 results with transformation logic applied 152 """ 153 154 return results 155 156 def load(self, path, name): 157 """ 158 Loads a JSON config file from the Hugging Face Hub. 159 160 Args: 161 path: model path 162 name: file to load 163 164 Returns: 165 config 166 """ 167 168 # Download file and parse JSON 169 config = None 170 try: 171 path = cached_file(path_or_repo_id=path, filename=name) 172 if path: 173 with open(path, encoding="utf-8") as f: 174 config = json.load(f) 175 176 # Ignore this error - invalid repo or directory 177 except (HFValidationError, OSError): 178 pass 179 180 return config 181 182 def loadprompts(self, path): 183 """ 184 Loads prompts from a sentence transformers configuration file. 185 186 Args: 187 path: model path 188 189 Returns: 190 prompts dictionary, if available 191 """ 192 193 prompts = None 194 config = self.load(path, "config_sentence_transformers.json") 195 if config: 196 # Copy document prompt to data 197 prompts = config.get("prompts") 198 if prompts and "document" in prompts: 199 prompts["data"] = prompts["document"] 200 201 return prompts