/ test / python / testgraph.py
testgraph.py
  1  """
  2  Graph module tests
  3  """
  4  
  5  import os
  6  import itertools
  7  import tempfile
  8  import unittest
  9  
 10  from unittest.mock import patch
 11  
 12  from txtai.archive import ArchiveFactory
 13  from txtai.embeddings import Embeddings
 14  from txtai.graph import Graph, GraphFactory
 15  from txtai.serialize import SerializeFactory
 16  
 17  
 18  # pylint: disable=R0904
 19  class TestGraph(unittest.TestCase):
 20      """
 21      Graph tests.
 22      """
 23  
 24      @classmethod
 25      def setUpClass(cls):
 26          """
 27          Initialize test data.
 28          """
 29  
 30          cls.data = [
 31              "US tops 5 million confirmed virus cases",
 32              "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg",
 33              "Beijing mobilises invasion craft along coast as Taiwan tensions escalate",
 34              "The National Park Service warns against sacrificing slower friends in a bear attack",
 35              "Maine man wins $1M from $25 lottery ticket",
 36              "Make huge profits without work, earn up to $100,000 a day",
 37          ]
 38  
 39          cls.config = {
 40              "path": "sentence-transformers/nli-mpnet-base-v2",
 41              "content": True,
 42              "functions": [{"name": "graph", "function": "graph.attribute"}],
 43              "expressions": [
 44                  {"name": "category", "expression": "graph(indexid, 'category')"},
 45                  {"name": "topic", "expression": "graph(indexid, 'topic')"},
 46                  {"name": "topicrank", "expression": "graph(indexid, 'topicrank')"},
 47              ],
 48              "graph": {"limit": 5, "minscore": 0.2, "batchsize": 4, "approximate": False, "topics": {"categories": ["News"], "stopwords": ["the"]}},
 49          }
 50  
 51          # Create embeddings instance
 52          cls.embeddings = Embeddings(cls.config)
 53  
 54      def testAnalysis(self):
 55          """
 56          Test analysis methods
 57          """
 58  
 59          # Create an index for the list of text
 60          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
 61  
 62          # Graph centrality
 63          graph = self.embeddings.graph
 64          centrality = graph.centrality()
 65          self.assertEqual(list(centrality.keys())[0], 5)
 66  
 67          # Page Rank
 68          pagerank = graph.pagerank()
 69          self.assertEqual(list(pagerank.keys())[0], 5)
 70  
 71          # Path between nodes
 72          path = graph.showpath(4, 5)
 73          self.assertEqual(len(path), 2)
 74  
 75      def testCommunity(self):
 76          """
 77          Test community detection
 78          """
 79  
 80          # Create an index for the list of text
 81          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
 82  
 83          # Get graph reference
 84          graph = self.embeddings.graph
 85  
 86          # Rebuild topics with updated graph settings
 87          graph.config = {"topics": {"algorithm": "greedy"}}
 88          graph.addtopics()
 89          self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 6)
 90  
 91          graph.config = {"topics": {"algorithm": "lpa"}}
 92          graph.addtopics()
 93          self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 4)
 94  
 95      def testCustomBackend(self):
 96          """
 97          Test resolving a custom backend
 98          """
 99  
