/ src / python / txtai / graph / base.py
base.py
  1  """
  2  Graph module
  3  """
  4  
  5  from collections import Counter
  6  
  7  from .topics import Topics
  8  
  9  
 10  # pylint: disable=R0904
 11  class Graph:
 12      """
 13      Base class for Graph instances. This class builds graph networks. Supports topic modeling
 14      and relationship traversal.
 15      """
 16  
 17      def __init__(self, config):
 18          """
 19          Creates a new Graph.
 20  
 21          Args:
 22              config: graph configuration
 23          """
 24  
 25          # Graph configuration
 26          self.config = config if config is not None else {}
 27  
 28          # Graph backend
 29          self.backend = None
 30  
 31          # Topic modeling
 32          self.categories = None
 33          self.topics = None
 34  
 35          # Transform columns
 36          columns = config.get("columns", {})
 37          self.text = columns.get("text", "text")
 38          self.object = columns.get("object", "object")
 39  
 40          # Attributes to copy - skips text/object/relationship fields - set to True to copy all
 41          self.copyattributes = config.get("copyattributes", False)
 42  
 43          # Relationships are manually-provided edges
 44          self.relationships = columns.get("relationships", "relationships")
 45          self.relations = {}
 46  
 47      def create(self):
 48          """
 49          Creates the graph network.
 50          """
 51  
 52          raise NotImplementedError
 53  
 54      def count(self):
 55          """
 56          Returns the total number of nodes in graph.
 57  
 58          Returns:
 59              total nodes in graph
 60          """
 61  
 62          raise NotImplementedError
 63  
 64      def scan(self, attribute=None, data=False):
 65          """
 66          Iterates over nodes that match a criteria. If no criteria specified, all nodes
 67          are returned.
 68  
 69          Args:
 70              attribute: if specified, nodes having this attribute are returned
 71              data: if True, attribute data is also returned
 72  
 73          Returns:
 74              node id iterator if data is False or (id, attribute dictionary) iterator if data is True
 75          """
 76  
 77          raise NotImplementedError
 78  
 79      def node(self, node):
 80          """
 81          Get node by id. Returns None if not found.
 82  
 83          Args:
 84              node: node id
 85  
 86          Returns:
 87              graph node
 88          """
 89  
 90          raise NotImplementedError
 91  
 92      def addnode(self, node, **attrs):
 93          """
 94          Adds a node to the graph.
 95  
 96          Args:
 97              node: node id
 98              attrs: node attributes
 99          """
