/ src / ssh_server.rs
ssh_server.rs
  1  use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
  2  
  3  use anyhow::{Context as _, anyhow};
  4  use radicle::{Profile, identity::Did, node::Alias, storage::ReadRepository};
  5  use russh::{
  6      Channel, ChannelId, CryptoVec,
  7      keys::{
  8          HashAlg, PrivateKey, PublicKey,
  9          ssh_key::private::{Ed25519PrivateKey, KeypairData},
 10      },
 11      server::{self, Msg, Server as _, Session},
 12  };
 13  use tokio::net::TcpListener;
 14  use tracing::{debug, error};
 15  
 16  use crate::{
 17      auth::AuthorizedKeys,
 18      git::{GitSshCommand, handle_git_upload_pack, parse_ssh_command},
 19      radicle::{check_repo_access, get_alias, resolve_repo, ssh_public_key_to_radicle},
 20  };
 21  
 22  #[derive(Debug)]
 23  pub struct Context {
 24      pub profile: Profile,
 25      pub authorized_keys: AuthorizedKeys,
 26  }
 27  
 28  pub async fn serve(
 29      ctx: Context,
 30      host_key: PrivateKey,
 31      listen_address: SocketAddr,
 32  ) -> anyhow::Result<()> {
 33      let config = Arc::new(server::Config {
 34          keys: vec![host_key],
 35          auth_rejection_time_initial: Some(Duration::ZERO),
 36          ..Default::default()
 37      });
 38  
 39      let mut server = Server(ctx.into());
 40  
 41      let socket = TcpListener::bind(listen_address)
 42          .await
 43          .with_context(|| anyhow!("Failed to bind to {listen_address}"))?;
 44      server
 45          .run_on_socket(config, &socket)
 46          .await
 47          .context("Failed to run server")?;
 48  
 49      Ok(())
 50  }
 51  
 52  struct Server(Arc<Context>);
 53  
 54  impl russh::server::Server for Server {
 55      type Handler = Handler;
 56  
 57      fn new_client(&mut self, peer_addr: Option<SocketAddr>) -> Self::Handler {
 58          debug!(?peer_addr, "client connected");
 59          Handler::new(Arc::clone(&self.0), peer_addr)
 60      }
 61  
 62      fn handle_session_error(&mut self, err: anyhow::Error) {
 63          error!("ssh session error: {err:#}");
 64      }
 65  }
 66  
 67  struct Handler {
 68      ctx: Arc<Context>,
 69      peer_addr: Option<SocketAddr>,
 70      channels: HashMap<ChannelId, Channel<Msg>>,
 71      public_key: Option<PublicKey>,
 72      radicle_key: Option<radicle::crypto::PublicKey>,
 73      alias: Option<Alias>,
 74  }
 75  
 76  impl Handler {
 77      fn new(ctx: Arc<Context>, peer_addr: Option<SocketAddr>) -> Self {
 78          Self {
 79              ctx,
 80              peer_addr,
 81              channels: Default::default(),
 82              public_key: None,
 83              radicle_key: None,
 84              alias: None,
 85          }
 86      }
 87  }
 88  
 89  impl server::Handler for Handler {
 90      type Error = anyhow::Error;
 91  
 92      async fn channel_open_session(
 93          &mut self,
 94          channel: Channel<Msg>,
 95          _session: &mut Session,
 96      ) -> Result<bool, Self::Error> {
 97          self.channels.insert(channel.id(), channel);
 98          Ok(true)
 99      }
100  
101      async fn auth_publickey(
102          &mut self,
103          user: &str,
104          key: &PublicKey,
105      ) -> Result<server::Auth, Self::Error> {
106          self.public_key = Some(key.clone());
107          self.radicle_key = ssh_public_key_to_radicle(key);
108          self.alias = self
109              .radicle_key
110              .as_ref()
111              .map(|key| get_alias(&self.ctx.profile, key))
112              .transpose()?
113              .flatten();
114  
115          debug!(
116              peer_addr = ?self.peer_addr,
117              %user,
118              key = ?self.public_key.as_ref().and_then(|key| key.to_openssh().ok()),
119              nid = ?self.radicle_key.as_ref().map(|key| key.to_string()),
120              alias = ?self.alias.as_deref(),
121              "client authenticated via public key",
122          );
123  
124          Ok(server::Auth::Accept)
125      }
126  
127      async fn shell_request(
128          &mut self,
129          channel_id: ChannelId,
130          session: &mut Session,
131      ) -> Result<(), Self::Error> {
132          session.channel_success(channel_id)?;
133  
134          let mut msg = match &self.alias {
135              Some(alias) => format!("\nWelcome, {alias}!\n"),
136              None => "\nWelcome!\n".into(),
137          };
138  
139          if let Some(key) = &self.public_key {
140              msg.push_str("\nYou successfully authenticated using the following public key:\n    ");
141              msg.push_str(&key.to_openssh().context("Failed to encode public key")?);
142              msg.push_str("\n    ");
143              msg.push_str(&key.fingerprint(HashAlg::Sha256).to_string());
144              msg.push('\n');
145          }
146  
147          if let Some(key) = &self.radicle_key {
148              msg.push_str("\nThis corresponds to the following Radicle IDs:\n    DID: ");
149              msg.push_str(&Did::from(key).to_string());
150              if let Some(alias) = &self.alias {
151                  msg.push_str(" (");
152                  msg.push_str(alias);
153                  msg.push(')');
154              }
155              msg.push_str("\n    NID: ");
156              msg.push_str(&key.to_string());
157              msg.push('\n');
158          }
159  
160          msg.push('\n');
161  
162          session.data(channel_id, CryptoVec::from_slice(msg.as_bytes()))?;
163          session.exit_status_request(channel_id, 0)?;
164          session.close(channel_id)?;
165  
166          Ok(())
167      }
168  
169      async fn exec_request(
170          &mut self,
171          channel_id: ChannelId,
172          data: &[u8],
173          session: &mut Session,
174      ) -> Result<(), Self::Error> {
175          session.channel_success(channel_id)?;
176  
177          let channel = self
178              .channels
179              .remove(&channel_id)
180              .ok_or_else(|| anyhow!("Failed to retrieve channel {channel_id}"))?;
181  
182          let data = str::from_utf8(data).context("Failed to decode exec data")?;
183  
184          let path = match parse_ssh_command(data) {
185              Some(GitSshCommand::UploadPack { path }) => path,
186              Some(GitSshCommand::ReceivePack { .. }) => {
187                  session.extended_data(
188                      channel_id,
189                      1,
190                      CryptoVec::from_slice(b"\nERROR: This server does not accept pushes\n\n"),
191                  )?;
192                  session.exit_status_request(channel_id, 1)?;
193                  session.close(channel_id)?;
194                  return Ok(());
195              }
196              None => {
197                  session.extended_data(
198                      channel_id,
199                      1,
200                      CryptoVec::from_slice(b"\nERROR: Invalid command\n\n"),
201                  )?;
202                  session.exit_status_request(channel_id, 1)?;
203                  session.close(channel_id)?;
204                  return Ok(());
205              }
206          };
207  
208          let repo = match resolve_repo(&self.ctx.profile, path)
209              .with_context(|| anyhow!("Failed to resolve repo {path:?}"))?
210          {
211              Some(repo)
212                  if check_repo_access(
213                      &repo,
214                      self.public_key.as_ref(),
215                      &self.ctx.authorized_keys,
216                  )
217                  .with_context(|| anyhow!("Failed to check repo access for {}", repo.id))? =>
218              {
219                  repo
220              }
221              _ => {
222                  let msg = format!("\nERROR: The repository {path:?} does not exist\n\n");
223                  session.extended_data(channel_id, 1, CryptoVec::from_slice(msg.as_bytes()))?;
224                  session.exit_status_request(channel_id, 1)?;
225                  session.close(channel_id)?;
226                  return Ok(());
227              }
228          };
229          debug!(
230              path = %repo.path().display(),
231              rid = %repo.id,
232              name = ?repo.project().ok().as_ref().map(|p| p.name()),
233              "serving repo",
234          );
235  
236          let handle = session.handle();
237  
238          tokio::spawn(async move {
239              let (mut channel_reader, channel_writer) = channel.split();
240              let status = match handle_git_upload_pack(
241                  repo.path(),
242                  channel_reader.make_reader(),
243                  channel_writer.make_writer(),
244                  channel_writer.make_writer_ext(Some(1)),
245              )
246              .await
247              {
248                  Ok(status) => status,
249                  Err(err) => {
250                      error!("git-upload-pack error: {err:#}");
251                      return;
252                  }
253              };
254  
255              let _: Result<_, ()> = handle
256                  .exit_status_request(
257                      channel_id,
258                      status.code().and_then(|c| c.try_into().ok()).unwrap_or(1),
259                  )
260                  .await;
261  
262              let _: Result<_, ()> = handle.close(channel_id).await;
263          });
264  
265          Ok(())
266      }
267  
268      async fn pty_request(
269          &mut self,
270          channel_id: ChannelId,
271          _term: &str,
272          _col_width: u32,
273          _row_height: u32,
274          _pix_width: u32,
275          _pix_height: u32,
276          _modes: &[(russh::Pty, u32)],
277          session: &mut Session,
278      ) -> Result<(), Self::Error> {
279          session.channel_failure(channel_id)?;
280          Ok(())
281      }
282  }
283  
284  pub fn generate_random_ssh_key() -> anyhow::Result<PrivateKey> {
285      let sk = Ed25519PrivateKey::from_bytes(&rand::random());
286      PrivateKey::new(KeypairData::Ed25519(sk.into()), "")
287          .context("Failed to generate random ssh-ed25519 host key")
288  }