/ firmware / src / services / ssh / transport.rs
transport.rs
   1  use super::channel::{Channel, Pipe};
   2  use super::codec::{ObjectHasher, ObjectWriter};
   3  use super::error::{Error, ProtocolError};
   4  use super::types::{self, AuthMethod, Behavior, Request, SecretKey, TransportError};
   5  use super::wire;
   6  
   7  use chacha20::cipher::{KeyInit, KeyIvInit, StreamCipher, StreamCipherSeek};
   8  use chacha20::ChaCha20Legacy;
   9  use constant_time_eq::constant_time_eq;
  10  use core::ops::Range;
  11  use ed25519_dalek::{Signature, Signer, Verifier, VerifyingKey};
  12  use embedded_io_async::{Read, Write};
  13  use poly1305::Poly1305;
  14  use rand::RngCore;
  15  use sha2::{Digest, Sha256};
  16  use x25519_dalek::{EphemeralSecret, PublicKey};
  17  
  18  const KEXINIT_KEX_ALGORITHM: &str = "curve25519-sha256";
  19  const KEXINIT_STRICT_KEX_CLIENT: &str = "kex-strict-c-v00@openssh.com";
  20  const KEXINIT_KEX: &str = "curve25519-sha256,kex-strict-s-v00@openssh.com";
  21  const KEXINIT_HOST_KEY: &str = "ssh-ed25519";
  22  const KEXINIT_ENCRYPTION: &str = "chacha20-poly1305@openssh.com";
  23  // TODO: this violates RFC4253 but seems to be most compatible
  24  const KEXINIT_MAC: &str = "";
  25  const KEXINIT_COMPRESSION: &str = "none";
  26  
  27  struct KexState {
  28      discard_guessed: bool,
  29      exchange_hash_hasher: ObjectHasher<Sha256>,
  30  }
  31  
  32  struct PendingKeys {
  33      prefix_hash: ObjectHasher<Sha256>,
  34      session_id: [u8; 32],
  35  }
  36  
  37  struct KeyMaterial {
  38      client_head_key: [u8; 32],
  39      client_main_key: [u8; 32],
  40      server_head_key: [u8; 32],
  41      server_main_key: [u8; 32],
  42      session_id: [u8; 32],
  43  }
  44  
  45  #[derive(Clone, Copy, Debug)]
  46  struct PendingChannel {
  47      sender_channel: u32,
  48      initial_window_size: u32,
  49      maximum_packet_size: u32,
  50  }
  51  
  52  #[derive(Clone, Copy, Debug)]
  53  enum HalfState {
  54      Window(u32),
  55      Eof,
  56      Close,
  57  }
  58  
  59  impl HalfState {
  60      pub fn increase_window(&mut self, amount: u32) -> Result<(), ProtocolError> {
  61          if let HalfState::Window(value) = self {
  62              *value = value
  63                  .checked_add(amount)
  64                  .ok_or(ProtocolError::WindowOverflow)?;
  65          }
  66  
  67          Ok(())
  68      }
  69  
  70      pub fn decrease_window(&mut self, amount: u32) -> Result<(), ProtocolError> {
  71          if let HalfState::Window(value) = self {
  72              *value = value
  73                  .checked_sub(amount)
  74                  .ok_or(ProtocolError::WindowOverflow)?;
  75          }
  76  
  77          Ok(())
  78      }
  79  }
  80  
  81  #[derive(Debug)]
  82  struct ChannelState {
  83      rx_channel_id: u32,
  84      tx_channel_id: u32,
  85      tx_max_packet: u32,
  86      rx_half: HalfState,
  87      tx_half: HalfState,
  88      rx_committed: bool,
  89  }
  90  
  91  /// Implementation of an SSH server's transport layer.
  92  pub struct Transport<'a, T: Behavior> {
  93      buffer: &'a mut [u8],
  94  
  95      behavior: T,
  96  
  97      client_ssh_id_buffer: [u8; 255],
  98      client_ssh_id_length: usize,
  99  
 100      kex: Option<KexState>,
 101      strict_kex: bool,
 102  
 103      next_keys: Option<PendingKeys>,
 104      curr_keys: Option<KeyMaterial>,
 105  
 106      client_sequence_number: u32,
 107      server_sequence_number: u32,
 108  
 109      userauth_enabled: bool,
 110      authenticated: bool,
 111  
 112      request: Option<Request<T::Command>>,
 113      current_user: Option<T::User>,
 114  
 115      channels: [Option<PendingChannel>; 4],
 116      active_channel: Option<ChannelState>,
 117  }
 118  
 119  impl<'a, T: Behavior> Transport<'a, T> {
 120      /// Creates a new transport from a packet buffer and behavior.
 121      pub fn new(buffer: &'a mut [u8], behavior: T) -> Self {
 122          assert!(buffer.len() >= 512, "packet buffer too small");
 123  
 124          Self {
 125              buffer,
 126              behavior,
 127  
 128              client_ssh_id_buffer: [0u8; 255],
 129              client_ssh_id_length: 0,
 130  
 131              kex: None,
 132              strict_kex: false,
 133  
 134              next_keys: None,
 135              curr_keys: None,
 136  
 137              client_sequence_number: u32::MAX,
 138              server_sequence_number: u32::MAX,
 139  
 140              userauth_enabled: false,
 141              authenticated: false,
 142  
 143              request: None,
 144              current_user: None,
 145  
 146              channels: [None; 4],
 147              active_channel: None,
 148          }
 149      }
 150  
 151      pub(crate) fn client_ssh_id_string(&self) -> &str {
 152          let slice = &self.client_ssh_id_buffer[..self.client_ssh_id_length];
 153  
 154          super::unwrap_unreachable(core::str::from_utf8(slice).ok())
 155      }
 156  
 157      async fn perform_handshake(&mut self) -> Result<(), TransportError<T>> {
 158          let ssh_str = self.behavior.server_id().as_bytes();
 159          assert!(ssh_str.len() <= 253); // required by spec
 160  
 161          self.behavior.stream().write_all(ssh_str).await?;
 162          self.behavior.stream().write_all(b"\r\n").await?;
 163  
 164          // The client is not allowed to send arbitrary lines prior to sending its
 165          // identification string (the server can, but we don't). The parser below
 166          // checks that the input is well-formed according to RFC4253 section 4.2.
 167  
 168          for i in 0..255 {
 169              self.behavior
 170                  .stream()
 171                  .read(&mut self.client_ssh_id_buffer[i..i + 1])
 172                  .await?;
 173  
 174              let curr = self.client_ssh_id_buffer[i];
 175  
 176              if !matches!(curr, b'\r' | b'\n' | 0x20..=0x7E) {
 177                  Err(ProtocolError::BadIdentificationString)?;
 178              }
 179  
 180              if i == 0 {
 181                  continue;
 182              }
 183  
 184              let prev = self.client_ssh_id_buffer[i - 1];
 185  
 186              if (prev, curr) == (b'\r', b'\n') {
 187                  self.client_ssh_id_length = i - 1;
 188                  break;
 189              }
 190          }
 191  
 192          if !self.client_ssh_id_string().starts_with("SSH-2.0-") {
 193              Err(ProtocolError::BadIdentificationString)?;
 194          }
 195  
 196          Ok(())
 197      }
 198  
 199      /// Accepts the next client request as a channel.
 200      pub async fn accept(&mut self) -> Result<Channel<'a, '_, T>, TransportError<T>> {
 201          assert!(self.request.is_none(), "channel request already active");
 202  
 203          loop {
 204              if self.request.is_some() {
 205                  return Ok(Channel::new(self));
 206              } else {
 207                  self.poll_client().await?;
 208              }
 209          }
 210      }
 211  
 212      pub(crate) fn channel_request(&self) -> Request<T::Command> {
 213          super::unwrap_unreachable(self.request.clone())
 214      }
 215  
 216      pub(crate) fn channel_user(&self) -> T::User {
 217          super::unwrap_unreachable(self.current_user.clone())
 218      }
 219  
 220      pub(crate) fn channel_data_payload_buffer(&mut self, pipe: Pipe) -> &mut [u8] {
 221          let max_packet_size = wire::from_u32(self.channel_state().tx_max_packet);
 222  
 223          // This is a little bit of a hack, we make an assumption on the representation of the
 224          // two ChannelData and ChannelExtendedData messages to retrieve the offset within our
 225          // packet buffer at which the caller may write its data so it could be sent in-place.
 226  
 227          let payload_offset = match pipe {
 228              Pipe::Stdout => 9,
 229              Pipe::Stderr => 13,
 230          };
 231  
 232          let payload_range = self.payload_range_full();
 233  
 234          let slice = &mut self.buffer[payload_range][payload_offset..];
 235  
 236          if slice.len() > max_packet_size {
 237              &mut slice[..max_packet_size]
 238          } else {
 239              slice
 240          }
 241      }
 242  
 243      fn maximum_channel_data_packet_size(&mut self) -> u32 {
 244          let range = self.payload_range_full();
 245  
 246          // The same principle applies to "stdin" ChannelData messages, we can compute the
 247          // largest data packet size we can receive without overflowing our packet buffer.
 248  
 249          wire::into_u32(range.end - range.start - 9)
 250      }
 251  
 252      pub(crate) async fn channel_adjust(
 253          &mut self,
 254          amount: Option<u32>,
 255      ) -> Result<(), TransportError<T>> {
 256          if self.channel_state().rx_committed {
 257              panic!("can no longer read from channel!");
 258          }
 259  
 260          if amount.is_none() {
 261              self.channel_state().rx_committed = true;
 262          }
 263  
 264          let amount = amount.unwrap_or(u32::MAX);
 265          assert!(amount != 0, "window is empty");
 266  
 267          match self.channel_state().rx_half {
 268              HalfState::Window(0) => {
 269                  let recipient_channel = self.channel_state().tx_channel_id;
 270  
 271                  self.send(wire::Message::ChannelWindowAdjust {
 272                      recipient_channel,
 273                      bytes_to_add: amount,
 274                  })
 275                  .await?;
 276  
 277                  self.channel_state().rx_half.increase_window(amount)?;
 278  
 279                  Ok(())
 280              }
 281              HalfState::Window(_) => {
 282                  panic!("channel reader did not read all data");
 283              }
 284              HalfState::Eof | HalfState::Close => Ok(()),
 285          }
 286      }
 287  
 288      pub(crate) fn channel_is_eof(&mut self) -> bool {
 289          matches!(
 290              self.channel_state().rx_half,
 291              HalfState::Eof | HalfState::Close
 292          )
 293      }
 294  
 295      pub(crate) async fn channel_read(&mut self) -> Result<Option<&[u8]>, TransportError<T>> {
 296          loop {
 297              match self.channel_state().rx_half {
 298                  HalfState::Window(0) => {
 299                      return Ok(None);
 300                  }
 301                  HalfState::Window(amount) => {
 302                      if self.channel_state().rx_committed {
 303                          // Only re-adjust the window if it is smaller than the maximum data size, this
 304                          // ensures we will only send a window adjust message approximately every 4GiB.
 305  
 306                          if amount < self.maximum_channel_data_packet_size() {
 307                              let bytes_to_add = u32::MAX - amount;
 308  
 309                              let recipient_channel = self.channel_state().tx_channel_id;
 310  
 311                              self.send(wire::Message::ChannelWindowAdjust {
 312                                  recipient_channel,
 313                                  bytes_to_add,
 314                              })
 315                              .await?;
 316  
 317                              self.channel_state().rx_half.increase_window(bytes_to_add)?;
 318                          }
 319                      }
 320  
 321                      if let Some(payload_len) = self.poll_client().await? {
 322                          if let wire::Message::ChannelData {
 323                              data: wire::Data::Borrowed { bytes },
 324                              ..
 325                          } = wire::Message::decode(&self.buffer[self.payload_range(payload_len)])?
 326                          {
 327                              return Ok(Some(bytes));
 328                          } else {
 329                              unreachable!("expected channel data");
 330                          }
 331                      }
 332                  }
 333                  HalfState::Eof | HalfState::Close => {
 334                      return Ok(None);
 335                  }
 336              }
 337          }
 338      }
 339  
 340      fn channel_state(&mut self) -> &mut ChannelState {
 341          super::unwrap_unreachable(self.active_channel.as_mut())
 342      }
 343  
 344      pub(crate) async fn channel_exit(&mut self, exit_status: u32) -> Result<(), TransportError<T>> {
 345          let recipient_channel = self.channel_state().tx_channel_id;
 346  
 347          self.send(wire::Message::ChannelEof { recipient_channel })
 348              .await?;
 349  
 350          self.send(wire::Message::ChannelRequest {
 351              recipient_channel,
 352              request: wire::Request::ExitStatus {
 353                  want_reply: false,
 354                  exit_status,
 355              },
 356          })
 357          .await?;
 358  
 359          self.send(wire::Message::ChannelClose { recipient_channel })
 360              .await?;
 361  
 362          self.channel_state().tx_half = HalfState::Close;
 363  
 364          if let HalfState::Close = self.channel_state().rx_half {
 365              self.dequeue_pending_channel().await?;
 366          }
 367  
 368          self.request = None;
 369  
 370          Ok(())
 371      }
 372  
 373      pub(crate) async fn channel_write_all(
 374          &mut self,
 375          len: usize,
 376          pipe: Pipe,
 377      ) -> Result<bool, TransportError<T>> {
 378          while !self.channel_write(len, pipe).await? {
 379              if let HalfState::Close = self.channel_state().tx_half {
 380                  return Ok(false); // client has closed the channel
 381              }
 382  
 383              self.poll_client().await?;
 384          }
 385  
 386          Ok(true)
 387      }
 388  
 389      pub(crate) async fn channel_write(
 390          &mut self,
 391          len: usize,
 392          pipe: Pipe,
 393      ) -> Result<bool, TransportError<T>> {
 394          assert!(len <= self.channel_data_payload_buffer(pipe).len());
 395  
 396          if len == 0 {
 397              return Ok(true);
 398          }
 399  
 400          if let HalfState::Window(amount) = self.channel_state().tx_half {
 401              if wire::from_u32(amount) >= len {
 402                  let recipient_channel = self.channel_state().tx_channel_id;
 403  
 404                  self.send(match pipe {
 405                      Pipe::Stdout => wire::Message::ChannelData {
 406                          recipient_channel,
 407                          data: wire::Data::InPlace {
 408                              len: wire::into_u32(len),
 409                          },
 410                      },
 411                      Pipe::Stderr => wire::Message::ChannelExtendedData {
 412                          recipient_channel,
 413                          data: wire::ExtendedData::Stderr {
 414                              data: wire::Data::InPlace {
 415                                  len: wire::into_u32(len),
 416                              },
 417                          },
 418                      },
 419                  })
 420                  .await?;
 421  
 422                  self.channel_state()
 423                      .tx_half
 424                      .decrease_window(wire::into_u32(len))?;
 425  
 426                  return Ok(true);
 427              }
 428          }
 429  
 430          Ok(false)
 431      }
 432  
 433      async fn poll_client(&mut self) -> Result<Option<usize>, TransportError<T>> {
 434          if self.client_ssh_id_length == 0 {
 435              self.perform_handshake().await?;
 436          }
 437  
 438          let mut reason = wire::DisconnectReason::ProtocolError;
 439  
 440          let payload_len = self.recv().await?; // we sometimes need the message payload bytes
 441          let message = wire::Message::decode(&self.buffer[self.payload_range(payload_len)])?;
 442  
 443          match message {
 444              wire::Message::KexInit {
 445                  kex_algorithms,
 446                  server_host_key_algorithms,
 447                  encryption_algorithms_client_to_server,
 448                  encryption_algorithms_server_to_client,
 449                  compression_algorithms_client_to_server,
 450                  compression_algorithms_server_to_client,
 451                  first_kex_packet_follows,
 452                  ..
 453              } if self.kex.is_none() => {
 454                  if self.curr_keys.is_none()
 455                      && kex_algorithms.find(KEXINIT_STRICT_KEX_CLIENT).is_some()
 456                  {
 457                      // Enable strict KEX mode for this transport as specified by OpenSSH
 458  
 459                      self.strict_kex = true;
 460  
 461                      if self.client_sequence_number != 0 {
 462                          return Err(Error::ServerDisconnect(
 463                              wire::DisconnectReason::ProtocolError,
 464                          ));
 465                      }
 466                  }
 467  
 468                  // We only have one algorithm for each name list, so the selection algorithm
 469                  // boils down to "does the client have our algorithm in their list". Process
 470                  // the kex and host_key algorithms specially if a guessed packet is sent.
 471  
 472                  let kex_index = kex_algorithms.find(KEXINIT_KEX_ALGORITHM);
 473                  let host_key_index = server_host_key_algorithms.find(KEXINIT_HOST_KEY);
 474  
 475                  if kex_index.is_none() || host_key_index.is_none() {
 476                      return Err(Error::ServerDisconnect(
 477                          wire::DisconnectReason::KeyExchangeFailed,
 478                      ));
 479                  }
 480  
 481                  if encryption_algorithms_client_to_server
 482                      .find(KEXINIT_ENCRYPTION)
 483                      .is_none()
 484                  {
 485                      return Err(Error::ServerDisconnect(
 486                          wire::DisconnectReason::KeyExchangeFailed,
 487                      ));
 488                  }
 489  
 490                  if encryption_algorithms_server_to_client
 491                      .find(KEXINIT_ENCRYPTION)
 492                      .is_none()
 493                  {
 494                      return Err(Error::ServerDisconnect(
 495                          wire::DisconnectReason::KeyExchangeFailed,
 496                      ));
 497                  }
 498  
 499                  // We use an AEAD cipher that doesn't require (and forbids) MAC algorithm negotiation.
 500  
 501                  if compression_algorithms_client_to_server
 502                      .find(KEXINIT_COMPRESSION)
 503                      .is_none()
 504                  {
 505                      return Err(Error::ServerDisconnect(
 506                          wire::DisconnectReason::KeyExchangeFailed,
 507                      ));
 508                  }
 509  
 510                  if compression_algorithms_server_to_client
 511                      .find(KEXINIT_COMPRESSION)
 512                      .is_none()
 513                  {
 514                      return Err(Error::ServerDisconnect(
 515                          wire::DisconnectReason::KeyExchangeFailed,
 516                      ));
 517                  }
 518  
 519                  // If the kex or host key algorithms were not the client's preferred algorithm, their guess
 520                  // will be wrong so we must discard the guessed key exchange packet they will have sent us.
 521  
 522                  let mut discard_guessed = false;
 523  
 524                  if first_kex_packet_follows && (kex_index != Some(0) || host_key_index != Some(0)) {
 525                      discard_guessed = true;
 526                  }
 527  
 528                  let mut cookie = [0u8; 16];
 529  
 530                  self.behavior.random().fill_bytes(&mut cookie);
 531  
 532                  // Note that we never need to guess since we reply to the client's KEXINIT, so we always
 533                  // know whether our guess would have been correct or not; the client may guess, however.
 534  
 535                  let kex_init_message = wire::Message::KexInit {
 536                      cookie: &cookie,
 537                      kex_algorithms: wire::NameList::new_from_string(KEXINIT_KEX)?,
 538                      server_host_key_algorithms: wire::NameList::new_from_string(KEXINIT_HOST_KEY)?,
 539                      encryption_algorithms_client_to_server: wire::NameList::new_from_string(
 540                          KEXINIT_ENCRYPTION,
 541                      )?,
 542                      encryption_algorithms_server_to_client: wire::NameList::new_from_string(
 543                          KEXINIT_ENCRYPTION,
 544                      )?,
 545                      mac_algorithms_client_to_server: wire::NameList::new_from_string(KEXINIT_MAC)?,
 546                      mac_algorithms_server_to_client: wire::NameList::new_from_string(KEXINIT_MAC)?,
 547                      compression_algorithms_client_to_server: wire::NameList::new_from_string(
 548                          KEXINIT_COMPRESSION,
 549                      )?,
 550                      compression_algorithms_server_to_client: wire::NameList::new_from_string(
 551                          KEXINIT_COMPRESSION,
 552                      )?,
 553                      languages_client_to_server: wire::NameList::default(),
 554                      languages_server_to_client: wire::NameList::default(),
 555                      first_kex_packet_follows: false,
 556                      reserved: 0,
 557                  };
 558  
 559                  let mut kex = KexState {
 560                      exchange_hash_hasher: ObjectHasher::new(Sha256::new()),
 561                      discard_guessed,
 562                  };
 563  
 564                  kex.exchange_hash_hasher
 565                      .hash_string_utf8(self.client_ssh_id_string());
 566                  kex.exchange_hash_hasher
 567                      .hash_string_utf8(self.behavior.server_id());
 568                  kex.exchange_hash_hasher
 569                      .hash_string(&self.buffer[self.payload_range(payload_len)]);
 570  
 571                  let payload_range = self.payload_range_full();
 572  
 573                  let payload = kex_init_message.encode(&mut self.buffer[payload_range])?;
 574  
 575                  kex.exchange_hash_hasher.hash_string(payload);
 576  
 577                  let payload_len = payload.len();
 578  
 579                  self.send_preencoded_payload(payload_len).await?;
 580  
 581                  self.kex = Some(kex);
 582  
 583                  return Ok(None);
 584              }
 585              wire::Message::KexEcdhInit {
 586                  client_ephemeral_public_key,
 587              } => {
 588                  if let Some(mut kex) = self.kex.take() {
 589                      if core::mem::replace(&mut kex.discard_guessed, false) {
 590                          self.kex = Some(kex);
 591                          return Ok(None);
 592                      } else if let Ok(client_ephemeral_public_key) =
 593                          <&[u8] as TryInto<[u8; 32]>>::try_into(client_ephemeral_public_key)
 594                      {
 595                          let client_ephemeral_public_key: PublicKey =
 596                              client_ephemeral_public_key.into();
 597  
 598                          let server_ephemeral_secret_key =
 599                              EphemeralSecret::random_from_rng(self.behavior.random());
 600  
 601                          let server_ephemeral_public_key: PublicKey =
 602                              (&server_ephemeral_secret_key).into();
 603  
 604                          // Generate a keypair
 605  
 606                          let shared_secret = server_ephemeral_secret_key
 607                              .diffie_hellman(&client_ephemeral_public_key);
 608  
 609                          // Finish building up the exchange hash
 610  
 611                          match self.behavior.host_secret_key() {
 612                              SecretKey::Ed25519 { secret_key } => {
 613                                  let public_key = secret_key.verifying_key();
 614  
 615                                  let host_key = wire::PublicKey::Ed25519 {
 616                                      public_key: public_key.as_bytes(),
 617                                  };
 618  
 619                                  let mut host_key_writer = ObjectWriter::new(self.buffer);
 620  
 621                                  host_key.encode_with(&mut host_key_writer)?;
 622  
 623                                  kex.exchange_hash_hasher
 624                                      .hash_byte_array(host_key_writer.into_written());
 625                                  kex.exchange_hash_hasher
 626                                      .hash_string(client_ephemeral_public_key.as_bytes());
 627                                  kex.exchange_hash_hasher
 628                                      .hash_string(server_ephemeral_public_key.as_bytes());
 629  
 630                                  let shared_secret = *shared_secret.as_bytes();
 631  
 632                                  kex.exchange_hash_hasher.hash_mpint(&shared_secret);
 633  
 634                                  let exchange_hash = kex.exchange_hash_hasher.into_digest();
 635  
 636                                  let signature = secret_key.sign(&exchange_hash);
 637  
 638                                  self.send(wire::Message::KexEcdhReply {
 639                                      server_public_host_key: wire::PublicKey::Ed25519 {
 640                                          public_key: public_key.as_bytes(),
 641                                      },
 642                                      server_ephemeral_public_key: server_ephemeral_public_key
 643                                          .as_bytes(),
 644                                      signature: wire::Signature::Ed25519 {
 645                                          signature: &signature.to_bytes(),
 646                                      },
 647                                  })
 648                                  .await?;
 649  
 650                                  let mut prefix_hash = ObjectHasher::new(Sha256::new());
 651  
 652                                  prefix_hash.hash_mpint(&shared_secret);
 653                                  prefix_hash.hash_byte_array(&exchange_hash);
 654  
 655                                  self.next_keys = Some(PendingKeys {
 656                                      session_id: match &self.curr_keys {
 657                                          Some(keys) => keys.session_id,
 658                                          None => exchange_hash.into(),
 659                                      },
 660                                      prefix_hash,
 661                                  });
 662  
 663                                  return Ok(None);
 664                              }
 665                          }
 666                      } else {
 667                          reason = wire::DisconnectReason::KeyExchangeFailed;
 668                      }
 669                  }
 670              }
 671              wire::Message::NewKeys => {
 672                  if let Some(keys) = self.next_keys.take() {
 673                      let mut enc_key_client_hash = keys.prefix_hash.clone();
 674                      enc_key_client_hash.hash_byte(b'C');
 675                      enc_key_client_hash.hash_byte_array(&keys.session_id);
 676                      let client_enc_k1 = enc_key_client_hash.into_digest();
 677  
 678                      let mut enc_key_server_hash = keys.prefix_hash.clone();
 679                      enc_key_server_hash.hash_byte(b'D');
 680                      enc_key_server_hash.hash_byte_array(&keys.session_id);
 681                      let server_enc_k1 = enc_key_server_hash.into_digest();
 682  
 683                      let mut digest = keys.prefix_hash.clone();
 684                      digest.hash_byte_array(&client_enc_k1);
 685                      let client_enc_k2 = digest.into_digest();
 686  
 687                      let mut digest = keys.prefix_hash.clone();
 688                      digest.hash_byte_array(&server_enc_k1);
 689                      let server_enc_k2 = digest.into_digest();
 690  
 691                      self.send(wire::Message::NewKeys).await?;
 692  
 693                      self.curr_keys = Some(KeyMaterial {
 694                          client_head_key: client_enc_k2.into(),
 695                          client_main_key: client_enc_k1.into(),
 696                          server_head_key: server_enc_k2.into(),
 697                          server_main_key: server_enc_k1.into(),
 698                          session_id: keys.session_id,
 699                      });
 700  
 701                      if self.strict_kex {
 702                          self.client_sequence_number = u32::MAX;
 703                          self.server_sequence_number = u32::MAX;
 704                      }
 705  
 706                      return Ok(None);
 707                  }
 708              }
 709              wire::Message::ServiceRequest { service_name } if self.curr_keys.is_some() => {
 710                  match service_name {
 711                      "ssh-userauth" => {
 712                          self.userauth_enabled = true;
 713  
 714                          self.send(wire::Message::ServiceAccept {
 715                              service_name: "ssh-userauth",
 716                          })
 717                          .await?;
 718  
 719                          return Ok(None);
 720                      }
 721                      _ => {
 722                          reason = wire::DisconnectReason::ServiceNotAvailable;
 723                      }
 724                  }
 725              }
 726              wire::Message::UserAuthRequest {
 727                  user_name,
 728                  service_name: "ssh-connection",
 729                  auth_method,
 730              } if self.userauth_enabled => {
 731                  // Unfortunately we need to use the user name for signature verification, meaning we need
 732                  // to store it outside the packet buffer. Rather than burden the crate user, we just copy
 733                  // the user name into a temporary string, enforcing a reasonable 80-byte maximum length.
 734  
 735                  let mut user_name_buffer = [0u8; 80];
 736  
 737                  if user_name.len() > user_name_buffer.len() {
 738                      self.send(wire::Message::UserAuthFailure {
 739                          authentications_that_can_continue: wire::NameList::new_from_string(
 740                              "publickey",
 741                          )?,
 742                          partial_success: false,
 743                      })
 744                      .await?;
 745  
 746                      return Ok(None);
 747                  }
 748  
 749                  let user_name_slice = &mut user_name_buffer[..user_name.len()];
 750                  user_name_slice.copy_from_slice(user_name.as_bytes());
 751  
 752                  let Some(user_auth_method) = (match auth_method {
 753                      wire::AuthMethod::None => Some(AuthMethod::None),
 754                      wire::AuthMethod::PublicKey {
 755                          public_key_algorithm_name: "ssh-ed25519",
 756                          public_key: wire::PublicKey::Ed25519 { public_key },
 757                          signature: Some(wire::Signature::Ed25519 { .. }) | None,
 758                      } => {
 759                          if let Ok(public_key) = VerifyingKey::from_bytes(public_key) {
 760                              Some(AuthMethod::PublicKey(types::PublicKey::Ed25519 {
 761                                  public_key,
 762                              }))
 763                          } else {
 764                              None
 765                          }
 766                      }
 767                      _ => None,
 768                  }) else {
 769                      self.send(wire::Message::UserAuthFailure {
 770                          authentications_that_can_continue: wire::NameList::new_from_string(
 771                              "publickey",
 772                          )?,
 773                          partial_success: false,
 774                      })
 775                      .await?;
 776  
 777                      return Ok(None);
 778                  };
 779  
 780                  if let Some(user) = self.behavior.allow_user(user_name, &user_auth_method) {
 781                      match user_auth_method {
 782                          AuthMethod::None => {
 783                              self.send(wire::Message::UserAuthSuccess).await?;
 784                              self.current_user = Some(user);
 785                              self.authenticated = true;
 786                          }
 787                          AuthMethod::PublicKey(types::PublicKey::Ed25519 { public_key }) => {
 788                              if let wire::AuthMethod::PublicKey {
 789                                  public_key_algorithm_name: "ssh-ed25519",
 790                                  public_key: wire::PublicKey::Ed25519 { .. },
 791                                  signature: Some(wire::Signature::Ed25519 { signature }),
 792                              } = auth_method
 793                              {
 794                                  let signed_public_key = wire::PublicKey::Ed25519 {
 795                                      public_key: public_key.as_bytes(),
 796                                  };
 797  
 798                                  let signature: Signature = signature.into();
 799  
 800                                  let mut writer = ObjectWriter::new(self.buffer);
 801  
 802                                  // TODO: feels like we could do a little better here
 803  
 804                                  writer.write_string(
 805                                      &super::unwrap_unreachable(self.curr_keys.as_ref()).session_id,
 806                                  )?;
 807                                  writer.write_byte(wire::MSG_USERAUTH_REQUEST)?;
 808                                  writer.write_string(user_name_slice)?;
 809                                  writer.write_string_utf8("ssh-connection")?;
 810                                  writer.write_string_utf8("publickey")?;
 811                                  writer.write_boolean(true)?;
 812                                  writer.write_string_utf8("ssh-ed25519")?;
 813                                  signed_public_key.encode_with(&mut writer)?;
 814  
 815                                  if let Ok(()) = public_key.verify(writer.into_written(), &signature)
 816                                  {
 817                                      self.send(wire::Message::UserAuthSuccess).await?;
 818                                      self.current_user = Some(user);
 819                                      self.authenticated = true;
 820                                  } else {
 821                                      self.send(wire::Message::UserAuthFailure {
 822                                          authentications_that_can_continue:
 823                                              wire::NameList::new_from_string("publickey")?,
 824                                          partial_success: false,
 825                                      })
 826                                      .await?;
 827                                  }
 828                              } else {
 829                                  self.send(wire::Message::UserAuthPkOk {
 830                                      public_key_algorithm_name: "ssh-ed25519",
 831                                      public_key: wire::PublicKey::Ed25519 {
 832                                          public_key: public_key.as_bytes(),
 833                                      },
 834                                  })
 835                                  .await?;
 836                              }
 837                          }
 838                      }
 839                  } else {
 840                      self.send(wire::Message::UserAuthFailure {
 841                          authentications_that_can_continue: wire::NameList::new_from_string(
 842                              "publickey",
 843                          )?,
 844                          partial_success: false,
 845                      })
 846                      .await?;
 847                  }
 848  
 849                  return Ok(None);
 850              }
 851  
 852              wire::Message::GlobalRequest { want_reply, .. } if self.authenticated => {
 853                  if want_reply {
 854                      self.send(wire::Message::RequestFailure).await?;
 855                  }
 856  
 857                  return Ok(None);
 858              }
 859  
 860              wire::Message::ChannelOpen {
 861                  channel:
 862                      wire::ChannelType::Session {
 863                          sender_channel,
 864                          initial_window_size,
 865                          maximum_packet_size,
 866                      },
 867              } if self.authenticated => {
 868                  for channel in self.channels.into_iter().flatten() {
 869                      if channel.sender_channel == sender_channel {
 870                          self.send(wire::Message::Disconnect {
 871                              reason: wire::DisconnectReason::ProtocolError,
 872                          })
 873                          .await?;
 874                          return Err(Error::ServerDisconnect(
 875                              wire::DisconnectReason::ProtocolError,
 876                          ));
 877                      }
 878                  }
 879  
 880                  for channel in &mut self.channels {
 881                      if channel.is_none() {
 882                          *channel = Some(PendingChannel {
 883                              sender_channel,
 884                              initial_window_size,
 885                              maximum_packet_size,
 886                          });
 887  
 888                          if self.active_channel.is_none() {
 889                              self.dequeue_pending_channel().await?;
 890                          }
 891  
 892                          return Ok(None);
 893                      }
 894                  }
 895  
 896                  self.send(wire::Message::ChannelOpenFailure {
 897                      recipient_channel: sender_channel,
 898                      reason: wire::ChannelOpenFailureReason::ResourceShortage,
 899                  })
 900                  .await?;
 901  
 902                  return Ok(None);
 903              }
 904  
 905              wire::Message::ChannelOpen {
 906                  channel: wire::ChannelType::Other { sender_channel, .. },
 907              } if self.authenticated => {
 908                  self.send(wire::Message::ChannelOpenFailure {
 909                      recipient_channel: sender_channel,
 910                      reason: wire::ChannelOpenFailureReason::UnknownChannelType,
 911                  })
 912                  .await?;
 913  
 914                  return Ok(None);
 915              }
 916  
 917              wire::Message::ChannelWindowAdjust {
 918                  recipient_channel,
 919                  bytes_to_add,
 920              } if self.authenticated => {
 921                  if let Some(channel_state) = &mut self.active_channel {
 922                      if channel_state.rx_channel_id == recipient_channel {
 923                          channel_state.tx_half.increase_window(bytes_to_add)?;
 924                          return Ok(None);
 925                      }
 926                  }
 927              }
 928  
 929              wire::Message::ChannelData {
 930                  recipient_channel,
 931                  data: wire::Data::Borrowed { bytes },
 932              } if self.authenticated => {
 933                  if let Some(channel_state) = &mut self.active_channel {
 934                      if channel_state.rx_channel_id == recipient_channel {
 935                          channel_state
 936                              .rx_half
 937                              .decrease_window(wire::into_u32(bytes.len()))?;
 938                          return Ok(Some(payload_len));
 939                      }
 940                  }
 941              }
 942  
 943              wire::Message::ChannelEof { recipient_channel } if self.authenticated => {
 944                  if let Some(channel_state) = &mut self.active_channel {
 945                      if channel_state.rx_channel_id == recipient_channel {
 946                          channel_state.rx_half = HalfState::Eof;
 947                          return Ok(None);
 948                      }
 949                  }
 950              }
 951  
 952              wire::Message::ChannelClose { recipient_channel } if self.authenticated => {
 953                  if self.active_channel.is_none() {
 954                      return Ok(None); // ignored
 955                  }
 956  
 957                  if let Some(channel_state) = &mut self.active_channel {
 958                      if channel_state.rx_channel_id == recipient_channel {
 959                          channel_state.rx_half = HalfState::Close;
 960  
 961                          if let HalfState::Close = channel_state.tx_half {
 962                              self.dequeue_pending_channel().await?;
 963                          } else if self.request.is_none() {
 964                              let sender_channel = channel_state.tx_channel_id;
 965  
 966                              self.send(wire::Message::ChannelEof {
 967                                  recipient_channel: sender_channel,
 968                              })
 969                              .await?;
 970  
 971                              self.send(wire::Message::ChannelClose {
 972                                  recipient_channel: sender_channel,
 973                              })
 974                              .await?;
 975  
 976                              self.dequeue_pending_channel().await?;
 977                          } else {
 978                              channel_state.tx_half = HalfState::Close;
 979                          }
 980  
 981                          return Ok(None);
 982                      }
 983                  }
 984              }
 985  
 986              wire::Message::ChannelRequest {
 987                  recipient_channel,
 988                  request:
 989                      wire::Request::Exec {
 990                          want_reply,
 991                          command,
 992                      },
 993              } if self.authenticated && self.request.is_none() => {
 994                  if let Some(channel_state) = &mut self.active_channel {
 995                      if channel_state.rx_channel_id == recipient_channel {
 996                          self.request = Some(Request::Exec(self.behavior.parse_command(command)));
 997  
 998                          if want_reply {
 999                              let sender_channel = channel_state.tx_channel_id;
1000  
1001                              self.send(wire::Message::ChannelSuccess {
1002                                  recipient_channel: sender_channel,
1003                              })
1004                              .await?;
1005                          }
1006  
1007                          return Ok(None);
1008                      }
1009                  }
1010              }
1011  
1012              wire::Message::ChannelRequest {
1013                  recipient_channel,
1014                  request: wire::Request::Env { want_reply, .. },
1015              } if self.authenticated => {
1016                  if let Some(channel_state) = &mut self.active_channel {
1017                      if channel_state.rx_channel_id == recipient_channel {
1018                          if want_reply {
1019                              let sender_channel = channel_state.tx_channel_id;
1020  
1021                              self.send(wire::Message::ChannelSuccess {
1022                                  recipient_channel: sender_channel,
1023                              })
1024                              .await?;
1025                          }
1026  
1027                          return Ok(None);
1028                      }
1029                  }
1030              }
1031  
1032              // PTY request — accept it and store terminal dimensions
1033              wire::Message::ChannelRequest {
1034                  recipient_channel,
1035                  request: wire::Request::PtyReq { want_reply, width_chars, height_rows, .. },
1036              } if self.authenticated => {
1037                  if let Some(channel_state) = &mut self.active_channel {
1038                      if channel_state.rx_channel_id == recipient_channel {
1039                          self.behavior.on_pty_request(width_chars, height_rows);
1040                          if want_reply {
1041                              let sender_channel = channel_state.tx_channel_id;
1042                              self.send(wire::Message::ChannelSuccess {
1043                                  recipient_channel: sender_channel,
1044                              }).await?;
1045                          }
1046                          return Ok(None);
1047                      }
1048                  }
1049              }
1050  
1051              wire::Message::ChannelRequest {
1052                  recipient_channel,
1053                  request: wire::Request::Shell { want_reply },
1054              } if self.authenticated && self.request.is_none() && self.behavior.allow_shell() => {
1055                  if let Some(channel_state) = &mut self.active_channel {
1056                      if channel_state.rx_channel_id == recipient_channel {
1057                          self.request = Some(Request::Shell);
1058  
1059                          if want_reply {
1060                              let sender_channel = channel_state.tx_channel_id;
1061  
1062                              self.send(wire::Message::ChannelSuccess {
1063                                  recipient_channel: sender_channel,
1064                              })
1065                              .await?;
1066                          }
1067  
1068                          return Ok(None);
1069                      }
1070                  }
1071              }
1072  
1073              wire::Message::ChannelRequest {
1074                  recipient_channel,
1075                  request,
1076              } if self.authenticated => {
1077                  if let Some(channel_state) = &mut self.active_channel {
1078                      if channel_state.rx_channel_id == recipient_channel {
1079                          if request.want_reply() {
1080                              let sender_channel = channel_state.tx_channel_id;
1081  
1082                              self.send(wire::Message::ChannelFailure {
1083                                  recipient_channel: sender_channel,
1084                              })
1085                              .await?;
1086                          }
1087  
1088                          return Ok(None);
1089                      }
1090                  }
1091              }
1092  
1093              wire::Message::Debug { .. }
1094              | wire::Message::Ignore { .. }
1095              | wire::Message::Unimplemented { .. } => {
1096                  if self.strict_kex && self.curr_keys.is_none() {
1097                      return Err(Error::ServerDisconnect(
1098                          wire::DisconnectReason::ProtocolError,
1099                      ));
1100                  }
1101  
1102                  return Ok(None);
1103              }
1104              wire::Message::Unknown { .. } => {
1105                  if self.strict_kex && self.curr_keys.is_none() {
1106                      return Err(Error::ServerDisconnect(
1107                          wire::DisconnectReason::ProtocolError,
1108                      ));
1109                  }
1110  
1111                  self.send(wire::Message::Unimplemented {
1112                      sequence_number: self.client_sequence_number,
1113                  })
1114                  .await?;
1115  
1116                  return Ok(None);
1117              }
1118              wire::Message::Disconnect { reason, .. } => {
1119                  return Err(Error::ClientDisconnect(reason));
1120              }
1121  
1122              _ => {}
1123          }
1124  
1125          self.send(wire::Message::Disconnect { reason }).await?;
1126          Err(Error::ServerDisconnect(reason))
1127      }
1128  
1129      async fn send_preencoded_payload(
1130          &mut self,
1131          payload_len: usize,
1132      ) -> Result<(), TransportError<T>> {
1133          self.server_sequence_number = self.server_sequence_number.wrapping_add(1);
1134  
1135          // NOTE: padding rules differ for AEAD cipher modes
1136  
1137          let mut padding_len = if self.curr_keys.is_some() {
1138              (7usize.wrapping_sub(payload_len)) % 8
1139          } else {
1140              (3usize.wrapping_sub(payload_len)) % 8
1141          };
1142  
1143          if padding_len < 4 {
1144              padding_len += 8;
1145          }
1146  
1147          let packet_len = wire::into_u32(1 + payload_len + padding_len);
1148          self.buffer[..4].copy_from_slice(&packet_len.to_be_bytes());
1149  
1150          if let Some(ctx) = &mut self.curr_keys {
1151              let mut cipher = ChaCha20Legacy::new(
1152                  (&ctx.server_head_key).into(),
1153                  (&(self.server_sequence_number as u64).to_be_bytes()).into(),
1154              );
1155  
1156              cipher.apply_keystream(&mut self.buffer[..4]);
1157          }
1158  
1159          self.buffer[4] = padding_len as u8;
1160  
1161          self.behavior
1162              .random()
1163              .fill_bytes(&mut self.buffer[5 + payload_len..][..padding_len]);
1164  
1165          if let Some(ctx) = &mut self.curr_keys {
1166              let (ciphertext, tag_buf) = self.buffer.split_at_mut(5 + payload_len + padding_len);
1167  
1168              let sequence_number = self.server_sequence_number as u64;
1169  
1170              let mut cipher = ChaCha20Legacy::new(
1171                  (&ctx.server_main_key).into(),
1172                  (&sequence_number.to_be_bytes()).into(),
1173              );
1174  
1175              let mut mac_key = [0u8; 32];
1176              cipher.apply_keystream(&mut mac_key);
1177              let mac = Poly1305::new((&mac_key).into());
1178  
1179              cipher.seek(64);
1180              cipher.apply_keystream(&mut ciphertext[4..]);
1181  
1182              let tag = mac.compute_unpadded(ciphertext);
1183              tag_buf[..16].copy_from_slice(&tag);
1184  
1185              self.behavior
1186                  .stream()
1187                  .write_all(&self.buffer[..5 + payload_len + padding_len + 16])
1188                  .await?;
1189          } else {
1190              self.behavior
1191                  .stream()
1192                  .write_all(&self.buffer[..5 + payload_len + padding_len])
1193                  .await?;
1194          }
1195  
1196          Ok(())
1197      }
1198  
1199      async fn send(&mut self, message: wire::Message<'_>) -> Result<(), TransportError<T>> {
1200          let payload_range = self.payload_range_full();
1201  
1202          let payload_len = message.encode(&mut self.buffer[payload_range])?.len();
1203          self.send_preencoded_payload(payload_len).await
1204      }
1205  
1206      async fn recv(&mut self) -> Result<usize, TransportError<T>> {
1207          self.client_sequence_number = self.client_sequence_number.wrapping_add(1);
1208  
1209          self.behavior
1210              .stream()
1211              .read_exact(&mut self.buffer[..4])
1212              .await?;
1213  
1214          let mut decrypted_packet_len = [0u8; 4];
1215          decrypted_packet_len.copy_from_slice(&self.buffer[..4]);
1216  
1217          if let Some(ctx) = &mut self.curr_keys {
1218              let mut cipher = ChaCha20Legacy::new(
1219                  (&ctx.client_head_key).into(),
1220                  (&(self.client_sequence_number as u64).to_be_bytes()).into(),
1221              );
1222  
1223              cipher.apply_keystream(&mut decrypted_packet_len);
1224          }
1225  
1226          let packet_len = wire::from_u32(u32::from_be_bytes(decrypted_packet_len));
1227  
1228          // NOTE: padding rules differ for AEAD cipher modes
1229  
1230          let padding_remainder = if self.curr_keys.is_some() { 0 } else { 4 };
1231  
1232          if packet_len < padding_remainder + 8 {
1233              Err(ProtocolError::MalformedPacket)?;
1234          }
1235  
1236          if packet_len % 8 != padding_remainder {
1237              Err(ProtocolError::MalformedPacket)?;
1238          }
1239  
1240          let mac_len = if self.curr_keys.is_some() { 16 } else { 0 };
1241  
1242          if 4 + packet_len + mac_len > self.buffer.len() {
1243              Err(ProtocolError::BufferExhausted)?;
1244          }
1245  
1246          self.behavior
1247              .stream()
1248              .read_exact(&mut self.buffer[4..4 + packet_len + mac_len])
1249              .await?;
1250  
1251          if let Some(ctx) = &mut self.curr_keys {
1252              let (ciphertext, tag_buf) = self.buffer.split_at_mut(4 + packet_len);
1253  
1254              let mut cipher = ChaCha20Legacy::new(
1255                  (&ctx.client_main_key).into(),
1256                  (&(self.client_sequence_number as u64).to_be_bytes()).into(),
1257              );
1258  
1259              let mut mac_key = [0u8; 32];
1260              cipher.apply_keystream(&mut mac_key);
1261              let mac = Poly1305::new((&mac_key).into());
1262  
1263              let tag = mac.compute_unpadded(ciphertext);
1264  
1265              // DO NOT report a MAC verification error to the client for security
1266              // reasons, just disconnect immediately and let it reinitiate later.
1267  
1268              if !constant_time_eq(&tag, &tag_buf[..16]) {
1269                  Err(ProtocolError::MalformedPacket)?;
1270              }
1271  
1272              cipher.seek(64);
1273              cipher.apply_keystream(&mut ciphertext[4..]);
1274          }
1275  
1276          let padding_len: usize = self.buffer[4] as usize;
1277  
1278          if padding_len < 4 {
1279              Err(ProtocolError::MalformedPacket)?;
1280          }
1281  
1282          if packet_len < 1 + padding_len {
1283              Err(ProtocolError::MalformedPacket)?;
1284          }
1285  
1286          Ok(packet_len - 1 - padding_len)
1287      }
1288  
1289      /// Disconnects from the client with a given reason.
1290      pub async fn disconnect(
1291          mut self,
1292          reason: wire::DisconnectReason,
1293      ) -> Result<(), TransportError<T>> {
1294          if self.client_ssh_id_length == 0 {
1295              return Ok(());
1296          }
1297  
1298          self.send(wire::Message::Disconnect { reason }).await
1299      }
1300  
1301      async fn dequeue_pending_channel(&mut self) -> Result<(), TransportError<T>> {
1302          self.active_channel = None;
1303  
1304          for channel in &mut self.channels {
1305              if let Some(state) = channel.take() {
1306                  self.active_channel = Some(ChannelState {
1307                      rx_half: HalfState::Window(0),
1308                      tx_half: HalfState::Window(state.initial_window_size),
1309                      tx_max_packet: state.maximum_packet_size,
1310                      rx_channel_id: 0,
1311                      tx_channel_id: state.sender_channel,
1312                      rx_committed: false,
1313                  });
1314  
1315                  let maximum_packet_size = self.maximum_channel_data_packet_size();
1316  
1317                  self.send(wire::Message::ChannelOpenConfirmation {
1318                      recipient_channel: state.sender_channel,
1319                      sender_channel: 0,
1320                      initial_window_size: 0,
1321                      maximum_packet_size,
1322                      payload: &[],
1323                  })
1324                  .await?;
1325  
1326                  break;
1327              }
1328          }
1329  
1330          Ok(())
1331      }
1332  
1333      fn payload_range(&self, payload_len: usize) -> Range<usize> {
1334          5..5 + payload_len
1335      }
1336  
1337      fn payload_range_full(&self) -> Range<usize> {
1338          5..self.buffer.len() - 255 - 16
1339      }
1340  }