/ pyod / models / pyg_anomalydae.py
pyg_anomalydae.py
  1  # -*- coding: utf-8 -*-
  2  """AnomalyDAE: Dual Autoencoder for Anomaly Detection.
  3  
  4  Attention-based structure encoder (GAT) and MLP attribute encoder,
  5  with separate decoders for each modality. Per-node anomaly score
  6  is the weighted sum of structure and attribute reconstruction error.
  7  
  8  See :cite:`fan2020anomalydae` for details.
  9  
 10  Reference:
 11      Fan, H., Zhang, F. and Li, Z., 2020. AnomalyDAE: Dual Autoencoder
 12      for Anomaly Detection on Attributed Networks. In CIKM, pp. 747-756.
 13  """
 14  # Author: Yue Zhao <yzhao062@gmail.com>
 15  # License: BSD 2 clause
 16  
 17  import numpy as np
 18  from sklearn.utils.validation import check_is_fitted
 19  
 20  from .base import BaseDetector
 21  from ._pyg_utils import validate_graph_input
 22  
 23  
 24  class AnomalyDAE(BaseDetector):
 25      """AnomalyDAE: Dual Autoencoder for Anomaly Detection.
 26  
 27      Uses GATConv for structure encoding and an MLP for attribute
 28      encoding. Reconstructs adjacency via inner product and
 29      attributes via an MLP decoder.
 30  
 31      This detector is **transductive**.
 32  
 33      Parameters
 34      ----------
 35      embed_dim : int, default=64
 36          Embedding dimension.
 37  
 38      num_heads : int, default=4
 39          Number of attention heads in GAT.
 40  
 41      alpha : float, default=0.5
 42          Weight for structure loss.
 43  
 44      dropout : float, default=0.3
 45          Dropout rate.
 46  
 47      epochs : int, default=100
 48          Number of training epochs.
 49  
 50      lr : float, default=5e-3
 51          Learning rate.
 52  
 53      contamination : float, default=0.1
 54          Expected proportion of anomalies.
 55  
 56      Attributes
 57      ----------
 58      decision_scores_ : numpy array of shape (n_nodes,)
 59      labels_ : numpy array of shape (n_nodes,)
 60      threshold_ : float
 61      """
 62  
 63      def __init__(self, embed_dim=64, num_heads=4, alpha=0.5,
 64                   dropout=0.3, epochs=100, lr=5e-3,
 65                   contamination=0.1):
 66          super(AnomalyDAE, self).__init__(contamination=contamination)
 67          self.embed_dim = embed_dim
 68          self.num_heads = num_heads
 69          self.alpha = alpha
 70          self.dropout = dropout
 71          self.epochs = epochs
 72          self.lr = lr
 73  
 74      def fit(self, X, y=None, edge_index=None):
 75          """Fit the detector on graph data.
 76  
 77          Parameters
 78          ----------
 79          X : Data or array-like
 80          y : ignored
 81          edge_index : array-like or None
 82  
 83          Returns
 84          -------
 85          self
 86          """
 87          import torch
 88          import torch.nn as nn
 89          from torch_geometric.nn import GATConv
 90  
 91          data = validate_graph_input(X, edge_index)
 92          n_nodes = data.num_nodes
 93          self._set_n_classes(y)
 94  
 95          if data.x is None:
 96              raise ValueError(
 97                  "AnomalyDAE requires node features (data.x).")
 98  
 99          in_dim = data.x.shape[1]
100  
101          model = _AnomalyDAEModel(
102              in_dim, self.embed_dim, self.num_heads, self.dropout)
103          optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
104  
105          x = data.x
106          ei = data.edge_index
107  
108          adj = torch.zeros(n_nodes, n_nodes)
109          adj[ei[0], ei[1]] = 1.0
110  
111          model.train()
112          for epoch in range(self.epochs):
113              a_hat, x_hat = model(x, ei)
114              struct_loss = torch.sum((adj - a_hat) ** 2, dim=1)
115              attr_loss = torch.sum((x - x_hat) ** 2, dim=1)
116              loss = torch.mean(
117                  self.alpha * struct_loss
118                  + (1 - self.alpha) * attr_loss)
119  
120              optimizer.zero_grad()
121              loss.backward()
122              optimizer.step()
123  
124          model.eval()
125          with torch.no_grad():
126              a_hat, x_hat = model(x, ei)
127              struct_err = torch.sum((adj - a_hat) ** 2, dim=1)
128              attr_err = torch.sum((x - x_hat) ** 2, dim=1)
129              scores = (self.alpha * struct_err
130                        + (1 - self.alpha) * attr_err)
131  
132          self.decision_scores_ = scores.cpu().numpy()
133          self._process_decision_scores()
134          return self
135  
136      def decision_function(self, X):
137          """Not supported (transductive detector)."""
138          raise NotImplementedError(
139              "AnomalyDAE is a transductive detector. Use "
140              "decision_scores_ after fit().")
141  
142      def predict(self, X, return_confidence=False):
143          """Not supported (transductive detector)."""
144          raise NotImplementedError(
145              "AnomalyDAE is a transductive detector. Use labels_ "
146              "after fit().")
147  
148      def predict_proba(self, X, method="linear", return_confidence=False):
149          """Not supported (transductive detector)."""
150          raise NotImplementedError("AnomalyDAE is a transductive detector.")
151  
152      def predict_confidence(self, X):
153          """Not supported (transductive detector)."""
154          raise NotImplementedError("AnomalyDAE is a transductive detector.")
155  
156  
157  def _AnomalyDAEModel(in_dim, embed_dim, num_heads, dropout):
158      """Factory: returns a torch.nn.Module for the AnomalyDAE."""
159      import torch
160      import torch.nn as nn
161      from torch_geometric.nn import GATConv
162  
163      class _Model(nn.Module):
164          def __init__(self):
165              super().__init__()
166              # Structure encoder: GAT
167              self.gat = GATConv(
168                  in_dim, embed_dim, heads=num_heads, dropout=dropout,
169                  concat=False)
170              # Attribute encoder: MLP
171              self.attr_encoder = nn.Sequential(
172                  nn.Linear(in_dim, embed_dim),
173                  nn.ReLU(),
174              )
175              # Attribute decoder
176              self.attr_decoder = nn.Linear(embed_dim, in_dim)
177  
178          def forward(self, x, edge_index):
179              # Structure embedding (n_nodes, embed_dim)
180              z_struct = self.gat(x, edge_index)
181              # Attribute embedding
182              z_attr = self.attr_encoder(x)
183              # Combined
184              z = (z_struct + z_attr) / 2.0
185              # Reconstruct
186              a_hat = torch.sigmoid(z @ z.t())
187              x_hat = self.attr_decoder(z)
188              return a_hat, x_hat
189  
190      return _Model()