client.py
1 """ 2 Client module 3 """ 4 5 import os 6 import time 7 8 # Conditional import 9 try: 10 from sqlalchemy import StaticPool, Text, cast, create_engine, insert, text as textsql 11 from sqlalchemy.orm import Session, aliased 12 from sqlalchemy.schema import CreateSchema 13 14 from .schema import Base, Batch, Document, Object, Section, SectionBase, Score 15 16 ORM = True 17 except ImportError: 18 ORM = False 19 20 from .rdbms import RDBMS 21 22 23 class Client(RDBMS): 24 """ 25 Database client instance. This class connects to an external database using SQLAlchemy. It supports any database 26 that is supported by SQLAlchemy (PostgreSQL, MariaDB, etc) and has JSON support. 27 """ 28 29 def __init__(self, config): 30 """ 31 Creates a new Database. 32 33 Args: 34 config: database configuration parameters 35 """ 36 37 super().__init__(config) 38 39 if not ORM: 40 raise ImportError('SQLAlchemy is not available - install "database" extra to enable') 41 42 # SQLAlchemy parameters 43 self.engine, self.dbconnection = None, None 44 45 def save(self, path): 46 # Commit session and database connection 47 super().save(path) 48 49 if self.dbconnection: 50 self.dbconnection.commit() 51 52 def close(self): 53 super().close() 54 55 # Dispose of engine, which also closes dbconnection 56 if self.engine: 57 self.engine.dispose() 58 59 def reindexstart(self): 60 # Working table name 61 name = f"rebuild{round(time.time() * 1000)}" 62 63 # Create working table metadata 64 type("Rebuild", (SectionBase,), {"__tablename__": name}) 65 Base.metadata.tables[name].create(self.dbconnection) 66 67 return name 68 69 def reindexend(self, name): 70 # Remove table object from metadata 71 Base.metadata.remove(Base.metadata.tables[name]) 72 73 def jsonprefix(self): 74 # JSON column prefix 75 return "cast(" 76 77 def jsoncolumn(self, name): 78 # Alias documents table 79 d = aliased(Document, name="d") 80 81 # Build JSON column expression for column 82 return str(cast(d.data[name].as_string(), Text).compile(dialect=self.engine.dialect, compile_kwargs={"literal_binds": True})) 83 84 def createtables(self): 85 # Create tables 86 Base.metadata.create_all(self.dbconnection, checkfirst=True) 87 88 # Clear existing data - table schema is created upon connecting to database 89 for table in ["sections", "documents", "objects"]: 90 self.cursor.execute(f"DELETE FROM {table}") 91 92 def finalize(self): 93 # Flush cached objects 94 self.connection.flush() 95 96 def insertdocument(self, uid, data, tags, entry): 97 self.connection.add(Document(id=uid, data=data, tags=tags, entry=entry)) 98 99 def insertobject(self, uid, data, tags, entry): 100 self.connection.add(Object(id=uid, object=data, tags=tags, entry=entry)) 101 102 def insertsection(self, index, uid, text, tags, entry): 103 # Save text section 104 self.connection.add(Section(indexid=index, id=uid, text=text, tags=tags, entry=entry)) 105 106 def createbatch(self): 107 # Create temporary batch table, if necessary 108 Base.metadata.tables["batch"].create(self.dbconnection, checkfirst=True) 109 110 def insertbatch(self, indexids, ids, batch): 111 if indexids: 112 self.connection.execute(insert(Batch), [{"indexid": i, "batch": batch} for i in indexids]) 113 if ids: 114 self.connection.execute(insert(Batch), [{"id": str(uid), "batch": batch} for uid in ids]) 115 116 def createscores(self): 117 # Create temporary scores table, if necessary 118 Base.metadata.tables["scores"].create(self.dbconnection, checkfirst=True) 119 120 def insertscores(self, scores): 121 # Average scores by id 122 if scores: 123 self.connection.execute(insert(Score), [{"indexid": i, "score": sum(s) / len(s)} for i, s in scores.items()]) 124 125 def connect(self, path=None): 126 # Connection URL 127 content = self.config.get("content") 128 129 # Read ENV variable, if necessary 130 content = os.environ.get("CLIENT_URL") if content == "client" else content 131 132 # Create engine using database URL 133 self.engine = create_engine(content, poolclass=StaticPool, echo=False, json_serializer=lambda x: x) 134 self.dbconnection = self.engine.connect() 135 136 # Create database session 137 database = Session(self.dbconnection) 138 139 # Set default schema, if necessary 140 schema = self.config.get("schema") 141 if schema: 142 with self.engine.begin(): 143 self.sqldialect(database, CreateSchema(schema, if_not_exists=True)) 144 145 self.sqldialect(database, textsql("SET search_path TO :schema"), {"schema": schema}) 146 147 return database 148 149 def getcursor(self): 150 return Cursor(self.connection) 151 152 def rows(self): 153 return self.cursor 154 155 def addfunctions(self): 156 return 157 158 def sqldialect(self, database, sql, parameters=None): 159 """ 160 Executes a SQL statement based on the current SQL dialect. 161 162 Args: 163 database: current database 164 sql: SQL to execute 165 parameters: optional bind parameters 166 """ 167 168 args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (textsql("SELECT 1"),) 169 database.execute(*args) 170 171 172 class Cursor: 173 """ 174 Implements basic compatibility with the Python DB-API. 175 """ 176 177 def __init__(self, connection): 178 self.connection = connection 179 self.result = None 180 181 def __iter__(self): 182 return self.result 183 184 def execute(self, statement, parameters=None): 185 """ 186 Executes statement. 187 188 Args: 189 statement: statement to execute 190 parameters: optional dictionary with bind parameters 191 """ 192 193 if isinstance(statement, str): 194 statement = textsql(statement) 195 196 self.result = self.connection.execute(statement, parameters) 197 198 def fetchall(self): 199 """ 200 Fetches all rows from the current result. 201 202 Returns: 203 all rows from current result 204 """ 205 206 return self.result.all() if self.result else None 207 208 def fetchone(self): 209 """ 210 Fetches first row from current result. 211 212 Returns: 213 first row from current result 214 """ 215 216 return self.result.first() if self.result else None 217 218 @property 219 def description(self): 220 """ 221 Returns columns for current result. 222 223 Returns: 224 list of columns 225 """ 226 227 return [(key,) for key in self.result.keys()] if self.result else None