networkx.py
1 """ 2 NetworkX module 3 """ 4 5 import os 6 7 from tempfile import TemporaryDirectory 8 9 # Conditional import 10 try: 11 import networkx as nx 12 13 from networkx.algorithms.community import asyn_lpa_communities, greedy_modularity_communities, louvain_partitions 14 from networkx.readwrite import json_graph 15 16 NETWORKX = True 17 except ImportError: 18 NETWORKX = False 19 20 from ..archive import ArchiveFactory 21 from ..serialize import SerializeError, SerializeFactory 22 23 from .base import Graph 24 from .query import Query 25 26 27 # pylint: disable=R0904 28 class NetworkX(Graph): 29 """ 30 Graph instance backed by NetworkX. 31 """ 32 33 def __init__(self, config): 34 super().__init__(config) 35 36 if not NETWORKX: 37 raise ImportError('NetworkX is not available - install "graph" extra to enable') 38 39 def create(self): 40 return nx.Graph() 41 42 def count(self): 43 return self.backend.number_of_nodes() 44 45 def scan(self, attribute=None, data=False): 46 # Full graph 47 graph = self.backend 48 49 # Filter graph to nodes having a specified attribute 50 if attribute: 51 graph = nx.subgraph_view(self.backend, filter_node=lambda x: attribute in self.node(x)) 52 53 # Return either list of matching ids or tuple of (id, attribute dictionary) 54 return graph.nodes(data=True) if data else graph 55 56 def node(self, node): 57 return self.backend.nodes.get(node) 58 59 def addnode(self, node, **attrs): 60 self.backend.add_node(node, **attrs) 61 62 def addnodes(self, nodes): 63 self.backend.add_nodes_from(nodes) 64 65 def removenode(self, node): 66 if self.hasnode(node): 67 self.backend.remove_node(node) 68 69 def hasnode(self, node): 70 return self.backend.has_node(node) 71 72 def attribute(self, node, field): 73 return self.node(node).get(field) if self.hasnode(node) else None 74 75 def addattribute(self, node, field, value): 76 if self.hasnode(node): 77 self.node(node)[field] = value 78 79 def removeattribute(self, node, field): 80 return self.node(node).pop(field, None) if self.hasnode(node) else None 81 82 def edgecount(self): 83 return self.backend.number_of_edges() 84 85 def edges(self, node): 86 edges = self.backend.adj.get(node) 87 if edges: 88 return dict(sorted(edges.items(), key=lambda x: x[1].get("weight", 0), reverse=True)) 89 90 return None 91 92 def addedge(self, source, target, **attrs): 93 self.backend.add_edge(source, target, **attrs) 94 95 def addedges(self, edges): 96 self.backend.add_edges_from(edges) 97 98 def hasedge(self, source, target=None): 99 if target is None: 100 edges = self.backend.adj.get(source) 101 return len(edges) > 0 if edges else False 102 103 return self.backend.has_edge(source, target) 104 105 def centrality(self): 106 rank = nx.degree_centrality(self.backend) 107 return dict(sorted(rank.items(), key=lambda x: x[1], reverse=True)) 108 109 def pagerank(self): 110 rank = nx.pagerank(self.backend, weight="weight") 111 return dict(sorted(rank.items(), key=lambda x: x[1], reverse=True)) 112 113 def showpath(self, source, target): 114 # pylint: disable=E1121 115 return nx.shortest_path(self.backend, source, target, self.distance) 116 117 def isquery(self, queries): 118 return Query().isquery(queries) 119 120 def parse(self, query): 121 return Query().parse(query) 122 123 def search(self, query, limit=None, graph=False): 124 # Run graph query 125 results = Query()(self, query, limit) 126 127 # Transform into filtered graph 128 if graph: 129 nodes = set() 130 for column in results.values(): 131 for value in column: 132 if isinstance(value, list): 133 # Path group 134 nodes.update([node for node in value if node and not isinstance(node, dict)]) 135 elif isinstance(value, dict): 136 # Nodes by id attribute 137 nodes.update(uid for uid, attr in self.scan(data=True) if attr["id"] == value["id"]) 138 elif value is not None: 139 # Single node id 140 nodes.add(value) 141 142 return self.filter(list(nodes)) 143 144 # Transform columnar structure into rows 145 keys = list(results.keys()) 146 rows, count = [], len(results[keys[0]]) 147 148 for x in range(count): 149 rows.append({str(key): results[key][x] for key in keys}) 150 151 return rows 152 153 def communities(self, config): 154 # Get community detection algorithm 155 algorithm = config.get("algorithm") 156 157 if algorithm == "greedy": 158 communities = greedy_modularity_communities(self.backend, weight="weight", resolution=config.get("resolution", 100)) 159 elif algorithm == "lpa": 160 communities = asyn_lpa_communities(self.backend, weight="weight", seed=0) 161 else: 162 communities = self.louvain(config) 163 164 return communities 165 166 def load(self, path): 167 try: 168 # Load graph data 169 data = SerializeFactory.create().load(path) 170 171 # Add data to graph 172 self.backend = self.create() 173 self.backend.add_nodes_from(data["nodes"]) 174 self.backend.add_edges_from(data["edges"]) 175 176 # Load categories 177 self.categories = data.get("categories") 178 179 # Load topics 180 self.topics = data.get("topics") 181 182 except SerializeError: 183 # Backwards compatible support for legacy TAR format 184 self.loadtar(path) 185 186 def save(self, path): 187 # Save graph data 188 SerializeFactory.create().save( 189 { 190 "nodes": [(uid, self.node(uid)) for uid in self.scan()], 191 "edges": list(self.backend.edges(data=True)), 192 "categories": self.categories, 193 "topics": self.topics, 194 }, 195 path, 196 ) 197 198 def loaddict(self, data): 199 self.backend = json_graph.node_link_graph(data, name="indexid") 200 self.categories, self.topics = data.get("categories"), data.get("topics") 201 202 def savedict(self): 203 data = json_graph.node_link_data(self.backend, name="indexid") 204 data["categories"] = self.categories 205 data["topics"] = self.topics 206 207 return data 208 209 def louvain(self, config): 210 """ 211 Runs the Louvain community detection algorithm. 212 213 Args: 214 config: topic configuration 215 216 Returns: 217 list of [ids] per community 218 """ 219 220 # Partition level to use 221 level = config.get("level", "best") 222 223 # Run community detection 224 results = list(louvain_partitions(self.backend, weight="weight", resolution=config.get("resolution", 100), seed=0)) 225 226 # Get partition level (first or best) 227 return results[0] if level == "first" else results[-1] 228 229 # pylint: disable=W0613 230 def distance(self, source, target, attrs): 231 """ 232 Computes distance between source and target nodes using weight. 233 234 Args: 235 source: source node 236 target: target node 237 attrs: edge attributes 238 239 Returns: 240 distance between source and target 241 """ 242 243 # Distance is 1 - score. Skip minimal distances as they are near duplicates. 244 distance = max(1.0 - attrs["weight"], 0.0) 245 return distance if distance >= 0.15 else 1.00 246 247 def loadtar(self, path): 248 """ 249 Loads a graph from the legacy TAR file. 250 251 Args: 252 path: path to graph 253 """ 254 255 # Pickle serialization - backwards compatible 256 serializer = SerializeFactory.create("pickle") 257 258 # Extract files to temporary directory and load content 259 with TemporaryDirectory() as directory: 260 # Unpack files 261 archive = ArchiveFactory.create(directory) 262 archive.load(path, "tar") 263 264 # Load graph backend 265 self.backend = serializer.load(f"{directory}/graph") 266 267 # Load categories, if necessary 268 path = f"{directory}/categories" 269 if os.path.exists(path): 270 self.categories = serializer.load(path) 271 272 # Load topics, if necessary 273 path = f"{directory}/topics" 274 if os.path.exists(path): 275 self.topics = serializer.load(path)