/ src / python / txtai / graph / rdbms.py
rdbms.py
  1  """
  2  RDBMS module
  3  """
  4  
  5  import os
  6  
  7  # Conditional import
  8  try:
  9      from grand import Graph
 10      from grand.backends import SQLBackend, InMemoryCachedBackend
 11  
 12      from sqlalchemy import create_engine, text, StaticPool
 13      from sqlalchemy.schema import CreateSchema
 14  
 15      ORM = True
 16  except ImportError:
 17      ORM = False
 18  
 19  from .networkx import NetworkX
 20  
 21  
 22  class RDBMS(NetworkX):
 23      """
 24      Graph instance backed by a relational database.
 25      """
 26  
 27      def __init__(self, config):
 28          # Check before super() in case those required libraries are also not available
 29          if not ORM:
 30              raise ImportError('RDBMS is not available - install "graph" extra to enable')
 31  
 32          super().__init__(config)
 33  
 34          # Graph and database instances
 35          self.graph = None
 36          self.database = None
 37  
 38      def __del__(self):
 39          if hasattr(self, "database") and self.database:
 40              self.database.close()
 41  
 42      def create(self):
 43          # Create graph instance
 44          self.graph, self.database = self.connect()
 45  
 46          # Clear previous graph, if available
 47          for table in [self.config.get("nodes", "nodes"), self.config.get("edges", "edges")]:
 48              self.database.execute(text(f"DELETE FROM {table}"))
 49  
 50          # Return NetworkX compatible backend
 51          return self.graph.nx
 52  
 53      def scan(self, attribute=None, data=False):
 54          if attribute:
 55              for node in self.backend:
 56                  attributes = self.node(node)
 57                  if attribute in attributes:
 58                      yield (node, attributes) if data else node
 59          else:
 60              yield from super().scan(attribute, data)
 61  
 62      def load(self, path):
 63          # Create graph instance
 64          self.graph, self.database = self.connect()
 65  
 66          # Store NetworkX compatible backend
 67          self.backend = self.graph.nx
 68  
 69      def save(self, path):
 70          self.database.commit()
 71  
 72      def close(self):
 73          # Parent logic
 74          super().close()
 75  
 76          # Close database connection
 77          self.database.close()
 78  
 79      def filter(self, nodes, graph=None):
 80          return super().filter(nodes, graph if graph else NetworkX(self.config))
 81  
 82      def connect(self):
 83          """
 84          Connects to a graph backed by a relational database.
 85  
 86          Args:
 87              Graph database instance
 88          """
 89  
 90          # Keyword arguments for SQLAlchemy
 91          kwargs = {"poolclass": StaticPool, "echo": False}
 92          url = self.config.get("url", os.environ.get("GRAPH_URL"))
 93  
 94          # Set default schema, if necessary
 95          schema = self.config.get("schema")
 96          if schema:
 97              # Check that schema exists
 98              engine = create_engine(url)
 99              with engine.begin() as connection:
100                  connection.execute(CreateSchema(schema, if_not_exists=True) if "postgresql" in url else text("SELECT 1"))
101  
102              # Set default schema
103              kwargs["connect_args"] = {"options": f'-c search_path="{schema}"'} if "postgresql" in url else {}
104  
105          backend = SQLBackend(
106              db_url=url,
107              node_table_name=self.config.get("nodes", "nodes"),
108              edge_table_name=self.config.get("edges", "edges"),
109              sqlalchemy_kwargs=kwargs,
110          )
111  
112          # pylint: disable=W0212
113          return Graph(backend=InMemoryCachedBackend(backend, maxsize=None)), backend._connection