/ pyod / test / test_devnet.py
test_devnet.py
  1  # -*- coding: utf-8 -*-
  2  from __future__ import division
  3  from __future__ import print_function
  4  
  5  import os
  6  import sys
  7  import unittest
  8  
  9  import numpy as np
 10  import torch
 11  from numpy.testing import assert_almost_equal
 12  from numpy.testing import assert_equal
 13  from numpy.testing import assert_raises
 14  from sklearn.metrics import roc_auc_score
 15  
 16  # temporary solution for relative imports in case pyod is not installed
 17  # if pyod is installed, no need to use the following line
 18  
 19  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
 20  
 21  from pyod.models.devnet import DevNet
 22  from pyod.utils.data import generate_data
 23  
 24  
 25  class TestDevNet(unittest.TestCase):
 26      def setUp(self):
 27          self.n_train = 3000
 28          self.n_test = 1500
 29          self.n_features = 2000
 30          self.contamination = 0.1
 31          self.roc_floor = 0.8
 32          self.X_train, self.X_test, self.y_train, self.y_test = generate_data(
 33              n_train=self.n_train, n_test=self.n_test,
 34              n_features=self.n_features, contamination=self.contamination,
 35              random_state=42)
 36  
 37          self.clf = DevNet(epochs=3, contamination=self.contamination)
 38          self.clf.fit(self.X_train, self.y_train)
 39  
 40      def test_parameters(self):
 41          assert (hasattr(self.clf, 'decision_scores_') and
 42                  self.clf.decision_scores_ is not None)
 43          assert (hasattr(self.clf, 'labels_') and
 44                  self.clf.labels_ is not None)
 45          assert (hasattr(self.clf, 'threshold_') and
 46                  self.clf.threshold_ is not None)
 47          assert (hasattr(self.clf, '_mu') and
 48                  self.clf._mu is not None)
 49          assert (hasattr(self.clf, '_sigma') and
 50                  self.clf._sigma is not None)
 51          assert (hasattr(self.clf, 'model') and
 52                  self.clf.model is not None)
 53  
 54      def test_train_scores(self):
 55          assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0])
 56  
 57      def test_prediction_scores(self):
 58          pred_scores = self.clf.decision_function(self.X_test)
 59  
 60          # check score shapes
 61          assert_equal(pred_scores.shape[0], self.X_test.shape[0])
 62  
 63          # check performance
 64          assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor)
 65  
 66      def test_prediction_labels(self):
 67          pred_labels = self.clf.predict(self.X_test)
 68          assert_equal(pred_labels.shape, self.y_test.shape)
 69  
 70      def test_prediction_proba(self):
 71          pred_proba = self.clf.predict_proba(self.X_test)
 72          assert (pred_proba.min() >= 0)
 73          assert (pred_proba.max() <= 1)
 74  
 75      def test_prediction_proba_linear(self):
 76          pred_proba = self.clf.predict_proba(self.X_test, method='linear')
 77          assert (pred_proba.min() >= 0)
 78          assert (pred_proba.max() <= 1)
 79  
 80      def test_prediction_proba_unify(self):
 81          pred_proba = self.clf.predict_proba(self.X_test, method='unify')
 82          assert (pred_proba.min() >= 0)
 83          assert (pred_proba.max() <= 1)
 84  
 85      def test_prediction_proba_parameter(self):
 86          with assert_raises(ValueError):
 87              self.clf.predict_proba(self.X_test, method='something')
 88  
 89      def test_prediction_labels_confidence(self):
 90          pred_labels, confidence = self.clf.predict(self.X_test,
 91                                                     return_confidence=True)
 92          assert_equal(pred_labels.shape, self.y_test.shape)
 93          assert_equal(confidence.shape, self.y_test.shape)
 94          assert (confidence.min() >= 0)
 95          assert (confidence.max() <= 1)
 96  
 97      def test_prediction_proba_linear_confidence(self):
 98          pred_proba, confidence = self.clf.predict_proba(self.X_test,
 99                                                          method='linear',
100                                                          return_confidence=True)
101          assert (pred_proba.min() >= 0)
102          assert (pred_proba.max() <= 1)
103  
104          assert_equal(confidence.shape, self.y_test.shape)
105          assert (confidence.min() >= 0)
106          assert (confidence.max() <= 1)
107  
108      def test_prediction_with_rejection(self):
109          pred_labels = self.clf.predict_with_rejection(self.X_test,
110                                                        return_stats=False)
111          assert_equal(pred_labels.shape, self.y_test.shape)
112  
113      def test_prediction_with_rejection_stats(self):
114          _, [expected_rejrate, ub_rejrate,
115              ub_cost] = self.clf.predict_with_rejection(self.X_test,
116                                                         return_stats=True)
117          assert (expected_rejrate >= 0)
118          assert (expected_rejrate <= 1)
119          assert (ub_rejrate >= 0)
120          assert (ub_rejrate <= 1)
121          assert (ub_cost >= 0)
122  
123      def test_fit_predict(self):
124          pred_labels = self.clf.fit_predict(self.X_train, self.y_train)
125          assert_equal(pred_labels.shape, self.y_train.shape)
126  
127      def test_fit_predict_score(self):
128          self.clf.fit_predict_score(self.X_test, self.y_test)
129          self.clf.fit_predict_score(self.X_test, self.y_test,
130                                     scoring='roc_auc_score')
131          self.clf.fit_predict_score(self.X_test, self.y_test,
132                                     scoring='prc_n_score')
133          with assert_raises(NotImplementedError):
134              self.clf.fit_predict_score(self.X_test, self.y_test,
135                                         scoring='something')
136  
137      def test_model_clone(self):
138          pass
139          # clone_clf = clone(self.clf)
140  
141      def tearDown(self):
142          pass
143  
144  
145  if __name__ == '__main__':
146      unittest.main()