testgraph.py
1 """ 2 Graph module tests 3 """ 4 5 import os 6 import itertools 7 import tempfile 8 import unittest 9 10 from unittest.mock import patch 11 12 from txtai.archive import ArchiveFactory 13 from txtai.embeddings import Embeddings 14 from txtai.graph import Graph, GraphFactory 15 from txtai.serialize import SerializeFactory 16 17 18 # pylint: disable=R0904 19 class TestGraph(unittest.TestCase): 20 """ 21 Graph tests. 22 """ 23 24 @classmethod 25 def setUpClass(cls): 26 """ 27 Initialize test data. 28 """ 29 30 cls.data = [ 31 "US tops 5 million confirmed virus cases", 32 "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 33 "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 34 "The National Park Service warns against sacrificing slower friends in a bear attack", 35 "Maine man wins $1M from $25 lottery ticket", 36 "Make huge profits without work, earn up to $100,000 a day", 37 ] 38 39 cls.config = { 40 "path": "sentence-transformers/nli-mpnet-base-v2", 41 "content": True, 42 "functions": [{"name": "graph", "function": "graph.attribute"}], 43 "expressions": [ 44 {"name": "category", "expression": "graph(indexid, 'category')"}, 45 {"name": "topic", "expression": "graph(indexid, 'topic')"}, 46 {"name": "topicrank", "expression": "graph(indexid, 'topicrank')"}, 47 ], 48 "graph": {"limit": 5, "minscore": 0.2, "batchsize": 4, "approximate": False, "topics": {"categories": ["News"], "stopwords": ["the"]}}, 49 } 50 51 # Create embeddings instance 52 cls.embeddings = Embeddings(cls.config) 53 54 def testAnalysis(self): 55 """ 56 Test analysis methods 57 """ 58 59 # Create an index for the list of text 60 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 61 62 # Graph centrality 63 graph = self.embeddings.graph 64 centrality = graph.centrality() 65 self.assertEqual(list(centrality.keys())[0], 5) 66 67 # Page Rank 68 pagerank = graph.pagerank() 69 self.assertEqual(list(pagerank.keys())[0], 5) 70 71 # Path between nodes 72 path = graph.showpath(4, 5) 73 self.assertEqual(len(path), 2) 74 75 def testCommunity(self): 76 """ 77 Test community detection 78 """ 79 80 # Create an index for the list of text 81 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 82 83 # Get graph reference 84 graph = self.embeddings.graph 85 86 # Rebuild topics with updated graph settings 87 graph.config = {"topics": {"algorithm": "greedy"}} 88 graph.addtopics() 89 self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 6) 90 91 graph.config = {"topics": {"algorithm": "lpa"}} 92 graph.addtopics() 93 self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 4) 94 95 def testCustomBackend(self): 96 """ 97 Test resolving a custom backend 98 """ 99 100 graph = GraphFactory.create({"backend": "txtai.graph.NetworkX"}) 101 graph.initialize() 102 self.assertIsNotNone(graph) 103 104 def testCustomBackendNotFound(self): 105 """ 106 Test resolving an unresolvable backend 107 """ 108 109 with self.assertRaises(ImportError): 110 graph = GraphFactory.create({"backend": "notfound.graph"}) 111 graph.initialize() 112 113 def testDatabase(self): 114 """ 115 Test creating a Graph backed by a relational database 116 """ 117 118 # Generate graph database 119 path = os.path.join(tempfile.gettempdir(), "graph.sqlite") 120 graph = GraphFactory.create({"backend": "rdbms", "url": f"sqlite:///{path}", "schema": "txtai"}) 121 122 # Initialize the graph 123 graph.initialize() 124 125 for x in range(5): 126 graph.addnode(x, field=x) 127 128 for x, y in itertools.combinations(range(5), 2): 129 graph.addedge(x, y) 130 131 # Test methods 132 self.assertEqual(list(graph.scan()), [str(x) for x in range(5)]) 133 self.assertEqual(list(graph.scan(attribute="field")), [str(x) for x in range(5)]) 134 self.assertEqual(list(graph.filter([0]).scan()), [0]) 135 136 # Test save/load 137 graph.save(None) 138 graph.load(None) 139 self.assertEqual(list(graph.scan()), [str(x) for x in range(5)]) 140 141 # Test remove node 142 graph.delete([0]) 143 self.assertFalse(graph.hasnode(0)) 144 self.assertFalse(graph.hasedge(0)) 145 146 # Close graph 147 graph.close() 148 149 def testDefault(self): 150 """ 151 Test embeddings default graph setting 152 """ 153 154 embeddings = Embeddings(content=True, graph=True) 155 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 156 157 self.assertEqual(embeddings.graph.count(), len(self.data)) 158 159 def testDelete(self): 160 """ 161 Test delete 162 """ 163 164 # Create an index for the list of text 165 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 166 167 # Delete row 168 self.embeddings.delete([4]) 169 170 # Validate counts 171 graph = self.embeddings.graph 172 self.assertEqual(graph.count(), 5) 173 self.assertEqual(graph.edgecount(), 1) 174 self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 5) 175 self.assertEqual(len(graph.categories), 6) 176 177 def testEdges(self): 178 """ 179 Test edges 180 """ 181 182 # Create graph 183 graph = GraphFactory.create({}) 184 graph.initialize() 185 graph.addedge(0, 1) 186 187 # Test edge exists 188 self.assertTrue(graph.hasedge(0)) 189 self.assertTrue(graph.hasedge(0, 1)) 190 191 def testFilter(self): 192 """ 193 Test creating filtered subgraphs 194 """ 195 196 # Create an index for the list of text 197 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 198 199 # Validate counts 200 graph = self.embeddings.search("feel good story", graph=True) 201 self.assertEqual(graph.count(), 3) 202 self.assertEqual(graph.edgecount(), 2) 203 204 def testFunction(self): 205 """ 206 Test running graph functions with SQL 207 """ 208 209 # Create an index for the list of text 210 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 211 212 # Test function 213 result = self.embeddings.search("select category, topic, topicrank from txtai where id = 0", 1)[0] 214 215 # Check columns have a value 216 self.assertIsNotNone(result["category"]) 217 self.assertIsNotNone(result["topic"]) 218 self.assertIsNotNone(result["topicrank"]) 219 220 def testFunctionReindex(self): 221 """ 222 Test running graph functions with SQL after reindex 223 """ 224 225 # Create an index for the list of text 226 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 227 228 # Test functions reset with a reindex 229 self.embeddings.reindex(self.embeddings.config) 230 231 # Test function 232 result = self.embeddings.search("select category, topic, topicrank from txtai where id = 0", 1)[0] 233 234 # Check columns have a value 235 self.assertIsNotNone(result["category"]) 236 self.assertIsNotNone(result["topic"]) 237 self.assertIsNotNone(result["topicrank"]) 238 239 def testIndex(self): 240 """ 241 Test index 242 """ 243 244 # Create an index for the list of text 245 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 246 247 # Validate counts 248 graph = self.embeddings.graph 249 self.assertEqual(graph.count(), 6) 250 self.assertEqual(graph.edgecount(), 2) 251 self.assertEqual(len(graph.topics), 6) 252 self.assertEqual(len(graph.categories), 6) 253 254 @patch.dict(os.environ, {"ALLOW_PICKLE": "True"}) 255 def testLegacy(self): 256 """ 257 Test loading a legacy graph in TAR format 258 """ 259 260 # Create graph 261 graph = GraphFactory.create({}) 262 graph.initialize() 263 graph.addedge(0, 1) 264 265 categories = ["C1"] 266 topics = {"T1": [0, 1]} 267 268 serializer = SerializeFactory.create("pickle", allowpickle=True) 269 270 # Save files to temporary directory and combine into TAR 271 path = os.path.join(tempfile.gettempdir(), "graph.tar") 272 with tempfile.TemporaryDirectory() as directory: 273 # Save graph 274 serializer.save(graph.backend, f"{directory}/graph") 275 276 # Save categories, if necessary 277 serializer.save(categories, f"{directory}/categories") 278 279 # Save topics, if necessary 280 serializer.save(topics, f"{directory}/topics") 281 282 # Pack files 283 archive = ArchiveFactory.create(directory) 284 archive.save(path, "tar") 285 286 # Load loading legacy format 287 graph = GraphFactory.create({}) 288 graph.load(path) 289 290 # Validate graph data is correct 291 self.assertEqual(graph.count(), 2) 292 self.assertEqual(graph.edgecount(), 1) 293 self.assertEqual(graph.topics, topics) 294 self.assertEqual(graph.categories, categories) 295 296 def testNotImplemented(self): 297 """ 298 Test exceptions for non-implemented methods 299 """ 300 301 graph = Graph({}) 302 303 self.assertRaises(NotImplementedError, graph.create) 304 self.assertRaises(NotImplementedError, graph.count) 305 self.assertRaises(NotImplementedError, graph.scan, None) 306 self.assertRaises(NotImplementedError, graph.node, None) 307 self.assertRaises(NotImplementedError, graph.addnode, None) 308 self.assertRaises(NotImplementedError, graph.addnodes, None) 309 self.assertRaises(NotImplementedError, graph.removenode, None) 310 self.assertRaises(NotImplementedError, graph.hasnode, None) 311 self.assertRaises(NotImplementedError, graph.attribute, None, None) 312 self.assertRaises(NotImplementedError, graph.addattribute, None, None, None) 313 self.assertRaises(NotImplementedError, graph.removeattribute, None, None) 314 self.assertRaises(NotImplementedError, graph.edgecount) 315 self.assertRaises(NotImplementedError, graph.edges, None) 316 self.assertRaises(NotImplementedError, graph.addedge, None, None) 317 self.assertRaises(NotImplementedError, graph.addedges, None) 318 self.assertRaises(NotImplementedError, graph.hasedge, None, None) 319 self.assertRaises(NotImplementedError, graph.centrality) 320 self.assertRaises(NotImplementedError, graph.pagerank) 321 self.assertRaises(NotImplementedError, graph.showpath, None, None) 322 self.assertRaises(NotImplementedError, graph.isquery, None) 323 self.assertRaises(NotImplementedError, graph.parse, None) 324 self.assertRaises(NotImplementedError, graph.search, None) 325 self.assertRaises(NotImplementedError, graph.communities, None) 326 self.assertRaises(NotImplementedError, graph.load, None) 327 self.assertRaises(NotImplementedError, graph.save, None) 328 self.assertRaises(NotImplementedError, graph.loaddict, None) 329 self.assertRaises(NotImplementedError, graph.savedict) 330 331 def testRelationships(self): 332 """ 333 Test manually-provided relationships 334 """ 335 336 # Create relationships for id 0 337 relationships = [{"id": f"ID{x}"} for x in range(1, len(self.data))] 338 339 # Test with content enabled 340 self.embeddings.index({"id": f"ID{i}", "text": x, "relationships": relationships if i == 0 else None} for i, x in enumerate(self.data)) 341 self.assertEqual(len(self.embeddings.graph.edges(0)), len(self.data) - 1) 342 343 # Test with content disabled 344 config = self.config.copy() 345 config["content"] = False 346 347 embeddings = Embeddings(config) 348 embeddings.index({"id": f"ID{i}", "text": x, "relationships": relationships if i == 0 else None} for i, x in enumerate(self.data)) 349 self.assertEqual(len(embeddings.graph.edges(0)), len(self.data) - 1) 350 embeddings.close() 351 352 def testRelationshipsInvalid(self): 353 """ 354 Test manually-provided relationships with no matching id 355 """ 356 357 # Create relationships for id 0 358 relationships = [{"id": "INVALID"}] 359 360 # Index with invalid relationship 361 self.embeddings.index({"text": x, "relationships": relationships if i == 0 else None} for i, x in enumerate(self.data)) 362 363 # Validate only relationship is semantically-derived 364 edges = list(self.embeddings.graph.edges(0)) 365 self.assertTrue(len(edges) == 1 and edges[0] != "INVALID") 366 367 def testResetTopics(self): 368 """ 369 Test resetting of topics 370 """ 371 372 # Create an index for the list of text 373 self.embeddings.index([(1, "text", None)]) 374 self.embeddings.upsert([(1, "graph", None)]) 375 self.assertEqual(list(self.embeddings.graph.topics.keys()), ["graph"]) 376 377 def testSave(self): 378 """ 379 Test save 380 """ 381 382 # Create an index for the list of text 383 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 384 385 # Generate temp file path 386 index = os.path.join(tempfile.gettempdir(), "graph") 387 388 # Save and reload index 389 self.embeddings.save(index) 390 self.embeddings.load(index) 391 392 # Validate counts 393 graph = self.embeddings.graph 394 self.assertEqual(graph.count(), 6) 395 self.assertEqual(graph.edgecount(), 2) 396 self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 6) 397 self.assertEqual(len(graph.categories), 6) 398 399 def testSaveDict(self): 400 """ 401 Test loading and saving to dictionaries 402 """ 403 404 # Create an index for the list of text 405 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 406 407 # Validate counts 408 graph = self.embeddings.graph 409 count, edgecount = graph.count(), graph.edgecount() 410 411 # Save and reload graph as dict 412 data = graph.savedict() 413 graph.loaddict(data) 414 415 # Validate counts 416 self.assertEqual(graph.count(), count) 417 self.assertEqual(graph.edgecount(), edgecount) 418 419 def testSearch(self): 420 """ 421 Test search 422 """ 423 424 # Create an index for the list of text 425 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 426 427 # Run standard search 428 results = self.embeddings.search( 429 """ 430 MATCH (A)-[]->(B) 431 RETURN A, B 432 """ 433 ) 434 self.assertEqual(len(results), 3) 435 436 # Run path search 437 results = self.embeddings.search( 438 """ 439 MATCH P=()-[]->() 440 RETURN P 441 """ 442 ) 443 self.assertEqual(len(results), 3) 444 445 # Run graph search 446 g = self.embeddings.search( 447 """ 448 MATCH (A)-[]->(B) 449 RETURN A, ID(B) 450 """, 451 graph=True, 452 ) 453 self.assertEqual(g.count(), 3) 454 455 # Run path search 456 results = self.embeddings.search( 457 """ 458 MATCH P=()-[]->() 459 RETURN P 460 """, 461 graph=True, 462 ) 463 self.assertEqual(g.count(), 3) 464 465 # Run similar search 466 results = self.embeddings.search( 467 """ 468 MATCH P=(A)-[]->() 469 WHERE SIMILAR(A, "feel good story") 470 RETURN A 471 ORDER BY A.score DESC 472 LIMIT 1 473 """, 474 graph=True, 475 ) 476 self.assertEqual(list(results.scan())[0], 4) 477 478 def testSearchBatch(self): 479 """ 480 Test batch search 481 """ 482 483 # Create an index for the list of text 484 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 485 486 # Run standard search 487 results = self.embeddings.batchsearch( 488 [ 489 """ 490 MATCH (A)-[]->(B) 491 RETURN A, B 492 """ 493 ] 494 ) 495 self.assertEqual(len(results[0]), 3) 496 497 def testSimple(self): 498 """ 499 Test creating a simple graph 500 """ 501 502 graph = GraphFactory.create({"topics": {}}) 503 504 # Initialize the graph 505 graph.initialize() 506 507 for x in range(5): 508 graph.addnode(x) 509 510 for x, y in itertools.combinations(range(5), 2): 511 graph.addedge(x, y) 512 513 # Validate counts 514 self.assertEqual(graph.count(), 5) 515 self.assertEqual(graph.edgecount(), 10) 516 517 # Test missing edge 518 self.assertIsNone(graph.edges(100)) 519 520 # Test topics with no text 521 graph.addtopics() 522 self.assertEqual(len(graph.topics), 5) 523 524 def testSubindex(self): 525 """ 526 Test subindex 527 """ 528 529 # Build data array 530 data = [(uid, text, None) for uid, text in enumerate(self.data)] 531 532 embeddings = Embeddings( 533 { 534 "content": True, 535 "functions": [{"name": "graph", "function": "indexes.index1.graph.attribute"}], 536 "expressions": [ 537 {"name": "category", "expression": "graph(indexid, 'category')"}, 538 {"name": "topic", "expression": "graph(indexid, 'topic')"}, 539 {"name": "topicrank", "expression": "graph(indexid, 'topicrank')"}, 540 ], 541 "indexes": { 542 "index1": { 543 "path": "sentence-transformers/nli-mpnet-base-v2", 544 "graph": { 545 "limit": 5, 546 "minscore": 0.2, 547 "batchsize": 4, 548 "approximate": False, 549 "topics": {"categories": ["News"], "stopwords": ["the"]}, 550 }, 551 } 552 }, 553 } 554 ) 555 556 # Create an index for the list of text 557 embeddings.index(data) 558 559 # Test function 560 result = embeddings.search("select id, category, topic, topicrank from txtai where id = 0", 1)[0] 561 562 # Check columns have a value 563 self.assertIsNotNone(result["category"]) 564 self.assertIsNotNone(result["topic"]) 565 self.assertIsNotNone(result["topicrank"]) 566 567 # Update data 568 data[0] = (0, "Feel good story: lottery winner announced", None) 569 embeddings.upsert([data[0]]) 570 571 # Test function 572 result = embeddings.search("select id, category, topic, topicrank from txtai where id = 0", 1)[0] 573 574 # Check columns have a value 575 self.assertIsNotNone(result["category"]) 576 self.assertIsNotNone(result["topic"]) 577 self.assertIsNotNone(result["topicrank"]) 578 579 def testUpsert(self): 580 """ 581 Test upsert 582 """ 583 584 # Update data 585 self.embeddings.upsert([(0, {"text": "Canadian ice shelf collapses".split()}, None)]) 586 587 # Validate counts 588 graph = self.embeddings.graph 589 self.assertEqual(graph.count(), 6) 590 self.assertEqual(graph.edgecount(), 2) 591 self.assertEqual(sum((len(graph.topics[x]) for x in graph.topics)), 6) 592 self.assertEqual(len(graph.categories), 6)