/ src / python / txtai / graph / topics.py
topics.py
  1  """
  2  Topics module
  3  """
  4  
  5  from ..pipeline import Tokenizer
  6  from ..scoring import ScoringFactory
  7  
  8  
  9  class Topics:
 10      """
 11      Topic modeling using community detection.
 12      """
 13  
 14      def __init__(self, config):
 15          """
 16          Creates a new Topics instance.
 17  
 18          Args:
 19              config: topic configuration
 20          """
 21  
 22          self.config = config if config else {}
 23          self.tokenizer = Tokenizer(stopwords=True)
 24  
 25          # Additional stopwords to ignore when building topic names
 26          self.stopwords = set()
 27          if "stopwords" in self.config:
 28              self.stopwords.update(self.config["stopwords"])
 29  
 30      def __call__(self, graph):
 31          """
 32          Runs topic modeling for input graph.
 33  
 34          Args:
 35              graph: Graph instance
 36  
 37          Returns:
 38              dictionary of {topic name: [ids]}
 39          """
 40  
 41          # Detect communities
 42          communities = graph.communities(self.config)
 43  
 44          # Sort by community size, largest to smallest
 45          communities = sorted(communities, key=len, reverse=True)
 46  
 47          # Calculate centrality of graph
 48          centrality = graph.centrality()
 49  
 50          # Score communities and generate topn terms
 51          topics = [self.score(graph, x, community, centrality) for x, community in enumerate(communities)]
 52  
 53          # Merge duplicate topics and return
 54          return self.merge(topics)
 55  
 56      def score(self, graph, index, community, centrality):
 57          """
 58          Scores a community of nodes and generates the topn terms in the community.
 59  
 60          Args:
 61              graph: Graph instance
 62              index: community index
 63              community: community of nodes
 64              centrality: node centrality scores
 65  
 66          Returns:
 67              (topn topic terms, topic ids sorted by score descending)
 68          """
 69  
 70          # Tokenize input and build scoring index
 71          scoring = ScoringFactory.create({"method": self.config.get("labels", "bm25"), "terms": True})
 72          scoring.index(((node, self.tokenize(graph, node), None) for node in community))
 73  
 74          # Check if scoring index has data
 75          if scoring.idf:
 76              # Sort by most commonly occurring terms (i.e. lowest score)
 77              idf = sorted(scoring.idf, key=scoring.idf.get)
 78  
 79              # Term count for generating topic labels
 80              topn = self.config.get("terms", 4)
 81  
 82              # Get topn terms
 83              terms = self.topn(idf, topn)
 84  
 85              # Sort community by score descending
 86              community = [uid for uid, _ in scoring.search(terms, len(community))]
 87          else:
 88              # No text found for topic, generate topic name
 89              terms = ["topic", str(index)]
 90  
 91              # Sort community by centrality scores
 92              community = sorted(community, key=lambda x: centrality[x], reverse=True)
 93  
 94          return (terms, community)
 95  
 96      def tokenize(self, graph, node):
 97          """
 98          Tokenizes node text.
 99  
100          Args:
101              graph: Graph instance
102              node: node id
103  
104          Returns:
105              list of node tokens
106          """
107  
108          text = graph.attribute(node, "text")
109          return self.tokenizer(text) if text else []
110  
111      def topn(self, terms, n):
112          """
113          Gets topn terms.
114  
115          Args:
116              terms: list of terms
117              n: topn
118  
119          Returns:
120              topn terms
121          """
122  
123          topn = []
124  
125          for term in terms:
126              # Add terms that pass tokenization rules
127              if self.tokenizer(term) and term not in self.stopwords:
128                  topn.append(term)
129  
130              # Break once topn terms collected
131              if len(topn) == n:
132                  break
133  
134          return topn
135  
136      def merge(self, topics):
137          """
138          Merges duplicate topics
139  
140          Args:
141              topics: list of (topn terms, topic ids)
142  
143          Returns:
144              dictionary of {topic name:[ids]}
145          """
146  
147          merge, termslist = {}, {}
148  
149          for terms, uids in topics:
150              # Use topic terms as key
151              key = frozenset(terms)
152  
153              # Add key to merged topics, if necessary
154              if key not in merge:
155                  merge[key], termslist[key] = [], terms
156  
157              # Merge communities
158              merge[key].extend(uids)
159  
160          # Sort communities largest to smallest since the order could have changed with merges
161          results = {}
162          for k, v in sorted(merge.items(), key=lambda x: len(x[1]), reverse=True):
163              # Create composite string key using topic terms and store ids
164              results["_".join(termslist[k])] = v
165  
166          return results