rdbms.py
1 """ 2 RDBMS module 3 """ 4 5 import os 6 7 # Conditional import 8 try: 9 from grand import Graph 10 from grand.backends import SQLBackend, InMemoryCachedBackend 11 12 from sqlalchemy import create_engine, text, StaticPool 13 from sqlalchemy.schema import CreateSchema 14 15 ORM = True 16 except ImportError: 17 ORM = False 18 19 from .networkx import NetworkX 20 21 22 class RDBMS(NetworkX): 23 """ 24 Graph instance backed by a relational database. 25 """ 26 27 def __init__(self, config): 28 # Check before super() in case those required libraries are also not available 29 if not ORM: 30 raise ImportError('RDBMS is not available - install "graph" extra to enable') 31 32 super().__init__(config) 33 34 # Graph and database instances 35 self.graph = None 36 self.database = None 37 38 def __del__(self): 39 if hasattr(self, "database") and self.database: 40 self.database.close() 41 42 def create(self): 43 # Create graph instance 44 self.graph, self.database = self.connect() 45 46 # Clear previous graph, if available 47 for table in [self.config.get("nodes", "nodes"), self.config.get("edges", "edges")]: 48 self.database.execute(text(f"DELETE FROM {table}")) 49 50 # Return NetworkX compatible backend 51 return self.graph.nx 52 53 def scan(self, attribute=None, data=False): 54 if attribute: 55 for node in self.backend: 56 attributes = self.node(node) 57 if attribute in attributes: 58 yield (node, attributes) if data else node 59 else: 60 yield from super().scan(attribute, data) 61 62 def load(self, path): 63 # Create graph instance 64 self.graph, self.database = self.connect() 65 66 # Store NetworkX compatible backend 67 self.backend = self.graph.nx 68 69 def save(self, path): 70 self.database.commit() 71 72 def close(self): 73 # Parent logic 74 super().close() 75 76 # Close database connection 77 self.database.close() 78 79 def filter(self, nodes, graph=None): 80 return super().filter(nodes, graph if graph else NetworkX(self.config)) 81 82 def connect(self): 83 """ 84 Connects to a graph backed by a relational database. 85 86 Args: 87 Graph database instance 88 """ 89 90 # Keyword arguments for SQLAlchemy 91 kwargs = {"poolclass": StaticPool, "echo": False} 92 url = self.config.get("url", os.environ.get("GRAPH_URL")) 93 94 # Set default schema, if necessary 95 schema = self.config.get("schema") 96 if schema: 97 # Check that schema exists 98 engine = create_engine(url) 99 with engine.begin() as connection: 100 connection.execute(CreateSchema(schema, if_not_exists=True) if "postgresql" in url else text("SELECT 1")) 101 102 # Set default schema 103 kwargs["connect_args"] = {"options": f'-c search_path="{schema}"'} if "postgresql" in url else {} 104 105 backend = SQLBackend( 106 db_url=url, 107 node_table_name=self.config.get("nodes", "nodes"), 108 edge_table_name=self.config.get("edges", "edges"), 109 sqlalchemy_kwargs=kwargs, 110 ) 111 112 # pylint: disable=W0212 113 return Graph(backend=InMemoryCachedBackend(backend, maxsize=None)), backend._connection