/ app / src / main.rs
main.rs
  1  use anyhow::{Context, Result, bail};
  2  use clap::{CommandFactory, Parser, Subcommand};
  3  use clap_complete::{
  4      generate,
  5      shells::{Bash, Elvish, Fish, PowerShell, Zsh},
  6  };
  7  
  8  use ssh2::Session;
  9  use std::io::{Read, Write};
 10  use std::net::TcpStream;
 11  use std::path::{Path, PathBuf};
 12  use std::sync::Arc;
 13  use std::sync::mpsc;
 14  use std::time::{Duration, Instant};
 15  use tracing::{debug, info, instrument, warn}; // Import tracing macros
 16  
 17  mod progress;
 18  mod progress_writer;
 19  mod utils;
 20  
 21  use progress::{ProgressDisplay, TransferProgress};
 22  use progress_writer::{ProgressTarBuilder, ProgressWriter};
 23  
 24  #[derive(Parser, Debug)]
 25  #[command(
 26      author,
 27      version,
 28      about = "A tool to efficiently transfer files and directories over SSH using a tar stream."
 29  )]
 30  struct Cli {
 31      #[command(subcommand)]
 32      command: Option<Commands>,
 33  
 34      /// Enable Zstandard compression for the tar stream.
 35      #[arg(short = 'z', long)]
 36      zstd: bool,
 37  
 38      /// The local source files or directories to transfer, followed by the remote destination.
 39      /// Format: [SOURCES]... <DESTINATION>
 40      /// Where DESTINATION is 'user@host:/path/to/dest'
 41      args: Vec<String>,
 42  }
 43  
 44  #[derive(Subcommand, Debug)]
 45  enum Commands {
 46      /// Generate shell completions
 47      Completions {
 48          /// The shell to generate completions for
 49          #[arg(value_enum)]
 50          shell: Shell,
 51      },
 52  }
 53  
 54  #[derive(clap::ValueEnum, Clone, Debug)]
 55  enum Shell {
 56      Bash,
 57      Elvish,
 58      Fish,
 59      PowerShell,
 60      Zsh,
 61  }
 62  
 63  fn main() {
 64      // --- INITIALIZATION ---
 65      // Set up the tracing subscriber.
 66      // This will read the RUST_LOG environment variable to determine the log level.
 67      // Example: `RUST_LOG=info` or `RUST_LOG=bale=debug,ssh2=warn`
 68      tracing_subscriber::fmt()
 69          .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
 70          .with_writer(std::io::stderr) // Log to stderr
 71          .init();
 72  
 73      if let Err(e) = run() {
 74          // Use tracing::error for the final error output
 75          tracing::error!("Application failed: {:?}", e);
 76          std::process::exit(1);
 77      }
 78  }
 79  
 80  // The `instrument` attribute automatically creates a span for this function.
 81  // Every log event inside `run` will be associated with this span.
 82  #[instrument]
 83  fn run() -> Result<()> {
 84      let cli = Cli::parse();
 85  
 86      // Handle subcommands
 87      if let Some(command) = cli.command {
 88          match command {
 89              Commands::Completions { shell } => {
 90                  generate_completions(shell)?;
 91                  return Ok(());
 92              }
 93          }
 94      }
 95  
 96      // --- 1. Parse Arguments & Validate Local Paths ---
 97      if cli.args.len() < 2 {
 98          bail!("At least one source and a destination must be provided");
 99      }
100  
101      let destination = cli.args.last().unwrap();
102      let sources_str = &cli.args[..cli.args.len() - 1];
103  
104      if sources_str.is_empty() {
105          bail!("At least one source must be provided");
106      }
107  
108      let mut sources = Vec::new();
109      let mut seen_names = std::collections::HashSet::new();
110  
111      for source_str in sources_str {
112          let source = PathBuf::from(source_str);
113          if !source.exists() {
114              bail!("Source '{}' does not exist.", source.display());
115          }
116  
117          // Check for duplicate basenames which would cause conflicts in tar
118          if let Some(name) = source.file_name().and_then(|n| n.to_str()) {
119              if !seen_names.insert(name.to_string()) {
120                  bail!(
121                      "Duplicate source name '{}'. Multiple sources with the same basename are not supported.",
122                      name
123                  );
124              }
125          }
126  
127          sources.push(source);
128      }
129  
130      let (user, connection, remote_path_str) = utils::parse_destination(destination)?;
131      let remote_path = Path::new(&remote_path_str);
132  
133      // --- 2. Connect and Authenticate ---
134      // We create a span to time and group all authentication-related logs.
135      // The `user` and `host` are attached as structured fields to the span.
136      let auth_span =
137          tracing::info_span!("authentication", user = %user, host = %connection.hostname);
138      let _auth_guard = auth_span.enter(); // Enter the span
139  
140      info!(
141          "Connecting to {}:{}...",
142          connection.hostname, connection.port
143      );
144      let tcp = TcpStream::connect(format!("{}:{}", connection.hostname, connection.port))
145          .with_context(|| {
146              format!(
147                  "Failed to connect to {}:{}",
148                  connection.hostname, connection.port
149              )
150          })?;
151      let mut sess = Session::new()?;
152      sess.set_tcp_stream(tcp);
153      sess.handshake()?;
154  
155      authenticate(&mut sess, &user, &connection.hostname)?;
156      info!("Authentication successful.");
157      drop(_auth_guard); // Explicitly drop the guard to end the span here
158  
159      // --- 3. Pre-flight Checks on Remote Host ---
160      let checks_span = tracing::info_span!("pre_flight_checks");
161      let _checks_guard = checks_span.enter();
162  
163      info!("Performing checks on remote host...");
164      // For multiple sources, we always treat the remote path as a directory
165      let has_multiple_sources = sources.len() > 1;
166      let has_directory = sources.iter().any(|s| s.is_dir());
167      let treat_as_directory = has_multiple_sources || has_directory;
168  
169      let final_dest_path =
170          determine_and_validate_remote_path(&sess, remote_path, treat_as_directory)?;
171      info!(
172          "Final remote destination will be: {}",
173          final_dest_path.display()
174      );
175      drop(_checks_guard);
176  
177      // --- 4. Data Transfer ---
178      let transfer_span = tracing::info_span!(
179          "transfer",
180          sources = ?sources.iter().map(|s| s.display().to_string()).collect::<Vec<_>>(),
181          dest = %final_dest_path.display(),
182          compressed = cli.zstd
183      );
184      let _transfer_guard = transfer_span.enter();
185  
186      // Check if we're in a TTY environment
187      let is_tty = atty::is(atty::Stream::Stdout);
188  
189      // Create progress tracking
190      let progress = Arc::new(TransferProgress::new());
191      let mut display = if is_tty {
192          Some(ProgressDisplay::new().context("Failed to create progress display")?)
193      } else {
194          None
195      };
196  
197      // Calculate source information for all sources
198      let (total_files, total_bytes) = progress
199          .calculate_multiple_sources_info(&sources)
200          .context("Failed to calculate source information")?;
201  
202      let source_name = if sources.len() == 1 {
203          sources[0]
204              .file_name()
205              .and_then(|n| n.to_str())
206              .unwrap_or("source")
207              .to_string()
208      } else {
209          format!("{} sources", sources.len())
210      };
211  
212      if let Some(ref mut display) = display {
213          if cli.zstd {
214              // For compressed transfers, use uncompressed size but show compression info
215              display.initialize(total_files, total_bytes, source_name.to_string());
216          } else {
217              display.initialize(total_files, total_bytes, source_name.to_string());
218          }
219      } else {
220          info!(
221              "Found {} files ({} bytes) to transfer",
222              total_files, total_bytes
223          );
224      }
225  
226      let compression_flag = if cli.zstd { "--zstd" } else { "" };
227      let remote_cmd = format!(
228          "tar {} -xvf - -C '{}'",
229          compression_flag,
230          final_dest_path.to_string_lossy()
231      );
232  
233      debug!(command = %remote_cmd, "Executing remote command");
234      let mut remote_channel = sess.channel_session()?;
235      remote_channel.exec(&remote_cmd)?;
236  
237      let start_time = Instant::now();
238      let _progress_clone = Arc::clone(&progress);
239  
240      // Start progress update thread for TTY environments
241      let (completion_tx, completion_rx) = mpsc::channel();
242      let progress_handle = if is_tty && display.is_some() {
243          let progress = Arc::clone(&progress);
244          let display = display.take().unwrap();
245  
246          Some(std::thread::spawn(move || {
247              let mut flushing_phase = false;
248  
249              loop {
250                  let snapshot = progress.get_progress();
251  
252                  // Check for completion signal from main thread
253                  if completion_rx.try_recv().is_ok() {
254                      let final_snapshot = progress.get_progress();
255                      // Ensure the bar is at 100% on finish
256                      display.main_bar.set_position(final_snapshot.total_bytes);
257                      display.finish(&final_snapshot);
258                      break;
259                  }
260  
261                  // MODIFIED: Use uncompressed_bytes_processed to set the progress bar position.
262                  // This correctly tracks progress against the original file sizes.
263                  display
264                      .main_bar
265                      .set_position(snapshot.uncompressed_bytes_processed);
266  
267                  // Update display with appropriate message
268                  if flushing_phase {
269                      display.status_bar.set_message("Finishing transfer...");
270                  } else {
271                      display.update(&snapshot);
272  
273                      // MODIFIED: Check if we're entering flushing phase based on uncompressed bytes.
274                      if snapshot.uncompressed_bytes_processed >= snapshot.total_bytes
275                          && snapshot.total_bytes > 0
276                      {
277                          flushing_phase = true;
278                      }
279                  }
280  
281                  std::thread::sleep(Duration::from_millis(100));
282              }
283          }))
284      } else {
285          None
286      };
287  
288      info!("Starting transfer...");
289      {
290          let mut remote_stdin = remote_channel.stream(0);
291          let progress_writer = ProgressWriter::new(&mut remote_stdin, Arc::clone(&progress));
292  
293          if cli.zstd {
294              let zstd_encoder = zstd::stream::Encoder::new(progress_writer, 3)?;
295              let mut tar_builder = ProgressTarBuilder::new(zstd_encoder, Arc::clone(&progress));
296  
297              // Add all sources to the tar archive
298              for source in &sources {
299                  if source.is_dir() {
300                      let dir_name = source
301                          .file_name()
302                          .and_then(|n| n.to_str())
303                          .unwrap_or("unnamed");
304                      // CORRECTED: Swapped arguments to match (path_in_tar, src_path)
305                      tar_builder.append_dir_all(dir_name, source)?;
306                  } else {
307                      let file_name = source
308                          .file_name()
309                          .and_then(|n| n.to_str())
310                          .unwrap_or("unnamed");
311                      tar_builder.append_path_with_name(source, file_name)?;
312                  }
313              }
314  
315              let zstd_encoder = tar_builder.into_inner()?;
316              let mut zstd_encoder = zstd_encoder.finish()?;
317              zstd_encoder.flush()?;
318          } else {
319              let mut tar_builder = ProgressTarBuilder::new(progress_writer, Arc::clone(&progress));
320  
321              // Add all sources to the tar archive
322              for source in &sources {
323                  if source.is_dir() {
324                      let dir_name = source
325                          .file_name()
326                          .and_then(|n| n.to_str())
327                          .unwrap_or("unnamed");
328                      // CORRECTED: Swapped arguments to match (path_in_tar, src_path)
329                      tar_builder.append_dir_all(dir_name, source)?;
330                  } else {
331                      let file_name = source
332                          .file_name()
333                          .and_then(|n| n.to_str())
334                          .unwrap_or("unnamed");
335                      tar_builder.append_path_with_name(source, file_name)?;
336                  }
337              }
338  
339              let mut writer = tar_builder.into_inner()?;
340              writer.flush()?;
341          }
342      }
343  
344      remote_channel.send_eof()?;
345  
346      // Read stdout and stderr
347      let mut stdout = String::new();
348      let mut stderr = String::new();
349  
350      let _ = remote_channel.read_to_string(&mut stdout);
351      let _ = remote_channel.stderr().read_to_string(&mut stderr);
352  
353      if !stdout.is_empty() {
354          debug!("Remote stdout: {}", stdout.trim());
355      }
356      if !stderr.is_empty() {
357          debug!("Remote stderr: {}", stderr.trim());
358      }
359  
360      // Ensure we properly close the channel - this is where actual transfer completion happens
361      remote_channel.wait_eof()?;
362      remote_channel.wait_close()?;
363      let exit_code = remote_channel.exit_status()?;
364      if exit_code != 0 {
365          let mut stderr = String::new();
366          remote_channel.stderr().read_to_string(&mut stderr)?;
367          bail!(
368              "Remote command failed with exit code {}.\nStderr: {}",
369              exit_code,
370              stderr
371          );
372      }
373  
374      // Wait for actual completion - this ensures all buffered data is transmitted
375      let final_progress = progress.get_progress();
376      if let Some(handle) = progress_handle {
377          // Keep progress active until remote command completes
378          completion_tx.send(()).unwrap();
379          handle.join().unwrap();
380      } else {
381          // Non-TTY completion message
382          info!(
383              "Transfer completed in {:.1}s",
384              start_time.elapsed().as_secs_f64()
385          );
386          info!(
387              "Transferred {} files ({} bytes)",
388              final_progress.total_files, final_progress.bytes_transferred
389          );
390      }
391  
392      info!("Successfully transferred {} sources.", sources.len());
393      Ok(())
394  }
395  
396  fn generate_completions(shell: Shell) -> Result<()> {
397      let mut cmd = Cli::command();
398  
399      match shell {
400          Shell::Bash => generate(Bash, &mut cmd, "bale", &mut std::io::stdout()),
401          Shell::Elvish => generate(Elvish, &mut cmd, "bale", &mut std::io::stdout()),
402          Shell::Fish => generate(Fish, &mut cmd, "bale", &mut std::io::stdout()),
403          Shell::PowerShell => generate(PowerShell, &mut cmd, "bale", &mut std::io::stdout()),
404          Shell::Zsh => generate(Zsh, &mut cmd, "bale", &mut std::io::stdout()),
405      }
406  
407      Ok(())
408  }
409  
410  // Helper functions remain largely the same, just with `tracing` macros
411  #[instrument(skip(sess))]
412  fn authenticate(sess: &mut Session, user: &str, host: &str) -> Result<()> {
413      // Try SSH agent authentication first with better error handling
414      info!("Attempting SSH agent authentication...");
415  
416      match sess.userauth_agent(user) {
417          Ok(_) => {
418              if sess.authenticated() {
419                  info!("Authenticated using SSH agent.");
420                  return Ok(());
421              } else {
422                  warn!("SSH agent authentication attempted but not authenticated.");
423              }
424          }
425          Err(e) => {
426              debug!("SSH agent authentication error: {:?}", e);
427              warn!("SSH agent authentication failed or not available.");
428          }
429      }
430  
431      // Try public key authentication from default SSH keys
432      info!("Attempting public key authentication...");
433      let home_dir = std::env::var("HOME").unwrap_or_else(|_| "/home".to_string());
434      let key_paths = [
435          format!("{}/.ssh/id_rsa", home_dir),
436          format!("{}/.ssh/id_ed25519", home_dir),
437          format!("{}/.ssh/id_ecdsa", home_dir),
438      ];
439  
440      for key_path in &key_paths {
441          let key_path = Path::new(key_path);
442          if key_path.exists() {
443              debug!("Trying key: {}", key_path.display());
444              if sess
445                  .userauth_pubkey_file(user, None, key_path, None)
446                  .is_ok()
447                  && sess.authenticated()
448              {
449                  info!("Authenticated using public key: {}", key_path.display());
450                  return Ok(());
451              }
452          }
453      }
454  
455      // Fall back to password authentication
456      warn!("Public key authentication failed, falling back to password...");
457      let password = rpassword::prompt_password(format!("Password for {}@{}: ", user, host))?;
458  
459      match sess.userauth_password(user, &password) {
460          Ok(_) => {
461              if !sess.authenticated() {
462                  bail!("Password authentication failed after attempting.");
463              }
464              info!("Authenticated using password.");
465              Ok(())
466          }
467          Err(e) => {
468              // Check if the error is specifically due to authentication failure
469              // This might indicate password authentication is disabled on the server
470              if e.code() == ssh2::ErrorCode::Session(-18) {
471                  // LIBSSH2_ERROR_AUTHENTICATION_FAILED
472                  bail!(
473                      "Password authentication failed. This might be because password authentication is disabled on the remote server. Please ensure SSH agent is configured correctly or enable password authentication on the server."
474                  );
475              }
476              // Re-propagate other errors
477              Err(e.into())
478          }
479      }
480  }
481  
482  /// Checks the remote path and determines the final destination directory.
483  /// `treat_as_directory` indicates whether to treat the destination as a directory
484  /// (true for multiple sources or when transferring directories).
485  fn determine_and_validate_remote_path(
486      sess: &Session,
487      path: &Path,
488      treat_as_directory: bool,
489  ) -> Result<PathBuf> {
490      let path_str = path.to_string_lossy();
491  
492      // Case 1: Remote path exists
493      if run_remote_test(sess, &format!("test -e '{}'", path_str))? {
494          if run_remote_test(sess, &format!("test -d '{}'", path_str))? {
495              // Remote path exists and is a directory - extract directly into it
496              info!(
497                  "Remote path '{}' exists and is a directory. Will extract into it.",
498                  path_str
499              );
500              return Ok(path.to_path_buf());
501          } else {
502              // Remote path exists but is not a directory (it's a file).
503              if treat_as_directory {
504                  bail!(
505                      "Remote path '{}' exists but is not a directory, cannot extract directory contents into file.",
506                      path_str
507                  );
508              } else {
509                  // For file transfer, we can overwrite the existing file
510                  info!("Remote path '{}' exists as file, will overwrite.", path_str);
511                  return Ok(path.to_path_buf());
512              }
513          }
514      }
515  
516      // Case 2: Remote path does not exist. Check if it ends with '/' (explicit directory)
517      let is_explicit_directory = path_str.ends_with('/');
518  
519      if is_explicit_directory {
520          // Explicit directory requested, create it
521          info!("Creating remote directory '{}' as requested.", path_str);
522          run_remote_cmd(sess, &format!("mkdir -p '{}'", path_str))?;
523          return Ok(path.to_path_buf());
524      }
525  
526      // Case 3: Remote path does not exist and no trailing slash
527      // Check if parent exists
528      if let Some(parent) = path.parent() {
529          let parent_str = parent.to_string_lossy();
530          if parent_str.is_empty() {
531              if !run_remote_test(sess, "test -d /")? {
532                  bail!("Remote root directory '/' does not exist or is not a directory.");
533              }
534          } else if !run_remote_test(sess, &format!("test -d '{}'", parent_str))? {
535              bail!(
536                  "Remote parent path '{}' does not exist or is not a directory.",
537                  parent.display()
538              );
539          }
540  
541          if treat_as_directory {
542              // Directory transfer: create the path as a directory
543              info!(
544                  "Creating remote directory '{}' for directory transfer.",
545                  path_str
546              );
547              run_remote_cmd(sess, &format!("mkdir -p '{}'", path_str))?;
548              Ok(path.to_path_buf())
549          } else {
550              // File transfer: the path is the target filename
551              info!("Remote path '{}' will be created as file.", path_str);
552              Ok(path.to_path_buf())
553          }
554      } else {
555          bail!(
556              "Could not determine parent of remote path '{}'.",
557              path.display()
558          );
559      }
560  }
561  
562  /// Helper to run a command and return true if it exits with status 0.
563  fn run_remote_test(sess: &Session, cmd: &str) -> Result<bool> {
564      let mut channel = sess.channel_session()?;
565      channel.exec(cmd)?;
566      channel.wait_eof()?;
567      channel.wait_close()?;
568      Ok(channel.exit_status()? == 0)
569  }
570  
571  /// Helper to run a command and bail if it fails.
572  fn run_remote_cmd(sess: &Session, cmd: &str) -> Result<()> {
573      let mut channel = sess.channel_session()?;
574      channel.exec(cmd)?;
575      channel.wait_eof()?;
576      channel.wait_close()?;
577      let exit_code = channel.exit_status()?;
578      if exit_code != 0 {
579          let mut stderr = String::new();
580          channel.stderr().read_to_string(&mut stderr)?;
581          bail!(
582              "Remote command '{}' failed with exit code {}: {}",
583              cmd,
584              exit_code,
585              stderr.trim()
586          );
587      }
588      Ok(())
589  }