late.py
1 """ 2 Late module 3 """ 4 5 import numpy as np 6 import torch 7 8 from safetensors import safe_open 9 from torch import nn 10 from transformers.utils import cached_file 11 12 from .base import Pooling 13 from .muvera import Muvera 14 15 16 class LatePooling(Pooling): 17 """ 18 Builds late pooled vectors using outputs from a transformers model. 19 """ 20 21 def __init__(self, path, device, tokenizer=None, maxlength=None, loadprompts=None, modelargs=None): 22 # Check if fixed dimensional encoder is enabled 23 modelargs = modelargs.copy() if modelargs else {} 24 muvera = modelargs.pop("muvera", {}) 25 self.encoder = Muvera(**muvera) if muvera is not None else None 26 27 # Call parent initialization 28 super().__init__(path, device, tokenizer, maxlength, loadprompts, modelargs) 29 30 # Get linear weights path 31 config = self.load(path, "1_Dense/config.json") 32 if config: 33 # PyLate weights format 34 name = "1_Dense/model.safetensors" 35 else: 36 # Stanford weights format 37 name = "model.safetensors" 38 39 # Read model settings 40 self.qprefix, self.qlength, self.dprefix, self.dlength = self.settings(path, config) 41 42 # Load linear layer 43 path = cached_file(path_or_repo_id=path, filename=name) 44 with safe_open(filename=path, framework="pt") as f: 45 weights = f.get_tensor("linear.weight") 46 47 # Load weights into linear layer 48 self.linear = nn.Linear(weights.shape[1], weights.shape[0], bias=False, device=self.device, dtype=weights.dtype) 49 with torch.no_grad(): 50 self.linear.weight.copy_(weights) 51 52 def forward(self, **inputs): 53 """ 54 Runs late pooling on token embeddings. 55 56 Args: 57 inputs: model inputs 58 59 Returns: 60 Late pooled embeddings using output token embeddings (i.e. last hidden state) 61 """ 62 63 # Run through transformers model 64 tokens = super().forward(**inputs) 65 66 # Run through final linear layer and return 67 return self.linear(tokens) 68 69 def preencode(self, documents, category): 70 """ 71 Apply prefixes and lengths to data. 72 73 Args: 74 documents: list of documents used to build embeddings 75 category: embeddings category (query or data) 76 """ 77 78 results = [] 79 80 # Apply prefix 81 for text in documents: 82 prefix = self.qprefix if category == "query" else self.dprefix 83 if prefix: 84 text = f"{prefix}{text}" 85 86 results.append(text) 87 88 # Set maxlength 89 maxlength = self.qlength if category == "query" else self.dlength 90 if maxlength: 91 self.maxlength = maxlength 92 93 return results 94 95 def postencode(self, results, category): 96 """ 97 Normalizes and pads results. 98 99 Args: 100 results: input results 101 102 Returns: 103 normalized results with padding 104 """ 105 106 length = 0 107 for vectors in results: 108 # Get max length 109 if vectors.shape[0] > length: 110 length = vectors.shape[0] 111 112 # Normalize vectors 113 vectors /= np.linalg.norm(vectors, axis=1)[:, np.newaxis] 114 115 # Pad values 116 data = [] 117 for vectors in results: 118 data.append(np.pad(vectors, [(0, length - vectors.shape[0]), (0, 0)])) 119 120 # Build NumPy array 121 data = np.asarray(data) 122 123 # Apply fixed dimesional encoder, if necessary 124 return self.encoder(data, category) if self.encoder else data 125 126 def settings(self, path, config): 127 """ 128 Reads model settings. 129 130 Args: 131 path: model path 132 config: PyLate model format if provided, otherwise read from Stanford format 133 """ 134 135 if config: 136 # PyLate format 137 config = self.load(path, "config_sentence_transformers.json") 138 params = ["query_prefix", "query_length", "document_prefix", "document_length"] 139 else: 140 # Stanford format 141 config = self.load(path, "artifact.metadata") 142 params = ["query_token_id", "query_maxlen", "doc_token_id", "doc_maxlen"] 143 144 return [config.get(p) for p in params]