base.py
1 """ 2 Graph module 3 """ 4 5 from collections import Counter 6 7 from .topics import Topics 8 9 10 # pylint: disable=R0904 11 class Graph: 12 """ 13 Base class for Graph instances. This class builds graph networks. Supports topic modeling 14 and relationship traversal. 15 """ 16 17 def __init__(self, config): 18 """ 19 Creates a new Graph. 20 21 Args: 22 config: graph configuration 23 """ 24 25 # Graph configuration 26 self.config = config if config is not None else {} 27 28 # Graph backend 29 self.backend = None 30 31 # Topic modeling 32 self.categories = None 33 self.topics = None 34 35 # Transform columns 36 columns = config.get("columns", {}) 37 self.text = columns.get("text", "text") 38 self.object = columns.get("object", "object") 39 40 # Attributes to copy - skips text/object/relationship fields - set to True to copy all 41 self.copyattributes = config.get("copyattributes", False) 42 43 # Relationships are manually-provided edges 44 self.relationships = columns.get("relationships", "relationships") 45 self.relations = {} 46 47 def create(self): 48 """ 49 Creates the graph network. 50 """ 51 52 raise NotImplementedError 53 54 def count(self): 55 """ 56 Returns the total number of nodes in graph. 57 58 Returns: 59 total nodes in graph 60 """ 61 62 raise NotImplementedError 63 64 def scan(self, attribute=None, data=False): 65 """ 66 Iterates over nodes that match a criteria. If no criteria specified, all nodes 67 are returned. 68 69 Args: 70 attribute: if specified, nodes having this attribute are returned 71 data: if True, attribute data is also returned 72 73 Returns: 74 node id iterator if data is False or (id, attribute dictionary) iterator if data is True 75 """ 76 77 raise NotImplementedError 78 79 def node(self, node): 80 """ 81 Get node by id. Returns None if not found. 82 83 Args: 84 node: node id 85 86 Returns: 87 graph node 88 """ 89 90 raise NotImplementedError 91 92 def addnode(self, node, **attrs): 93 """ 94 Adds a node to the graph. 95 96 Args: 97 node: node id 98 attrs: node attributes 99 """ 100 101 raise NotImplementedError 102 103 def addnodes(self, nodes): 104 """ 105 Adds nodes to the graph. 106 107 Args: 108 nodes: list of (node, attributes) to add 109 """ 110 111 raise NotImplementedError 112 113 def removenode(self, node): 114 """ 115 Removes a node and all it's edges from graph. 116 117 Args: 118 node: node id 119 """ 120 121 raise NotImplementedError 122 123 def hasnode(self, node): 124 """ 125 Returns True if node found, False otherwise. 126 127 Args: 128 node: node id 129 130 Returns: 131 True if node found, False otherwise 132 """ 133 134 raise NotImplementedError 135 136 def attribute(self, node, field): 137 """ 138 Gets a node attribute. 139 140 Args: 141 node: node id 142 field: attribute name 143 144 Returns: 145 attribute value 146 """ 147 148 raise NotImplementedError 149 150 def addattribute(self, node, field, value): 151 """ 152 Adds an attribute to node. 153 154 Args: 155 node: node id 156 field: attribute name 157 value: attribute value 158 """ 159 160 raise NotImplementedError 161 162 def removeattribute(self, node, field): 163 """ 164 Removes an attribute from node. 165 166 Args: 167 node: node id 168 field: attribute name 169 170 Returns: 171 attribute value or None if not present 172 """ 173 174 raise NotImplementedError 175 176 def edgecount(self): 177 """ 178 Returns the total number of edges. 179 180 Returns: 181 total number of edges in graph 182 """ 183 184 raise NotImplementedError 185 186 def edges(self, node): 187 """ 188 Gets edges of node by id. 189 190 Args: 191 node: node id 192 193 Returns: 194 list of edge node ids 195 """ 196 197 raise NotImplementedError 198 199 def addedge(self, source, target, **attrs): 200 """ 201 Adds an edge to graph. 202 203 Args: 204 source: node 1 id 205 target: node 2 id 206 """ 207 208 raise NotImplementedError 209 210 def addedges(self, edges): 211 """ 212 Adds an edge to graph. 213 214 Args: 215 edges: list of (source, target, attributes) to add 216 """ 217 218 raise NotImplementedError 219 220 def hasedge(self, source, target=None): 221 """ 222 Returns True if edge found, False otherwise. If target is None, this method 223 returns True if any edge is found. 224 225 Args: 226 source: node 1 id 227 target: node 2 id 228 229 Returns: 230 True if edge found, False otherwise 231 """ 232 233 raise NotImplementedError 234 235 def centrality(self): 236 """ 237 Runs a centrality algorithm on the graph. 238 239 Returns: 240 dict of {node id: centrality score} 241 """ 242 243 raise NotImplementedError 244 245 def pagerank(self): 246 """ 247 Runs the pagerank algorithm on the graph. 248 249 Returns: 250 dict of {node id, page rank score} 251 """ 252 253 raise NotImplementedError 254 255 def showpath(self, source, target): 256 """ 257 Gets the shortest path between source and target. 258 259 Args: 260 source: start node id 261 target: end node id 262 263 Returns: 264 list of node ids representing the shortest path 265 """ 266 267 raise NotImplementedError 268 269 def isquery(self, queries): 270 """ 271 Checks if queries are supported graph queries. 272 273 Args: 274 queries: queries to check 275 276 Returns: 277 True if all the queries are supported graph queries, False otherwise 278 """ 279 280 raise NotImplementedError 281 282 def parse(self, query): 283 """ 284 Parses a graph query into query components. 285 286 Args: 287 query: graph query 288 289 Returns: 290 query components as a dictionary 291 """ 292 293 raise NotImplementedError 294 295 def search(self, query, limit=None, graph=False): 296 """ 297 Searches graph for nodes matching query. 298 299 Args: 300 query: graph query 301 limit: maximum results 302 graph: return graph results if True 303 304 Returns: 305 list of dict if graph is set to False 306 filtered graph if graph is set to True 307 """ 308 309 raise NotImplementedError 310 311 def batchsearch(self, queries, limit=None, graph=False): 312 """ 313 Searches graph for nodes matching query. 314 315 Args: 316 query: graph query 317 limit: maximum results 318 graph: return graph results if True 319 320 Returns: 321 list of dict if graph is set to False 322 filtered graph if graph is set to True 323 """ 324 325 return [self.search(query, limit, graph) for query in queries] 326 327 def communities(self, config): 328 """ 329 Run community detection on the graph. 330 331 Args: 332 config: configuration 333 334 Returns: 335 dictionary of {topic name:[ids]} 336 """ 337 338 raise NotImplementedError 339 340 def load(self, path): 341 """ 342 Loads a graph at path. 343 344 Args: 345 path: path to graph 346 """ 347 348 raise NotImplementedError 349 350 def save(self, path): 351 """ 352 Saves a graph at path. 353 354 Args: 355 path: path to save graph 356 """ 357 358 raise NotImplementedError 359 360 def loaddict(self, data): 361 """ 362 Loads data from input dictionary into this graph. 363 364 Args: 365 data: input dictionary 366 """ 367 368 raise NotImplementedError 369 370 def savedict(self): 371 """ 372 Saves graph data to a dictionary. 373 374 Returns: 375 dict 376 """ 377 378 raise NotImplementedError 379 380 def initialize(self): 381 """ 382 Initialize graph instance. 383 """ 384 385 if not self.backend: 386 self.backend = self.create() 387 388 def close(self): 389 """ 390 Closes this graph. 391 """ 392 393 self.backend, self.categories, self.topics = None, None, None 394 395 def insert(self, documents, index=0): 396 """ 397 Insert graph nodes for each document. 398 399 Args: 400 documents: list of (id, data, tags) 401 index: indexid offset, used for node ids 402 """ 403 404 # Initialize graph backend 405 self.initialize() 406 407 nodes = [] 408 for uid, document, _ in documents: 409 # Manually provided relationships and attributes to copy 410 relations, attributes = None, {} 411 412 # Extract data from dictionary 413 if isinstance(document, dict): 414 # Extract relationships 415 relations = document.get(self.relationships) 416 417 # Attributes to copy, if any 418 search = self.copyattributes if isinstance(self.copyattributes, list) else [] 419 attributes = { 420 k: v 421 for k, v in document.items() 422 if k not in [self.text, self.object, self.relationships] and (self.copyattributes is True or k in search) 423 } 424 425 # Require text or object field 426 document = document.get(self.text, document.get(self.object)) 427 428 if document is not None: 429 if isinstance(document, list): 430 # Join tokens as text 431 document = " ".join(document) 432 433 # Create node 434 nodes.append((index, {**{"id": uid, "data": document}, **attributes})) 435 436 # Add relationships 437 self.addrelations(index, relations) 438 439 index += 1 440 441 # Add nodes 442 self.addnodes(nodes) 443 444 def delete(self, ids): 445 """ 446 Deletes ids from graph. 447 448 Args: 449 ids: node ids to delete 450 """ 451 452 for node in ids: 453 # Remove existing node, if it exists 454 if self.hasnode(node): 455 # Delete from topics 456 topic = self.attribute(node, "topic") 457 if topic and self.topics: 458 # Delete id from topic 459 self.topics[topic].remove(node) 460 461 # Also delete topic, if it's empty 462 if not self.topics[topic]: 463 self.topics.pop(topic) 464 465 # Delete node 466 self.removenode(node) 467 468 def index(self, search, ids, similarity): 469 """ 470 Build relationships between graph nodes using a score-based search function. 471 472 Args: 473 search: batch search function - takes a list of queries and returns lists of (id, scores) to use as edge weights 474 ids: ids function - internal id resolver 475 similarity: batch similarity function - takes a list of text and labels and returns best matches 476 """ 477 478 # Add relationship edges 479 self.resolverelations(ids) 480 481 # Infer node edges using search function 482 self.inferedges(self.scan(), search) 483 484 # Label categories/topics 485 if "topics" in self.config: 486 self.addtopics(similarity) 487 488 def upsert(self, search, ids, similarity=None): 489 """ 490 Adds relationships for new graph nodes using a score-based search function. 491 492 Args: 493 search: batch search function - takes a list of queries and returns lists of (id, scores) to use as edge weights 494 ids: ids function - internal id resolver 495 similarity: batch similarity function - takes a list of text and labels and returns best matches 496 """ 497 498 # Detect if topics processing is enabled 499 hastopics = "topics" in self.config 500 501 # Add relationship edges 502 self.resolverelations(ids) 503 504 # Infer node edges using new/updated nodes, set updated flag for topic processing, if necessary 505 self.inferedges(self.scan(attribute="data"), search, {"updated": True} if hastopics else None) 506 507 # Infer topics with topics of connected nodes 508 if hastopics: 509 # Infer topics if there is at least one topic, otherwise rebuild 510 if self.topics: 511 self.infertopics() 512 else: 513 self.addtopics(similarity) 514 515 def filter(self, nodes, graph=None): 516 """ 517 Creates a subgraph of this graph using the list of input nodes. This method creates a new graph 518 selecting only matching nodes, edges, topics and categories. 519 520 Args: 521 nodes: nodes to select as a list of ids or list of (id, score) tuples 522 graph: optional graph used to store filtered results 523 524 Returns: 525 graph 526 """ 527 528 # Set graph if available, otherwise create a new empty graph of the same type 529 graph = graph if graph else type(self)(self.config) 530 531 # Initalize subgraph 532 graph.initialize() 533 534 nodeids = {node[0] if isinstance(node, tuple) else node for node in nodes} 535 for node in nodes: 536 # Unpack node and score, if available 537 node, score = node if isinstance(node, tuple) else (node, None) 538 539 # Add nodes 540 graph.addnode(node, **self.node(node)) 541 542 # Add score if present 543 if score is not None: 544 graph.addattribute(node, "score", score) 545 546 # Add edges 547 edges = self.edges(node) 548 if edges: 549 for target, attributes in self.edges(node).items(): 550 if target in nodeids: 551 graph.addedge(node, target, **attributes) 552 553 # Filter categories and topics 554 if self.topics: 555 topics = {} 556 for i, (topic, ids) in enumerate(self.topics.items()): 557 ids = [x for x in ids if x in nodeids] 558 if ids: 559 topics[topic] = (self.categories[i] if self.categories else None, ids) 560 561 # Sort by number of nodes descending 562 topics = sorted(topics.items(), key=lambda x: len(x[1][1]), reverse=True) 563 564 # Copy filtered categories and topics 565 graph.categories = [category for _, (category, _) in topics] if self.categories else None 566 graph.topics = {topic: ids for topic, (_, ids) in topics} 567 568 return graph 569 570 def addrelations(self, node, relations): 571 """ 572 Add manually-provided relationships. 573 574 Args: 575 node: node id 576 relations: list of relationships to add 577 """ 578 579 # Add relationships, if any 580 if relations: 581 if node not in self.relations: 582 self.relations[node] = [] 583 584 # Add each relationship 585 for relation in relations: 586 # Support both dict and string ids 587 relation = {"id": relation} if not isinstance(relation, dict) else relation 588 self.relations[node].append(relation) 589 590 def resolverelations(self, ids): 591 """ 592 Resolves ids and creates edges for manually-provided relationships. 593 594 Args: 595 ids: internal id resolver 596 """ 597 598 # Relationship edges 599 edges = [] 600 601 # Resolve ids and create edges for relationships 602 for node, relations in self.relations.items(): 603 # Resolve internal ids 604 iids = ids(y["id"] for y in relations) 605 606 # Add each edge 607 for relation in relations: 608 # Make copy of relation 609 relation = relation.copy() 610 611 # Lookup targets for relationship 612 targets = iids.get(str(relation.pop("id"))) 613 614 # Create edge for each instance of id - internal id pair 615 if targets: 616 for target in targets: 617 # Add weight, if not provided 618 relation["weight"] = relation.get("weight", 1.0) 619 620 # Add edge and all other attributes 621 edges.append((node, target, relation)) 622 623 # Add relationships 624 if edges: 625 self.addedges(edges) 626 627 # Clear temporary relationship storage 628 self.relations = {} 629 630 def inferedges(self, nodes, search, attributes=None): 631 """ 632 Infers edges for a list of nodes using a score-based search function. 633 634 Args: 635 nodes: list of nodes 636 search: search function to use to identify edges 637 attribute: dictionary of attributes to add to each node 638 """ 639 640 # Read graph parameters 641 batchsize, limit, minscore = self.config.get("batchsize", 256), self.config.get("limit", 15), self.config.get("minscore", 0.1) 642 approximate = self.config.get("approximate", True) 643 644 batch = [] 645 for node in nodes: 646 # Get data attribute 647 data = self.removeattribute(node, "data") 648 649 # Set text field when data is a string 650 if isinstance(data, str): 651 self.addattribute(node, "text", data) 652 653 # Add additional attributes, if specified 654 if attributes: 655 for field, value in attributes.items(): 656 self.addattribute(node, field, value) 657 658 # Skip nodes with existing edges when building an approximate network 659 if not approximate or not self.hasedge(node): 660 batch.append((node, data)) 661 662 # Process batch 663 if len(batch) == batchsize: 664 self.addbatch(search, batch, limit, minscore) 665 batch = [] 666 667 if batch: 668 self.addbatch(search, batch, limit, minscore) 669 670 def addbatch(self, search, batch, limit, minscore): 671 """ 672 Adds batch of documents to graph. This method runs the search function for each item in batch 673 and adds node edges between the input and each search result. 674 675 Args: 676 search: search function to use to identify edges 677 batch: batch to add 678 limit: max edges to add per node 679 minscore: min score to add node edge 680 """ 681 682 edges = [] 683 for x, result in enumerate(search([data for _, data in batch], limit)): 684 # Get input node id 685 x, _ = batch[x] 686 687 # Add edges for each input node id and result node id pair that meets specified criteria 688 for y, score in result: 689 if str(x) != str(y) and score > minscore: 690 edges.append((x, y, {"weight": score})) 691 692 self.addedges(edges) 693 694 def addtopics(self, similarity=None): 695 """ 696 Identifies and adds topics using community detection. 697 698 Args: 699 similarity: similarity function for labeling categories 700 """ 701 702 # Clear previous topics, if any 703 self.cleartopics() 704 705 # Use community detection to get topics 706 topics = Topics(self.config["topics"]) 707 config = topics.config 708 self.topics = topics(self) 709 710 # Label each topic with a higher level category 711 if "categories" in config and similarity: 712 self.categories = [] 713 results = similarity(self.topics.keys(), config["categories"]) 714 for result in results: 715 self.categories.append(config["categories"][result[0][0]]) 716 717 # Add topic-related node attributes 718 for x, topic in enumerate(self.topics): 719 for r, node in enumerate(self.topics[topic]): 720 self.addattribute(node, "topic", topic) 721 self.addattribute(node, "topicrank", r) 722 723 if self.categories: 724 self.addattribute(node, "category", self.categories[x]) 725 726 def cleartopics(self): 727 """ 728 Clears topic fields from all nodes. 729 """ 730 731 # Clear previous topics, if any 732 if self.topics: 733 for node in self.scan(): 734 self.removeattribute(node, "topic") 735 self.removeattribute(node, "topicrank") 736 737 if self.categories: 738 self.removeattribute(node, "category") 739 740 self.topics, self.categories = None, None 741 742 def infertopics(self): 743 """ 744 Infers topics for all nodes with an "updated" attribute. This method analyzes the direct node 745 neighbors and set the most commonly occuring topic and category for each node. 746 """ 747 748 # Iterate over nodes missing topic attribute (only occurs for new nodes) 749 for node in self.scan(attribute="updated"): 750 # Remove updated attribute 751 self.removeattribute(node, "updated") 752 753 # Get list of neighboring nodes 754 ids = self.edges(node) 755 ids = ids.keys() if ids else None 756 757 # Infer topic 758 topic = Counter(self.attribute(x, "topic") for x in ids).most_common(1)[0][0] if ids else None 759 if topic: 760 # Add id to topic list and set topic attribute 761 self.topics[topic].append(node) 762 self.addattribute(node, "topic", topic) 763 764 # Set topic rank 765 self.addattribute(node, "topicrank", len(self.topics[topic]) - 1) 766 767 # Infer category 768 category = Counter(self.attribute(x, "category") for x in ids).most_common(1)[0][0] 769 self.addattribute(node, "category", category)