base.py
1 """ 2 Database module 3 """ 4 5 import logging 6 import types 7 8 from .encoder import EncoderFactory 9 from .sql import SQL, SQLError, Token 10 11 # Logging configuration 12 logger = logging.getLogger(__name__) 13 14 15 class Database: 16 """ 17 Base class for database instances. This class encapsulates a content database used for 18 storing field content as dicts and objects. The database instance works in conjuction 19 with a vector index to execute SQL-driven similarity search. 20 """ 21 22 def __init__(self, config): 23 """ 24 Creates a new Database. 25 26 Args: 27 config: database configuration 28 """ 29 30 # Initialize configuration 31 self.configure(config) 32 33 def load(self, path): 34 """ 35 Loads a database path. 36 37 Args: 38 path: database url 39 """ 40 41 raise NotImplementedError 42 43 def insert(self, documents, index=0): 44 """ 45 Inserts documents into the database. 46 47 Args: 48 documents: list of documents to save 49 index: indexid offset, used for internal ids 50 """ 51 52 raise NotImplementedError 53 54 def delete(self, ids): 55 """ 56 Deletes documents from database. 57 58 Args: 59 ids: ids to delete 60 """ 61 62 raise NotImplementedError 63 64 def reindex(self, config): 65 """ 66 Reindexes internal database content and streams results back. This method must renumber indexids 67 sequentially as deletes could have caused indexid gaps. 68 69 Args: 70 config: new configuration 71 """ 72 73 raise NotImplementedError 74 75 def save(self, path): 76 """ 77 Saves a database at path. 78 79 Args: 80 path: path to write database 81 """ 82 83 raise NotImplementedError 84 85 def close(self): 86 """ 87 Closes this database. 88 """ 89 90 raise NotImplementedError 91 92 def ids(self, ids): 93 """ 94 Retrieves the internal indexids for a list of ids. Multiple indexids may be present for an id in cases 95 where data is segmented. 96 97 Args: 98 ids: list of document ids 99 100 Returns: 101 list of (indexid, id) 102 """ 103 104 raise NotImplementedError 105 106 def count(self): 107 """ 108 Retrieves the count of this database instance. 109 110 Returns: 111 total database count 112 """ 113 114 raise NotImplementedError 115 116 def search(self, query, similarity=None, limit=None, parameters=None, indexids=False): 117 """ 118 Runs a search against the database. Supports the following methods: 119 120 1. Standard similarity query. This mode retrieves content for the ids in the similarity results 121 2. Similarity query as SQL. This mode will combine similarity results and database results into 122 a single result set. Similarity queries are set via the SIMILAR() function. 123 3. SQL with no similarity query. This mode runs a SQL query and retrieves the results without similarity queries. 124 125 Example queries: 126 "natural language processing" - standard similarity only query 127 "select * from txtai where similar('natural language processing')" - similarity query as SQL 128 "select * from txtai where similar('nlp') and entry > '2021-01-01'" - similarity query with additional SQL clauses 129 "select id, text, score from txtai where similar('nlp')" - similarity query with additional SQL column selections 130 "select * from txtai where entry > '2021-01-01' - database only query 131 132 Args: 133 query: input query 134 similarity: similarity results as [(indexid, score)] 135 limit: maximum number of results to return 136 parameters: dict of named parameters to bind to placeholders 137 138 Returns: 139 query results as a list of dicts 140 list of ([indexid, score]) if indexids is True 141 """ 142 143 # Parse query if necessary 144 if isinstance(query, str): 145 query = self.parse(query) 146 147 # Add in similar results 148 where = query.get("where") 149 150 if "select" in query and similarity: 151 for x in range(len(similarity)): 152 token = f"{Token.SIMILAR_TOKEN}{x}" 153 if where and token in where: 154 where = where.replace(token, self.embed(similarity, x)) 155 156 elif similarity: 157 # Not a SQL query, load similarity results, if any 158 where = self.embed(similarity, 0) 159 160 # Save where 161 query["where"] = where 162 163 # Run query 164 return self.query(query, limit, parameters, indexids) 165 166 def parse(self, query): 167 """ 168 Parses a query into query components. 169 170 Args: 171 query: input query 172 173 Returns: 174 dict of parsed query components 175 """ 176 177 return self.sql(query) 178 179 def resolve(self, name, alias=None): 180 """ 181 Resolves a query column name with the database column name. This method also builds alias expressions 182 if alias is set. 183 184 Args: 185 name: query column name 186 alias: alias name, defaults to None 187 188 Returns: 189 database column name 190 """ 191 192 raise NotImplementedError 193 194 def embed(self, similarity, batch): 195 """ 196 Embeds similarity query results into a database query. 197 198 Args: 199 similarity: similarity results as [(indexid, score)] 200 batch: batch id 201 """ 202 203 raise NotImplementedError 204 205 def query(self, query, limit, parameters, indexids): 206 """ 207 Executes query against database. 208 209 Args: 210 query: input query 211 limit: maximum number of results to return 212 parameters: dict of named parameters to bind to placeholders 213 indexids: results are returned as [(indexid, score)] regardless of select clause parameters if True 214 215 Returns: 216 query results 217 """ 218 219 raise NotImplementedError 220 221 def configure(self, config): 222 """ 223 Initialize configuration. 224 225 Args: 226 config: configuration 227 """ 228 229 # Database configuration 230 self.config = config 231 232 # SQL parser 233 self.sql = SQL(self) 234 235 # Load objects encoder 236 encoder = self.config.get("objects") 237 self.encoder = EncoderFactory.create(encoder) if encoder else None 238 239 # Columns configuration 240 columns = config.get("columns", {}) 241 self.text = columns.get("text", "text") 242 self.object = columns.get("object", "object") 243 244 # JSON data storage. If not set, all columns are stored (default). 245 # Otherwise, only columns stored in store list are kept. 246 # If store is set to None, no columns are stored. 247 self.store = ([] if columns["store"] is None else columns["store"]) if "store" in columns else None 248 249 # Custom functions and expressions 250 self.functions, self.expressions = None, None 251 252 # Load custom functions 253 self.registerfunctions(self.config) 254 255 # Load custom expressions 256 self.registerexpressions(self.config) 257 258 def registerfunctions(self, config): 259 """ 260 Register custom functions. This method stores the function details for underlying 261 database implementations to handle. 262 263 Args: 264 config: database configuration 265 """ 266 267 inputs = config.get("functions") if config else None 268 if inputs: 269 functions = [] 270 for fn in inputs: 271 name, argcount, deterministic = None, -1, None 272 273 # Optional function configuration 274 if isinstance(fn, dict): 275 name, argcount, fn, deterministic = (fn.get("name"), fn.get("argcount", -1), fn["function"], fn.get("deterministic")) 276 277 # Determine if this is a callable object or a function 278 if not isinstance(fn, types.FunctionType) and hasattr(fn, "__call__"): 279 name = name if name else fn.__class__.__name__.lower() 280 fn = fn.__call__ 281 else: 282 name = name if name else fn.__name__.lower() 283 284 # Store function details 285 functions.append((name, argcount, fn, deterministic)) 286 287 # pylint: disable=W0201 288 self.functions = functions 289 290 def registerexpressions(self, config): 291 """ 292 Register custom expressions. This method parses and resolves expressions for later use in SQL queries. 293 294 Args: 295 config: database configuration 296 """ 297 298 inputs = config.get("expressions") if config else None 299 if inputs: 300 expressions = {} 301 for entry in inputs: 302 name = entry.get("name") 303 expression = entry.get("expression", name) 304 if name: 305 expressions[name] = {"expression": self.sql.snippet(expression), "index": entry.get("index", False)} 306 307 # pylint: disable=W0201 308 self.expressions = expressions 309 310 def execute(self, function, *args): 311 """ 312 Executes a user query. This method has common error handling logic. 313 314 Args: 315 function: database execute function 316 args: function arguments 317 318 Returns: 319 result of function(args) 320 """ 321 322 try: 323 # Debug log SQL 324 logger.debug(" ".join(["%s"] * len(args)), *args) 325 326 return function(*args) 327 except Exception as e: 328 raise SQLError(e) from None 329 330 def setting(self, name, default=None): 331 """ 332 Looks up database specific setting. 333 334 Args: 335 name: setting name 336 default: default value when setting not found 337 338 Returns: 339 setting value 340 """ 341 342 # Get the database-specific config object 343 database = self.config.get(self.config["content"]) 344 345 # Get setting value, set default value if not found 346 setting = database.get(name) if database else None 347 return setting if setting else default