main.rs
1 use anyhow::{Context, Result, bail}; 2 use clap::{CommandFactory, Parser, Subcommand}; 3 use clap_complete::{ 4 generate, 5 shells::{Bash, Elvish, Fish, PowerShell, Zsh}, 6 }; 7 8 use ssh2::Session; 9 use std::io::{Read, Write}; 10 use std::net::TcpStream; 11 use std::path::{Path, PathBuf}; 12 use std::sync::Arc; 13 use std::sync::mpsc; 14 use std::time::{Duration, Instant}; 15 use tracing::{debug, info, instrument, warn}; // Import tracing macros 16 17 mod progress; 18 mod progress_writer; 19 mod utils; 20 21 use progress::{ProgressDisplay, TransferProgress}; 22 use progress_writer::{ProgressTarBuilder, ProgressWriter}; 23 24 #[derive(Parser, Debug)] 25 #[command( 26 author, 27 version, 28 about = "A tool to efficiently transfer files and directories over SSH using a tar stream." 29 )] 30 struct Cli { 31 #[command(subcommand)] 32 command: Option<Commands>, 33 34 /// Enable Zstandard compression for the tar stream. 35 #[arg(short = 'z', long)] 36 zstd: bool, 37 38 /// The local source files or directories to transfer, followed by the remote destination. 39 /// Format: [SOURCES]... <DESTINATION> 40 /// Where DESTINATION is 'user@host:/path/to/dest' 41 args: Vec<String>, 42 } 43 44 #[derive(Subcommand, Debug)] 45 enum Commands { 46 /// Generate shell completions 47 Completions { 48 /// The shell to generate completions for 49 #[arg(value_enum)] 50 shell: Shell, 51 }, 52 } 53 54 #[derive(clap::ValueEnum, Clone, Debug)] 55 enum Shell { 56 Bash, 57 Elvish, 58 Fish, 59 PowerShell, 60 Zsh, 61 } 62 63 fn main() { 64 // --- INITIALIZATION --- 65 // Set up the tracing subscriber. 66 // This will read the RUST_LOG environment variable to determine the log level. 67 // Example: `RUST_LOG=info` or `RUST_LOG=bale=debug,ssh2=warn` 68 tracing_subscriber::fmt() 69 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) 70 .with_writer(std::io::stderr) // Log to stderr 71 .init(); 72 73 if let Err(e) = run() { 74 // Use tracing::error for the final error output 75 tracing::error!("Application failed: {:?}", e); 76 std::process::exit(1); 77 } 78 } 79 80 // The `instrument` attribute automatically creates a span for this function. 81 // Every log event inside `run` will be associated with this span. 82 #[instrument] 83 fn run() -> Result<()> { 84 let cli = Cli::parse(); 85 86 // Handle subcommands 87 if let Some(command) = cli.command { 88 match command { 89 Commands::Completions { shell } => { 90 generate_completions(shell)?; 91 return Ok(()); 92 } 93 } 94 } 95 96 // --- 1. Parse Arguments & Validate Local Paths --- 97 if cli.args.len() < 2 { 98 bail!("At least one source and a destination must be provided"); 99 } 100 101 let destination = cli.args.last().unwrap(); 102 let sources_str = &cli.args[..cli.args.len() - 1]; 103 104 if sources_str.is_empty() { 105 bail!("At least one source must be provided"); 106 } 107 108 let mut sources = Vec::new(); 109 let mut seen_names = std::collections::HashSet::new(); 110 111 for source_str in sources_str { 112 let source = PathBuf::from(source_str); 113 if !source.exists() { 114 bail!("Source '{}' does not exist.", source.display()); 115 } 116 117 // Check for duplicate basenames which would cause conflicts in tar 118 if let Some(name) = source.file_name().and_then(|n| n.to_str()) { 119 if !seen_names.insert(name.to_string()) { 120 bail!( 121 "Duplicate source name '{}'. Multiple sources with the same basename are not supported.", 122 name 123 ); 124 } 125 } 126 127 sources.push(source); 128 } 129 130 let (user, connection, remote_path_str) = utils::parse_destination(destination)?; 131 let remote_path = Path::new(&remote_path_str); 132 133 // --- 2. Connect and Authenticate --- 134 // We create a span to time and group all authentication-related logs. 135 // The `user` and `host` are attached as structured fields to the span. 136 let auth_span = 137 tracing::info_span!("authentication", user = %user, host = %connection.hostname); 138 let _auth_guard = auth_span.enter(); // Enter the span 139 140 info!( 141 "Connecting to {}:{}...", 142 connection.hostname, connection.port 143 ); 144 let tcp = TcpStream::connect(format!("{}:{}", connection.hostname, connection.port)) 145 .with_context(|| { 146 format!( 147 "Failed to connect to {}:{}", 148 connection.hostname, connection.port 149 ) 150 })?; 151 let mut sess = Session::new()?; 152 sess.set_tcp_stream(tcp); 153 sess.handshake()?; 154 155 authenticate(&mut sess, &user, &connection.hostname)?; 156 info!("Authentication successful."); 157 drop(_auth_guard); // Explicitly drop the guard to end the span here 158 159 // --- 3. Pre-flight Checks on Remote Host --- 160 let checks_span = tracing::info_span!("pre_flight_checks"); 161 let _checks_guard = checks_span.enter(); 162 163 info!("Performing checks on remote host..."); 164 // For multiple sources, we always treat the remote path as a directory 165 let has_multiple_sources = sources.len() > 1; 166 let has_directory = sources.iter().any(|s| s.is_dir()); 167 let treat_as_directory = has_multiple_sources || has_directory; 168 169 let final_dest_path = 170 determine_and_validate_remote_path(&sess, remote_path, treat_as_directory)?; 171 info!( 172 "Final remote destination will be: {}", 173 final_dest_path.display() 174 ); 175 drop(_checks_guard); 176 177 // --- 4. Data Transfer --- 178 let transfer_span = tracing::info_span!( 179 "transfer", 180 sources = ?sources.iter().map(|s| s.display().to_string()).collect::<Vec<_>>(), 181 dest = %final_dest_path.display(), 182 compressed = cli.zstd 183 ); 184 let _transfer_guard = transfer_span.enter(); 185 186 // Check if we're in a TTY environment 187 let is_tty = atty::is(atty::Stream::Stdout); 188 189 // Create progress tracking 190 let progress = Arc::new(TransferProgress::new()); 191 let mut display = if is_tty { 192 Some(ProgressDisplay::new().context("Failed to create progress display")?) 193 } else { 194 None 195 }; 196 197 // Calculate source information for all sources 198 let (total_files, total_bytes) = progress 199 .calculate_multiple_sources_info(&sources) 200 .context("Failed to calculate source information")?; 201 202 let source_name = if sources.len() == 1 { 203 sources[0] 204 .file_name() 205 .and_then(|n| n.to_str()) 206 .unwrap_or("source") 207 .to_string() 208 } else { 209 format!("{} sources", sources.len()) 210 }; 211 212 if let Some(ref mut display) = display { 213 if cli.zstd { 214 // For compressed transfers, use uncompressed size but show compression info 215 display.initialize(total_files, total_bytes, source_name.to_string()); 216 } else { 217 display.initialize(total_files, total_bytes, source_name.to_string()); 218 } 219 } else { 220 info!( 221 "Found {} files ({} bytes) to transfer", 222 total_files, total_bytes 223 ); 224 } 225 226 let compression_flag = if cli.zstd { "--zstd" } else { "" }; 227 let remote_cmd = format!( 228 "tar {} -xvf - -C '{}'", 229 compression_flag, 230 final_dest_path.to_string_lossy() 231 ); 232 233 debug!(command = %remote_cmd, "Executing remote command"); 234 let mut remote_channel = sess.channel_session()?; 235 remote_channel.exec(&remote_cmd)?; 236 237 let start_time = Instant::now(); 238 let _progress_clone = Arc::clone(&progress); 239 240 // Start progress update thread for TTY environments 241 let (completion_tx, completion_rx) = mpsc::channel(); 242 let progress_handle = if is_tty && display.is_some() { 243 let progress = Arc::clone(&progress); 244 let display = display.take().unwrap(); 245 246 Some(std::thread::spawn(move || { 247 let mut flushing_phase = false; 248 249 loop { 250 let snapshot = progress.get_progress(); 251 252 // Check for completion signal from main thread 253 if completion_rx.try_recv().is_ok() { 254 let final_snapshot = progress.get_progress(); 255 // Ensure the bar is at 100% on finish 256 display.main_bar.set_position(final_snapshot.total_bytes); 257 display.finish(&final_snapshot); 258 break; 259 } 260 261 // MODIFIED: Use uncompressed_bytes_processed to set the progress bar position. 262 // This correctly tracks progress against the original file sizes. 263 display 264 .main_bar 265 .set_position(snapshot.uncompressed_bytes_processed); 266 267 // Update display with appropriate message 268 if flushing_phase { 269 display.status_bar.set_message("Finishing transfer..."); 270 } else { 271 display.update(&snapshot); 272 273 // MODIFIED: Check if we're entering flushing phase based on uncompressed bytes. 274 if snapshot.uncompressed_bytes_processed >= snapshot.total_bytes 275 && snapshot.total_bytes > 0 276 { 277 flushing_phase = true; 278 } 279 } 280 281 std::thread::sleep(Duration::from_millis(100)); 282 } 283 })) 284 } else { 285 None 286 }; 287 288 info!("Starting transfer..."); 289 { 290 let mut remote_stdin = remote_channel.stream(0); 291 let progress_writer = ProgressWriter::new(&mut remote_stdin, Arc::clone(&progress)); 292 293 if cli.zstd { 294 let zstd_encoder = zstd::stream::Encoder::new(progress_writer, 3)?; 295 let mut tar_builder = ProgressTarBuilder::new(zstd_encoder, Arc::clone(&progress)); 296 297 // Add all sources to the tar archive 298 for source in &sources { 299 if source.is_dir() { 300 let dir_name = source 301 .file_name() 302 .and_then(|n| n.to_str()) 303 .unwrap_or("unnamed"); 304 // CORRECTED: Swapped arguments to match (path_in_tar, src_path) 305 tar_builder.append_dir_all(dir_name, source)?; 306 } else { 307 let file_name = source 308 .file_name() 309 .and_then(|n| n.to_str()) 310 .unwrap_or("unnamed"); 311 tar_builder.append_path_with_name(source, file_name)?; 312 } 313 } 314 315 let zstd_encoder = tar_builder.into_inner()?; 316 let mut zstd_encoder = zstd_encoder.finish()?; 317 zstd_encoder.flush()?; 318 } else { 319 let mut tar_builder = ProgressTarBuilder::new(progress_writer, Arc::clone(&progress)); 320 321 // Add all sources to the tar archive 322 for source in &sources { 323 if source.is_dir() { 324 let dir_name = source 325 .file_name() 326 .and_then(|n| n.to_str()) 327 .unwrap_or("unnamed"); 328 // CORRECTED: Swapped arguments to match (path_in_tar, src_path) 329 tar_builder.append_dir_all(dir_name, source)?; 330 } else { 331 let file_name = source 332 .file_name() 333 .and_then(|n| n.to_str()) 334 .unwrap_or("unnamed"); 335 tar_builder.append_path_with_name(source, file_name)?; 336 } 337 } 338 339 let mut writer = tar_builder.into_inner()?; 340 writer.flush()?; 341 } 342 } 343 344 remote_channel.send_eof()?; 345 346 // Read stdout and stderr 347 let mut stdout = String::new(); 348 let mut stderr = String::new(); 349 350 let _ = remote_channel.read_to_string(&mut stdout); 351 let _ = remote_channel.stderr().read_to_string(&mut stderr); 352 353 if !stdout.is_empty() { 354 debug!("Remote stdout: {}", stdout.trim()); 355 } 356 if !stderr.is_empty() { 357 debug!("Remote stderr: {}", stderr.trim()); 358 } 359 360 // Ensure we properly close the channel - this is where actual transfer completion happens 361 remote_channel.wait_eof()?; 362 remote_channel.wait_close()?; 363 let exit_code = remote_channel.exit_status()?; 364 if exit_code != 0 { 365 let mut stderr = String::new(); 366 remote_channel.stderr().read_to_string(&mut stderr)?; 367 bail!( 368 "Remote command failed with exit code {}.\nStderr: {}", 369 exit_code, 370 stderr 371 ); 372 } 373 374 // Wait for actual completion - this ensures all buffered data is transmitted 375 let final_progress = progress.get_progress(); 376 if let Some(handle) = progress_handle { 377 // Keep progress active until remote command completes 378 completion_tx.send(()).unwrap(); 379 handle.join().unwrap(); 380 } else { 381 // Non-TTY completion message 382 info!( 383 "Transfer completed in {:.1}s", 384 start_time.elapsed().as_secs_f64() 385 ); 386 info!( 387 "Transferred {} files ({} bytes)", 388 final_progress.total_files, final_progress.bytes_transferred 389 ); 390 } 391 392 info!("Successfully transferred {} sources.", sources.len()); 393 Ok(()) 394 } 395 396 fn generate_completions(shell: Shell) -> Result<()> { 397 let mut cmd = Cli::command(); 398 399 match shell { 400 Shell::Bash => generate(Bash, &mut cmd, "bale", &mut std::io::stdout()), 401 Shell::Elvish => generate(Elvish, &mut cmd, "bale", &mut std::io::stdout()), 402 Shell::Fish => generate(Fish, &mut cmd, "bale", &mut std::io::stdout()), 403 Shell::PowerShell => generate(PowerShell, &mut cmd, "bale", &mut std::io::stdout()), 404 Shell::Zsh => generate(Zsh, &mut cmd, "bale", &mut std::io::stdout()), 405 } 406 407 Ok(()) 408 } 409 410 // Helper functions remain largely the same, just with `tracing` macros 411 #[instrument(skip(sess))] 412 fn authenticate(sess: &mut Session, user: &str, host: &str) -> Result<()> { 413 // Try SSH agent authentication first with better error handling 414 info!("Attempting SSH agent authentication..."); 415 416 match sess.userauth_agent(user) { 417 Ok(_) => { 418 if sess.authenticated() { 419 info!("Authenticated using SSH agent."); 420 return Ok(()); 421 } else { 422 warn!("SSH agent authentication attempted but not authenticated."); 423 } 424 } 425 Err(e) => { 426 debug!("SSH agent authentication error: {:?}", e); 427 warn!("SSH agent authentication failed or not available."); 428 } 429 } 430 431 // Try public key authentication from default SSH keys 432 info!("Attempting public key authentication..."); 433 let home_dir = std::env::var("HOME").unwrap_or_else(|_| "/home".to_string()); 434 let key_paths = [ 435 format!("{}/.ssh/id_rsa", home_dir), 436 format!("{}/.ssh/id_ed25519", home_dir), 437 format!("{}/.ssh/id_ecdsa", home_dir), 438 ]; 439 440 for key_path in &key_paths { 441 let key_path = Path::new(key_path); 442 if key_path.exists() { 443 debug!("Trying key: {}", key_path.display()); 444 if sess 445 .userauth_pubkey_file(user, None, key_path, None) 446 .is_ok() 447 && sess.authenticated() 448 { 449 info!("Authenticated using public key: {}", key_path.display()); 450 return Ok(()); 451 } 452 } 453 } 454 455 // Fall back to password authentication 456 warn!("Public key authentication failed, falling back to password..."); 457 let password = rpassword::prompt_password(format!("Password for {}@{}: ", user, host))?; 458 459 match sess.userauth_password(user, &password) { 460 Ok(_) => { 461 if !sess.authenticated() { 462 bail!("Password authentication failed after attempting."); 463 } 464 info!("Authenticated using password."); 465 Ok(()) 466 } 467 Err(e) => { 468 // Check if the error is specifically due to authentication failure 469 // This might indicate password authentication is disabled on the server 470 if e.code() == ssh2::ErrorCode::Session(-18) { 471 // LIBSSH2_ERROR_AUTHENTICATION_FAILED 472 bail!( 473 "Password authentication failed. This might be because password authentication is disabled on the remote server. Please ensure SSH agent is configured correctly or enable password authentication on the server." 474 ); 475 } 476 // Re-propagate other errors 477 Err(e.into()) 478 } 479 } 480 } 481 482 /// Checks the remote path and determines the final destination directory. 483 /// `treat_as_directory` indicates whether to treat the destination as a directory 484 /// (true for multiple sources or when transferring directories). 485 fn determine_and_validate_remote_path( 486 sess: &Session, 487 path: &Path, 488 treat_as_directory: bool, 489 ) -> Result<PathBuf> { 490 let path_str = path.to_string_lossy(); 491 492 // Case 1: Remote path exists 493 if run_remote_test(sess, &format!("test -e '{}'", path_str))? { 494 if run_remote_test(sess, &format!("test -d '{}'", path_str))? { 495 // Remote path exists and is a directory - extract directly into it 496 info!( 497 "Remote path '{}' exists and is a directory. Will extract into it.", 498 path_str 499 ); 500 return Ok(path.to_path_buf()); 501 } else { 502 // Remote path exists but is not a directory (it's a file). 503 if treat_as_directory { 504 bail!( 505 "Remote path '{}' exists but is not a directory, cannot extract directory contents into file.", 506 path_str 507 ); 508 } else { 509 // For file transfer, we can overwrite the existing file 510 info!("Remote path '{}' exists as file, will overwrite.", path_str); 511 return Ok(path.to_path_buf()); 512 } 513 } 514 } 515 516 // Case 2: Remote path does not exist. Check if it ends with '/' (explicit directory) 517 let is_explicit_directory = path_str.ends_with('/'); 518 519 if is_explicit_directory { 520 // Explicit directory requested, create it 521 info!("Creating remote directory '{}' as requested.", path_str); 522 run_remote_cmd(sess, &format!("mkdir -p '{}'", path_str))?; 523 return Ok(path.to_path_buf()); 524 } 525 526 // Case 3: Remote path does not exist and no trailing slash 527 // Check if parent exists 528 if let Some(parent) = path.parent() { 529 let parent_str = parent.to_string_lossy(); 530 if parent_str.is_empty() { 531 if !run_remote_test(sess, "test -d /")? { 532 bail!("Remote root directory '/' does not exist or is not a directory."); 533 } 534 } else if !run_remote_test(sess, &format!("test -d '{}'", parent_str))? { 535 bail!( 536 "Remote parent path '{}' does not exist or is not a directory.", 537 parent.display() 538 ); 539 } 540 541 if treat_as_directory { 542 // Directory transfer: create the path as a directory 543 info!( 544 "Creating remote directory '{}' for directory transfer.", 545 path_str 546 ); 547 run_remote_cmd(sess, &format!("mkdir -p '{}'", path_str))?; 548 Ok(path.to_path_buf()) 549 } else { 550 // File transfer: the path is the target filename 551 info!("Remote path '{}' will be created as file.", path_str); 552 Ok(path.to_path_buf()) 553 } 554 } else { 555 bail!( 556 "Could not determine parent of remote path '{}'.", 557 path.display() 558 ); 559 } 560 } 561 562 /// Helper to run a command and return true if it exits with status 0. 563 fn run_remote_test(sess: &Session, cmd: &str) -> Result<bool> { 564 let mut channel = sess.channel_session()?; 565 channel.exec(cmd)?; 566 channel.wait_eof()?; 567 channel.wait_close()?; 568 Ok(channel.exit_status()? == 0) 569 } 570 571 /// Helper to run a command and bail if it fails. 572 fn run_remote_cmd(sess: &Session, cmd: &str) -> Result<()> { 573 let mut channel = sess.channel_session()?; 574 channel.exec(cmd)?; 575 channel.wait_eof()?; 576 channel.wait_close()?; 577 let exit_code = channel.exit_status()?; 578 if exit_code != 0 { 579 let mut stderr = String::new(); 580 channel.stderr().read_to_string(&mut stderr)?; 581 bail!( 582 "Remote command '{}' failed with exit code {}: {}", 583 cmd, 584 exit_code, 585 stderr.trim() 586 ); 587 } 588 Ok(()) 589 }