/ src / python / txtai / graph / networkx.py
networkx.py
  1  """
  2  NetworkX module
  3  """
  4  
  5  import os
  6  
  7  from tempfile import TemporaryDirectory
  8  
  9  # Conditional import
 10  try:
 11      import networkx as nx
 12  
 13      from networkx.algorithms.community import asyn_lpa_communities, greedy_modularity_communities, louvain_partitions
 14      from networkx.readwrite import json_graph
 15  
 16      NETWORKX = True
 17  except ImportError:
 18      NETWORKX = False
 19  
 20  from ..archive import ArchiveFactory
 21  from ..serialize import SerializeError, SerializeFactory
 22  
 23  from .base import Graph
 24  from .query import Query
 25  
 26  
 27  # pylint: disable=R0904
 28  class NetworkX(Graph):
 29      """
 30      Graph instance backed by NetworkX.
 31      """
 32  
 33      def __init__(self, config):
 34          super().__init__(config)
 35  
 36          if not NETWORKX:
 37              raise ImportError('NetworkX is not available - install "graph" extra to enable')
 38  
 39      def create(self):
 40          return nx.Graph()
 41  
 42      def count(self):
 43          return self.backend.number_of_nodes()
 44  
 45      def scan(self, attribute=None, data=False):
 46          # Full graph
 47          graph = self.backend
 48  
 49          # Filter graph to nodes having a specified attribute
 50          if attribute:
 51              graph = nx.subgraph_view(self.backend, filter_node=lambda x: attribute in self.node(x))
 52  
 53          # Return either list of matching ids or tuple of (id, attribute dictionary)
 54          return graph.nodes(data=True) if data else graph
 55  
 56      def node(self, node):
 57          return self.backend.nodes.get(node)
 58  
 59      def addnode(self, node, **attrs):
 60          self.backend.add_node(node, **attrs)
 61  
 62      def addnodes(self, nodes):
 63          self.backend.add_nodes_from(nodes)
 64  
 65      def removenode(self, node):
 66          if self.hasnode(node):
 67              self.backend.remove_node(node)
 68  
 69      def hasnode(self, node):
 70          return self.backend.has_node(node)
 71  
 72      def attribute(self, node, field):
 73          return self.node(node).get(field) if self.hasnode(node) else None
 74  
 75      def addattribute(self, node, field, value):
 76          if self.hasnode(node):
 77              self.node(node)[field] = value
 78  
 79      def removeattribute(self, node, field):
 80          return self.node(node).pop(field, None) if self.hasnode(node) else None
 81  
 82      def edgecount(self):
 83          return self.backend.number_of_edges()
 84  
 85      def edges(self, node):
 86          edges = self.backend.adj.get(node)
 87          if edges:
 88              return dict(sorted(edges.items(), key=lambda x: x[1].get("weight", 0), reverse=True))
 89  
 90          return None
 91  
 92      def addedge(self, source, target, **attrs):
 93          self.backend.add_edge(source, target, **attrs)
 94  
 95      def addedges(self, edges):
 96          self.backend.add_edges_from(edges)
 97  
 98      def hasedge(self, source, target=None):
 99          if target is None:
