/ abzu-chameleon / src / bin / train.rs
train.rs
  1  //! GRU Training Binary for Ghost Mode
  2  //!
  3  //! Ingests PCAP files or generates synthetic TLS patterns, trains GRU, exports weights.
  4  //!
  5  //! Usage:
  6  //!   cargo run --bin chameleon-train --features train -- /path/to/pcaps/ -o weights/ghost_v1.safetensors
  7  //!   cargo run --bin chameleon-train --features train -- --synthetic -o weights/ghost_v1.safetensors
  8  
  9  use std::collections::HashMap;
 10  use std::fs::File;
 11  use std::io::Read;
 12  use std::path::PathBuf;
 13  
 14  use abzu_chameleon::size_to_bucket;
 15  use abzu_chameleon::synthetic::{TLS_PROFILES, generate_flow};
 16  use candle_core::{DType, Device, Module, Tensor};
 17  use candle_nn::{linear, rnn, Linear, Optimizer, VarBuilder, VarMap, SGD, RNN};
 18  use clap::Parser;
 19  use indicatif::{ProgressBar, ProgressStyle};
 20  use pcap_parser::traits::PcapReaderIterator;
 21  use pcap_parser::{LegacyPcapReader, PcapBlockOwned, PcapError, PcapNGReader};
 22  use rand::distributions::{Distribution, WeightedIndex};
 23  
 24  /// Training CLI arguments
 25  #[derive(Parser, Debug)]
 26  #[command(name = "chameleon-train")]
 27  #[command(about = "Train Ghost Mode traffic generator from PCAPs or synthetic data")]
 28  struct Args {
 29      /// Directory containing PCAP/PCAPNG files (optional with --synthetic)
 30      #[arg(value_name = "PCAP_DIR")]
 31      input_dir: Option<PathBuf>,
 32  
 33      /// Output safetensors file
 34      #[arg(short, long, default_value = "weights/ghost_v1.safetensors")]
 35      output: PathBuf,
 36  
 37      /// Number of training epochs
 38      #[arg(short, long, default_value = "100")]
 39      epochs: usize,
 40  
 41      /// Learning rate
 42      #[arg(short, long, default_value = "0.01")]
 43      learning_rate: f64,
 44  
 45      /// Minimum packets per flow to include
 46      #[arg(long, default_value = "10")]
 47      min_packets: usize,
 48  
 49      /// Maximum number of sequences to use (0 = all)
 50      #[arg(long, default_value = "1000")]
 51      max_sequences: usize,
 52  
 53      /// Maximum steps per sequence to use (0 = all)
 54      #[arg(long, default_value = "50")]
 55      max_steps: usize,
 56  
 57      /// Use synthetic TLS data instead of PCAPs
 58      #[arg(long)]
 59      synthetic: bool,
 60  
 61      /// Number of synthetic flows to generate (default: 500)
 62      #[arg(long, default_value = "500")]
 63      num_flows: usize,
 64  
 65      /// Packets per synthetic flow (default: 100)
 66      #[arg(long, default_value = "100")]
 67      packets_per_flow: usize,
 68  }
 69  
 70  /// A single traffic flow (5-tuple identified)
 71  #[derive(Default)]
 72  struct Flow {
 73      /// Timestamps in microseconds
 74      timestamps_us: Vec<u64>,
 75      /// Packet sizes
 76      sizes: Vec<usize>,
 77  }
 78  
 79  
 80  impl Flow {
 81      /// Convert to training sequence: (delay_ms, size_bucket) pairs
 82      fn to_sequence(&self) -> Vec<(f32, f32)> {
 83          let mut seq = Vec::new();
 84          for i in 1..self.timestamps_us.len() {
 85              let delay_us = self.timestamps_us[i].saturating_sub(self.timestamps_us[i - 1]);
 86              let delay_ms = (delay_us as f32) / 1000.0;
 87              let size_bucket = size_to_bucket(self.sizes[i]) as f32;
 88              seq.push((delay_ms, size_bucket));
 89          }
 90          seq
 91      }
 92  }
 93  
 94  /// Extract TCP/UDP source port from Ethernet frame
 95  ///
 96  /// Returns (src_port, dst_port) or None if not TCP/UDP
 97  fn extract_ports(data: &[u8]) -> Option<(u16, u16)> {
 98      if data.len() < 14 {
 99          return None;
100      }
101      
102      // Ethernet header: 14 bytes (dst MAC 6 + src MAC 6 + ethertype 2)
103      let ethertype = u16::from_be_bytes([data[12], data[13]]);
104      
105      match ethertype {
106          0x0800 => {
107              // IPv4
108              if data.len() < 14 + 20 {
109                  return None;
110              }
111              let ihl = (data[14] & 0x0f) as usize * 4;
112              let protocol = data[14 + 9];
113              
114              // TCP or UDP?
115              if protocol == 6 || protocol == 17 {
116                  let transport_offset = 14 + ihl;
117                  if data.len() < transport_offset + 4 {
118                      return None;
119                  }
120                  let src_port = u16::from_be_bytes([data[transport_offset], data[transport_offset + 1]]);
121                  let dst_port = u16::from_be_bytes([data[transport_offset + 2], data[transport_offset + 3]]);
122                  return Some((src_port, dst_port));
123              }
124          }
125          0x86dd => {
126              // IPv6 - fixed header is 40 bytes
127              if data.len() < 14 + 40 + 4 {
128                  return None;
129              }
130              let next_header = data[14 + 6];
131              
132              // TCP (6) or UDP (17)?
133              if next_header == 6 || next_header == 17 {
134                  let transport_offset = 14 + 40;
135                  if data.len() < transport_offset + 4 {
136                      return None;
137                  }
138                  let src_port = u16::from_be_bytes([data[transport_offset], data[transport_offset + 1]]);
139                  let dst_port = u16::from_be_bytes([data[transport_offset + 2], data[transport_offset + 3]]);
140                  return Some((src_port, dst_port));
141              }
142          }
143          _ => {}
144      }
145      None
146  }
147  
148  /// Extract flows from a PCAP or PCAPNG file
149  fn extract_flows_from_pcap(path: &PathBuf) -> Result<Vec<Flow>, Box<dyn std::error::Error>> {
150      let file = File::open(path)?;
151      let mut reader = std::io::BufReader::new(file);
152      
153      // Read raw bytes
154      let mut data = Vec::new();
155      reader.read_to_end(&mut data)?;
156      
157      if data.len() < 4 {
158          return Ok(Vec::new());
159      }
160      
161      // Detect format by magic bytes
162      // Classic PCAP: d4c3b2a1 (little endian) or a1b2c3d4 (big endian)
163      // PCAPNG: 0a0d0d0a (Section Header Block)
164      let magic = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
165      
166      let mut flows: HashMap<(u16, u16), Flow> = HashMap::new();
167      
168      if magic == 0xa1b2c3d4 || magic == 0xd4c3b2a1 {
169          // Classic PCAP format
170          if let Ok(mut pcap_reader) = LegacyPcapReader::new(65536, std::io::Cursor::new(&data)) {
171              loop {
172                  match pcap_reader.next() {
173                      Ok((offset, block)) => {
174                          match block {
175                              PcapBlockOwned::Legacy(pkt) => {
176                                  let ts_us = (pkt.ts_sec as u64) * 1_000_000 + (pkt.ts_usec as u64);
177                                  let size = pkt.data.len();
178                                  
179                                  if let Some((src_port, dst_port)) = extract_ports(pkt.data) {
180                                      // Use port pair as flow key (sorted to group bidrectional)
181                                      let key = if src_port < dst_port { (src_port, dst_port) } else { (dst_port, src_port) };
182                                      let flow = flows.entry(key).or_default();
183                                      flow.timestamps_us.push(ts_us);
184                                      flow.sizes.push(size);
185                                  }
186                              }
187                              _ => {}
188                          }
189                          pcap_reader.consume(offset);
190                      }
191                      Err(PcapError::Eof) => break,
192                      Err(PcapError::Incomplete(_)) => {
193                          if pcap_reader.refill().is_err() {
194                              break;
195                          }
196                      }
197                      Err(_) => break,
198                  }
199              }
200          }
201      } else {
202          // Try PCAPNG format
203          if let Ok(mut pcapng_reader) = PcapNGReader::new(65536, std::io::Cursor::new(&data)) {
204              loop {
205                  match pcapng_reader.next() {
206                      Ok((offset, block)) => {
207                          match block {
208                              PcapBlockOwned::NG(pcap_parser::Block::EnhancedPacket(epb)) => {
209                                  let ts_us = (epb.ts_high as u64) << 32 | (epb.ts_low as u64);
210                                  let size = epb.data.len();
211                                  
212                                  if let Some((src_port, dst_port)) = extract_ports(epb.data) {
213                                      let key = if src_port < dst_port { (src_port, dst_port) } else { (dst_port, src_port) };
214                                      let flow = flows.entry(key).or_default();
215                                      flow.timestamps_us.push(ts_us);
216                                      flow.sizes.push(size);
217                                  }
218                              }
219                              _ => {}
220                          }
221                          pcapng_reader.consume(offset);
222                      }
223                      Err(PcapError::Eof) => break,
224                      Err(PcapError::Incomplete(_)) => {
225                          pcapng_reader.refill().ok();
226                      }
227                      Err(_) => break,
228                  }
229              }
230          }
231      }
232      
233      Ok(flows.into_values().collect())
234  }
235  
236  /// Prepare training data from flows
237  fn prepare_training_data(
238      flows: &[Flow],
239      min_packets: usize,
240      max_sequences: usize,
241      max_steps: usize,
242  ) -> (Vec<Vec<(f32, f32)>>, usize) {
243      let mut sequences: Vec<_> = flows
244          .iter()
245          .filter(|f| f.sizes.len() >= min_packets)
246          .map(|f| f.to_sequence())
247          .filter(|s| !s.is_empty())
248          .collect();
249      
250      // Limit number of sequences to prevent stack overflow
251      if max_sequences > 0 && sequences.len() > max_sequences {
252          // Take a random sample for diversity
253          use rand::seq::SliceRandom;
254          let mut rng = rand::thread_rng();
255          sequences.shuffle(&mut rng);
256          sequences.truncate(max_sequences);
257      }
258      
259      // Limit steps per sequence
260      if max_steps > 0 {
261          for seq in &mut sequences {
262              if seq.len() > max_steps {
263                  seq.truncate(max_steps);
264              }
265          }
266      }
267      
268      let total_samples: usize = sequences.iter().map(|s| s.len().saturating_sub(1)).sum();
269      (sequences, total_samples)
270  }
271  
272  /// GRU trainer
273  struct GruTrainer {
274      varmap: VarMap,
275      gru: rnn::GRU,
276      output: Linear,
277      device: Device,
278  }
279  
280  impl GruTrainer {
281      const HIDDEN_DIM: usize = 64;
282      const INPUT_DIM: usize = 2;
283      const OUTPUT_DIM: usize = 2;
284  
285      fn new() -> candle_core::Result<Self> {
286          let device = Device::Cpu;
287          let varmap = VarMap::new();
288          let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
289  
290          let gru_config = rnn::GRUConfig::default();
291          let gru = rnn::gru(Self::INPUT_DIM, Self::HIDDEN_DIM, gru_config, vb.pp("gru"))?;
292          let output = linear(Self::HIDDEN_DIM, Self::OUTPUT_DIM, vb.pp("output"))?;
293  
294          Ok(Self {
295              varmap,
296              gru,
297              output,
298              device,
299          })
300      }
301  
302      /// Train on a batch of sequences
303      fn train_step(
304          &self,
305          sequences: &[Vec<(f32, f32)>],
306          optimizer: &mut SGD,
307      ) -> candle_core::Result<f32> {
308          let mut total_loss = 0.0;
309          let mut count = 0;
310  
311          for seq in sequences {
312              if seq.len() < 2 {
313                  continue;
314              }
315  
316              // Process sequence step by step
317              let mut hidden = Tensor::zeros((1, Self::HIDDEN_DIM), DType::F32, &self.device)?;
318  
319              for i in 0..seq.len() - 1 {
320                  let (delay, size) = seq[i];
321                  let (target_delay, target_size) = seq[i + 1];
322  
323                  // Normalize and clamp inputs to prevent NaN
324                  // Delay: clamp to [1ms, 5000ms] then divide by 5000 to get [0.0002, 1.0]
325                  let delay_norm = (delay.clamp(1.0, 5000.0)) / 5000.0;
326                  let size_norm = (size.clamp(0.0, 5.0)) / 5.0;
327                  let target_delay_norm = (target_delay.clamp(1.0, 5000.0)) / 5000.0;
328                  let target_size_norm = (target_size.clamp(0.0, 5.0)) / 5.0;
329  
330                  let input = Tensor::new(&[[delay_norm, size_norm]], &self.device)?;
331                  let target = Tensor::new(&[[target_delay_norm, target_size_norm]], &self.device)?;
332  
333                  // Forward pass
334                  let state = rnn::GRUState { h: hidden.clone() };
335                  let new_state = self.gru.step(&input, &state)?;
336                  hidden = new_state.h.clone();
337  
338                  let pred = self.output.forward(&new_state.h)?;
339  
340                  // MSE loss
341                  let diff = (&pred - &target)?;
342                  let loss = diff.sqr()?.mean_all()?;
343  
344                  // Skip NaN losses (rare edge cases)
345                  let loss_val = loss.to_scalar::<f32>()?;
346                  if loss_val.is_nan() || loss_val.is_infinite() {
347                      continue;
348                  }
349  
350                  // Backward pass
351                  optimizer.backward_step(&loss)?;
352  
353                  total_loss += loss.to_scalar::<f32>()?;
354                  count += 1;
355              }
356          }
357  
358          Ok(if count > 0 { total_loss / count as f32 } else { 0.0 })
359      }
360  
361      /// Save weights to safetensors
362      fn save(&self, path: &PathBuf) -> Result<(), Box<dyn std::error::Error>> {
363          // Ensure parent dir exists
364          if let Some(parent) = path.parent() {
365              std::fs::create_dir_all(parent)?;
366          }
367          self.varmap.save(path)?;
368          Ok(())
369      }
370  }
371  
372  fn main() -> Result<(), Box<dyn std::error::Error>> {
373      let args = Args::parse();
374  
375      println!("šŸ¦Ž Chameleon GRU Trainer");
376      println!("   Output: {}", args.output.display());
377      println!("   Epochs: {}", args.epochs);
378      println!("   LR:     {}", args.learning_rate);
379      println!();
380  
381      let sequences: Vec<Vec<(f32, f32)>>;
382      let total_samples: usize;
383  
384      if args.synthetic {
385          // Generate synthetic training data
386          println!("šŸ”§ Generating synthetic TLS data...");
387          println!("   Flows: {}", args.num_flows);
388          println!("   Packets/flow: {}", args.packets_per_flow);
389          
390          let mut rng = rand::thread_rng();
391          let profile_weights: Vec<_> = TLS_PROFILES.iter().map(|p| p.weight).collect();
392          let profile_dist = WeightedIndex::new(&profile_weights).unwrap();
393          
394          let pb = ProgressBar::new(args.num_flows as u64);
395          pb.set_style(
396              ProgressStyle::default_bar()
397                  .template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} {msg}")
398                  .unwrap(),
399          );
400          pb.set_message("Generating flows...");
401          
402          let mut all_sequences = Vec::new();
403          for _ in 0..args.num_flows {
404              let profile = &TLS_PROFILES[profile_dist.sample(&mut rng)];
405              let flow = generate_flow(profile, args.packets_per_flow);
406              // Convert to (f32, f32) format
407              let seq: Vec<(f32, f32)> = flow.iter().map(|(d, s)| (*d as f32, *s as f32)).collect();
408              all_sequences.push(seq);
409              pb.inc(1);
410          }
411          pb.finish_with_message("Done!");
412          
413          total_samples = all_sequences.iter().map(|s| s.len().saturating_sub(1)).sum();
414          sequences = all_sequences;
415          
416          println!("šŸ“Š Generated {} flows, {} training samples", args.num_flows, total_samples);
417      } else {
418          // Load from PCAP files
419          let input_dir = args.input_dir.as_ref().ok_or("PCAP_DIR required unless --synthetic")?;
420          println!("   Input:  {}", input_dir.display());
421          
422          let pcap_files: Vec<_> = std::fs::read_dir(input_dir)?
423              .filter_map(|e| e.ok())
424              .map(|e| e.path())
425              .filter(|p| {
426                  p.extension()
427                      .map(|e| e == "pcap" || e == "pcapng")
428                      .unwrap_or(false)
429              })
430              .collect();
431  
432          if pcap_files.is_empty() {
433              eprintln!("āŒ No PCAP files found in {}", input_dir.display());
434              eprintln!("   Tip: Use --synthetic to train on generated TLS patterns");
435              std::process::exit(1);
436          }
437  
438          println!("šŸ“ Found {} PCAP files", pcap_files.len());
439  
440          let pb = ProgressBar::new(pcap_files.len() as u64);
441          pb.set_style(
442              ProgressStyle::default_bar()
443                  .template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} {msg}")
444                  .unwrap(),
445          );
446          pb.set_message("Extracting flows...");
447  
448          let mut all_flows = Vec::new();
449          for path in &pcap_files {
450              match extract_flows_from_pcap(path) {
451                  Ok(flows) => all_flows.extend(flows),
452                  Err(e) => eprintln!("āš ļø  Failed to parse {}: {}", path.display(), e),
453              }
454              pb.inc(1);
455          }
456          pb.finish_with_message("Done!");
457  
458          let (seqs, samples) = prepare_training_data(&all_flows, args.min_packets, args.max_sequences, args.max_steps);
459          sequences = seqs;
460          total_samples = samples;
461          
462          println!(
463              "šŸ“Š {} flows → {} sequences, {} training samples",
464              all_flows.len(),
465              sequences.len(),
466              total_samples
467          );
468      }
469  
470      if sequences.is_empty() {
471          eprintln!("āŒ No valid training sequences");
472          std::process::exit(1);
473      }
474  
475      // Initialize model
476      println!("\n🧠 Initializing GRU model...");
477      let trainer = GruTrainer::new()?;
478      let mut optimizer = SGD::new(trainer.varmap.all_vars(), args.learning_rate)?;
479  
480      // Training loop
481      println!("šŸ‹ļø Training for {} epochs...\n", args.epochs);
482  
483      let pb = ProgressBar::new(args.epochs as u64);
484      pb.set_style(
485          ProgressStyle::default_bar()
486              .template("{spinner:.green} [{bar:40.cyan/blue}] {pos}/{len} loss: {msg}")
487              .unwrap(),
488      );
489  
490      for epoch in 0..args.epochs {
491          let loss = trainer.train_step(&sequences, &mut optimizer)?;
492          pb.set_message(format!("{:.6}", loss));
493          pb.inc(1);
494  
495          // Early stopping if loss is very low
496          if loss < 0.0001 {
497              pb.finish_with_message(format!("Converged at epoch {} with loss {:.6}", epoch, loss));
498              break;
499          }
500      }
501      pb.finish_with_message("Training complete!");
502  
503      // Save weights
504      println!("\nšŸ’¾ Saving weights to {}...", args.output.display());
505      trainer.save(&args.output)?;
506  
507      println!("āœ… Done! Weights saved to {}", args.output.display());
508      println!("\nšŸ“ Next steps:");
509      println!("   1. Copy {} to abzu-chameleon/weights/", args.output.display());
510      println!("   2. Update MlGenerator to use include_bytes!(\"weights/ghost_v1.safetensors\")");
511      println!("   3. Rebuild with: cargo build -p abzu-chameleon --features ml");
512  
513      Ok(())
514  }