pgsparse.py
1 """ 2 PGSparse module 3 """ 4 5 import os 6 7 import numpy as np 8 9 # Conditional import 10 try: 11 from pgvector import SparseVector 12 from pgvector.sqlalchemy import SPARSEVEC 13 14 PGSPARSE = True 15 except ImportError: 16 PGSPARSE = False 17 18 from ..dense import PGVector 19 20 21 class PGSparse(PGVector): 22 """ 23 Builds a Sparse ANN index backed by a Postgres database. 24 """ 25 26 def __init__(self, config): 27 if not PGSPARSE: 28 raise ImportError('PGSparse is not available - install "ann" extra to enable') 29 30 super().__init__(config) 31 32 # Quantization not supported 33 self.qbits = None 34 35 def defaulttable(self): 36 return "svectors" 37 38 def url(self): 39 return self.setting("url", os.environ.get("SCORING_URL", os.environ.get("ANN_URL"))) 40 41 def column(self): 42 return SPARSEVEC(self.config["dimensions"]) 43 44 def operation(self): 45 return "sparsevec_ip_ops" 46 47 def prepare(self, data): 48 # pgvector only allows 1000 non-zero values for sparse vectors 49 # Trim to top 1000 values, if necessary 50 if data.count_nonzero() > 1000: 51 value = -np.sort(-data[0, :].data)[1000] 52 data.data = np.where(data.data > value, data.data, 0) 53 data.eliminate_zeros() 54 55 # Wrap as sparse vector 56 return SparseVector(data)