100              edges = self.backend.adj.get(source)
101              return len(edges) > 0 if edges else False
102  
103          return self.backend.has_edge(source, target)
104  
105      def centrality(self):
106          rank = nx.degree_centrality(self.backend)
107          return dict(sorted(rank.items(), key=lambda x: x[1], reverse=True))
108  
109      def pagerank(self):
110          rank = nx.pagerank(self.backend, weight="weight")
111          return dict(sorted(rank.items(), key=lambda x: x[1], reverse=True))
112  
113      def showpath(self, source, target):
114          # pylint: disable=E1121
115          return nx.shortest_path(self.backend, source, target, self.distance)
116  
117      def isquery(self, queries):
118          return Query().isquery(queries)
119  
120      def parse(self, query):
121          return Query().parse(query)
122  
123      def search(self, query, limit=None, graph=False):
124          # Run graph query
125          results = Query()(self, query, limit)
126  
127          # Transform into filtered graph
128          if graph:
129              nodes = set()
130              for column in results.values():
131                  for value in column:
132                      if isinstance(value, list):
133                          # Path group
134                          nodes.update([node for node in value if node and not isinstance(node, dict)])
135                      elif isinstance(value, dict):
136                          # Nodes by id attribute
137                          nodes.update(uid for uid, attr in self.scan(data=True) if attr["id"] == value["id"])
138                      elif value is not None:
139                          # Single node id
140                          nodes.add(value)
141  
142              return self.filter(list(nodes))
143  
144          # Transform columnar structure into rows
145          keys = list(results.keys())
146          rows, count = [], len(results[keys[0]])
147  
148          for x in range(count):
149              rows.append({str(key): results[key][x] for key in keys})
150  
151          return rows
152  
153      def communities(self, config):
154          # Get community detection algorithm
155          algorithm = config.get("algorithm")
156  
157          if algorithm == "greedy":
158              communities = greedy_modularity_communities(self.backend, weight="weight", resolution=config.get("resolution", 100))
159          elif algorithm == "lpa":
160              communities = asyn_lpa_communities(self.backend, weight="weight", seed=0)
161          else:
162              communities = self.louvain(config)
163  
164          return communities
165  
166      def load(self, path):
167          try:
168              # Load graph data
169              data = SerializeFactory.create().load(path)
170  
171              # Add data to graph
172              self.backend = self.create()
173              self.backend.add_nodes_from(data["nodes"])
174              self.backend.add_edges_from(data["edges"])
175  
176              # Load categories
177              self.categories = data.get("categories")
178  
179              # Load topics
180              self.topics = data.get("topics")
181  
182          except SerializeError:
183              # Backwards compatible support for legacy TAR format
184              self.loadtar(path)
185  
186      def save(self, path):
187          # Save graph data
188          SerializeFactory.create().save(
189              {
190                  "nodes": [(uid, self.node(uid)) for uid in self.scan()],
191                  "edges": list(self.backend.edges(data=True)),
192                  "categories": self.categories,
193                  "topics": self.topics,
194              },
195              path,
196          )
197  
198      def loaddict(self, data):
199          self.backend = json_graph.node_link_graph(data, name="indexid")
200          self.categories, self.topics = data.get("categories"), data.get("topics")
201  
202      def savedict(self):
203          data = json_graph.node_link_data(self.backend, name="indexid")
204          data["categories"] = self.categories
205          data["topics"] = self.topics
206  
207          return data
208  
209      def louvain(self, config):
210          """
211          Runs the Louvain community detection algorithm.
212  
213          Args:
214              config: topic configuration
215  
216          Returns:
217              list of [ids] per community
218          """
219  
220          # Partition level to use
221          level = config.get("level", "best")
222  
223          # Run community detection
224          results = list(louvain_partitions(self.backend, weight="weight", resolution=config.get("resolution", 100), seed=0))
225  
226          # Get partition level (first or best)
227          return results[0] if level == "first" else results[-1]
228  
229      # pylint: disable=W0613
230      def distance(self, source, target, attrs):
231          """
232          Computes distance between source and target nodes using weight.
233  
234          Args:
235              source: source node
236              target: target node
237              attrs: edge attributes
238  
239          Returns:
240              distance between source and target
241          """
242  
243          # Distance is 1 - score. Skip minimal distances as they are near duplicates.
244          distance = max(1.0 - attrs["weight"], 0.0)
245          return distance if distance >= 0.15 else 1.00
246  
247      def loadtar(self, path):
248          """
249          Loads a graph from the legacy TAR file.
250  
251          Args:
252              path: path to graph
253          """
254  
255          # Pickle serialization - backwards compatible
256          serializer = SerializeFactory.create("pickle")
257  
258          # Extract files to temporary directory and load content
259          with TemporaryDirectory() as directory:
260              # Unpack files
261              archive = ArchiveFactory.create(directory)
262              archive.load(path, "tar")
263  
264              # Load graph backend
265              self.backend = serializer.load(f"{directory}/graph")
266  
267              # Load categories, if necessary
268              path = f"{directory}/categories"
269              if os.path.exists(path):
270                  self.categories = serializer.load(path)
271  
272              # Load topics, if necessary
273              path = f"{directory}/topics"
274              if os.path.exists(path):
275                  self.topics = serializer.load(path)