/ pyod / models / pyg_conad.py
pyg_conad.py
  1  # -*- coding: utf-8 -*-
  2  """CONAD: Contrastive Attributed Network Anomaly Detection.
  3  
  4  Constructs an anomalous view by injecting synthetic anomalies
  5  (attribute swapping + random edge injection), then contrasts
  6  with the original view. Dual reconstruction (structure +
  7  attributes) from the original view. Nodes with high contrastive
  8  distance and high reconstruction error are anomalous.
  9  
 10  See :cite:`xu2022conad` for details.
 11  
 12  Reference:
 13      Xu, Z., Huang, X., Zhao, Y., Dong, Y., and Li, J., 2022.
 14      Contrastive Attributed Network Anomaly Detection with Data
 15      Augmentation. In PAKDD, pp. 444-457.
 16  """
 17  # Author: Yue Zhao <yzhao062@gmail.com>
 18  # License: BSD 2 clause
 19  
 20  import numpy as np
 21  from sklearn.utils.validation import check_is_fitted
 22  
 23  from .base import BaseDetector
 24  from ._pyg_utils import validate_graph_input
 25  
 26  
 27  class CONAD(BaseDetector):
 28      """CONAD: Contrastive + Reconstruction Anomaly Detection.
 29  
 30      Constructs an anomalous view via attribute swapping and
 31      random edge injection, encodes both views with a shared
 32      GCN, and scores nodes by contrastive distance + dual
 33      (structure + attribute) reconstruction error.
 34  
 35      This detector is **transductive**.
 36  
 37      Parameters
 38      ----------
 39      hidden_dim : int, default=64
 40          Hidden dimension.
 41  
 42      num_layers : int, default=2
 43          Number of GCN layers.
 44  
 45      aug_ratio : float, default=0.2
 46          Fraction of edges/attributes to drop/mask.
 47  
 48      alpha : float, default=0.5
 49          Weight for reconstruction loss (vs contrastive).
 50  
 51      dropout : float, default=0.3
 52          Dropout rate.
 53  
 54      epochs : int, default=100
 55          Training epochs.
 56  
 57      lr : float, default=1e-3
 58          Learning rate.
 59  
 60      contamination : float, default=0.1
 61          Expected proportion of anomalies.
 62  
 63      Attributes
 64      ----------
 65      decision_scores_ : numpy array of shape (n_nodes,)
 66      labels_ : numpy array of shape (n_nodes,)
 67      threshold_ : float
 68      """
 69  
 70      def __init__(self, hidden_dim=64, num_layers=2, aug_ratio=0.2,
 71                   alpha=0.5, dropout=0.3, epochs=100, lr=1e-3,
 72                   contamination=0.1):
 73          super(CONAD, self).__init__(contamination=contamination)
 74          self.hidden_dim = hidden_dim
 75          self.num_layers = num_layers
 76          self.aug_ratio = aug_ratio
 77          self.alpha = alpha
 78          self.dropout = dropout
 79          self.epochs = epochs
 80          self.lr = lr
 81  
 82      def fit(self, X, y=None, edge_index=None):
 83          """Fit the detector on graph data.
 84  
 85          Parameters
 86          ----------
 87          X : Data or array-like
 88          y : ignored
 89          edge_index : array-like or None
 90  
 91          Returns
 92          -------
 93          self
 94          """
 95          import torch
 96          import torch.nn as nn
 97          import torch.nn.functional as F
 98          from torch_geometric.nn import GCNConv
 99  