100  
101          raise NotImplementedError
102  
103      def addnodes(self, nodes):
104          """
105          Adds nodes to the graph.
106  
107          Args:
108              nodes: list of (node, attributes) to add
109          """
110  
111          raise NotImplementedError
112  
113      def removenode(self, node):
114          """
115          Removes a node and all it's edges from graph.
116  
117          Args:
118              node: node id
119          """
120  
121          raise NotImplementedError
122  
123      def hasnode(self, node):
124          """
125          Returns True if node found, False otherwise.
126  
127          Args:
128              node: node id
129  
130          Returns:
131              True if node found, False otherwise
132          """
133  
134          raise NotImplementedError
135  
136      def attribute(self, node, field):
137          """
138          Gets a node attribute.
139  
140          Args:
141              node: node id
142              field: attribute name
143  
144          Returns:
145              attribute value
146          """
147  
148          raise NotImplementedError
149  
150      def addattribute(self, node, field, value):
151          """
152          Adds an attribute to node.
153  
154          Args:
155              node: node id
156              field: attribute name
157              value: attribute value
158          """
159  
160          raise NotImplementedError
161  
162      def removeattribute(self, node, field):
163          """
164          Removes an attribute from node.
165  
166          Args:
167              node: node id
168              field: attribute name
169  
170          Returns:
171              attribute value or None if not present
172          """
173  
174          raise NotImplementedError
175  
176      def edgecount(self):
177          """
178          Returns the total number of edges.
179  
180          Returns:
181              total number of edges in graph
182          """
183  
184          raise NotImplementedError
185  
186      def edges(self, node):
187          """
188          Gets edges of node by id.
189  
190          Args:
191              node: node id
192  
193          Returns:
194              list of edge node ids
195          """
196  
197          raise NotImplementedError
198  
199      def addedge(self, source, target, **attrs):
200          """
201          Adds an edge to graph.
202  
203          Args:
204              source: node 1 id
205              target: node 2 id
206          """
207  
208          raise NotImplementedError
209  
210      def addedges(self, edges):
211          """
212          Adds an edge to graph.
213  
214          Args:
215              edges: list of (source, target, attributes) to add
216          """
217  
218          raise NotImplementedError
219  
220      def hasedge(self, source, target=None):
221          """
222          Returns True if edge found, False otherwise. If target is None, this method
223          returns True if any edge is found.
224  
225          Args:
226              source: node 1 id
227              target: node 2 id
228  
229          Returns:
230              True if edge found, False otherwise
231          """
232  
233          raise NotImplementedError
234  
235      def centrality(self):
236          """
237          Runs a centrality algorithm on the graph.
238  
239          Returns:
240              dict of {node id: centrality score}
241          """
242  
243          raise NotImplementedError
244  
245      def pagerank(self):
246          """
247          Runs the pagerank algorithm on the graph.
248  
249          Returns:
250              dict of {node id, page rank score}
251          """
252  
253          raise NotImplementedError
254  
255      def showpath(self, source, target):
256          """
257          Gets the shortest path between source and target.
258  
259          Args:
260              source: start node id
261              target: end node id
262  
263          Returns:
264              list of node ids representing the shortest path
265          """
266  
267          raise NotImplementedError
268  
269      def isquery(self, queries):
270          """
271          Checks if queries are supported graph queries.
272  
273          Args:
274              queries: queries to check
275  
276          Returns:
277              True if all the queries are supported graph queries, False otherwise
278          """
279  
280          raise NotImplementedError
281  
282      def parse(self, query):
283          """
284          Parses a graph query into query components.
285  
286          Args:
287              query: graph query
288  
289          Returns:
290              query components as a dictionary
291          """
292  
293          raise NotImplementedError
294  
295      def search(self, query, limit=None, graph=False):
296          """
297          Searches graph for nodes matching query.
298  
299          Args:
300              query: graph query
301              limit: maximum results
302              graph: return graph results if True
303  
304          Returns:
305              list of dict if graph is set to False
306              filtered graph if graph is set to True
307          """
308  
309          raise NotImplementedError
310  
311      def batchsearch(self, queries, limit=None, graph=False):
312          """
313          Searches graph for nodes matching query.
314  
315          Args:
316              query: graph query
317              limit: maximum results
318              graph: return graph results if True
319  
320          Returns:
321              list of dict if graph is set to False
322              filtered graph if graph is set to True
323          """
324  
325          return [self.search(query, limit, graph) for query in queries]
326  
327      def communities(self, config):
328          """
329          Run community detection on the graph.
330  
331          Args:
332              config: configuration
333  
334          Returns:
335              dictionary of {topic name:[ids]}
336          """
337  
338          raise NotImplementedError
339  
340      def load(self, path):
341          """
342          Loads a graph at path.
343  
344          Args:
345              path: path to graph
346          """
347  
348          raise NotImplementedError
349  
350      def save(self, path):
351          """
352          Saves a graph at path.
353  
354          Args:
355              path: path to save graph
356          """
357  
358          raise NotImplementedError
359  
360      def loaddict(self, data):
361          """
362          Loads data from input dictionary into this graph.
363  
364          Args:
365              data: input dictionary
366          """
367  
368          raise NotImplementedError
369  
370      def savedict(self):
371          """
372          Saves graph data to a dictionary.
373  
374          Returns:
375              dict
376          """
377  
378          raise NotImplementedError
379  
380      def initialize(self):
381          """
382          Initialize graph instance.
383          """
384  
385          if not self.backend:
386              self.backend = self.create()
387  
388      def close(self):
389          """
390          Closes this graph.
391          """
392  
393          self.backend, self.categories, self.topics = None, None, None
394  
395      def insert(self, documents, index=0):
396          """
397          Insert graph nodes for each document.
398  
399          Args:
400              documents: list of (id, data, tags)
401              index: indexid offset, used for node ids
402          """
403  
404          # Initialize graph backend
405          self.initialize()
406  
407          nodes = []
408          for uid, document, _ in documents:
409              # Manually provided relationships and attributes to copy
410              relations, attributes = None, {}
411  
412              # Extract data from dictionary
413              if isinstance(document, dict):
414                  # Extract relationships
415                  relations = document.get(self.relationships)
416  
417                  # Attributes to copy, if any
418                  search = self.copyattributes if isinstance(self.copyattributes, list) else []
419                  attributes = {
420                      k: v
421                      for k, v in document.items()
422                      if k not in [self.text, self.object, self.relationships] and (self.copyattributes is True or k in search)
423                  }
424  
425                  # Require text or object field
426                  document = document.get(self.text, document.get(self.object))
427  
428              if document is not None:
429                  if isinstance(document, list):
430                      # Join tokens as text
431                      document = " ".join(document)
432  
433                  # Create node
434                  nodes.append((index, {**{"id": uid, "data": document}, **attributes}))
435  
436                  # Add relationships
437                  self.addrelations(index, relations)
438  
439                  index += 1
440  
441          # Add nodes
442          self.addnodes(nodes)
443  
444      def delete(self, ids):
445          """
446          Deletes ids from graph.
447  
448          Args:
449              ids: node ids to delete
450          """
451  
452          for node in ids:
453              # Remove existing node, if it exists
454              if self.hasnode(node):
455                  # Delete from topics
456                  topic = self.attribute(node, "topic")
457                  if topic and self.topics:
458                      # Delete id from topic
459                      self.topics[topic].remove(node)
460  
461                      # Also delete topic, if it's empty
462                      if not self.topics[topic]:
463                          self.topics.pop(topic)
464  
465                  # Delete node
466                  self.removenode(node)
467  
468      def index(self, search, ids, similarity):
469          """
470          Build relationships between graph nodes using a score-based search function.
471  
472          Args:
473              search: batch search function - takes a list of queries and returns lists of (id, scores) to use as edge weights
474              ids: ids function - internal id resolver
475              similarity: batch similarity function - takes a list of text and labels and returns best matches
476          """
477  
478          # Add relationship edges
479          self.resolverelations(ids)
480  
481          # Infer node edges using search function
482          self.inferedges(self.scan(), search)
483  
484          # Label categories/topics
485          if "topics" in self.config:
486              self.addtopics(similarity)
487  
488      def upsert(self, search, ids, similarity=None):
489          """
490          Adds relationships for new graph nodes using a score-based search function.
491  
492          Args:
493              search: batch search function - takes a list of queries and returns lists of (id, scores) to use as edge weights
494              ids: ids function - internal id resolver
495              similarity: batch similarity function - takes a list of text and labels and returns best matches
496          """
497  
498          # Detect if topics processing is enabled
499          hastopics = "topics" in self.config
500  
501          # Add relationship edges
502          self.resolverelations(ids)
503  
504          # Infer node edges using new/updated nodes, set updated flag for topic processing, if necessary
505          self.inferedges(self.scan(attribute="data"), search, {"updated": True} if hastopics else None)
506  
507          # Infer topics with topics of connected nodes
508          if hastopics:
509              # Infer topics if there is at least one topic, otherwise rebuild
510              if self.topics:
511                  self.infertopics()
512              else:
513                  self.addtopics(similarity)
514  
515      def filter(self, nodes, graph=None):
516          """
517          Creates a subgraph of this graph using the list of input nodes. This method creates a new graph
518          selecting only matching nodes, edges, topics and categories.
519  
520          Args:
521              nodes: nodes to select as a list of ids or list of (id, score) tuples
522              graph: optional graph used to store filtered results
523  
524          Returns:
525              graph
526          """
527  
528          # Set graph if available, otherwise create a new empty graph of the same type
529          graph = graph if graph else type(self)(self.config)
530  
531          # Initalize subgraph
532          graph.initialize()
533  
534          nodeids = {node[0] if isinstance(node, tuple) else node for node in nodes}
535          for node in nodes:
536              # Unpack node and score, if available
537              node, score = node if isinstance(node, tuple) else (node, None)
538  
539              # Add nodes
540              graph.addnode(node, **self.node(node))
541  
542              # Add score if present
543              if score is not None:
544                  graph.addattribute(node, "score", score)
545  
546              # Add edges
547              edges = self.edges(node)
548              if edges:
549                  for target, attributes in self.edges(node).items():
550                      if target in nodeids:
551                          graph.addedge(node, target, **attributes)
552  
553          # Filter categories and topics
554          if self.topics:
555              topics = {}
556              for i, (topic, ids) in enumerate(self.topics.items()):
557                  ids = [x for x in ids if x in nodeids]
558                  if ids:
559                      topics[topic] = (self.categories[i] if self.categories else None, ids)
560  
561              # Sort by number of nodes descending
562              topics = sorted(topics.items(), key=lambda x: len(x[1][1]), reverse=True)
563  
564              # Copy filtered categories and topics
565              graph.categories = [category for _, (category, _) in topics] if self.categories else None
566              graph.topics = {topic: ids for topic, (_, ids) in topics}
567  
568          return graph
569  
570      def addrelations(self, node, relations):
571          """
572          Add manually-provided relationships.
573  
574          Args:
575              node: node id
576              relations: list of relationships to add
577          """
578  
579          # Add relationships, if any
580          if relations:
581              if node not in self.relations:
582                  self.relations[node] = []
583  
584              # Add each relationship
585              for relation in relations:
586                  # Support both dict and string ids
587                  relation = {"id": relation} if not isinstance(relation, dict) else relation
588                  self.relations[node].append(relation)
589  
590      def resolverelations(self, ids):
591          """
592          Resolves ids and creates edges for manually-provided relationships.
593  
594          Args:
595              ids: internal id resolver
596          """
597  
598          # Relationship edges
599          edges = []
600  
601          # Resolve ids and create edges for relationships
602          for node, relations in self.relations.items():
603              # Resolve internal ids
604              iids = ids(y["id"] for y in relations)
605  
606              # Add each edge
607              for relation in relations:
608                  # Make copy of relation
609                  relation = relation.copy()
610  
611                  # Lookup targets for relationship
612                  targets = iids.get(str(relation.pop("id")))
613  
614                  # Create edge for each instance of id - internal id pair
615                  if targets:
616                      for target in targets:
617                          # Add weight, if not provided
618                          relation["weight"] = relation.get("weight", 1.0)
619  
620                          # Add edge and all other attributes
621                          edges.append((node, target, relation))
622  
623          # Add relationships
624          if edges:
625              self.addedges(edges)
626  
627          # Clear temporary relationship storage
628          self.relations = {}
629  
630      def inferedges(self, nodes, search, attributes=None):
631          """
632          Infers edges for a list of nodes using a score-based search function.
633  
634          Args:
635              nodes: list of nodes
636              search: search function to use to identify edges
637              attribute: dictionary of attributes to add to each node
638          """
639  
640          # Read graph parameters
641          batchsize, limit, minscore = self.config.get("batchsize", 256), self.config.get("limit", 15), self.config.get("minscore", 0.1)
642          approximate = self.config.get("approximate", True)
643  
644          batch = []
645          for node in nodes:
646              # Get data attribute
647              data = self.removeattribute(node, "data")
648  
649              # Set text field when data is a string
650              if isinstance(data, str):
651                  self.addattribute(node, "text", data)
652  
653              # Add additional attributes, if specified
654              if attributes:
655                  for field, value in attributes.items():
656                      self.addattribute(node, field, value)
657  
658              # Skip nodes with existing edges when building an approximate network
659              if not approximate or not self.hasedge(node):
660                  batch.append((node, data))
661  
662              # Process batch
663              if len(batch) == batchsize:
664                  self.addbatch(search, batch, limit, minscore)
665                  batch = []
666  
667          if batch:
668              self.addbatch(search, batch, limit, minscore)
669  
670      def addbatch(self, search, batch, limit, minscore):
671          """
672          Adds batch of documents to graph. This method runs the search function for each item in batch
673          and adds node edges between the input and each search result.
674  
675          Args:
676              search: search function to use to identify edges
677              batch: batch to add
678              limit: max edges to add per node
679              minscore: min score to add node edge
680          """
681  
682          edges = []
683          for x, result in enumerate(search([data for _, data in batch], limit)):
684              # Get input node id
685              x, _ = batch[x]
686  
687              # Add edges for each input node id and result node id pair that meets specified criteria
688              for y, score in result:
689                  if str(x) != str(y) and score > minscore:
690                      edges.append((x, y, {"weight": score}))
691  
692          self.addedges(edges)
693  
694      def addtopics(self, similarity=None):
695          """
696          Identifies and adds topics using community detection.
697  
698          Args:
699              similarity: similarity function for labeling categories
700          """
701  
702          # Clear previous topics, if any
703          self.cleartopics()
704  
705          # Use community detection to get topics
706          topics = Topics(self.config["topics"])
707          config = topics.config
708          self.topics = topics(self)
709  
710          # Label each topic with a higher level category
711          if "categories" in config and similarity:
712              self.categories = []
713              results = similarity(self.topics.keys(), config["categories"])
714              for result in results:
715                  self.categories.append(config["categories"][result[0][0]])
716  
717          # Add topic-related node attributes
718          for x, topic in enumerate(self.topics):
719              for r, node in enumerate(self.topics[topic]):
720                  self.addattribute(node, "topic", topic)
721                  self.addattribute(node, "topicrank", r)
722  
723                  if self.categories:
724                      self.addattribute(node, "category", self.categories[x])
725  
726      def cleartopics(self):
727          """
728          Clears topic fields from all nodes.
729          """
730  
731          # Clear previous topics, if any
732          if self.topics:
733              for node in self.scan():
734                  self.removeattribute(node, "topic")
735                  self.removeattribute(node, "topicrank")
736  
737                  if self.categories:
738                      self.removeattribute(node, "category")
739  
740              self.topics, self.categories = None, None
741  
742      def infertopics(self):
743          """
744          Infers topics for all nodes with an "updated" attribute. This method analyzes the direct node
745          neighbors and set the most commonly occuring topic and category for each node.
746          """
747  
748          # Iterate over nodes missing topic attribute (only occurs for new nodes)
749          for node in self.scan(attribute="updated"):
750              # Remove updated attribute
751              self.removeattribute(node, "updated")
752  
753              # Get list of neighboring nodes
754              ids = self.edges(node)
755              ids = ids.keys() if ids else None
756  
757              # Infer topic
758              topic = Counter(self.attribute(x, "topic") for x in ids).most_common(1)[0][0] if ids else None
759              if topic:
760                  # Add id to topic list and set topic attribute
761                  self.topics[topic].append(node)
762                  self.addattribute(node, "topic", topic)
763  
764                  # Set topic rank
765                  self.addattribute(node, "topicrank", len(self.topics[topic]) - 1)
766  
767                  # Infer category
768                  category = Counter(self.attribute(x, "category") for x in ids).most_common(1)[0][0]
769                  self.addattribute(node, "category", category)