/ pyod / test / test_pyg_cola.py
test_pyg_cola.py
 1  # -*- coding: utf-8 -*-
 2  """Tests for CoLA graph anomaly detector."""
 3  
 4  import os
 5  import sys
 6  import unittest
 7  
 8  import numpy as np
 9  
10  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
11  
12  from pyod.utils.data import generate_graph_data
13  
14  try:
15      import torch
16      from torch_geometric.data import Data
17      from pyod.models.pyg_cola import CoLA
18      HAS_PYG = True
19  except ImportError:
20      HAS_PYG = False
21  
22  
23  @unittest.skipUnless(HAS_PYG, "torch_geometric not installed")
24  class TestCoLA(unittest.TestCase):
25      def setUp(self):
26          self.X, self.edge_index, self.y = generate_graph_data(
27              n_nodes=100, n_features=16, contamination=0.1,
28              random_state=42)
29          self.data = Data(
30              x=torch.FloatTensor(self.X),
31              edge_index=torch.LongTensor(self.edge_index))
32  
33      def test_fit_pyg_data(self):
34          clf = CoLA(hidden_dim=32, num_layers=2, epochs=5,
35                     contamination=0.1)
36          clf.fit(self.data)
37          assert hasattr(clf, 'decision_scores_')
38          assert hasattr(clf, 'labels_')
39          assert hasattr(clf, 'threshold_')
40          assert len(clf.decision_scores_) == 100
41  
42      def test_fit_numpy(self):
43          clf = CoLA(hidden_dim=32, epochs=5, contamination=0.1)
44          clf.fit(self.X, edge_index=self.edge_index)
45          assert len(clf.decision_scores_) == 100
46  
47      def test_transductive_no_decision_function(self):
48          clf = CoLA(hidden_dim=32, epochs=5)
49          clf.fit(self.data)
50          with self.assertRaises(NotImplementedError):
51              clf.decision_function(self.data)
52  
53      def test_transductive_no_predict(self):
54          clf = CoLA(hidden_dim=32, epochs=5)
55          clf.fit(self.data)
56          with self.assertRaises(NotImplementedError):
57              clf.predict(self.data)
58  
59      def test_no_features_raises(self):
60          """CoLA requires node features."""
61          data_no_feat = Data(
62              edge_index=torch.LongTensor(self.edge_index),
63              num_nodes=100)
64          clf = CoLA(hidden_dim=32, epochs=5)
65          with self.assertRaises(ValueError):
66              clf.fit(data_no_feat)
67  
68  
69  if __name__ == '__main__':
70      unittest.main()