100          data = validate_graph_input(X, edge_index)
101          n_nodes = data.num_nodes
102          self._set_n_classes(y)
103  
104          if data.x is None:
105              raise ValueError("CONAD requires node features (data.x).")
106  
107          in_dim = data.x.shape[1]
108  
109          model = _CONADModel(
110              in_dim, self.hidden_dim, self.num_layers, self.dropout)
111          optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
112  
113          x = data.x
114          ei = data.edge_index
115  
116          # Dense adjacency for structure reconstruction
117          adj = torch.zeros(n_nodes, n_nodes)
118          adj[ei[0], ei[1]] = 1.0
119  
120          model.train()
121          for epoch in range(self.epochs):
122              # Create anomalous view (inject synthetic anomalies)
123              x_aug, ei_aug = _create_anomalous_view(
124                  x, ei, self.aug_ratio)
125  
126              z_orig, z_aug, x_hat, a_hat = model(
127                  x, ei, x_aug, ei_aug)
128  
129              # Contrastive loss between original and anomalous views
130              z_o = F.normalize(z_orig, dim=1)
131              z_a = F.normalize(z_aug, dim=1)
132              cos_sim = (z_o * z_a).sum(dim=1)
133              contrastive_loss = -cos_sim.mean()
134  
135              # Dual reconstruction: structure + attributes
136              struct_loss = torch.mean(
137                  torch.sum((adj - a_hat) ** 2, dim=1))
138              attr_loss = torch.mean(
139                  torch.sum((x - x_hat) ** 2, dim=1))
140              recon_loss = struct_loss + attr_loss
141  
142              loss = contrastive_loss + self.alpha * recon_loss
143  
144              optimizer.zero_grad()
145              loss.backward()
146              optimizer.step()
147  
148          # Score: contrastive distance + dual reconstruction error
149          model.eval()
150          with torch.no_grad():
151              x_aug, ei_aug = _create_anomalous_view(
152                  x, ei, self.aug_ratio)
153              z_orig, z_aug, x_hat, a_hat = model(
154                  x, ei, x_aug, ei_aug)
155  
156              z_o = F.normalize(z_orig, dim=1)
157              z_a = F.normalize(z_aug, dim=1)
158              cos_dist = 1.0 - (z_o * z_a).sum(dim=1)
159              struct_err = torch.sum((adj - a_hat) ** 2, dim=1)
160              attr_err = torch.sum((x - x_hat) ** 2, dim=1)
161              scores = cos_dist + self.alpha * (struct_err + attr_err)
162  
163          self.decision_scores_ = scores.cpu().numpy()
164          self._process_decision_scores()
165          return self
166  
167      def decision_function(self, X):
168          """Not supported (transductive detector)."""
169          raise NotImplementedError(
170              "CONAD is a transductive detector. Use decision_scores_ "
171              "after fit().")
172  
173      def predict(self, X, return_confidence=False):
174          """Not supported (transductive detector)."""
175          raise NotImplementedError(
176              "CONAD is a transductive detector. Use labels_ "
177              "after fit().")
178  
179      def predict_proba(self, X, method="linear", return_confidence=False):
180          """Not supported (transductive detector)."""
181          raise NotImplementedError("CONAD is a transductive detector.")
182  
183      def predict_confidence(self, X):
184          """Not supported (transductive detector)."""
185          raise NotImplementedError("CONAD is a transductive detector.")
186  
187  
188  def _create_anomalous_view(x, edge_index, ratio):
189      """Create anomalous view by injecting synthetic anomalies.
190  
191      Following the CONAD paper: perturb a subset of node attributes
192      by swapping with random other nodes, and add random edges.
193  
194      Parameters
195      ----------
196      x : torch.Tensor of shape (n, d)
197      edge_index : torch.LongTensor of shape (2, m)
198      ratio : float
199  
200      Returns
201      -------
202      x_aug : torch.Tensor
203      edge_index_aug : torch.LongTensor
204      """
205      import torch
206  
207      n = x.shape[0]
208      n_perturb = max(1, int(n * ratio))
209  
210      # Attribute perturbation: swap attributes with random nodes
211      x_aug = x.clone()
212      perturb_idx = torch.randperm(n)[:n_perturb]
213      swap_idx = torch.randperm(n)[:n_perturb]
214      x_aug[perturb_idx] = x[swap_idx]
215  
216      # Edge perturbation: add random edges
217      n_edges = edge_index.shape[1]
218      n_add = max(1, int(n_edges * ratio))
219      new_src = torch.randint(0, n, (n_add,))
220      new_dst = torch.randint(0, n, (n_add,))
221      ei_aug = torch.cat(
222          [edge_index, torch.stack([new_src, new_dst])], dim=1)
223  
224      return x_aug, ei_aug
225  
226  
227  def _CONADModel(in_dim, hid_dim, num_layers, dropout):
228      """Factory: returns torch.nn.Module for CONAD."""
229      import torch
230      import torch.nn as nn
231      from torch_geometric.nn import GCNConv
232  
233      class _Model(nn.Module):
234          def __init__(self):
235              super().__init__()
236              self.convs = nn.ModuleList()
237              self.convs.append(GCNConv(in_dim, hid_dim))
238              for _ in range(num_layers - 1):
239                  self.convs.append(GCNConv(hid_dim, hid_dim))
240              self.attr_decoder = nn.Linear(hid_dim, in_dim)
241              self._dropout = dropout
242  
243          def _encode(self, x, edge_index):
244              z = x
245              for i, conv in enumerate(self.convs):
246                  z = conv(z, edge_index)
247                  if i < len(self.convs) - 1:
248                      z = torch.relu(z)
249                      z = torch.dropout(
250                          z, p=self._dropout, train=self.training)
251              return z
252  
253          def forward(self, x, ei, x_aug, ei_aug):
254              z_orig = self._encode(x, ei)
255              z_aug = self._encode(x_aug, ei_aug)
256              # Dual reconstruction from original view
257              x_hat = self.attr_decoder(z_orig)
258              a_hat = torch.sigmoid(z_orig @ z_orig.t())
259              return z_orig, z_aug, x_hat, a_hat
260  
261      return _Model()