pgtext.py
1 """ 2 PGText module 3 """ 4 5 import os 6 import re 7 8 # Conditional import 9 try: 10 from sqlalchemy import create_engine, desc, delete, func, text 11 from sqlalchemy import Column, Computed, Index, Integer, MetaData, StaticPool, Table, Text 12 from sqlalchemy.dialects.postgresql import TSVECTOR 13 from sqlalchemy.orm import Session 14 from sqlalchemy.schema import CreateSchema 15 16 PGTEXT = True 17 except ImportError: 18 PGTEXT = False 19 20 from .base import Scoring 21 22 23 class PGText(Scoring): 24 """ 25 Postgres full text search (FTS) based scoring. 26 """ 27 28 def __init__(self, config=None): 29 super().__init__(config) 30 31 if not PGTEXT: 32 raise ImportError('PGText is not available - install "scoring" extra to enable') 33 34 # Database connection 35 self.engine, self.database, self.connection, self.table = None, None, None, None 36 37 # Language 38 self.language = self.config.get("language", "english") 39 40 def insert(self, documents, index=None, checkpoint=None): 41 # Initialize tables 42 self.initialize(recreate=True) 43 44 # Collection of rows to insert 45 rows = [] 46 47 # Collect rows 48 for uid, document, _ in documents: 49 # Extract text, if necessary 50 if isinstance(document, dict): 51 document = document.get(self.text, document.get(self.object)) 52 53 if document is not None: 54 # If index is passed, use indexid, otherwise use id 55 uid = index if index is not None else uid 56 57 # Add row if the data type is accepted 58 if isinstance(document, (str, list)): 59 rows.append((uid, " ".join(document) if isinstance(document, list) else document)) 60 61 # Increment index 62 index = index + 1 if index is not None else None 63 64 # Insert rows 65 self.database.execute(self.table.insert(), [{"indexid": x, "text": text} for x, text in rows]) 66 67 def delete(self, ids): 68 self.database.execute(delete(self.table).where(self.table.c["indexid"].in_(ids))) 69 70 def weights(self, tokens): 71 # Not supported 72 return None 73 74 def search(self, query, limit=3): 75 # Replace wildcards with prefix wildcard term 76 query = re.sub(r"(?<!\:)\*", ":*", query) 77 78 # Run query 79 query = ( 80 self.database.query(self.table.c["indexid"], text("ts_rank(vector, plainto_tsquery(:language, :query)) rank")) 81 .order_by(desc(text("rank"))) 82 .limit(limit) 83 .params({"language": self.language, "query": query}) 84 ) 85 86 return [(uid, score) for uid, score in query if score > 1e-5] 87 88 def batchsearch(self, queries, limit=3, threads=True): 89 return [self.search(query, limit) for query in queries] 90 91 def count(self): 92 # pylint: disable=E1102 93 return self.database.query(func.count(self.table.c["indexid"])).scalar() 94 95 def load(self, path): 96 # Reset database to original checkpoint 97 if self.database: 98 self.database.rollback() 99 self.connection.rollback() 100 101 # Initialize tables 102 self.initialize() 103 104 def save(self, path): 105 # Commit session and connection 106 if self.database: 107 self.database.commit() 108 self.connection.commit() 109 110 def close(self): 111 if self.database: 112 self.database.close() 113 self.engine.dispose() 114 115 def issparse(self): 116 return True 117 118 def isnormalized(self): 119 return True 120 121 def isbayes(self): 122 return False 123 124 def initialize(self, recreate=False): 125 """ 126 Initializes a new database session. 127 128 Args: 129 recreate: Recreates the database tables if True 130 """ 131 132 if not self.database: 133 # Create engine, connection and session 134 self.engine = create_engine(self.config.get("url", os.environ.get("SCORING_URL")), poolclass=StaticPool, echo=False) 135 self.connection = self.engine.connect() 136 self.database = Session(self.connection) 137 138 # Set default schema, if necessary 139 schema = self.config.get("schema") 140 if schema: 141 with self.engine.begin(): 142 self.sqldialect(CreateSchema(schema, if_not_exists=True)) 143 144 self.sqldialect(text("SET search_path TO :schema"), {"schema": schema}) 145 146 # Table name 147 table = self.config.get("table", "scoring") 148 149 # Create vectors table 150 self.table = Table( 151 table, 152 MetaData(), 153 Column("indexid", Integer, primary_key=True, autoincrement=False), 154 Column("text", Text), 155 ( 156 Column("vector", TSVECTOR, Computed(f"to_tsvector('{self.language}', text)", persisted=True)) 157 if self.engine.dialect.name == "postgresql" 158 else Column("vector", Integer) 159 ), 160 ) 161 162 # Create text index 163 index = Index( 164 f"{table}-index", 165 self.table.c["vector"], 166 postgresql_using="gin", 167 ) 168 169 # Drop and recreate table 170 if recreate: 171 self.table.drop(self.connection, checkfirst=True) 172 index.drop(self.connection, checkfirst=True) 173 174 # Create table and index 175 self.table.create(self.connection, checkfirst=True) 176 index.create(self.connection, checkfirst=True) 177 178 def sqldialect(self, sql, parameters=None): 179 """ 180 Executes a SQL statement based on the current SQL dialect. 181 182 Args: 183 sql: SQL to execute 184 parameters: optional bind parameters 185 """ 186 187 args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (text("SELECT 1"),) 188 self.database.execute(*args)