/ src / python / txtai / ann / dense / torch.py
torch.py
  1  """
  2  PyTorch module
  3  """
  4  
  5  import numpy as np
  6  import torch
  7  
  8  try:
  9      from bitsandbytes import matmul_4bit
 10      from bitsandbytes.functional import (
 11          QuantState,
 12          int8_vectorwise_quant,
 13          int8_vectorwise_dequant,
 14          int8_linear_matmul,
 15          int8_mm_dequant,
 16          quantize_4bit,
 17          dequantize_4bit,
 18      )
 19  
 20      BNB = True
 21  except ImportError:
 22      BNB = False
 23  
 24  from .numpy import NumPy
 25  
 26  
 27  class Torch(NumPy):
 28      """
 29      Builds an ANN index backed by a PyTorch array.
 30      """
 31  
 32      def __init__(self, config):
 33          super().__init__(config)
 34  
 35          # Define array functions
 36          self.all, self.cat, self.dot, self.zeros = torch.all, torch.cat, torch.mm, torch.zeros
 37          self.argsort, self.xor, self.clip = torch.argsort, torch.bitwise_xor, torch.clip
 38  
 39          # Quantization parameters
 40          self.qstate, self.qdeleted = None, 0
 41  
 42          # Initialize quantization
 43          settings = self.qsettings()
 44          if settings:
 45              if not BNB:
 46                  raise ImportError('bitsandbytes is not available - install "ann" extra to enable')
 47  
 48              if settings.get("type") == "int8":
 49                  # Matrix multiply for 8 bit vectors
 50                  self.dot = self.matmul8bit
 51              else:
 52                  # Matrix multiply for 4 bit vectors
 53                  self.dot = self.matmul4bit
 54  
 55              # Require safetensors storage
 56              self.config[self.config["backend"]]["safetensors"] = True
 57  
 58      def index(self, embeddings):
 59          with QuantizeContext(self):
 60              super().index(embeddings)
 61  
 62      def append(self, embeddings):
 63          with QuantizeContext(self):
 64              super().append(embeddings)
 65  
 66      def delete(self, ids):
 67          with QuantizeContext(self):
 68              super().delete(ids)
 69  
 70              # Calculate deleted for quantized data, if necessary
 71              if self.qstate:
 72                  self.qdeleted = self.qstate.shape[0] - super().count()
 73  
 74      def count(self):
 75          return self.qstate.shape[0] - self.qdeleted if self.qstate else super().count()
 76  
 77      def tensor(self, array):
 78          # Convert array to Tensor
 79          if isinstance(array, np.ndarray):
 80              array = torch.from_numpy(array)
 81  
 82          # Load to GPU device, if available
 83          return array.cuda() if torch.cuda.is_available() else array
 84  
 85      def numpy(self, array):
 86          return array.cpu().numpy()
 87  
 88      def totype(self, array, dtype):
 89          return array.long() if dtype == np.int64 else array
 90  
 91      def settings(self):
 92          return {"torch": torch.__version__}
 93  
 94      def loadsafetensors(self, path):
 95          data = super().loadsafetensors(path)
 96  
 97          # Load quantization settings
 98          if self.qsettings():
 99              self.qstate = QuantState(
100                  absmax=self.tensor(data["absmax"]),
101                  shape=torch.Size(data["shape"].tolist()),
102                  code=self.tensor(data["code"]) if "code" in data else None,
103                  blocksize=int(data["blocksize"]) if "blocksize" in data else None,
104                  quant_type=data["quant_type"],
105                  dtype=getattr(torch, data["dtype"]),
106              )
107              self.qdeleted = int(data["qdeleted"])
108  
109          return data
110  
111      def savesafetensors(self, data, path, metadata=None):
112          # Save quantization settings
113          if self.qstate:
114              # Required elements
115              data["absmax"] = self.qstate.absmax.cpu().numpy()
116              data["shape"] = np.array(list(self.qstate.shape))
117  
118              metadata = {
119                  "quant_type": str(self.qstate.quant_type),
120                  "dtype": str(self.qstate.dtype).rsplit(".", maxsplit=1)[-1],
121                  "qdeleted": str(self.qdeleted),
122              }
123  
124              # Add optional elements
125              if self.qstate.code is not None:
126                  data["code"] = self.qstate.code.cpu().numpy()
127  
128              if self.qstate.blocksize:
129                  metadata["blocksize"] = str(self.qstate.blocksize)
130  
131          super().savesafetensors(data, path, metadata)
132  
133      def quantize(self):
134          """
135          Quantizes data if quantization if supported and enabled.
136          """
137  
138          # Get quantization settings and quantize
139          settings = self.qsettings()
140          if settings:
141              if settings.get("type") == "int8":
142                  # Get current backend config
143                  shape, dtype = self.backend.shape, self.backend.dtype
144  
145                  # 8-bit quantization
146                  self.backend, absmax, _ = int8_vectorwise_quant(self.backend.half())
147                  self.qstate = QuantState(absmax=absmax, shape=shape, quant_type=settings["type"], dtype=dtype)
148              else:
149                  # 4-bit quantization
150                  self.backend, self.qstate = quantize_4bit(
151                      self.backend, blocksize=settings.get("blocksize", 64), quant_type=settings.get("type", "nf4")
152                  )
153  
154      def dequantize(self):
155          """
156          Dequantizes data if quantization is supported and enabled.
157          """
158  
159          # Dequantize using current quantization state
160          if self.qstate:
161              if self.qstate.quant_type == "int8":
162                  # 8-bit quantization
163                  self.backend = int8_vectorwise_dequant(self.backend, self.qstate.absmax)
164              else:
165                  # 4-bit quantization
166                  self.backend = dequantize_4bit(self.backend, self.qstate)
167  
168      def qsettings(self):
169          """
170          Parse quantization settings. Only read parameters if CUDA is available.
171  
172          Returns:
173              {quantization settings}
174          """
175  
176          quantize = self.setting("quantize")
177          return {"quantize": True} if quantize and isinstance(quantize, bool) else quantize
178  
179      def matmul8bit(self, query, data):
180          """
181          8-bit integer matrix multiplication.
182  
183          Args:
184              query: query matrix
185              data: data matrix
186  
187          Returns:
188              query @ data
189          """
190  
191          # Matrix multiplication method requires transposing data matrix
192          query, absmax, _ = int8_vectorwise_quant(query.half())
193          return int8_mm_dequant(int8_linear_matmul(query, data.T), absmax, self.qstate.absmax).float()
194  
195      def matmul4bit(self, query, data):
196          """
197          4-bit float matrix multiplication.
198  
199          Args:
200              query: query matrix
201              data: data matrix
202  
203          Returns:
204              query @ data
205          """
206  
207          # Matrix multiplication method transposes data already
208          return matmul_4bit(query, data, self.qstate)
209  
210  
211  class QuantizeContext:
212      """
213      Quantization context. Facilitates modifications to quantized tensors.
214      """
215  
216      def __init__(self, ann):
217          self.ann = ann
218  
219      def __enter__(self):
220          self.ann.dequantize()
221  
222      def __exit__(self, exc_type, exc_val, exc_tb):
223          self.ann.quantize()