transport.rs
1 use super::channel::{Channel, Pipe}; 2 use super::codec::{ObjectHasher, ObjectWriter}; 3 use super::error::{Error, ProtocolError}; 4 use super::types::{self, AuthMethod, Behavior, Request, SecretKey, TransportError}; 5 use super::wire; 6 7 use chacha20::cipher::{KeyInit, KeyIvInit, StreamCipher, StreamCipherSeek}; 8 use chacha20::ChaCha20Legacy; 9 use constant_time_eq::constant_time_eq; 10 use core::ops::Range; 11 use ed25519_dalek::{Signature, Signer, Verifier, VerifyingKey}; 12 use embedded_io_async::{Read, Write}; 13 use poly1305::Poly1305; 14 use rand::RngCore; 15 use sha2::{Digest, Sha256}; 16 use x25519_dalek::{EphemeralSecret, PublicKey}; 17 18 const KEXINIT_KEX_ALGORITHM: &str = "curve25519-sha256"; 19 const KEXINIT_STRICT_KEX_CLIENT: &str = "kex-strict-c-v00@openssh.com"; 20 const KEXINIT_KEX: &str = "curve25519-sha256,kex-strict-s-v00@openssh.com"; 21 const KEXINIT_HOST_KEY: &str = "ssh-ed25519"; 22 const KEXINIT_ENCRYPTION: &str = "chacha20-poly1305@openssh.com"; 23 // TODO: this violates RFC4253 but seems to be most compatible 24 const KEXINIT_MAC: &str = ""; 25 const KEXINIT_COMPRESSION: &str = "none"; 26 27 struct KexState { 28 discard_guessed: bool, 29 exchange_hash_hasher: ObjectHasher<Sha256>, 30 } 31 32 struct PendingKeys { 33 prefix_hash: ObjectHasher<Sha256>, 34 session_id: [u8; 32], 35 } 36 37 struct KeyMaterial { 38 client_head_key: [u8; 32], 39 client_main_key: [u8; 32], 40 server_head_key: [u8; 32], 41 server_main_key: [u8; 32], 42 session_id: [u8; 32], 43 } 44 45 #[derive(Clone, Copy, Debug)] 46 struct PendingChannel { 47 sender_channel: u32, 48 initial_window_size: u32, 49 maximum_packet_size: u32, 50 } 51 52 #[derive(Clone, Copy, Debug)] 53 enum HalfState { 54 Window(u32), 55 Eof, 56 Close, 57 } 58 59 impl HalfState { 60 pub fn increase_window(&mut self, amount: u32) -> Result<(), ProtocolError> { 61 if let HalfState::Window(value) = self { 62 *value = value 63 .checked_add(amount) 64 .ok_or(ProtocolError::WindowOverflow)?; 65 } 66 67 Ok(()) 68 } 69 70 pub fn decrease_window(&mut self, amount: u32) -> Result<(), ProtocolError> { 71 if let HalfState::Window(value) = self { 72 *value = value 73 .checked_sub(amount) 74 .ok_or(ProtocolError::WindowOverflow)?; 75 } 76 77 Ok(()) 78 } 79 } 80 81 #[derive(Debug)] 82 struct ChannelState { 83 rx_channel_id: u32, 84 tx_channel_id: u32, 85 tx_max_packet: u32, 86 rx_half: HalfState, 87 tx_half: HalfState, 88 rx_committed: bool, 89 } 90 91 /// Implementation of an SSH server's transport layer. 92 pub struct Transport<'a, T: Behavior> { 93 buffer: &'a mut [u8], 94 95 behavior: T, 96 97 client_ssh_id_buffer: [u8; 255], 98 client_ssh_id_length: usize, 99 100 kex: Option<KexState>, 101 strict_kex: bool, 102 103 next_keys: Option<PendingKeys>, 104 curr_keys: Option<KeyMaterial>, 105 106 client_sequence_number: u32, 107 server_sequence_number: u32, 108 109 userauth_enabled: bool, 110 authenticated: bool, 111 112 request: Option<Request<T::Command>>, 113 current_user: Option<T::User>, 114 115 channels: [Option<PendingChannel>; 4], 116 active_channel: Option<ChannelState>, 117 } 118 119 impl<'a, T: Behavior> Transport<'a, T> { 120 /// Creates a new transport from a packet buffer and behavior. 121 pub fn new(buffer: &'a mut [u8], behavior: T) -> Self { 122 assert!(buffer.len() >= 512, "packet buffer too small"); 123 124 Self { 125 buffer, 126 behavior, 127 128 client_ssh_id_buffer: [0u8; 255], 129 client_ssh_id_length: 0, 130 131 kex: None, 132 strict_kex: false, 133 134 next_keys: None, 135 curr_keys: None, 136 137 client_sequence_number: u32::MAX, 138 server_sequence_number: u32::MAX, 139 140 userauth_enabled: false, 141 authenticated: false, 142 143 request: None, 144 current_user: None, 145 146 channels: [None; 4], 147 active_channel: None, 148 } 149 } 150 151 pub(crate) fn client_ssh_id_string(&self) -> &str { 152 let slice = &self.client_ssh_id_buffer[..self.client_ssh_id_length]; 153 154 super::unwrap_unreachable(core::str::from_utf8(slice).ok()) 155 } 156 157 async fn perform_handshake(&mut self) -> Result<(), TransportError<T>> { 158 let ssh_str = self.behavior.server_id().as_bytes(); 159 assert!(ssh_str.len() <= 253); // required by spec 160 161 self.behavior.stream().write_all(ssh_str).await?; 162 self.behavior.stream().write_all(b"\r\n").await?; 163 164 // The client is not allowed to send arbitrary lines prior to sending its 165 // identification string (the server can, but we don't). The parser below 166 // checks that the input is well-formed according to RFC4253 section 4.2. 167 168 for i in 0..255 { 169 self.behavior 170 .stream() 171 .read(&mut self.client_ssh_id_buffer[i..i + 1]) 172 .await?; 173 174 let curr = self.client_ssh_id_buffer[i]; 175 176 if !matches!(curr, b'\r' | b'\n' | 0x20..=0x7E) { 177 Err(ProtocolError::BadIdentificationString)?; 178 } 179 180 if i == 0 { 181 continue; 182 } 183 184 let prev = self.client_ssh_id_buffer[i - 1]; 185 186 if (prev, curr) == (b'\r', b'\n') { 187 self.client_ssh_id_length = i - 1; 188 break; 189 } 190 } 191 192 if !self.client_ssh_id_string().starts_with("SSH-2.0-") { 193 Err(ProtocolError::BadIdentificationString)?; 194 } 195 196 Ok(()) 197 } 198 199 /// Accepts the next client request as a channel. 200 pub async fn accept(&mut self) -> Result<Channel<'a, '_, T>, TransportError<T>> { 201 assert!(self.request.is_none(), "channel request already active"); 202 203 loop { 204 if self.request.is_some() { 205 return Ok(Channel::new(self)); 206 } else { 207 self.poll_client().await?; 208 } 209 } 210 } 211 212 pub(crate) fn channel_request(&self) -> Request<T::Command> { 213 super::unwrap_unreachable(self.request.clone()) 214 } 215 216 pub(crate) fn channel_user(&self) -> T::User { 217 super::unwrap_unreachable(self.current_user.clone()) 218 } 219 220 pub(crate) fn channel_data_payload_buffer(&mut self, pipe: Pipe) -> &mut [u8] { 221 let max_packet_size = wire::from_u32(self.channel_state().tx_max_packet); 222 223 // This is a little bit of a hack, we make an assumption on the representation of the 224 // two ChannelData and ChannelExtendedData messages to retrieve the offset within our 225 // packet buffer at which the caller may write its data so it could be sent in-place. 226 227 let payload_offset = match pipe { 228 Pipe::Stdout => 9, 229 Pipe::Stderr => 13, 230 }; 231 232 let payload_range = self.payload_range_full(); 233 234 let slice = &mut self.buffer[payload_range][payload_offset..]; 235 236 if slice.len() > max_packet_size { 237 &mut slice[..max_packet_size] 238 } else { 239 slice 240 } 241 } 242 243 fn maximum_channel_data_packet_size(&mut self) -> u32 { 244 let range = self.payload_range_full(); 245 246 // The same principle applies to "stdin" ChannelData messages, we can compute the 247 // largest data packet size we can receive without overflowing our packet buffer. 248 249 wire::into_u32(range.end - range.start - 9) 250 } 251 252 pub(crate) async fn channel_adjust( 253 &mut self, 254 amount: Option<u32>, 255 ) -> Result<(), TransportError<T>> { 256 if self.channel_state().rx_committed { 257 panic!("can no longer read from channel!"); 258 } 259 260 if amount.is_none() { 261 self.channel_state().rx_committed = true; 262 } 263 264 let amount = amount.unwrap_or(u32::MAX); 265 assert!(amount != 0, "window is empty"); 266 267 match self.channel_state().rx_half { 268 HalfState::Window(0) => { 269 let recipient_channel = self.channel_state().tx_channel_id; 270 271 self.send(wire::Message::ChannelWindowAdjust { 272 recipient_channel, 273 bytes_to_add: amount, 274 }) 275 .await?; 276 277 self.channel_state().rx_half.increase_window(amount)?; 278 279 Ok(()) 280 } 281 HalfState::Window(_) => { 282 panic!("channel reader did not read all data"); 283 } 284 HalfState::Eof | HalfState::Close => Ok(()), 285 } 286 } 287 288 pub(crate) fn channel_is_eof(&mut self) -> bool { 289 matches!( 290 self.channel_state().rx_half, 291 HalfState::Eof | HalfState::Close 292 ) 293 } 294 295 pub(crate) async fn channel_read(&mut self) -> Result<Option<&[u8]>, TransportError<T>> { 296 loop { 297 match self.channel_state().rx_half { 298 HalfState::Window(0) => { 299 return Ok(None); 300 } 301 HalfState::Window(amount) => { 302 if self.channel_state().rx_committed { 303 // Only re-adjust the window if it is smaller than the maximum data size, this 304 // ensures we will only send a window adjust message approximately every 4GiB. 305 306 if amount < self.maximum_channel_data_packet_size() { 307 let bytes_to_add = u32::MAX - amount; 308 309 let recipient_channel = self.channel_state().tx_channel_id; 310 311 self.send(wire::Message::ChannelWindowAdjust { 312 recipient_channel, 313 bytes_to_add, 314 }) 315 .await?; 316 317 self.channel_state().rx_half.increase_window(bytes_to_add)?; 318 } 319 } 320 321 if let Some(payload_len) = self.poll_client().await? { 322 if let wire::Message::ChannelData { 323 data: wire::Data::Borrowed { bytes }, 324 .. 325 } = wire::Message::decode(&self.buffer[self.payload_range(payload_len)])? 326 { 327 return Ok(Some(bytes)); 328 } else { 329 unreachable!("expected channel data"); 330 } 331 } 332 } 333 HalfState::Eof | HalfState::Close => { 334 return Ok(None); 335 } 336 } 337 } 338 } 339 340 fn channel_state(&mut self) -> &mut ChannelState { 341 super::unwrap_unreachable(self.active_channel.as_mut()) 342 } 343 344 pub(crate) async fn channel_exit(&mut self, exit_status: u32) -> Result<(), TransportError<T>> { 345 let recipient_channel = self.channel_state().tx_channel_id; 346 347 self.send(wire::Message::ChannelEof { recipient_channel }) 348 .await?; 349 350 self.send(wire::Message::ChannelRequest { 351 recipient_channel, 352 request: wire::Request::ExitStatus { 353 want_reply: false, 354 exit_status, 355 }, 356 }) 357 .await?; 358 359 self.send(wire::Message::ChannelClose { recipient_channel }) 360 .await?; 361 362 self.channel_state().tx_half = HalfState::Close; 363 364 if let HalfState::Close = self.channel_state().rx_half { 365 self.dequeue_pending_channel().await?; 366 } 367 368 self.request = None; 369 370 Ok(()) 371 } 372 373 pub(crate) async fn channel_write_all( 374 &mut self, 375 len: usize, 376 pipe: Pipe, 377 ) -> Result<bool, TransportError<T>> { 378 while !self.channel_write(len, pipe).await? { 379 if let HalfState::Close = self.channel_state().tx_half { 380 return Ok(false); // client has closed the channel 381 } 382 383 self.poll_client().await?; 384 } 385 386 Ok(true) 387 } 388 389 pub(crate) async fn channel_write( 390 &mut self, 391 len: usize, 392 pipe: Pipe, 393 ) -> Result<bool, TransportError<T>> { 394 assert!(len <= self.channel_data_payload_buffer(pipe).len()); 395 396 if len == 0 { 397 return Ok(true); 398 } 399 400 if let HalfState::Window(amount) = self.channel_state().tx_half { 401 if wire::from_u32(amount) >= len { 402 let recipient_channel = self.channel_state().tx_channel_id; 403 404 self.send(match pipe { 405 Pipe::Stdout => wire::Message::ChannelData { 406 recipient_channel, 407 data: wire::Data::InPlace { 408 len: wire::into_u32(len), 409 }, 410 }, 411 Pipe::Stderr => wire::Message::ChannelExtendedData { 412 recipient_channel, 413 data: wire::ExtendedData::Stderr { 414 data: wire::Data::InPlace { 415 len: wire::into_u32(len), 416 }, 417 }, 418 }, 419 }) 420 .await?; 421 422 self.channel_state() 423 .tx_half 424 .decrease_window(wire::into_u32(len))?; 425 426 return Ok(true); 427 } 428 } 429 430 Ok(false) 431 } 432 433 async fn poll_client(&mut self) -> Result<Option<usize>, TransportError<T>> { 434 if self.client_ssh_id_length == 0 { 435 self.perform_handshake().await?; 436 } 437 438 let mut reason = wire::DisconnectReason::ProtocolError; 439 440 let payload_len = self.recv().await?; // we sometimes need the message payload bytes 441 let message = wire::Message::decode(&self.buffer[self.payload_range(payload_len)])?; 442 443 match message { 444 wire::Message::KexInit { 445 kex_algorithms, 446 server_host_key_algorithms, 447 encryption_algorithms_client_to_server, 448 encryption_algorithms_server_to_client, 449 compression_algorithms_client_to_server, 450 compression_algorithms_server_to_client, 451 first_kex_packet_follows, 452 .. 453 } if self.kex.is_none() => { 454 if self.curr_keys.is_none() 455 && kex_algorithms.find(KEXINIT_STRICT_KEX_CLIENT).is_some() 456 { 457 // Enable strict KEX mode for this transport as specified by OpenSSH 458 459 self.strict_kex = true; 460 461 if self.client_sequence_number != 0 { 462 return Err(Error::ServerDisconnect( 463 wire::DisconnectReason::ProtocolError, 464 )); 465 } 466 } 467 468 // We only have one algorithm for each name list, so the selection algorithm 469 // boils down to "does the client have our algorithm in their list". Process 470 // the kex and host_key algorithms specially if a guessed packet is sent. 471 472 let kex_index = kex_algorithms.find(KEXINIT_KEX_ALGORITHM); 473 let host_key_index = server_host_key_algorithms.find(KEXINIT_HOST_KEY); 474 475 if kex_index.is_none() || host_key_index.is_none() { 476 return Err(Error::ServerDisconnect( 477 wire::DisconnectReason::KeyExchangeFailed, 478 )); 479 } 480 481 if encryption_algorithms_client_to_server 482 .find(KEXINIT_ENCRYPTION) 483 .is_none() 484 { 485 return Err(Error::ServerDisconnect( 486 wire::DisconnectReason::KeyExchangeFailed, 487 )); 488 } 489 490 if encryption_algorithms_server_to_client 491 .find(KEXINIT_ENCRYPTION) 492 .is_none() 493 { 494 return Err(Error::ServerDisconnect( 495 wire::DisconnectReason::KeyExchangeFailed, 496 )); 497 } 498 499 // We use an AEAD cipher that doesn't require (and forbids) MAC algorithm negotiation. 500 501 if compression_algorithms_client_to_server 502 .find(KEXINIT_COMPRESSION) 503 .is_none() 504 { 505 return Err(Error::ServerDisconnect( 506 wire::DisconnectReason::KeyExchangeFailed, 507 )); 508 } 509 510 if compression_algorithms_server_to_client 511 .find(KEXINIT_COMPRESSION) 512 .is_none() 513 { 514 return Err(Error::ServerDisconnect( 515 wire::DisconnectReason::KeyExchangeFailed, 516 )); 517 } 518 519 // If the kex or host key algorithms were not the client's preferred algorithm, their guess 520 // will be wrong so we must discard the guessed key exchange packet they will have sent us. 521 522 let mut discard_guessed = false; 523 524 if first_kex_packet_follows && (kex_index != Some(0) || host_key_index != Some(0)) { 525 discard_guessed = true; 526 } 527 528 let mut cookie = [0u8; 16]; 529 530 self.behavior.random().fill_bytes(&mut cookie); 531 532 // Note that we never need to guess since we reply to the client's KEXINIT, so we always 533 // know whether our guess would have been correct or not; the client may guess, however. 534 535 let kex_init_message = wire::Message::KexInit { 536 cookie: &cookie, 537 kex_algorithms: wire::NameList::new_from_string(KEXINIT_KEX)?, 538 server_host_key_algorithms: wire::NameList::new_from_string(KEXINIT_HOST_KEY)?, 539 encryption_algorithms_client_to_server: wire::NameList::new_from_string( 540 KEXINIT_ENCRYPTION, 541 )?, 542 encryption_algorithms_server_to_client: wire::NameList::new_from_string( 543 KEXINIT_ENCRYPTION, 544 )?, 545 mac_algorithms_client_to_server: wire::NameList::new_from_string(KEXINIT_MAC)?, 546 mac_algorithms_server_to_client: wire::NameList::new_from_string(KEXINIT_MAC)?, 547 compression_algorithms_client_to_server: wire::NameList::new_from_string( 548 KEXINIT_COMPRESSION, 549 )?, 550 compression_algorithms_server_to_client: wire::NameList::new_from_string( 551 KEXINIT_COMPRESSION, 552 )?, 553 languages_client_to_server: wire::NameList::default(), 554 languages_server_to_client: wire::NameList::default(), 555 first_kex_packet_follows: false, 556 reserved: 0, 557 }; 558 559 let mut kex = KexState { 560 exchange_hash_hasher: ObjectHasher::new(Sha256::new()), 561 discard_guessed, 562 }; 563 564 kex.exchange_hash_hasher 565 .hash_string_utf8(self.client_ssh_id_string()); 566 kex.exchange_hash_hasher 567 .hash_string_utf8(self.behavior.server_id()); 568 kex.exchange_hash_hasher 569 .hash_string(&self.buffer[self.payload_range(payload_len)]); 570 571 let payload_range = self.payload_range_full(); 572 573 let payload = kex_init_message.encode(&mut self.buffer[payload_range])?; 574 575 kex.exchange_hash_hasher.hash_string(payload); 576 577 let payload_len = payload.len(); 578 579 self.send_preencoded_payload(payload_len).await?; 580 581 self.kex = Some(kex); 582 583 return Ok(None); 584 } 585 wire::Message::KexEcdhInit { 586 client_ephemeral_public_key, 587 } => { 588 if let Some(mut kex) = self.kex.take() { 589 if core::mem::replace(&mut kex.discard_guessed, false) { 590 self.kex = Some(kex); 591 return Ok(None); 592 } else if let Ok(client_ephemeral_public_key) = 593 <&[u8] as TryInto<[u8; 32]>>::try_into(client_ephemeral_public_key) 594 { 595 let client_ephemeral_public_key: PublicKey = 596 client_ephemeral_public_key.into(); 597 598 let server_ephemeral_secret_key = 599 EphemeralSecret::random_from_rng(self.behavior.random()); 600 601 let server_ephemeral_public_key: PublicKey = 602 (&server_ephemeral_secret_key).into(); 603 604 // Generate a keypair 605 606 let shared_secret = server_ephemeral_secret_key 607 .diffie_hellman(&client_ephemeral_public_key); 608 609 // Finish building up the exchange hash 610 611 match self.behavior.host_secret_key() { 612 SecretKey::Ed25519 { secret_key } => { 613 let public_key = secret_key.verifying_key(); 614 615 let host_key = wire::PublicKey::Ed25519 { 616 public_key: public_key.as_bytes(), 617 }; 618 619 let mut host_key_writer = ObjectWriter::new(self.buffer); 620 621 host_key.encode_with(&mut host_key_writer)?; 622 623 kex.exchange_hash_hasher 624 .hash_byte_array(host_key_writer.into_written()); 625 kex.exchange_hash_hasher 626 .hash_string(client_ephemeral_public_key.as_bytes()); 627 kex.exchange_hash_hasher 628 .hash_string(server_ephemeral_public_key.as_bytes()); 629 630 let shared_secret = *shared_secret.as_bytes(); 631 632 kex.exchange_hash_hasher.hash_mpint(&shared_secret); 633 634 let exchange_hash = kex.exchange_hash_hasher.into_digest(); 635 636 let signature = secret_key.sign(&exchange_hash); 637 638 self.send(wire::Message::KexEcdhReply { 639 server_public_host_key: wire::PublicKey::Ed25519 { 640 public_key: public_key.as_bytes(), 641 }, 642 server_ephemeral_public_key: server_ephemeral_public_key 643 .as_bytes(), 644 signature: wire::Signature::Ed25519 { 645 signature: &signature.to_bytes(), 646 }, 647 }) 648 .await?; 649 650 let mut prefix_hash = ObjectHasher::new(Sha256::new()); 651 652 prefix_hash.hash_mpint(&shared_secret); 653 prefix_hash.hash_byte_array(&exchange_hash); 654 655 self.next_keys = Some(PendingKeys { 656 session_id: match &self.curr_keys { 657 Some(keys) => keys.session_id, 658 None => exchange_hash.into(), 659 }, 660 prefix_hash, 661 }); 662 663 return Ok(None); 664 } 665 } 666 } else { 667 reason = wire::DisconnectReason::KeyExchangeFailed; 668 } 669 } 670 } 671 wire::Message::NewKeys => { 672 if let Some(keys) = self.next_keys.take() { 673 let mut enc_key_client_hash = keys.prefix_hash.clone(); 674 enc_key_client_hash.hash_byte(b'C'); 675 enc_key_client_hash.hash_byte_array(&keys.session_id); 676 let client_enc_k1 = enc_key_client_hash.into_digest(); 677 678 let mut enc_key_server_hash = keys.prefix_hash.clone(); 679 enc_key_server_hash.hash_byte(b'D'); 680 enc_key_server_hash.hash_byte_array(&keys.session_id); 681 let server_enc_k1 = enc_key_server_hash.into_digest(); 682 683 let mut digest = keys.prefix_hash.clone(); 684 digest.hash_byte_array(&client_enc_k1); 685 let client_enc_k2 = digest.into_digest(); 686 687 let mut digest = keys.prefix_hash.clone(); 688 digest.hash_byte_array(&server_enc_k1); 689 let server_enc_k2 = digest.into_digest(); 690 691 self.send(wire::Message::NewKeys).await?; 692 693 self.curr_keys = Some(KeyMaterial { 694 client_head_key: client_enc_k2.into(), 695 client_main_key: client_enc_k1.into(), 696 server_head_key: server_enc_k2.into(), 697 server_main_key: server_enc_k1.into(), 698 session_id: keys.session_id, 699 }); 700 701 if self.strict_kex { 702 self.client_sequence_number = u32::MAX; 703 self.server_sequence_number = u32::MAX; 704 } 705 706 return Ok(None); 707 } 708 } 709 wire::Message::ServiceRequest { service_name } if self.curr_keys.is_some() => { 710 match service_name { 711 "ssh-userauth" => { 712 self.userauth_enabled = true; 713 714 self.send(wire::Message::ServiceAccept { 715 service_name: "ssh-userauth", 716 }) 717 .await?; 718 719 return Ok(None); 720 } 721 _ => { 722 reason = wire::DisconnectReason::ServiceNotAvailable; 723 } 724 } 725 } 726 wire::Message::UserAuthRequest { 727 user_name, 728 service_name: "ssh-connection", 729 auth_method, 730 } if self.userauth_enabled => { 731 // Unfortunately we need to use the user name for signature verification, meaning we need 732 // to store it outside the packet buffer. Rather than burden the crate user, we just copy 733 // the user name into a temporary string, enforcing a reasonable 80-byte maximum length. 734 735 let mut user_name_buffer = [0u8; 80]; 736 737 if user_name.len() > user_name_buffer.len() { 738 self.send(wire::Message::UserAuthFailure { 739 authentications_that_can_continue: wire::NameList::new_from_string( 740 "publickey", 741 )?, 742 partial_success: false, 743 }) 744 .await?; 745 746 return Ok(None); 747 } 748 749 let user_name_slice = &mut user_name_buffer[..user_name.len()]; 750 user_name_slice.copy_from_slice(user_name.as_bytes()); 751 752 let Some(user_auth_method) = (match auth_method { 753 wire::AuthMethod::None => Some(AuthMethod::None), 754 wire::AuthMethod::PublicKey { 755 public_key_algorithm_name: "ssh-ed25519", 756 public_key: wire::PublicKey::Ed25519 { public_key }, 757 signature: Some(wire::Signature::Ed25519 { .. }) | None, 758 } => { 759 if let Ok(public_key) = VerifyingKey::from_bytes(public_key) { 760 Some(AuthMethod::PublicKey(types::PublicKey::Ed25519 { 761 public_key, 762 })) 763 } else { 764 None 765 } 766 } 767 _ => None, 768 }) else { 769 self.send(wire::Message::UserAuthFailure { 770 authentications_that_can_continue: wire::NameList::new_from_string( 771 "publickey", 772 )?, 773 partial_success: false, 774 }) 775 .await?; 776 777 return Ok(None); 778 }; 779 780 if let Some(user) = self.behavior.allow_user(user_name, &user_auth_method) { 781 match user_auth_method { 782 AuthMethod::None => { 783 self.send(wire::Message::UserAuthSuccess).await?; 784 self.current_user = Some(user); 785 self.authenticated = true; 786 } 787 AuthMethod::PublicKey(types::PublicKey::Ed25519 { public_key }) => { 788 if let wire::AuthMethod::PublicKey { 789 public_key_algorithm_name: "ssh-ed25519", 790 public_key: wire::PublicKey::Ed25519 { .. }, 791 signature: Some(wire::Signature::Ed25519 { signature }), 792 } = auth_method 793 { 794 let signed_public_key = wire::PublicKey::Ed25519 { 795 public_key: public_key.as_bytes(), 796 }; 797 798 let signature: Signature = signature.into(); 799 800 let mut writer = ObjectWriter::new(self.buffer); 801 802 // TODO: feels like we could do a little better here 803 804 writer.write_string( 805 &super::unwrap_unreachable(self.curr_keys.as_ref()).session_id, 806 )?; 807 writer.write_byte(wire::MSG_USERAUTH_REQUEST)?; 808 writer.write_string(user_name_slice)?; 809 writer.write_string_utf8("ssh-connection")?; 810 writer.write_string_utf8("publickey")?; 811 writer.write_boolean(true)?; 812 writer.write_string_utf8("ssh-ed25519")?; 813 signed_public_key.encode_with(&mut writer)?; 814 815 if let Ok(()) = public_key.verify(writer.into_written(), &signature) 816 { 817 self.send(wire::Message::UserAuthSuccess).await?; 818 self.current_user = Some(user); 819 self.authenticated = true; 820 } else { 821 self.send(wire::Message::UserAuthFailure { 822 authentications_that_can_continue: 823 wire::NameList::new_from_string("publickey")?, 824 partial_success: false, 825 }) 826 .await?; 827 } 828 } else { 829 self.send(wire::Message::UserAuthPkOk { 830 public_key_algorithm_name: "ssh-ed25519", 831 public_key: wire::PublicKey::Ed25519 { 832 public_key: public_key.as_bytes(), 833 }, 834 }) 835 .await?; 836 } 837 } 838 } 839 } else { 840 self.send(wire::Message::UserAuthFailure { 841 authentications_that_can_continue: wire::NameList::new_from_string( 842 "publickey", 843 )?, 844 partial_success: false, 845 }) 846 .await?; 847 } 848 849 return Ok(None); 850 } 851 852 wire::Message::GlobalRequest { want_reply, .. } if self.authenticated => { 853 if want_reply { 854 self.send(wire::Message::RequestFailure).await?; 855 } 856 857 return Ok(None); 858 } 859 860 wire::Message::ChannelOpen { 861 channel: 862 wire::ChannelType::Session { 863 sender_channel, 864 initial_window_size, 865 maximum_packet_size, 866 }, 867 } if self.authenticated => { 868 for channel in self.channels.into_iter().flatten() { 869 if channel.sender_channel == sender_channel { 870 self.send(wire::Message::Disconnect { 871 reason: wire::DisconnectReason::ProtocolError, 872 }) 873 .await?; 874 return Err(Error::ServerDisconnect( 875 wire::DisconnectReason::ProtocolError, 876 )); 877 } 878 } 879 880 for channel in &mut self.channels { 881 if channel.is_none() { 882 *channel = Some(PendingChannel { 883 sender_channel, 884 initial_window_size, 885 maximum_packet_size, 886 }); 887 888 if self.active_channel.is_none() { 889 self.dequeue_pending_channel().await?; 890 } 891 892 return Ok(None); 893 } 894 } 895 896 self.send(wire::Message::ChannelOpenFailure { 897 recipient_channel: sender_channel, 898 reason: wire::ChannelOpenFailureReason::ResourceShortage, 899 }) 900 .await?; 901 902 return Ok(None); 903 } 904 905 wire::Message::ChannelOpen { 906 channel: wire::ChannelType::Other { sender_channel, .. }, 907 } if self.authenticated => { 908 self.send(wire::Message::ChannelOpenFailure { 909 recipient_channel: sender_channel, 910 reason: wire::ChannelOpenFailureReason::UnknownChannelType, 911 }) 912 .await?; 913 914 return Ok(None); 915 } 916 917 wire::Message::ChannelWindowAdjust { 918 recipient_channel, 919 bytes_to_add, 920 } if self.authenticated => { 921 if let Some(channel_state) = &mut self.active_channel { 922 if channel_state.rx_channel_id == recipient_channel { 923 channel_state.tx_half.increase_window(bytes_to_add)?; 924 return Ok(None); 925 } 926 } 927 } 928 929 wire::Message::ChannelData { 930 recipient_channel, 931 data: wire::Data::Borrowed { bytes }, 932 } if self.authenticated => { 933 if let Some(channel_state) = &mut self.active_channel { 934 if channel_state.rx_channel_id == recipient_channel { 935 channel_state 936 .rx_half 937 .decrease_window(wire::into_u32(bytes.len()))?; 938 return Ok(Some(payload_len)); 939 } 940 } 941 } 942 943 wire::Message::ChannelEof { recipient_channel } if self.authenticated => { 944 if let Some(channel_state) = &mut self.active_channel { 945 if channel_state.rx_channel_id == recipient_channel { 946 channel_state.rx_half = HalfState::Eof; 947 return Ok(None); 948 } 949 } 950 } 951 952 wire::Message::ChannelClose { recipient_channel } if self.authenticated => { 953 if self.active_channel.is_none() { 954 return Ok(None); // ignored 955 } 956 957 if let Some(channel_state) = &mut self.active_channel { 958 if channel_state.rx_channel_id == recipient_channel { 959 channel_state.rx_half = HalfState::Close; 960 961 if let HalfState::Close = channel_state.tx_half { 962 self.dequeue_pending_channel().await?; 963 } else if self.request.is_none() { 964 let sender_channel = channel_state.tx_channel_id; 965 966 self.send(wire::Message::ChannelEof { 967 recipient_channel: sender_channel, 968 }) 969 .await?; 970 971 self.send(wire::Message::ChannelClose { 972 recipient_channel: sender_channel, 973 }) 974 .await?; 975 976 self.dequeue_pending_channel().await?; 977 } else { 978 channel_state.tx_half = HalfState::Close; 979 } 980 981 return Ok(None); 982 } 983 } 984 } 985 986 wire::Message::ChannelRequest { 987 recipient_channel, 988 request: 989 wire::Request::Exec { 990 want_reply, 991 command, 992 }, 993 } if self.authenticated && self.request.is_none() => { 994 if let Some(channel_state) = &mut self.active_channel { 995 if channel_state.rx_channel_id == recipient_channel { 996 self.request = Some(Request::Exec(self.behavior.parse_command(command))); 997 998 if want_reply { 999 let sender_channel = channel_state.tx_channel_id; 1000 1001 self.send(wire::Message::ChannelSuccess { 1002 recipient_channel: sender_channel, 1003 }) 1004 .await?; 1005 } 1006 1007 return Ok(None); 1008 } 1009 } 1010 } 1011 1012 wire::Message::ChannelRequest { 1013 recipient_channel, 1014 request: wire::Request::Env { want_reply, .. }, 1015 } if self.authenticated => { 1016 if let Some(channel_state) = &mut self.active_channel { 1017 if channel_state.rx_channel_id == recipient_channel { 1018 if want_reply { 1019 let sender_channel = channel_state.tx_channel_id; 1020 1021 self.send(wire::Message::ChannelSuccess { 1022 recipient_channel: sender_channel, 1023 }) 1024 .await?; 1025 } 1026 1027 return Ok(None); 1028 } 1029 } 1030 } 1031 1032 // PTY request — accept it and store terminal dimensions 1033 wire::Message::ChannelRequest { 1034 recipient_channel, 1035 request: wire::Request::PtyReq { want_reply, width_chars, height_rows, .. }, 1036 } if self.authenticated => { 1037 if let Some(channel_state) = &mut self.active_channel { 1038 if channel_state.rx_channel_id == recipient_channel { 1039 self.behavior.on_pty_request(width_chars, height_rows); 1040 if want_reply { 1041 let sender_channel = channel_state.tx_channel_id; 1042 self.send(wire::Message::ChannelSuccess { 1043 recipient_channel: sender_channel, 1044 }).await?; 1045 } 1046 return Ok(None); 1047 } 1048 } 1049 } 1050 1051 wire::Message::ChannelRequest { 1052 recipient_channel, 1053 request: wire::Request::Shell { want_reply }, 1054 } if self.authenticated && self.request.is_none() && self.behavior.allow_shell() => { 1055 if let Some(channel_state) = &mut self.active_channel { 1056 if channel_state.rx_channel_id == recipient_channel { 1057 self.request = Some(Request::Shell); 1058 1059 if want_reply { 1060 let sender_channel = channel_state.tx_channel_id; 1061 1062 self.send(wire::Message::ChannelSuccess { 1063 recipient_channel: sender_channel, 1064 }) 1065 .await?; 1066 } 1067 1068 return Ok(None); 1069 } 1070 } 1071 } 1072 1073 wire::Message::ChannelRequest { 1074 recipient_channel, 1075 request, 1076 } if self.authenticated => { 1077 if let Some(channel_state) = &mut self.active_channel { 1078 if channel_state.rx_channel_id == recipient_channel { 1079 if request.want_reply() { 1080 let sender_channel = channel_state.tx_channel_id; 1081 1082 self.send(wire::Message::ChannelFailure { 1083 recipient_channel: sender_channel, 1084 }) 1085 .await?; 1086 } 1087 1088 return Ok(None); 1089 } 1090 } 1091 } 1092 1093 wire::Message::Debug { .. } 1094 | wire::Message::Ignore { .. } 1095 | wire::Message::Unimplemented { .. } => { 1096 if self.strict_kex && self.curr_keys.is_none() { 1097 return Err(Error::ServerDisconnect( 1098 wire::DisconnectReason::ProtocolError, 1099 )); 1100 } 1101 1102 return Ok(None); 1103 } 1104 wire::Message::Unknown { .. } => { 1105 if self.strict_kex && self.curr_keys.is_none() { 1106 return Err(Error::ServerDisconnect( 1107 wire::DisconnectReason::ProtocolError, 1108 )); 1109 } 1110 1111 self.send(wire::Message::Unimplemented { 1112 sequence_number: self.client_sequence_number, 1113 }) 1114 .await?; 1115 1116 return Ok(None); 1117 } 1118 wire::Message::Disconnect { reason, .. } => { 1119 return Err(Error::ClientDisconnect(reason)); 1120 } 1121 1122 _ => {} 1123 } 1124 1125 self.send(wire::Message::Disconnect { reason }).await?; 1126 Err(Error::ServerDisconnect(reason)) 1127 } 1128 1129 async fn send_preencoded_payload( 1130 &mut self, 1131 payload_len: usize, 1132 ) -> Result<(), TransportError<T>> { 1133 self.server_sequence_number = self.server_sequence_number.wrapping_add(1); 1134 1135 // NOTE: padding rules differ for AEAD cipher modes 1136 1137 let mut padding_len = if self.curr_keys.is_some() { 1138 (7usize.wrapping_sub(payload_len)) % 8 1139 } else { 1140 (3usize.wrapping_sub(payload_len)) % 8 1141 }; 1142 1143 if padding_len < 4 { 1144 padding_len += 8; 1145 } 1146 1147 let packet_len = wire::into_u32(1 + payload_len + padding_len); 1148 self.buffer[..4].copy_from_slice(&packet_len.to_be_bytes()); 1149 1150 if let Some(ctx) = &mut self.curr_keys { 1151 let mut cipher = ChaCha20Legacy::new( 1152 (&ctx.server_head_key).into(), 1153 (&(self.server_sequence_number as u64).to_be_bytes()).into(), 1154 ); 1155 1156 cipher.apply_keystream(&mut self.buffer[..4]); 1157 } 1158 1159 self.buffer[4] = padding_len as u8; 1160 1161 self.behavior 1162 .random() 1163 .fill_bytes(&mut self.buffer[5 + payload_len..][..padding_len]); 1164 1165 if let Some(ctx) = &mut self.curr_keys { 1166 let (ciphertext, tag_buf) = self.buffer.split_at_mut(5 + payload_len + padding_len); 1167 1168 let sequence_number = self.server_sequence_number as u64; 1169 1170 let mut cipher = ChaCha20Legacy::new( 1171 (&ctx.server_main_key).into(), 1172 (&sequence_number.to_be_bytes()).into(), 1173 ); 1174 1175 let mut mac_key = [0u8; 32]; 1176 cipher.apply_keystream(&mut mac_key); 1177 let mac = Poly1305::new((&mac_key).into()); 1178 1179 cipher.seek(64); 1180 cipher.apply_keystream(&mut ciphertext[4..]); 1181 1182 let tag = mac.compute_unpadded(ciphertext); 1183 tag_buf[..16].copy_from_slice(&tag); 1184 1185 self.behavior 1186 .stream() 1187 .write_all(&self.buffer[..5 + payload_len + padding_len + 16]) 1188 .await?; 1189 } else { 1190 self.behavior 1191 .stream() 1192 .write_all(&self.buffer[..5 + payload_len + padding_len]) 1193 .await?; 1194 } 1195 1196 Ok(()) 1197 } 1198 1199 async fn send(&mut self, message: wire::Message<'_>) -> Result<(), TransportError<T>> { 1200 let payload_range = self.payload_range_full(); 1201 1202 let payload_len = message.encode(&mut self.buffer[payload_range])?.len(); 1203 self.send_preencoded_payload(payload_len).await 1204 } 1205 1206 async fn recv(&mut self) -> Result<usize, TransportError<T>> { 1207 self.client_sequence_number = self.client_sequence_number.wrapping_add(1); 1208 1209 self.behavior 1210 .stream() 1211 .read_exact(&mut self.buffer[..4]) 1212 .await?; 1213 1214 let mut decrypted_packet_len = [0u8; 4]; 1215 decrypted_packet_len.copy_from_slice(&self.buffer[..4]); 1216 1217 if let Some(ctx) = &mut self.curr_keys { 1218 let mut cipher = ChaCha20Legacy::new( 1219 (&ctx.client_head_key).into(), 1220 (&(self.client_sequence_number as u64).to_be_bytes()).into(), 1221 ); 1222 1223 cipher.apply_keystream(&mut decrypted_packet_len); 1224 } 1225 1226 let packet_len = wire::from_u32(u32::from_be_bytes(decrypted_packet_len)); 1227 1228 // NOTE: padding rules differ for AEAD cipher modes 1229 1230 let padding_remainder = if self.curr_keys.is_some() { 0 } else { 4 }; 1231 1232 if packet_len < padding_remainder + 8 { 1233 Err(ProtocolError::MalformedPacket)?; 1234 } 1235 1236 if packet_len % 8 != padding_remainder { 1237 Err(ProtocolError::MalformedPacket)?; 1238 } 1239 1240 let mac_len = if self.curr_keys.is_some() { 16 } else { 0 }; 1241 1242 if 4 + packet_len + mac_len > self.buffer.len() { 1243 Err(ProtocolError::BufferExhausted)?; 1244 } 1245 1246 self.behavior 1247 .stream() 1248 .read_exact(&mut self.buffer[4..4 + packet_len + mac_len]) 1249 .await?; 1250 1251 if let Some(ctx) = &mut self.curr_keys { 1252 let (ciphertext, tag_buf) = self.buffer.split_at_mut(4 + packet_len); 1253 1254 let mut cipher = ChaCha20Legacy::new( 1255 (&ctx.client_main_key).into(), 1256 (&(self.client_sequence_number as u64).to_be_bytes()).into(), 1257 ); 1258 1259 let mut mac_key = [0u8; 32]; 1260 cipher.apply_keystream(&mut mac_key); 1261 let mac = Poly1305::new((&mac_key).into()); 1262 1263 let tag = mac.compute_unpadded(ciphertext); 1264 1265 // DO NOT report a MAC verification error to the client for security 1266 // reasons, just disconnect immediately and let it reinitiate later. 1267 1268 if !constant_time_eq(&tag, &tag_buf[..16]) { 1269 Err(ProtocolError::MalformedPacket)?; 1270 } 1271 1272 cipher.seek(64); 1273 cipher.apply_keystream(&mut ciphertext[4..]); 1274 } 1275 1276 let padding_len: usize = self.buffer[4] as usize; 1277 1278 if padding_len < 4 { 1279 Err(ProtocolError::MalformedPacket)?; 1280 } 1281 1282 if packet_len < 1 + padding_len { 1283 Err(ProtocolError::MalformedPacket)?; 1284 } 1285 1286 Ok(packet_len - 1 - padding_len) 1287 } 1288 1289 /// Disconnects from the client with a given reason. 1290 pub async fn disconnect( 1291 mut self, 1292 reason: wire::DisconnectReason, 1293 ) -> Result<(), TransportError<T>> { 1294 if self.client_ssh_id_length == 0 { 1295 return Ok(()); 1296 } 1297 1298 self.send(wire::Message::Disconnect { reason }).await 1299 } 1300 1301 async fn dequeue_pending_channel(&mut self) -> Result<(), TransportError<T>> { 1302 self.active_channel = None; 1303 1304 for channel in &mut self.channels { 1305 if let Some(state) = channel.take() { 1306 self.active_channel = Some(ChannelState { 1307 rx_half: HalfState::Window(0), 1308 tx_half: HalfState::Window(state.initial_window_size), 1309 tx_max_packet: state.maximum_packet_size, 1310 rx_channel_id: 0, 1311 tx_channel_id: state.sender_channel, 1312 rx_committed: false, 1313 }); 1314 1315 let maximum_packet_size = self.maximum_channel_data_packet_size(); 1316 1317 self.send(wire::Message::ChannelOpenConfirmation { 1318 recipient_channel: state.sender_channel, 1319 sender_channel: 0, 1320 initial_window_size: 0, 1321 maximum_packet_size, 1322 payload: &[], 1323 }) 1324 .await?; 1325 1326 break; 1327 } 1328 } 1329 1330 Ok(()) 1331 } 1332 1333 fn payload_range(&self, payload_len: usize) -> Range<usize> { 1334 5..5 + payload_len 1335 } 1336 1337 fn payload_range_full(&self) -> Range<usize> { 1338 5..self.buffer.len() - 255 - 16 1339 } 1340 }