rdbms.py
1 """ 2 RDBMS module 3 """ 4 5 import datetime 6 import json 7 8 from .base import Database 9 from .schema import Statement 10 11 12 # pylint: disable=R0904 13 class RDBMS(Database): 14 """ 15 Base relational database class. A relational database uses SQL to insert, update, delete and select from a 16 database instance. 17 """ 18 19 def __init__(self, config): 20 """ 21 Creates a new Database. 22 23 Args: 24 config: database configuration parameters 25 """ 26 27 super().__init__(config) 28 29 # Database connection 30 self.connection = None 31 self.cursor = None 32 33 def load(self, path): 34 # Load an existing database. Thread locking must be handled externally. 35 self.session(path) 36 37 def insert(self, documents, index=0): 38 # Initialize connection if not open 39 self.initialize() 40 41 # Get entry date 42 entry = datetime.datetime.now(datetime.timezone.utc) 43 44 # Insert documents 45 for uid, document, tags in documents: 46 if isinstance(document, dict): 47 # Insert document and use return value for sections table 48 document = self.loaddocument(uid, document, tags, entry) 49 50 if document is not None: 51 if isinstance(document, list): 52 # Join tokens to text 53 document = " ".join(document) 54 elif not isinstance(document, str): 55 # If object support is enabled, save object 56 self.loadobject(uid, document, tags, entry) 57 58 # Clear section text for objects, even when objects aren't inserted 59 document = None 60 61 # Save text section 62 self.loadsection(index, uid, document, tags, entry) 63 index += 1 64 65 # Post processing logic 66 self.finalize() 67 68 def delete(self, ids): 69 if self.connection: 70 # Batch ids 71 self.batch(ids=ids) 72 73 # Delete all documents, objects and sections by id 74 self.cursor.execute(Statement.DELETE_DOCUMENTS) 75 self.cursor.execute(Statement.DELETE_OBJECTS) 76 self.cursor.execute(Statement.DELETE_SECTIONS) 77 78 def reindex(self, config): 79 if self.connection: 80 # Set new configuration 81 self.configure(config) 82 83 # Resolve text column 84 select = self.resolve(self.text) 85 86 # Initialize reindex operation 87 name = self.reindexstart() 88 89 # Copy data over 90 self.cursor.execute(Statement.COPY_SECTIONS % (name, select)) 91 92 # Stream new results 93 self.cursor.execute(Statement.STREAM_SECTIONS % name) 94 for uid, text, data, obj, tags in self.rows(): 95 if not text and self.encoder and obj: 96 yield (uid, self.encoder.decode(obj), tags) 97 else: 98 # Read JSON data, if provided 99 data = json.loads(data) if data and isinstance(data, str) else data 100 101 # Stream data if available, otherwise use section text 102 yield (uid, data if data else text, tags) 103 104 # Swap as new table 105 self.cursor.execute(Statement.DROP_SECTIONS) 106 self.cursor.execute(Statement.RENAME_SECTIONS % name) 107 108 # Finish reindex operation 109 self.reindexend(name) 110 111 def save(self, path): 112 if self.connection: 113 self.connection.commit() 114 115 def close(self): 116 # Close connection 117 if self.connection: 118 self.connection.close() 119 120 def ids(self, ids): 121 # Batch ids and run query 122 self.batch(ids=ids) 123 self.cursor.execute(Statement.SELECT_IDS) 124 125 # Format and return results 126 return self.cursor.fetchall() 127 128 def count(self): 129 self.cursor.execute(Statement.COUNT_IDS) 130 return self.cursor.fetchone()[0] 131 132 def resolve(self, name, alias=None): 133 # Standard column names 134 sections = ["indexid", "id", "tags", "entry"] 135 noprefix = ["data", "object", "score", "text"] 136 137 # Alias expression 138 if alias: 139 # Skip if name matches alias or alias is a standard column name 140 if name == alias or alias in sections: 141 return name 142 143 # Build alias clause 144 return f'{name} as "{alias}"' 145 146 # Resolve expression 147 if self.expressions and name in self.expressions: 148 return self.expressions[name]["expression"] 149 150 # Name is already resolved, skip 151 if name.startswith(self.jsonprefix()) or any(f"s.{s}" == name for s in sections): 152 return name 153 154 # Standard columns - need prefixes 155 if name.lower() in sections: 156 return f"s.{name}" 157 158 # Standard columns - no prefixes 159 if name.lower() in noprefix: 160 return name 161 162 # Other columns come from documents.data JSON 163 return self.jsoncolumn(name) 164 165 def embed(self, similarity, batch): 166 # Load similarity results id batch 167 self.batch(indexids=[i for i, _ in similarity[batch]], batch=batch) 168 169 # Average and load all similarity scores with first batch 170 if not batch: 171 self.scores(similarity) 172 173 # Return ids clause placeholder 174 return Statement.IDS_CLAUSE % batch 175 176 # pylint: disable=R0912 177 def query(self, query, limit, parameters, indexids): 178 # Extract query components 179 select = query.get("select", self.defaults()) 180 where = query.get("where") 181 groupby, having = query.get("groupby"), query.get("having") 182 orderby, qlimit, offset = query.get("orderby"), query.get("limit"), query.get("offset") 183 similarity = query.get("similar") 184 185 # Select "indexid, score" when indexids is True 186 if indexids: 187 select = f"{self.resolve('indexid')}, {self.resolve('score')}" 188 189 # Use JOIN when documents table is used to utilize indexes, default to LEFT JOIN 190 join = "JOIN" if any(x and self.jsonprefix() in x for x in [where, groupby, orderby]) else "LEFT JOIN" 191 192 # Build query text 193 query = Statement.TABLE_CLAUSE % (select, join) 194 if where is not None: 195 query += f" WHERE {where}" 196 if groupby is not None: 197 query += f" GROUP BY {groupby}" 198 if having is not None: 199 query += f" HAVING {having}" 200 if orderby is not None: 201 query += f" ORDER BY {orderby}" 202 203 # Default ORDER BY if not provided and similarity scores are available 204 if similarity and orderby is None: 205 query += " ORDER BY score DESC" 206 207 # Apply query limit 208 if qlimit is not None or limit: 209 query += f" LIMIT {qlimit if qlimit else limit}" 210 211 # Apply offset 212 if offset is not None: 213 query += f" OFFSET {offset}" 214 215 # Clear scores when no similar clauses present 216 if not similarity: 217 self.scores(None) 218 219 # Runs a user query through execute method, which has common user query handling logic 220 args = (query, parameters) if parameters else (query,) 221 self.execute(self.cursor.execute, *args) 222 223 # Retrieve column list from query 224 columns = [c[0] for c in self.cursor.description] 225 226 # Map results and return 227 results = [] 228 for row in self.rows(): 229 result = {} 230 231 # Copy columns to result. In cases with duplicate column names, find one with a value 232 for x, column in enumerate(columns): 233 if column not in result or result[column] is None: 234 # Decode object 235 if self.encoder and column == self.object: 236 result[column] = self.encoder.decode(row[x]) 237 else: 238 result[column] = row[x] 239 240 results.append(result) 241 242 # Transform results, if necessary 243 return [(x["indexid"], x["score"]) for x in results] if indexids else results 244 245 def initialize(self): 246 """ 247 Creates connection and initial database schema if no connection exists. 248 """ 249 250 if not self.connection: 251 # Create database session. Thread locking must be handled externally. 252 self.session() 253 254 # Create initial table schema 255 self.createtables() 256 257 # Create indexes 258 self.createindexes() 259 260 def session(self, path=None, connection=None): 261 """ 262 Starts a new database session. 263 264 Args: 265 path: path to database file 266 connection: existing connection to use 267 """ 268 269 # Create database connection and cursor 270 self.connection = connection if connection else self.connect(path) if path else self.connect() 271 self.cursor = self.getcursor() 272 273 # Register custom functions - session scope 274 self.addfunctions() 275 276 # Create temporary tables - session scope 277 self.createbatch() 278 self.createscores() 279 280 def createtables(self): 281 """ 282 Creates the initial table schema. 283 """ 284 285 self.cursor.execute(Statement.CREATE_DOCUMENTS) 286 self.cursor.execute(Statement.CREATE_OBJECTS) 287 self.cursor.execute(Statement.CREATE_SECTIONS % "sections") 288 self.cursor.execute(Statement.CREATE_SECTIONS_INDEX) 289 290 def createindexes(self): 291 """ 292 Creates expression indexes 293 """ 294 295 if self.expressions: 296 for key, values in self.expressions.items(): 297 # Create index for expression, if enabled 298 if values["index"]: 299 # Get parameters 300 name = f"expression_{key}".lower() 301 expression = values["expression"] 302 table = "documents" if expression.startswith(self.jsonprefix()) else "sections" 303 304 # Execute statement 305 self.cursor.execute(Statement.CREATE_EXPRESSION_INDEX % (name, table, expression)) 306 307 def finalize(self): 308 """ 309 Post processing logic run after inserting a batch of documents. Default method is no-op. 310 """ 311 312 def loaddocument(self, uid, document, tags, entry): 313 """ 314 Applies pre-processing logic and inserts a document. 315 316 Args: 317 uid: unique id 318 document: input document dictionary 319 tags: document tags 320 entry: generated entry date 321 322 Returns: 323 section value 324 """ 325 326 # Make a copy of document before changing 327 document = document.copy() 328 329 # Get and remove object field from document 330 obj = document.pop(self.object) if self.object in document else None 331 332 if document: 333 # Apply data filters, if necessary 334 data = {key: value for key, value in document.items() if key in self.store} if self.store is not None else document 335 336 # Insert document as JSON 337 if data: 338 self.insertdocument(uid, json.dumps(data, allow_nan=False), tags, entry) 339 340 # If text and object are both available, load object as it won't otherwise be used 341 if self.text in document and obj: 342 self.loadobject(uid, obj, tags, entry) 343 344 # Return value to use for section - use text if available otherwise use object 345 return document[self.text] if self.text in document else obj 346 347 def insertdocument(self, uid, data, tags, entry): 348 """ 349 Inserts a document. 350 351 Args: 352 uid: unique id 353 data: document data 354 tags: document tags 355 entry: generated entry date 356 """ 357 358 self.cursor.execute(Statement.INSERT_DOCUMENT, [uid, data, tags, entry]) 359 360 def loadobject(self, uid, obj, tags, entry): 361 """ 362 Applies pre-preprocessing logic and inserts an object. 363 364 Args: 365 uid: unique id 366 obj: input object 367 tags: object tags 368 entry: generated entry date 369 """ 370 371 # If object support is enabled, save object 372 if self.encoder: 373 self.insertobject(uid, self.encoder.encode(obj), tags, entry) 374 375 def insertobject(self, uid, data, tags, entry): 376 """ 377 Inserts an object. 378 379 Args: 380 uid: unique id 381 data: encoded data 382 tags: object tags 383 entry: generated entry date 384 """ 385 386 self.cursor.execute(Statement.INSERT_OBJECT, [uid, data, tags, entry]) 387 388 def loadsection(self, index, uid, text, tags, entry): 389 """ 390 Applies pre-processing logic and inserts a section. 391 392 Args: 393 index: index id 394 uid: unique id 395 text: section text 396 tags: section tags 397 entry: generated entry date 398 """ 399 400 self.insertsection(index, uid, text, tags, entry) 401 402 def insertsection(self, index, uid, text, tags, entry): 403 """ 404 Inserts a section. 405 406 Args: 407 index: index id 408 uid: unique id 409 text: section text 410 tags: section tags 411 entry: generated entry date 412 """ 413 414 # Save text section 415 self.cursor.execute(Statement.INSERT_SECTION, [index, uid, text, tags, entry]) 416 417 def reindexstart(self): 418 """ 419 Starts a reindex operation. 420 421 Returns: 422 temporary working table name 423 """ 424 425 # Working table name 426 name = "rebuild" 427 428 # Create new table to hold reordered sections 429 self.cursor.execute(Statement.CREATE_SECTIONS % name) 430 431 return name 432 433 # pylint: disable=W0613 434 def reindexend(self, name): 435 """ 436 Ends a reindex operation. 437 438 Args: 439 name: working table name 440 """ 441 442 self.cursor.execute(Statement.CREATE_SECTIONS_INDEX) 443 444 def batch(self, indexids=None, ids=None, batch=None): 445 """ 446 Loads ids to a temporary batch table for efficient query processing. 447 448 Args: 449 indexids: list of indexids 450 ids: list of ids 451 batch: batch index, used when statement has multiple subselects 452 """ 453 454 # Delete batch when batch id is empty or for batch 0 455 if not batch: 456 self.cursor.execute(Statement.DELETE_BATCH) 457 458 # Add batch 459 self.insertbatch(indexids, ids, batch) 460 461 def createbatch(self): 462 """ 463 Creates temporary batch table. 464 """ 465 466 # Create or Replace temporary batch table 467 self.cursor.execute(Statement.CREATE_BATCH) 468 469 def insertbatch(self, indexids, ids, batch): 470 """ 471 Inserts batch of ids. 472 """ 473 474 if indexids: 475 self.cursor.executemany(Statement.INSERT_BATCH_INDEXID, [(i, batch) for i in indexids]) 476 if ids: 477 self.cursor.executemany(Statement.INSERT_BATCH_ID, [(str(uid), batch) for uid in ids]) 478 479 def scores(self, similarity): 480 """ 481 Loads a batch of similarity scores to a temporary table for efficient query processing. 482 483 Args: 484 similarity: similarity results as [(indexid, score)] 485 """ 486 487 # Delete scores 488 self.cursor.execute(Statement.DELETE_SCORES) 489 490 if similarity: 491 # Average scores per id, needed for multiple similar() clauses 492 scores = {} 493 for s in similarity: 494 for i, score in s: 495 if i not in scores: 496 scores[i] = [] 497 scores[i].append(score) 498 499 # Add scores 500 self.insertscores(scores) 501 502 def createscores(self): 503 """ 504 Creates temporary scores table. 505 """ 506 507 # Create or Replace temporary scores table 508 self.cursor.execute(Statement.CREATE_SCORES) 509 510 def insertscores(self, scores): 511 """ 512 Inserts a batch of scores. 513 514 Args: 515 scores: scores to add 516 """ 517 518 # Average scores by id 519 if scores: 520 self.cursor.executemany(Statement.INSERT_SCORE, [(i, sum(s) / len(s)) for i, s in scores.items()]) 521 522 def defaults(self): 523 """ 524 Returns a list of default columns when there is no select clause. 525 526 Returns: 527 list of default columns 528 """ 529 530 return "s.id, text, score" 531 532 def connect(self, path=None): 533 """ 534 Creates a new database connection. 535 536 Args: 537 path: path to database file 538 539 Returns: 540 connection 541 """ 542 543 raise NotImplementedError 544 545 def getcursor(self): 546 """ 547 Opens a cursor for current connection. 548 549 Returns: 550 cursor 551 """ 552 553 raise NotImplementedError 554 555 def jsonprefix(self): 556 """ 557 Returns json column prefix to test for. 558 559 Returns: 560 dynamic column prefix 561 """ 562 563 raise NotImplementedError 564 565 def jsoncolumn(self, name): 566 """ 567 Builds a json extract column expression for name. 568 569 Args: 570 name: column name 571 572 Returns: 573 dynamic column expression 574 """ 575 576 raise NotImplementedError 577 578 def rows(self): 579 """ 580 Returns current cursor row iterator for last executed query. 581 582 Args: 583 cursor: cursor 584 585 Returns: 586 iterable collection of rows 587 """ 588 589 raise NotImplementedError 590 591 def addfunctions(self): 592 """ 593 Adds custom functions in current connection. 594 """ 595 596 raise NotImplementedError