100          graph = GraphFactory.create({"backend": "txtai.graph.NetworkX"})
101          graph.initialize()
102          self.assertIsNotNone(graph)
103  
104      def testCustomBackendNotFound(self):
105          """
106          Test resolving an unresolvable backend
107          """
108  
109          with self.assertRaises(ImportError):
110              graph = GraphFactory.create({"backend": "notfound.graph"})
111              graph.initialize()
112  
113      def testDatabase(self):
114          """
115          Test creating a Graph backed by a relational database
116          """
117  
118          # Generate graph database
119          path = os.path.join(tempfile.gettempdir(), "graph.sqlite")
120          graph = GraphFactory.create({"backend": "rdbms", "url": f"sqlite:///{path}", "schema": "txtai"})
121  
122          # Initialize the graph
123          graph.initialize()
124  
125          for x in range(5):
126              graph.addnode(x, field=x)
127  
128          for x, y in itertools.combinations(range(5), 2):
129              graph.addedge(x, y)
130  
131          # Test methods
132          self.assertEqual(list(graph.scan()), [str(x) for x in range(5)])
133          self.assertEqual(list(graph.scan(attribute="field")), [str(x) for x in range(5)])
134          self.assertEqual(list(graph.filter([0]).scan()), [0])
135  
136          # Test save/load
137          graph.save(None)
138          graph.load(None)
139          self.assertEqual(list(graph.scan()), [str(x) for x in range(5)])
140  
141          # Test remove node
142          graph.delete([0])
143          self.assertFalse(graph.hasnode(0))
144          self.assertFalse(graph.hasedge(0))
145  
146          # Close graph
147          graph.close()
148  
149      def testDefault(self):
150          """
151          Test embeddings default graph setting
152          """
153  
154          embeddings = Embeddings(content=True, graph=True)
155          embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
156  
157          self.assertEqual(embeddings.graph.count(), len(self.data))
158  
159      def testDelete(self):
160          """
161          Test delete
162          """
163  
164          # Create an index for the list of text
165          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
166  
167          # Delete row
168          self.embeddings.delete([4])
169  
170          # Validate counts
171          graph = self.embeddings.graph
172          self.assertEqual(graph.count(), 5)
173          self.assertEqual(graph.edgecount(), 1)
174          self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 5)
175          self.assertEqual(len(graph.categories), 6)
176  
177      def testEdges(self):
178          """
179          Test edges
180          """
181  
182          # Create graph
183          graph = GraphFactory.create({})
184          graph.initialize()
185          graph.addedge(0, 1)
186  
187          # Test edge exists
188          self.assertTrue(graph.hasedge(0))
189          self.assertTrue(graph.hasedge(0, 1))
190  
191      def testFilter(self):
192          """
193          Test creating filtered subgraphs
194          """
195  
196          # Create an index for the list of text
197          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
198  
199          # Validate counts
200          graph = self.embeddings.search("feel good story", graph=True)
201          self.assertEqual(graph.count(), 3)
202          self.assertEqual(graph.edgecount(), 2)
203  
204      def testFunction(self):
205          """
206          Test running graph functions with SQL
207          """
208  
209          # Create an index for the list of text
210          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
211  
212          # Test function
213          result = self.embeddings.search("select category, topic, topicrank from txtai where id = 0", 1)[0]
214  
215          # Check columns have a value
216          self.assertIsNotNone(result["category"])
217          self.assertIsNotNone(result["topic"])
218          self.assertIsNotNone(result["topicrank"])
219  
220      def testFunctionReindex(self):
221          """
222          Test running graph functions with SQL after reindex
223          """
224  
225          # Create an index for the list of text
226          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
227  
228          # Test functions reset with a reindex
229          self.embeddings.reindex(self.embeddings.config)
230  
231          # Test function
232          result = self.embeddings.search("select category, topic, topicrank from txtai where id = 0", 1)[0]
233  
234          # Check columns have a value
235          self.assertIsNotNone(result["category"])
236          self.assertIsNotNone(result["topic"])
237          self.assertIsNotNone(result["topicrank"])
238  
239      def testIndex(self):
240          """
241          Test index
242          """
243  
244          # Create an index for the list of text
245          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
246  
247          # Validate counts
248          graph = self.embeddings.graph
249          self.assertEqual(graph.count(), 6)
250          self.assertEqual(graph.edgecount(), 2)
251          self.assertEqual(len(graph.topics), 6)
252          self.assertEqual(len(graph.categories), 6)
253  
254      @patch.dict(os.environ, {"ALLOW_PICKLE": "True"})
255      def testLegacy(self):
256          """
257          Test loading a legacy graph in TAR format
258          """
259  
260          # Create graph
261          graph = GraphFactory.create({})
262          graph.initialize()
263          graph.addedge(0, 1)
264  
265          categories = ["C1"]
266          topics = {"T1": [0, 1]}
267  
268          serializer = SerializeFactory.create("pickle", allowpickle=True)
269  
270          # Save files to temporary directory and combine into TAR
271          path = os.path.join(tempfile.gettempdir(), "graph.tar")
272          with tempfile.TemporaryDirectory() as directory:
273              # Save graph
274              serializer.save(graph.backend, f"{directory}/graph")
275  
276              # Save categories, if necessary
277              serializer.save(categories, f"{directory}/categories")
278  
279              # Save topics, if necessary
280              serializer.save(topics, f"{directory}/topics")
281  
282              # Pack files
283              archive = ArchiveFactory.create(directory)
284              archive.save(path, "tar")
285  
286          # Load loading legacy format
287          graph = GraphFactory.create({})
288          graph.load(path)
289  
290          # Validate graph data is correct
291          self.assertEqual(graph.count(), 2)
292          self.assertEqual(graph.edgecount(), 1)
293          self.assertEqual(graph.topics, topics)
294          self.assertEqual(graph.categories, categories)
295  
296      def testNotImplemented(self):
297          """
298          Test exceptions for non-implemented methods
299          """
300  
301          graph = Graph({})
302  
303          self.assertRaises(NotImplementedError, graph.create)
304          self.assertRaises(NotImplementedError, graph.count)
305          self.assertRaises(NotImplementedError, graph.scan, None)
306          self.assertRaises(NotImplementedError, graph.node, None)
307          self.assertRaises(NotImplementedError, graph.addnode, None)
308          self.assertRaises(NotImplementedError, graph.addnodes, None)
309          self.assertRaises(NotImplementedError, graph.removenode, None)
310          self.assertRaises(NotImplementedError, graph.hasnode, None)
311          self.assertRaises(NotImplementedError, graph.attribute, None, None)
312          self.assertRaises(NotImplementedError, graph.addattribute, None, None, None)
313          self.assertRaises(NotImplementedError, graph.removeattribute, None, None)
314          self.assertRaises(NotImplementedError, graph.edgecount)
315          self.assertRaises(NotImplementedError, graph.edges, None)
316          self.assertRaises(NotImplementedError, graph.addedge, None, None)
317          self.assertRaises(NotImplementedError, graph.addedges, None)
318          self.assertRaises(NotImplementedError, graph.hasedge, None, None)
319          self.assertRaises(NotImplementedError, graph.centrality)
320          self.assertRaises(NotImplementedError, graph.pagerank)
321          self.assertRaises(NotImplementedError, graph.showpath, None, None)
322          self.assertRaises(NotImplementedError, graph.isquery, None)
323          self.assertRaises(NotImplementedError, graph.parse, None)
324          self.assertRaises(NotImplementedError, graph.search, None)
325          self.assertRaises(NotImplementedError, graph.communities, None)
326          self.assertRaises(NotImplementedError, graph.load, None)
327          self.assertRaises(NotImplementedError, graph.save, None)
328          self.assertRaises(NotImplementedError, graph.loaddict, None)
329          self.assertRaises(NotImplementedError, graph.savedict)
330  
331      def testRelationships(self):
332          """
333          Test manually-provided relationships
334          """
335  
336          # Create relationships for id 0
337          relationships = [{"id": f"ID{x}"} for x in range(1, len(self.data))]
338  
339          # Test with content enabled
340          self.embeddings.index({"id": f"ID{i}", "text": x, "relationships": relationships if i == 0 else None} for i, x in enumerate(self.data))
341          self.assertEqual(len(self.embeddings.graph.edges(0)), len(self.data) - 1)
342  
343          # Test with content disabled
344          config = self.config.copy()
345          config["content"] = False
346  
347          embeddings = Embeddings(config)
348          embeddings.index({"id": f"ID{i}", "text": x, "relationships": relationships if i == 0 else None} for i, x in enumerate(self.data))
349          self.assertEqual(len(embeddings.graph.edges(0)), len(self.data) - 1)
350          embeddings.close()
351  
352      def testRelationshipsInvalid(self):
353          """
354          Test manually-provided relationships with no matching id
355          """
356  
357          # Create relationships for id 0
358          relationships = [{"id": "INVALID"}]
359  
360          # Index with invalid relationship
361          self.embeddings.index({"text": x, "relationships": relationships if i == 0 else None} for i, x in enumerate(self.data))
362  
363          # Validate only relationship is semantically-derived
364          edges = list(self.embeddings.graph.edges(0))
365          self.assertTrue(len(edges) == 1 and edges[0] != "INVALID")
366  
367      def testResetTopics(self):
368          """
369          Test resetting of topics
370          """
371  
372          # Create an index for the list of text
373          self.embeddings.index([(1, "text", None)])
374          self.embeddings.upsert([(1, "graph", None)])
375          self.assertEqual(list(self.embeddings.graph.topics.keys()), ["graph"])
376  
377      def testSave(self):
378          """
379          Test save
380          """
381  
382          # Create an index for the list of text
383          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
384  
385          # Generate temp file path
386          index = os.path.join(tempfile.gettempdir(), "graph")
387  
388          # Save and reload index
389          self.embeddings.save(index)
390          self.embeddings.load(index)
391  
392          # Validate counts
393          graph = self.embeddings.graph
394          self.assertEqual(graph.count(), 6)
395          self.assertEqual(graph.edgecount(), 2)
396          self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 6)
397          self.assertEqual(len(graph.categories), 6)
398  
399      def testSaveDict(self):
400          """
401          Test loading and saving to dictionaries
402          """
403  
404          # Create an index for the list of text
405          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
406  
407          # Validate counts
408          graph = self.embeddings.graph
409          count, edgecount = graph.count(), graph.edgecount()
410  
411          # Save and reload graph as dict
412          data = graph.savedict()
413          graph.loaddict(data)
414  
415          # Validate counts
416          self.assertEqual(graph.count(), count)
417          self.assertEqual(graph.edgecount(), edgecount)
418  
419      def testSearch(self):
420          """
421          Test search
422          """
423  
424          # Create an index for the list of text
425          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
426  
427          # Run standard search
428          results = self.embeddings.search(
429              """
430              MATCH (A)-[]->(B)
431              RETURN A, B
432          """
433          )
434          self.assertEqual(len(results), 3)
435  
436          # Run path search
437          results = self.embeddings.search(
438              """
439              MATCH P=()-[]->()
440              RETURN P
441          """
442          )
443          self.assertEqual(len(results), 3)
444  
445          # Run graph search
446          g = self.embeddings.search(
447              """
448              MATCH (A)-[]->(B)
449              RETURN A, ID(B)
450          """,
451              graph=True,
452          )
453          self.assertEqual(g.count(), 3)
454  
455          # Run path search
456          results = self.embeddings.search(
457              """
458              MATCH P=()-[]->()
459              RETURN P
460          """,
461              graph=True,
462          )
463          self.assertEqual(g.count(), 3)
464  
465          # Run similar search
466          results = self.embeddings.search(
467              """
468              MATCH P=(A)-[]->()
469              WHERE SIMILAR(A, "feel good story")
470              RETURN A
471              ORDER BY A.score DESC
472              LIMIT 1
473          """,
474              graph=True,
475          )
476          self.assertEqual(list(results.scan())[0], 4)
477  
478      def testSearchBatch(self):
479          """
480          Test batch search
481          """
482  
483          # Create an index for the list of text
484          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
485  
486          # Run standard search
487          results = self.embeddings.batchsearch(
488              [
489                  """
490              MATCH (A)-[]->(B)
491              RETURN A, B
492          """
493              ]
494          )
495          self.assertEqual(len(results[0]), 3)
496  
497      def testSimple(self):
498          """
499          Test creating a simple graph
500          """
501  
502          graph = GraphFactory.create({"topics": {}})
503  
504          # Initialize the graph
505          graph.initialize()
506  
507          for x in range(5):
508              graph.addnode(x)
509  
510          for x, y in itertools.combinations(range(5), 2):
511              graph.addedge(x, y)
512  
513          # Validate counts
514          self.assertEqual(graph.count(), 5)
515          self.assertEqual(graph.edgecount(), 10)
516  
517          # Test missing edge
518          self.assertIsNone(graph.edges(100))
519  
520          # Test topics with no text
521          graph.addtopics()
522          self.assertEqual(len(graph.topics), 5)
523  
524      def testSubindex(self):
525          """
526          Test subindex
527          """
528  
529          # Build data array
530          data = [(uid, text, None) for uid, text in enumerate(self.data)]
531  
532          embeddings = Embeddings(
533              {
534                  "content": True,
535                  "functions": [{"name": "graph", "function": "indexes.index1.graph.attribute"}],
536                  "expressions": [
537                      {"name": "category", "expression": "graph(indexid, 'category')"},
538                      {"name": "topic", "expression": "graph(indexid, 'topic')"},
539                      {"name": "topicrank", "expression": "graph(indexid, 'topicrank')"},
540                  ],
541                  "indexes": {
542                      "index1": {
543                          "path": "sentence-transformers/nli-mpnet-base-v2",
544                          "graph": {
545                              "limit": 5,
546                              "minscore": 0.2,
547                              "batchsize": 4,
548                              "approximate": False,
549                              "topics": {"categories": ["News"], "stopwords": ["the"]},
550                          },
551                      }
552                  },
553              }
554          )
555  
556          # Create an index for the list of text
557          embeddings.index(data)
558  
559          # Test function
560          result = embeddings.search("select id, category, topic, topicrank from txtai where id = 0", 1)[0]
561  
562          # Check columns have a value
563          self.assertIsNotNone(result["category"])
564          self.assertIsNotNone(result["topic"])
565          self.assertIsNotNone(result["topicrank"])
566  
567          # Update data
568          data[0] = (0, "Feel good story: lottery winner announced", None)
569          embeddings.upsert([data[0]])
570  
571          # Test function
572          result = embeddings.search("select id, category, topic, topicrank from txtai where id = 0", 1)[0]
573  
574          # Check columns have a value
575          self.assertIsNotNone(result["category"])
576          self.assertIsNotNone(result["topic"])
577          self.assertIsNotNone(result["topicrank"])
578  
579      def testUpsert(self):
580          """
581          Test upsert
582          """
583  
584          # Update data
585          self.embeddings.upsert([(0, {"text": "Canadian ice shelf collapses".split()}, None)])
586  
587          # Validate counts
588          graph = self.embeddings.graph
589          self.assertEqual(graph.count(), 6)
590          self.assertEqual(graph.edgecount(), 2)
591          self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 6)
592          self.assertEqual(len(graph.categories), 6)