torch.py
1 """ 2 PyTorch module 3 """ 4 5 import numpy as np 6 import torch 7 8 try: 9 from bitsandbytes import matmul_4bit 10 from bitsandbytes.functional import ( 11 QuantState, 12 int8_vectorwise_quant, 13 int8_vectorwise_dequant, 14 int8_linear_matmul, 15 int8_mm_dequant, 16 quantize_4bit, 17 dequantize_4bit, 18 ) 19 20 BNB = True 21 except ImportError: 22 BNB = False 23 24 from .numpy import NumPy 25 26 27 class Torch(NumPy): 28 """ 29 Builds an ANN index backed by a PyTorch array. 30 """ 31 32 def __init__(self, config): 33 super().__init__(config) 34 35 # Define array functions 36 self.all, self.cat, self.dot, self.zeros = torch.all, torch.cat, torch.mm, torch.zeros 37 self.argsort, self.xor, self.clip = torch.argsort, torch.bitwise_xor, torch.clip 38 39 # Quantization parameters 40 self.qstate, self.qdeleted = None, 0 41 42 # Initialize quantization 43 settings = self.qsettings() 44 if settings: 45 if not BNB: 46 raise ImportError('bitsandbytes is not available - install "ann" extra to enable') 47 48 if settings.get("type") == "int8": 49 # Matrix multiply for 8 bit vectors 50 self.dot = self.matmul8bit 51 else: 52 # Matrix multiply for 4 bit vectors 53 self.dot = self.matmul4bit 54 55 # Require safetensors storage 56 self.config[self.config["backend"]]["safetensors"] = True 57 58 def index(self, embeddings): 59 with QuantizeContext(self): 60 super().index(embeddings) 61 62 def append(self, embeddings): 63 with QuantizeContext(self): 64 super().append(embeddings) 65 66 def delete(self, ids): 67 with QuantizeContext(self): 68 super().delete(ids) 69 70 # Calculate deleted for quantized data, if necessary 71 if self.qstate: 72 self.qdeleted = self.qstate.shape[0] - super().count() 73 74 def count(self): 75 return self.qstate.shape[0] - self.qdeleted if self.qstate else super().count() 76 77 def tensor(self, array): 78 # Convert array to Tensor 79 if isinstance(array, np.ndarray): 80 array = torch.from_numpy(array) 81 82 # Load to GPU device, if available 83 return array.cuda() if torch.cuda.is_available() else array 84 85 def numpy(self, array): 86 return array.cpu().numpy() 87 88 def totype(self, array, dtype): 89 return array.long() if dtype == np.int64 else array 90 91 def settings(self): 92 return {"torch": torch.__version__} 93 94 def loadsafetensors(self, path): 95 data = super().loadsafetensors(path) 96 97 # Load quantization settings 98 if self.qsettings(): 99 self.qstate = QuantState( 100 absmax=self.tensor(data["absmax"]), 101 shape=torch.Size(data["shape"].tolist()), 102 code=self.tensor(data["code"]) if "code" in data else None, 103 blocksize=int(data["blocksize"]) if "blocksize" in data else None, 104 quant_type=data["quant_type"], 105 dtype=getattr(torch, data["dtype"]), 106 ) 107 self.qdeleted = int(data["qdeleted"]) 108 109 return data 110 111 def savesafetensors(self, data, path, metadata=None): 112 # Save quantization settings 113 if self.qstate: 114 # Required elements 115 data["absmax"] = self.qstate.absmax.cpu().numpy() 116 data["shape"] = np.array(list(self.qstate.shape)) 117 118 metadata = { 119 "quant_type": str(self.qstate.quant_type), 120 "dtype": str(self.qstate.dtype).rsplit(".", maxsplit=1)[-1], 121 "qdeleted": str(self.qdeleted), 122 } 123 124 # Add optional elements 125 if self.qstate.code is not None: 126 data["code"] = self.qstate.code.cpu().numpy() 127 128 if self.qstate.blocksize: 129 metadata["blocksize"] = str(self.qstate.blocksize) 130 131 super().savesafetensors(data, path, metadata) 132 133 def quantize(self): 134 """ 135 Quantizes data if quantization if supported and enabled. 136 """ 137 138 # Get quantization settings and quantize 139 settings = self.qsettings() 140 if settings: 141 if settings.get("type") == "int8": 142 # Get current backend config 143 shape, dtype = self.backend.shape, self.backend.dtype 144 145 # 8-bit quantization 146 self.backend, absmax, _ = int8_vectorwise_quant(self.backend.half()) 147 self.qstate = QuantState(absmax=absmax, shape=shape, quant_type=settings["type"], dtype=dtype) 148 else: 149 # 4-bit quantization 150 self.backend, self.qstate = quantize_4bit( 151 self.backend, blocksize=settings.get("blocksize", 64), quant_type=settings.get("type", "nf4") 152 ) 153 154 def dequantize(self): 155 """ 156 Dequantizes data if quantization is supported and enabled. 157 """ 158 159 # Dequantize using current quantization state 160 if self.qstate: 161 if self.qstate.quant_type == "int8": 162 # 8-bit quantization 163 self.backend = int8_vectorwise_dequant(self.backend, self.qstate.absmax) 164 else: 165 # 4-bit quantization 166 self.backend = dequantize_4bit(self.backend, self.qstate) 167 168 def qsettings(self): 169 """ 170 Parse quantization settings. Only read parameters if CUDA is available. 171 172 Returns: 173 {quantization settings} 174 """ 175 176 quantize = self.setting("quantize") 177 return {"quantize": True} if quantize and isinstance(quantize, bool) else quantize 178 179 def matmul8bit(self, query, data): 180 """ 181 8-bit integer matrix multiplication. 182 183 Args: 184 query: query matrix 185 data: data matrix 186 187 Returns: 188 query @ data 189 """ 190 191 # Matrix multiplication method requires transposing data matrix 192 query, absmax, _ = int8_vectorwise_quant(query.half()) 193 return int8_mm_dequant(int8_linear_matmul(query, data.T), absmax, self.qstate.absmax).float() 194 195 def matmul4bit(self, query, data): 196 """ 197 4-bit float matrix multiplication. 198 199 Args: 200 query: query matrix 201 data: data matrix 202 203 Returns: 204 query @ data 205 """ 206 207 # Matrix multiplication method transposes data already 208 return matmul_4bit(query, data, self.qstate) 209 210 211 class QuantizeContext: 212 """ 213 Quantization context. Facilitates modifications to quantized tensors. 214 """ 215 216 def __init__(self, ann): 217 self.ann = ann 218 219 def __enter__(self): 220 self.ann.dequantize() 221 222 def __exit__(self, exc_type, exc_val, exc_tb): 223 self.ann.quantize()