/ firmware / src / programs / shell / server.rs
server.rs
  1  use core::fmt::Write;
  2  
  3  use alloc::string::String as AllocString;
  4  use defmt::info;
  5  use embassy_executor::Spawner;
  6  use embassy_net::{Stack, tcp::TcpSocket};
  7  use embassy_time::{Duration, Timer};
  8  use esp_hal::rng::Rng;
  9  
 10  use crate::hardware::crypto::CryptoRng;
 11  use crate::services::{identity, ssh::{AuthMethod, Behavior, Request, SecretKey, Transport}};
 12  
 13  use super::{
 14      build_motd, build_prompt, dispatch, load_history, save_history, set_terminal_width,
 15      CTRL_L, CTRL_N, CTRL_P, CTRL_U, CTRL_W,
 16  };
 17  
 18  const SSH_PORT: u16 = crate::config::app::ssh::PORT;
 19  const RX_BUF_SIZE: usize = crate::config::app::ssh::RX_BUF_SIZE;
 20  const TX_BUF_SIZE: usize = crate::config::app::ssh::TX_BUF_SIZE;
 21  
 22  #[derive(Clone, Copy)]
 23  pub struct TermSize {
 24      pub width: u32,
 25      pub height: u32,
 26  }
 27  
 28  struct SshBehavior<'a> {
 29      socket: TcpSocket<'a>,
 30      rng: CryptoRng,
 31      host_key: SecretKey,
 32      term_size: &'a core::cell::Cell<TermSize>,
 33  }
 34  
 35  impl<'a> Behavior for SshBehavior<'a> {
 36      type Stream = TcpSocket<'a>;
 37  
 38      fn stream(&mut self) -> &mut Self::Stream {
 39          &mut self.socket
 40      }
 41  
 42      type Random = CryptoRng;
 43  
 44      fn random(&mut self) -> &mut Self::Random {
 45          &mut self.rng
 46      }
 47  
 48      fn host_secret_key(&self) -> &SecretKey {
 49          &self.host_key
 50      }
 51  
 52      type User = ();
 53  
 54      fn allow_user(&mut self, username: &str, auth_method: &AuthMethod) -> Option<()> {
 55          if username == identity::ssh_user() && matches!(auth_method, AuthMethod::None) {
 56              Some(())
 57          } else {
 58              None
 59          }
 60      }
 61  
 62      fn allow_shell(&self) -> bool {
 63          true
 64      }
 65  
 66      fn on_pty_request(&mut self, width: u32, height: u32) {
 67          self.term_size.set(TermSize { width, height });
 68      }
 69  
 70      type Command = ();
 71  
 72      fn parse_command(&mut self, _: &str) {}
 73  }
 74  
 75  async fn redraw_line<T: Behavior>(
 76      channel: &mut crate::services::ssh::Channel<'_, '_, T>,
 77      terminal: &crate::services::ssh::terminal::Terminal<256>,
 78  ) {
 79      let _ = channel.write_all_stdout(b"\r\x1b[K").await;
 80      let mut prefix = AllocString::new();
 81      let _ = write!(
 82          prefix,
 83          "{}{}{} ",
 84          super::prompt::theme::FRAME_COLOR,
 85          super::prompt::theme::FRAME_BOT_LEFT,
 86          super::prompt::theme::RESET
 87      );
 88      let _ = channel.write_all_stdout(prefix.as_bytes()).await;
 89  
 90      if let Ok(buf) = terminal.buffer_str() {
 91          let _ = channel.write_all_stdout(buf.as_bytes()).await;
 92          let cursor = terminal.cursor_position();
 93          let buf_len = buf.len();
 94          if cursor < buf_len {
 95              let mut back = AllocString::new();
 96              let _ = write!(back, "\x1b[{}D", buf_len - cursor);
 97              let _ = channel.write_all_stdout(back.as_bytes()).await;
 98          }
 99      }
100  }
101  
102  #[embassy_executor::task]
103  pub async fn task(stack: Stack<'static>) {
104      info!("Microshell (SSH) listening on port {}", SSH_PORT);
105  
106      loop {
107          static mut RX_BUFFER: [u8; RX_BUF_SIZE] = [0; RX_BUF_SIZE];
108          static mut TX_BUFFER: [u8; TX_BUF_SIZE] = [0; TX_BUF_SIZE];
109  
110          let socket = unsafe {
111              TcpSocket::new(
112                  stack,
113                  &mut *core::ptr::addr_of_mut!(RX_BUFFER),
114                  &mut *core::ptr::addr_of_mut!(TX_BUFFER),
115              )
116          };
117  
118          let term_size = core::cell::Cell::new(TermSize { width: 80, height: 24 });
119  
120          let mut behavior = SshBehavior {
121              socket,
122              rng: CryptoRng(Rng::new()),
123              host_key: SecretKey::Ed25519 {
124                  secret_key: identity::signing_key(),
125              },
126              term_size: &term_size,
127          };
128  
129          if let Err(error) = behavior.socket.accept(SSH_PORT).await {
130              info!("SSH accept failed: {:?}", error);
131              Timer::after(Duration::from_millis(250)).await;
132              continue;
133          }
134  
135          let remote_str = behavior
136              .socket
137              .remote_endpoint()
138              .map(|endpoint| {
139                  let mut s = AllocString::new();
140                  let _ = write!(s, "{}", endpoint);
141                  s
142              })
143              .unwrap_or_else(|| AllocString::from("unknown"));
144          info!("SSH client connected from {}", remote_str.as_str());
145          behavior.socket.set_timeout(Some(Duration::from_secs(300)));
146  
147          let mut packet_buffer = [0u8; 4096];
148          let mut transport = Transport::new(&mut packet_buffer, behavior);
149  
150          match transport.accept().await {
151              Ok(mut channel) => {
152                  info!("SSH channel opened");
153  
154                  match channel.request() {
155                      Request::Shell => {
156                          let term_size = term_size.get();
157                          info!("Terminal size: {}x{}", term_size.width, term_size.height);
158                          set_terminal_width(term_size.width);
159  
160                          let _ = channel.write_all_stdout(b"\x1b[2J\x1b[H").await;
161                          let motd = build_motd(remote_str.as_str());
162                          let _ = channel.write_all_stdout(motd.as_bytes()).await;
163  
164                          let mut cwd = identity::home_dir();
165  
166                          if let Ok(mshrc) =
167                              crate::filesystems::sd::read_file_at::<1024>(cwd.as_str(), ".MSHRC")
168                          {
169                              if let Ok(text) = core::str::from_utf8(mshrc.as_slice()) {
170                                  for line in text.lines() {
171                                      let line = line.trim();
172                                      if line.is_empty() || line.starts_with('#') {
173                                          continue;
174                                      }
175                                      let (output, _) = dispatch(line, &mut cwd);
176                                      if !output.is_empty() {
177                                          let _ = channel.write_all_stdout(output.as_bytes()).await;
178                                      }
179                                  }
180                              }
181                          }
182  
183                          let prompt_str = build_prompt(&cwd);
184                          let _ = channel.write_all_stdout(prompt_str.as_bytes()).await;
185  
186                          use crate::services::ssh::history::{History, HistoryConfig};
187                          use crate::services::ssh::terminal::{Terminal, TerminalConfig, TerminalEvent};
188  
189                          let mut terminal = Terminal::<256>::new(TerminalConfig {
190                              buffer_size: 256,
191                              prompt: "",
192                              echo: true,
193                              ansi_enabled: true,
194                          });
195  
196                          let mut history = History::<256>::new(HistoryConfig {
197                              max_entries: 16,
198                              deduplicate: true,
199                          });
200                          load_history(&mut history);
201  
202                          loop {
203                              let mut byte_buf = [0u8; 1];
204                              match channel.read_exact_stdin(&mut byte_buf).await {
205                                  Ok(0) => break,
206                                  Err(_) => break,
207                                  Ok(_) => {}
208                              }
209  
210                              let byte = byte_buf[0];
211  
212                              if byte == CTRL_L {
213                                  let _ = channel.write_all_stdout(b"\x1b[2J\x1b[H").await;
214                                  terminal.clear_buffer();
215                                  let prompt = build_prompt(&cwd);
216                                  let _ = channel.write_all_stdout(prompt.as_bytes()).await;
217                                  continue;
218                              }
219  
220                              let byte = match byte {
221                                  CTRL_P => {
222                                      if let Some(entry) = history.previous() {
223                                          let _ = terminal.set_buffer(entry);
224                                          redraw_line(&mut channel, &terminal).await;
225                                      }
226                                      continue;
227                                  }
228                                  CTRL_N => {
229                                      if let Some(entry) = history.next() {
230                                          let _ = terminal.set_buffer(entry);
231                                      } else {
232                                          terminal.clear_buffer();
233                                      }
234                                      redraw_line(&mut channel, &terminal).await;
235                                      continue;
236                                  }
237                                  CTRL_W => {
238                                      let buffer_copy = AllocString::from(terminal.buffer_str().unwrap_or(""));
239                                      let trimmed = buffer_copy.trim_end();
240                                      if let Some(last_space) = trimmed.rfind(' ') {
241                                          let _ = terminal.set_buffer(&buffer_copy[..last_space + 1]);
242                                      } else {
243                                          terminal.clear_buffer();
244                                      }
245                                      redraw_line(&mut channel, &terminal).await;
246                                      continue;
247                                  }
248                                  CTRL_U => {
249                                      terminal.clear_buffer();
250                                      redraw_line(&mut channel, &terminal).await;
251                                      continue;
252                                  }
253                                  other => other,
254                              };
255  
256                              let key = match terminal.process_byte(byte) {
257                                  Some(key) => key,
258                                  None => continue,
259                              };
260  
261                              match terminal.handle_key(key) {
262                                  TerminalEvent::CommandReady => {
263                                      let _ = channel.write_all_stdout(b"\r\n").await;
264                                      if let Ok(cmd) = terminal.take_command() {
265                                          let cmd_str = cmd.as_str().trim();
266                                          if !cmd_str.is_empty() {
267                                              let _ = history.add(cmd_str);
268                                          }
269                                          let (output, should_exit) = dispatch(cmd_str, &mut cwd);
270                                          if !output.is_empty() {
271                                              let _ = channel.write_all_stdout(output.as_bytes()).await;
272                                          }
273                                          if should_exit {
274                                              break;
275                                          }
276                                      }
277                                      history.reset_position();
278                                      let prompt = build_prompt(&cwd);
279                                      let _ = channel.write_all_stdout(prompt.as_bytes()).await;
280                                  }
281                                  TerminalEvent::EmptyCommand => {
282                                      let _ = channel.write_all_stdout(b"\r\n").await;
283                                      let prompt = build_prompt(&cwd);
284                                      let _ = channel.write_all_stdout(prompt.as_bytes()).await;
285                                  }
286                                  TerminalEvent::BufferChanged | TerminalEvent::CursorMoved => {
287                                      redraw_line(&mut channel, &terminal).await;
288                                  }
289                                  TerminalEvent::Interrupt => {
290                                      terminal.clear_buffer();
291                                      let _ = channel.write_all_stdout(b"^C\r\n").await;
292                                      history.reset_position();
293                                      let prompt = build_prompt(&cwd);
294                                      let _ = channel.write_all_stdout(prompt.as_bytes()).await;
295                                  }
296                                  TerminalEvent::EndOfFile => break,
297                                  TerminalEvent::HistoryPrevious => {
298                                      if let Some(entry) = history.previous() {
299                                          let _ = terminal.set_buffer(entry);
300                                          redraw_line(&mut channel, &terminal).await;
301                                      }
302                                  }
303                                  TerminalEvent::HistoryNext => {
304                                      if let Some(entry) = history.next() {
305                                          let _ = terminal.set_buffer(entry);
306                                      } else {
307                                          terminal.clear_buffer();
308                                      }
309                                      redraw_line(&mut channel, &terminal).await;
310                                  }
311                                  _ => {}
312                              }
313                          }
314  
315                          save_history(&history);
316                          let _ = channel.exit(0).await;
317                      }
318                      _ => {
319                          let _ = channel
320                              .write_all_stderr(b"Only shell mode is supported.\n")
321                              .await;
322                          let _ = channel.exit(1).await;
323                      }
324                  }
325              }
326              Err(_) => info!("SSH handshake failed"),
327          }
328  
329          info!("SSH client disconnected");
330      }
331  }
332  
333  pub fn spawn(spawner: &Spawner, stack: Stack<'static>) {
334      spawner.spawn(task(stack).unwrap());
335  }