pyg_conad.py
1 # -*- coding: utf-8 -*- 2 """CONAD: Contrastive Attributed Network Anomaly Detection. 3 4 Constructs an anomalous view by injecting synthetic anomalies 5 (attribute swapping + random edge injection), then contrasts 6 with the original view. Dual reconstruction (structure + 7 attributes) from the original view. Nodes with high contrastive 8 distance and high reconstruction error are anomalous. 9 10 See :cite:`xu2022conad` for details. 11 12 Reference: 13 Xu, Z., Huang, X., Zhao, Y., Dong, Y., and Li, J., 2022. 14 Contrastive Attributed Network Anomaly Detection with Data 15 Augmentation. In PAKDD, pp. 444-457. 16 """ 17 # Author: Yue Zhao <yzhao062@gmail.com> 18 # License: BSD 2 clause 19 20 import numpy as np 21 from sklearn.utils.validation import check_is_fitted 22 23 from .base import BaseDetector 24 from ._pyg_utils import validate_graph_input 25 26 27 class CONAD(BaseDetector): 28 """CONAD: Contrastive + Reconstruction Anomaly Detection. 29 30 Constructs an anomalous view via attribute swapping and 31 random edge injection, encodes both views with a shared 32 GCN, and scores nodes by contrastive distance + dual 33 (structure + attribute) reconstruction error. 34 35 This detector is **transductive**. 36 37 Parameters 38 ---------- 39 hidden_dim : int, default=64 40 Hidden dimension. 41 42 num_layers : int, default=2 43 Number of GCN layers. 44 45 aug_ratio : float, default=0.2 46 Fraction of edges/attributes to drop/mask. 47 48 alpha : float, default=0.5 49 Weight for reconstruction loss (vs contrastive). 50 51 dropout : float, default=0.3 52 Dropout rate. 53 54 epochs : int, default=100 55 Training epochs. 56 57 lr : float, default=1e-3 58 Learning rate. 59 60 contamination : float, default=0.1 61 Expected proportion of anomalies. 62 63 Attributes 64 ---------- 65 decision_scores_ : numpy array of shape (n_nodes,) 66 labels_ : numpy array of shape (n_nodes,) 67 threshold_ : float 68 """ 69 70 def __init__(self, hidden_dim=64, num_layers=2, aug_ratio=0.2, 71 alpha=0.5, dropout=0.3, epochs=100, lr=1e-3, 72 contamination=0.1): 73 super(CONAD, self).__init__(contamination=contamination) 74 self.hidden_dim = hidden_dim 75 self.num_layers = num_layers 76 self.aug_ratio = aug_ratio 77 self.alpha = alpha 78 self.dropout = dropout 79 self.epochs = epochs 80 self.lr = lr 81 82 def fit(self, X, y=None, edge_index=None): 83 """Fit the detector on graph data. 84 85 Parameters 86 ---------- 87 X : Data or array-like 88 y : ignored 89 edge_index : array-like or None 90 91 Returns 92 ------- 93 self 94 """ 95 import torch 96 import torch.nn as nn 97 import torch.nn.functional as F 98 from torch_geometric.nn import GCNConv 99 100 data = validate_graph_input(X, edge_index) 101 n_nodes = data.num_nodes 102 self._set_n_classes(y) 103 104 if data.x is None: 105 raise ValueError("CONAD requires node features (data.x).") 106 107 in_dim = data.x.shape[1] 108 109 model = _CONADModel( 110 in_dim, self.hidden_dim, self.num_layers, self.dropout) 111 optimizer = torch.optim.Adam(model.parameters(), lr=self.lr) 112 113 x = data.x 114 ei = data.edge_index 115 116 # Dense adjacency for structure reconstruction 117 adj = torch.zeros(n_nodes, n_nodes) 118 adj[ei[0], ei[1]] = 1.0 119 120 model.train() 121 for epoch in range(self.epochs): 122 # Create anomalous view (inject synthetic anomalies) 123 x_aug, ei_aug = _create_anomalous_view( 124 x, ei, self.aug_ratio) 125 126 z_orig, z_aug, x_hat, a_hat = model( 127 x, ei, x_aug, ei_aug) 128 129 # Contrastive loss between original and anomalous views 130 z_o = F.normalize(z_orig, dim=1) 131 z_a = F.normalize(z_aug, dim=1) 132 cos_sim = (z_o * z_a).sum(dim=1) 133 contrastive_loss = -cos_sim.mean() 134 135 # Dual reconstruction: structure + attributes 136 struct_loss = torch.mean( 137 torch.sum((adj - a_hat) ** 2, dim=1)) 138 attr_loss = torch.mean( 139 torch.sum((x - x_hat) ** 2, dim=1)) 140 recon_loss = struct_loss + attr_loss 141 142 loss = contrastive_loss + self.alpha * recon_loss 143 144 optimizer.zero_grad() 145 loss.backward() 146 optimizer.step() 147 148 # Score: contrastive distance + dual reconstruction error 149 model.eval() 150 with torch.no_grad(): 151 x_aug, ei_aug = _create_anomalous_view( 152 x, ei, self.aug_ratio) 153 z_orig, z_aug, x_hat, a_hat = model( 154 x, ei, x_aug, ei_aug) 155 156 z_o = F.normalize(z_orig, dim=1) 157 z_a = F.normalize(z_aug, dim=1) 158 cos_dist = 1.0 - (z_o * z_a).sum(dim=1) 159 struct_err = torch.sum((adj - a_hat) ** 2, dim=1) 160 attr_err = torch.sum((x - x_hat) ** 2, dim=1) 161 scores = cos_dist + self.alpha * (struct_err + attr_err) 162 163 self.decision_scores_ = scores.cpu().numpy() 164 self._process_decision_scores() 165 return self 166 167 def decision_function(self, X): 168 """Not supported (transductive detector).""" 169 raise NotImplementedError( 170 "CONAD is a transductive detector. Use decision_scores_ " 171 "after fit().") 172 173 def predict(self, X, return_confidence=False): 174 """Not supported (transductive detector).""" 175 raise NotImplementedError( 176 "CONAD is a transductive detector. Use labels_ " 177 "after fit().") 178 179 def predict_proba(self, X, method="linear", return_confidence=False): 180 """Not supported (transductive detector).""" 181 raise NotImplementedError("CONAD is a transductive detector.") 182 183 def predict_confidence(self, X): 184 """Not supported (transductive detector).""" 185 raise NotImplementedError("CONAD is a transductive detector.") 186 187 188 def _create_anomalous_view(x, edge_index, ratio): 189 """Create anomalous view by injecting synthetic anomalies. 190 191 Following the CONAD paper: perturb a subset of node attributes 192 by swapping with random other nodes, and add random edges. 193 194 Parameters 195 ---------- 196 x : torch.Tensor of shape (n, d) 197 edge_index : torch.LongTensor of shape (2, m) 198 ratio : float 199 200 Returns 201 ------- 202 x_aug : torch.Tensor 203 edge_index_aug : torch.LongTensor 204 """ 205 import torch 206 207 n = x.shape[0] 208 n_perturb = max(1, int(n * ratio)) 209 210 # Attribute perturbation: swap attributes with random nodes 211 x_aug = x.clone() 212 perturb_idx = torch.randperm(n)[:n_perturb] 213 swap_idx = torch.randperm(n)[:n_perturb] 214 x_aug[perturb_idx] = x[swap_idx] 215 216 # Edge perturbation: add random edges 217 n_edges = edge_index.shape[1] 218 n_add = max(1, int(n_edges * ratio)) 219 new_src = torch.randint(0, n, (n_add,)) 220 new_dst = torch.randint(0, n, (n_add,)) 221 ei_aug = torch.cat( 222 [edge_index, torch.stack([new_src, new_dst])], dim=1) 223 224 return x_aug, ei_aug 225 226 227 def _CONADModel(in_dim, hid_dim, num_layers, dropout): 228 """Factory: returns torch.nn.Module for CONAD.""" 229 import torch 230 import torch.nn as nn 231 from torch_geometric.nn import GCNConv 232 233 class _Model(nn.Module): 234 def __init__(self): 235 super().__init__() 236 self.convs = nn.ModuleList() 237 self.convs.append(GCNConv(in_dim, hid_dim)) 238 for _ in range(num_layers - 1): 239 self.convs.append(GCNConv(hid_dim, hid_dim)) 240 self.attr_decoder = nn.Linear(hid_dim, in_dim) 241 self._dropout = dropout 242 243 def _encode(self, x, edge_index): 244 z = x 245 for i, conv in enumerate(self.convs): 246 z = conv(z, edge_index) 247 if i < len(self.convs) - 1: 248 z = torch.relu(z) 249 z = torch.dropout( 250 z, p=self._dropout, train=self.training) 251 return z 252 253 def forward(self, x, ei, x_aug, ei_aug): 254 z_orig = self._encode(x, ei) 255 z_aug = self._encode(x_aug, ei_aug) 256 # Dual reconstruction from original view 257 x_hat = self.attr_decoder(z_orig) 258 a_hat = torch.sigmoid(z_orig @ z_orig.t()) 259 return z_orig, z_aug, x_hat, a_hat 260 261 return _Model()