train.rs
1 //! GRU Training Binary for Ghost Mode 2 //! 3 //! Ingests PCAP files or generates synthetic TLS patterns, trains GRU, exports weights. 4 //! 5 //! Usage: 6 //! cargo run --bin chameleon-train --features train -- /path/to/pcaps/ -o weights/ghost_v1.safetensors 7 //! cargo run --bin chameleon-train --features train -- --synthetic -o weights/ghost_v1.safetensors 8 9 use std::collections::HashMap; 10 use std::fs::File; 11 use std::io::Read; 12 use std::path::PathBuf; 13 14 use abzu_chameleon::size_to_bucket; 15 use abzu_chameleon::synthetic::{TLS_PROFILES, generate_flow}; 16 use candle_core::{DType, Device, Module, Tensor}; 17 use candle_nn::{linear, rnn, Linear, Optimizer, VarBuilder, VarMap, SGD, RNN}; 18 use clap::Parser; 19 use indicatif::{ProgressBar, ProgressStyle}; 20 use pcap_parser::traits::PcapReaderIterator; 21 use pcap_parser::{LegacyPcapReader, PcapBlockOwned, PcapError, PcapNGReader}; 22 use rand::distributions::{Distribution, WeightedIndex}; 23 24 /// Training CLI arguments 25 #[derive(Parser, Debug)] 26 #[command(name = "chameleon-train")] 27 #[command(about = "Train Ghost Mode traffic generator from PCAPs or synthetic data")] 28 struct Args { 29 /// Directory containing PCAP/PCAPNG files (optional with --synthetic) 30 #[arg(value_name = "PCAP_DIR")] 31 input_dir: Option<PathBuf>, 32 33 /// Output safetensors file 34 #[arg(short, long, default_value = "weights/ghost_v1.safetensors")] 35 output: PathBuf, 36 37 /// Number of training epochs 38 #[arg(short, long, default_value = "100")] 39 epochs: usize, 40 41 /// Learning rate 42 #[arg(short, long, default_value = "0.01")] 43 learning_rate: f64, 44 45 /// Minimum packets per flow to include 46 #[arg(long, default_value = "10")] 47 min_packets: usize, 48 49 /// Maximum number of sequences to use (0 = all) 50 #[arg(long, default_value = "1000")] 51 max_sequences: usize, 52 53 /// Maximum steps per sequence to use (0 = all) 54 #[arg(long, default_value = "50")] 55 max_steps: usize, 56 57 /// Use synthetic TLS data instead of PCAPs 58 #[arg(long)] 59 synthetic: bool, 60 61 /// Number of synthetic flows to generate (default: 500) 62 #[arg(long, default_value = "500")] 63 num_flows: usize, 64 65 /// Packets per synthetic flow (default: 100) 66 #[arg(long, default_value = "100")] 67 packets_per_flow: usize, 68 } 69 70 /// A single traffic flow (5-tuple identified) 71 #[derive(Default)] 72 struct Flow { 73 /// Timestamps in microseconds 74 timestamps_us: Vec<u64>, 75 /// Packet sizes 76 sizes: Vec<usize>, 77 } 78 79 80 impl Flow { 81 /// Convert to training sequence: (delay_ms, size_bucket) pairs 82 fn to_sequence(&self) -> Vec<(f32, f32)> { 83 let mut seq = Vec::new(); 84 for i in 1..self.timestamps_us.len() { 85 let delay_us = self.timestamps_us[i].saturating_sub(self.timestamps_us[i - 1]); 86 let delay_ms = (delay_us as f32) / 1000.0; 87 let size_bucket = size_to_bucket(self.sizes[i]) as f32; 88 seq.push((delay_ms, size_bucket)); 89 } 90 seq 91 } 92 } 93 94 /// Extract TCP/UDP source port from Ethernet frame 95 /// 96 /// Returns (src_port, dst_port) or None if not TCP/UDP 97 fn extract_ports(data: &[u8]) -> Option<(u16, u16)> { 98 if data.len() < 14 { 99 return None; 100 } 101 102 // Ethernet header: 14 bytes (dst MAC 6 + src MAC 6 + ethertype 2) 103 let ethertype = u16::from_be_bytes([data[12], data[13]]); 104 105 match ethertype { 106 0x0800 => { 107 // IPv4 108 if data.len() < 14 + 20 { 109 return None; 110 } 111 let ihl = (data[14] & 0x0f) as usize * 4; 112 let protocol = data[14 + 9]; 113 114 // TCP or UDP? 115 if protocol == 6 || protocol == 17 { 116 let transport_offset = 14 + ihl; 117 if data.len() < transport_offset + 4 { 118 return None; 119 } 120 let src_port = u16::from_be_bytes([data[transport_offset], data[transport_offset + 1]]); 121 let dst_port = u16::from_be_bytes([data[transport_offset + 2], data[transport_offset + 3]]); 122 return Some((src_port, dst_port)); 123 } 124 } 125 0x86dd => { 126 // IPv6 - fixed header is 40 bytes 127 if data.len() < 14 + 40 + 4 { 128 return None; 129 } 130 let next_header = data[14 + 6]; 131 132 // TCP (6) or UDP (17)? 133 if next_header == 6 || next_header == 17 { 134 let transport_offset = 14 + 40; 135 if data.len() < transport_offset + 4 { 136 return None; 137 } 138 let src_port = u16::from_be_bytes([data[transport_offset], data[transport_offset + 1]]); 139 let dst_port = u16::from_be_bytes([data[transport_offset + 2], data[transport_offset + 3]]); 140 return Some((src_port, dst_port)); 141 } 142 } 143 _ => {} 144 } 145 None 146 } 147 148 /// Extract flows from a PCAP or PCAPNG file 149 fn extract_flows_from_pcap(path: &PathBuf) -> Result<Vec<Flow>, Box<dyn std::error::Error>> { 150 let file = File::open(path)?; 151 let mut reader = std::io::BufReader::new(file); 152 153 // Read raw bytes 154 let mut data = Vec::new(); 155 reader.read_to_end(&mut data)?; 156 157 if data.len() < 4 { 158 return Ok(Vec::new()); 159 } 160 161 // Detect format by magic bytes 162 // Classic PCAP: d4c3b2a1 (little endian) or a1b2c3d4 (big endian) 163 // PCAPNG: 0a0d0d0a (Section Header Block) 164 let magic = u32::from_le_bytes([data[0], data[1], data[2], data[3]]); 165 166 let mut flows: HashMap<(u16, u16), Flow> = HashMap::new(); 167 168 if magic == 0xa1b2c3d4 || magic == 0xd4c3b2a1 { 169 // Classic PCAP format 170 if let Ok(mut pcap_reader) = LegacyPcapReader::new(65536, std::io::Cursor::new(&data)) { 171 loop { 172 match pcap_reader.next() { 173 Ok((offset, block)) => { 174 match block { 175 PcapBlockOwned::Legacy(pkt) => { 176 let ts_us = (pkt.ts_sec as u64) * 1_000_000 + (pkt.ts_usec as u64); 177 let size = pkt.data.len(); 178 179 if let Some((src_port, dst_port)) = extract_ports(pkt.data) { 180 // Use port pair as flow key (sorted to group bidrectional) 181 let key = if src_port < dst_port { (src_port, dst_port) } else { (dst_port, src_port) }; 182 let flow = flows.entry(key).or_default(); 183 flow.timestamps_us.push(ts_us); 184 flow.sizes.push(size); 185 } 186 } 187 _ => {} 188 } 189 pcap_reader.consume(offset); 190 } 191 Err(PcapError::Eof) => break, 192 Err(PcapError::Incomplete(_)) => { 193 if pcap_reader.refill().is_err() { 194 break; 195 } 196 } 197 Err(_) => break, 198 } 199 } 200 } 201 } else { 202 // Try PCAPNG format 203 if let Ok(mut pcapng_reader) = PcapNGReader::new(65536, std::io::Cursor::new(&data)) { 204 loop { 205 match pcapng_reader.next() { 206 Ok((offset, block)) => { 207 match block { 208 PcapBlockOwned::NG(pcap_parser::Block::EnhancedPacket(epb)) => { 209 let ts_us = (epb.ts_high as u64) << 32 | (epb.ts_low as u64); 210 let size = epb.data.len(); 211 212 if let Some((src_port, dst_port)) = extract_ports(epb.data) { 213 let key = if src_port < dst_port { (src_port, dst_port) } else { (dst_port, src_port) }; 214 let flow = flows.entry(key).or_default(); 215 flow.timestamps_us.push(ts_us); 216 flow.sizes.push(size); 217 } 218 } 219 _ => {} 220 } 221 pcapng_reader.consume(offset); 222 } 223 Err(PcapError::Eof) => break, 224 Err(PcapError::Incomplete(_)) => { 225 pcapng_reader.refill().ok(); 226 } 227 Err(_) => break, 228 } 229 } 230 } 231 } 232 233 Ok(flows.into_values().collect()) 234 } 235 236 /// Prepare training data from flows 237 fn prepare_training_data( 238 flows: &[Flow], 239 min_packets: usize, 240 max_sequences: usize, 241 max_steps: usize, 242 ) -> (Vec<Vec<(f32, f32)>>, usize) { 243 let mut sequences: Vec<_> = flows 244 .iter() 245 .filter(|f| f.sizes.len() >= min_packets) 246 .map(|f| f.to_sequence()) 247 .filter(|s| !s.is_empty()) 248 .collect(); 249 250 // Limit number of sequences to prevent stack overflow 251 if max_sequences > 0 && sequences.len() > max_sequences { 252 // Take a random sample for diversity 253 use rand::seq::SliceRandom; 254 let mut rng = rand::thread_rng(); 255 sequences.shuffle(&mut rng); 256 sequences.truncate(max_sequences); 257 } 258 259 // Limit steps per sequence 260 if max_steps > 0 { 261 for seq in &mut sequences { 262 if seq.len() > max_steps { 263 seq.truncate(max_steps); 264 } 265 } 266 } 267 268 let total_samples: usize = sequences.iter().map(|s| s.len().saturating_sub(1)).sum(); 269 (sequences, total_samples) 270 } 271 272 /// GRU trainer 273 struct GruTrainer { 274 varmap: VarMap, 275 gru: rnn::GRU, 276 output: Linear, 277 device: Device, 278 } 279 280 impl GruTrainer { 281 const HIDDEN_DIM: usize = 64; 282 const INPUT_DIM: usize = 2; 283 const OUTPUT_DIM: usize = 2; 284 285 fn new() -> candle_core::Result<Self> { 286 let device = Device::Cpu; 287 let varmap = VarMap::new(); 288 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); 289 290 let gru_config = rnn::GRUConfig::default(); 291 let gru = rnn::gru(Self::INPUT_DIM, Self::HIDDEN_DIM, gru_config, vb.pp("gru"))?; 292 let output = linear(Self::HIDDEN_DIM, Self::OUTPUT_DIM, vb.pp("output"))?; 293 294 Ok(Self { 295 varmap, 296 gru, 297 output, 298 device, 299 }) 300 } 301 302 /// Train on a batch of sequences 303 fn train_step( 304 &self, 305 sequences: &[Vec<(f32, f32)>], 306 optimizer: &mut SGD, 307 ) -> candle_core::Result<f32> { 308 let mut total_loss = 0.0; 309 let mut count = 0; 310 311 for seq in sequences { 312 if seq.len() < 2 { 313 continue; 314 } 315 316 // Process sequence step by step 317 let mut hidden = Tensor::zeros((1, Self::HIDDEN_DIM), DType::F32, &self.device)?; 318 319 for i in 0..seq.len() - 1 { 320 let (delay, size) = seq[i]; 321 let (target_delay, target_size) = seq[i + 1]; 322 323 // Normalize and clamp inputs to prevent NaN 324 // Delay: clamp to [1ms, 5000ms] then divide by 5000 to get [0.0002, 1.0] 325 let delay_norm = (delay.clamp(1.0, 5000.0)) / 5000.0; 326 let size_norm = (size.clamp(0.0, 5.0)) / 5.0; 327 let target_delay_norm = (target_delay.clamp(1.0, 5000.0)) / 5000.0; 328 let target_size_norm = (target_size.clamp(0.0, 5.0)) / 5.0; 329 330 let input = Tensor::new(&[[delay_norm, size_norm]], &self.device)?; 331 let target = Tensor::new(&[[target_delay_norm, target_size_norm]], &self.device)?; 332 333 // Forward pass 334 let state = rnn::GRUState { h: hidden.clone() }; 335 let new_state = self.gru.step(&input, &state)?; 336 hidden = new_state.h.clone(); 337 338 let pred = self.output.forward(&new_state.h)?; 339 340 // MSE loss 341 let diff = (&pred - &target)?; 342 let loss = diff.sqr()?.mean_all()?; 343 344 // Skip NaN losses (rare edge cases) 345 let loss_val = loss.to_scalar::<f32>()?; 346 if loss_val.is_nan() || loss_val.is_infinite() { 347 continue; 348 } 349 350 // Backward pass 351 optimizer.backward_step(&loss)?; 352 353 total_loss += loss.to_scalar::<f32>()?; 354 count += 1; 355 } 356 } 357 358 Ok(if count > 0 { total_loss / count as f32 } else { 0.0 }) 359 } 360 361 /// Save weights to safetensors 362 fn save(&self, path: &PathBuf) -> Result<(), Box<dyn std::error::Error>> { 363 // Ensure parent dir exists 364 if let Some(parent) = path.parent() { 365 std::fs::create_dir_all(parent)?; 366 } 367 self.varmap.save(path)?; 368 Ok(()) 369 } 370 } 371 372 fn main() -> Result<(), Box<dyn std::error::Error>> { 373 let args = Args::parse(); 374 375 println!("š¦ Chameleon GRU Trainer"); 376 println!(" Output: {}", args.output.display()); 377 println!(" Epochs: {}", args.epochs); 378 println!(" LR: {}", args.learning_rate); 379 println!(); 380 381 let sequences: Vec<Vec<(f32, f32)>>; 382 let total_samples: usize; 383 384 if args.synthetic { 385 // Generate synthetic training data 386 println!("š§ Generating synthetic TLS data..."); 387 println!(" Flows: {}", args.num_flows); 388 println!(" Packets/flow: {}", args.packets_per_flow); 389 390 let mut rng = rand::thread_rng(); 391 let profile_weights: Vec<_> = TLS_PROFILES.iter().map(|p| p.weight).collect(); 392 let profile_dist = WeightedIndex::new(&profile_weights).unwrap(); 393 394 let pb = ProgressBar::new(args.num_flows as u64); 395 pb.set_style( 396 ProgressStyle::default_bar() 397 .template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} {msg}") 398 .unwrap(), 399 ); 400 pb.set_message("Generating flows..."); 401 402 let mut all_sequences = Vec::new(); 403 for _ in 0..args.num_flows { 404 let profile = &TLS_PROFILES[profile_dist.sample(&mut rng)]; 405 let flow = generate_flow(profile, args.packets_per_flow); 406 // Convert to (f32, f32) format 407 let seq: Vec<(f32, f32)> = flow.iter().map(|(d, s)| (*d as f32, *s as f32)).collect(); 408 all_sequences.push(seq); 409 pb.inc(1); 410 } 411 pb.finish_with_message("Done!"); 412 413 total_samples = all_sequences.iter().map(|s| s.len().saturating_sub(1)).sum(); 414 sequences = all_sequences; 415 416 println!("š Generated {} flows, {} training samples", args.num_flows, total_samples); 417 } else { 418 // Load from PCAP files 419 let input_dir = args.input_dir.as_ref().ok_or("PCAP_DIR required unless --synthetic")?; 420 println!(" Input: {}", input_dir.display()); 421 422 let pcap_files: Vec<_> = std::fs::read_dir(input_dir)? 423 .filter_map(|e| e.ok()) 424 .map(|e| e.path()) 425 .filter(|p| { 426 p.extension() 427 .map(|e| e == "pcap" || e == "pcapng") 428 .unwrap_or(false) 429 }) 430 .collect(); 431 432 if pcap_files.is_empty() { 433 eprintln!("ā No PCAP files found in {}", input_dir.display()); 434 eprintln!(" Tip: Use --synthetic to train on generated TLS patterns"); 435 std::process::exit(1); 436 } 437 438 println!("š Found {} PCAP files", pcap_files.len()); 439 440 let pb = ProgressBar::new(pcap_files.len() as u64); 441 pb.set_style( 442 ProgressStyle::default_bar() 443 .template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} {msg}") 444 .unwrap(), 445 ); 446 pb.set_message("Extracting flows..."); 447 448 let mut all_flows = Vec::new(); 449 for path in &pcap_files { 450 match extract_flows_from_pcap(path) { 451 Ok(flows) => all_flows.extend(flows), 452 Err(e) => eprintln!("ā ļø Failed to parse {}: {}", path.display(), e), 453 } 454 pb.inc(1); 455 } 456 pb.finish_with_message("Done!"); 457 458 let (seqs, samples) = prepare_training_data(&all_flows, args.min_packets, args.max_sequences, args.max_steps); 459 sequences = seqs; 460 total_samples = samples; 461 462 println!( 463 "š {} flows ā {} sequences, {} training samples", 464 all_flows.len(), 465 sequences.len(), 466 total_samples 467 ); 468 } 469 470 if sequences.is_empty() { 471 eprintln!("ā No valid training sequences"); 472 std::process::exit(1); 473 } 474 475 // Initialize model 476 println!("\nš§ Initializing GRU model..."); 477 let trainer = GruTrainer::new()?; 478 let mut optimizer = SGD::new(trainer.varmap.all_vars(), args.learning_rate)?; 479 480 // Training loop 481 println!("šļø Training for {} epochs...\n", args.epochs); 482 483 let pb = ProgressBar::new(args.epochs as u64); 484 pb.set_style( 485 ProgressStyle::default_bar() 486 .template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} loss: {msg}") 487 .unwrap(), 488 ); 489 490 for epoch in 0..args.epochs { 491 let loss = trainer.train_step(&sequences, &mut optimizer)?; 492 pb.set_message(format!("{:.6}", loss)); 493 pb.inc(1); 494 495 // Early stopping if loss is very low 496 if loss < 0.0001 { 497 pb.finish_with_message(format!("Converged at epoch {} with loss {:.6}", epoch, loss)); 498 break; 499 } 500 } 501 pb.finish_with_message("Training complete!"); 502 503 // Save weights 504 println!("\nš¾ Saving weights to {}...", args.output.display()); 505 trainer.save(&args.output)?; 506 507 println!("ā Done! Weights saved to {}", args.output.display()); 508 println!("\nš Next steps:"); 509 println!(" 1. Copy {} to abzu-chameleon/weights/", args.output.display()); 510 println!(" 2. Update MlGenerator to use include_bytes!(\"weights/ghost_v1.safetensors\")"); 511 println!(" 3. Rebuild with: cargo build -p abzu-chameleon --features ml"); 512 513 Ok(()) 514 }