/ otp-chat-server / src / main.rs
main.rs
  1  use anyhow::Result;
  2  use std::{collections::HashMap, net::SocketAddr, sync::Arc};
  3  use std::time::{Duration, Instant};
  4  use tokio::{
  5      io::{AsyncReadExt, AsyncWriteExt},
  6      net::{TcpListener, TcpStream},
  7      sync::Mutex,
  8  };
  9  
 10  type Clients = Arc<Mutex<HashMap<SocketAddr, Arc<Mutex<tokio::net::tcp::OwnedWriteHalf>>>>>;
 11  
 12  // Rate limiting tracker
 13  struct ConnectionTracker {
 14      connections: HashMap<std::net::IpAddr, Vec<Instant>>,
 15  }
 16  
 17  impl ConnectionTracker {
 18      fn new() -> Self {
 19          Self {
 20              connections: HashMap::new(),
 21          }
 22      }
 23      
 24      fn allow_connection(&mut self, ip: std::net::IpAddr) -> bool {
 25          let now = Instant::now();
 26          let hour_ago = now - Duration::from_secs(3600);
 27          
 28          let ip_connections = self.connections.entry(ip).or_insert_with(Vec::new);
 29          
 30          // Remove old connections
 31          ip_connections.retain(|&time| time > hour_ago);
 32          
 33          // Allow max 10 connections per hour per IP
 34          if ip_connections.len() >= 10 {
 35              false
 36          } else {
 37              ip_connections.push(now);
 38              true
 39          }
 40      }
 41  }
 42  
 43  // Username validation function
 44  fn validate_username(raw_input: &[u8]) -> Result<String, &'static str> {
 45      // Convert to string, reject if not valid UTF-8
 46      let username = match std::str::from_utf8(raw_input) {
 47          Ok(s) => s.trim(),
 48          Err(_) => return Err("Username contains invalid UTF-8"),
 49      };
 50      
 51      // Check length
 52      if username.len() < 1 || username.len() > 32 {
 53          return Err("Username must be 1-32 characters");
 54      }
 55      
 56      // Only allow printable ASCII + space
 57      if !username.chars().all(|c| c.is_ascii_graphic() || c == ' ') {
 58          return Err("Username contains invalid characters");
 59      }
 60      
 61      // Reject control characters that could mess up terminals
 62      if username.contains('\n') || username.contains('\r') || 
 63         username.contains('\t') || username.contains('\0') {
 64          return Err("Username contains control characters");
 65      }
 66      
 67      Ok(username.to_string())
 68  }
 69  
 70  async fn handle_client(
 71      socket: TcpStream,
 72      addr: SocketAddr,
 73      clients: Clients,
 74  ) -> Result<()> {
 75      let (mut reader, writer) = socket.into_split();
 76      let writer = Arc::new(Mutex::new(writer));
 77      clients.lock().await.insert(addr, Arc::clone(&writer));
 78  
 79      loop {
 80          let mut header = [0u8; 4];
 81          if let Err(e) = reader.read_exact(&mut header).await {
 82              eprintln!("🔌 {} disconnected: {}", addr, e);
 83              break;
 84          }
 85          let len = u32::from_be_bytes(header) as usize;
 86          if len > 1_048_576 {
 87              eprintln!("⚠️ {} sent oversized message", addr);
 88              break;
 89          }
 90          let mut buf = vec![0u8; len];
 91          if let Err(e) = reader.read_exact(&mut buf).await {
 92              eprintln!("⚠️ {} payload error: {}", addr, e);
 93              break;
 94          }
 95          let mut packet = Vec::with_capacity(4 + len);
 96          packet.extend_from_slice(&header);
 97          packet.extend_from_slice(&buf);
 98          broadcast(addr, &packet, &clients).await;
 99      }
100      clients.lock().await.remove(&addr);
101      Ok(())
102  }
103  
104  async fn broadcast(sender: SocketAddr, packet: &[u8], clients: &Clients) {
105      let peers: Vec<_> = {
106          let map = clients.lock().await;
107          map.iter()
108              .filter(|(&addr, _)| addr != sender)
109              .map(|(_, w)| Arc::clone(w))
110              .collect()
111      };
112      for peer in peers {
113          let mut w = peer.lock().await;
114          let _ = w.write_all(packet).await;
115      }
116  }
117  
118  async fn run_otp_server() -> Result<()> {
119      let listener = TcpListener::bind("0.0.0.0:8080").await?;
120      println!("✅ OTP server listening on :8080");
121      let clients: Clients = Arc::new(Mutex::new(HashMap::new()));
122  
123      loop {
124          let (socket, addr) = listener.accept().await?;
125          let clients = clients.clone();
126          tokio::spawn(async move {
127              if let Err(e) = handle_client(socket, addr, clients).await {
128                  eprintln!("⚠️ Client error: {}", e);
129              }
130          });
131      }
132  }
133  
134  #[derive(Debug)]
135  struct ModerationState {
136      moderator: Option<String>,
137      permitted_users: HashMap<String, bool>,
138      user_connections: HashMap<String, Arc<Mutex<tokio::net::tcp::OwnedWriteHalf>>>,
139  }
140  
141  impl ModerationState {
142      fn new() -> Self {
143          Self {
144              moderator: None,
145              permitted_users: HashMap::new(),
146              user_connections: HashMap::new(),
147          }
148      }
149  }
150  
151  async fn handle_moderation_client(
152      socket: TcpStream,
153      addr: SocketAddr,
154      moderation_state: Arc<Mutex<ModerationState>>,
155      connection_tracker: Arc<Mutex<ConnectionTracker>>,
156  ) -> Result<()> {
157      // Rate limiting check
158      {
159          let mut tracker = connection_tracker.lock().await;
160          if !tracker.allow_connection(addr.ip()) {
161              println!("⚠️ Rate limit exceeded for {}", addr.ip());
162              return Ok(()); // Silently drop connection
163          }
164      }
165  
166      let (mut reader, writer) = socket.into_split();
167      let writer = Arc::new(Mutex::new(writer));
168      let mut buf = [0; 512];
169      
170      // Set timeout for username reading
171      let username_result = tokio::time::timeout(
172          Duration::from_secs(10),
173          reader.read(&mut buf)
174      ).await;
175  
176      let username = match username_result {
177          Ok(Ok(n)) if n > 0 => {
178              // Validate username using new validation function
179              match validate_username(&buf[..n]) {
180                  Ok(name) => name,
181                  Err(error) => {
182                      println!("⚠️ Invalid username from {}: {}", addr, error);
183                      // Silently drop connection - don't give attackers feedback
184                      return Ok(());
185                  }
186              }
187          },
188          Ok(Ok(_)) => {
189              println!("⚠️ Empty username from {}", addr);
190              return Ok(());
191          },
192          Ok(Err(e)) => {
193              println!("⚠️ Read error from {}: {}", addr, e);
194              return Ok(());
195          },
196          Err(_) => {
197              println!("⚠️ Username timeout from {}", addr);
198              return Ok(());
199          }
200      };
201  
202      println!("👤 User '{}' connected from {}", username, addr);
203      
204      let writer_clone = Arc::clone(&writer);
205      let mut state = moderation_state.lock().await;
206      
207      // First user becomes moderator
208      if state.moderator.is_none() {
209          state.moderator = Some(username.clone());
210          state.permitted_users.insert(username.clone(), true);
211          println!("👑 '{}' is now the moderator", username);
212          
213          // Notify the user they are moderator
214          let mut w = writer.lock().await;
215          let _ = w.write_all(format!("MOD:{}\n", username).as_bytes()).await;
216      } else {
217          // Regular user - no speaking permission initially
218          state.permitted_users.insert(username.clone(), false);
219      }
220      
221      // Send existing users to new client
222      for existing_user in state.user_connections.keys() {
223          let mut w = writer.lock().await;
224          let _ = w.write_all(format!("👥 {}\n", existing_user).as_bytes()).await;
225      }
226      
227      // Add user to connections
228      state.user_connections.insert(username.clone(), writer_clone);
229      
230      // Broadcast new user to all existing users (excluding the new user)
231      let connections: Vec<_> = state.user_connections.iter()
232          .filter(|(user, _)| *user != &username)
233          .map(|(_, w)| Arc::clone(w))
234          .collect();
235      
236      // Broadcast moderator status to new user if not moderator
237      if Some(&username) != state.moderator.as_ref() {
238          if let Some(mod_name) = &state.moderator {
239              let mut w = writer.lock().await;
240              let _ = w.write_all(format!("MOD:{}\n", mod_name).as_bytes()).await;
241          }
242      }
243      
244      drop(state); // Release the lock
245      
246      for conn in connections {
247          let mut writer_handle = conn.lock().await;
248          let _ = writer_handle.write_all(format!("👥 {}\n", username).as_bytes()).await;
249      }
250      
251      // Handle incoming commands
252      loop {
253          let mut cmd_buf = [0; 512];
254          match reader.read(&mut cmd_buf).await {
255              Ok(n) if n > 0 => {
256                  let command = String::from_utf8_lossy(&cmd_buf[..n]).trim().to_string();
257                  
258                  if command.starts_with("REQUEST:") {
259                      let parts: Vec<&str> = command.splitn(2, ':').collect();
260                      if parts.len() == 2 {
261                          let requester = parts[1];
262                          let state = moderation_state.lock().await;
263                          
264                          // Notify moderator of the request
265                          if let Some(moderator) = &state.moderator {
266                              if let Some(mod_conn) = state.user_connections.get(moderator) {
267                                  let mut mod_writer = mod_conn.lock().await;
268                                  let _ = mod_writer.write_all(format!("REQ:{}\n", requester).as_bytes()).await;
269                              }
270                          }
271                      }
272                  } else if command.starts_with("GRANT ") {
273                      let parts: Vec<&str> = command.splitn(2, ' ').collect();
274                      if parts.len() == 2 {
275                          let user_to_grant = parts[1];
276                          let mut state = moderation_state.lock().await;
277                          
278                          // Only moderator can grant
279                          if Some(&username) == state.moderator.as_ref() {
280                              state.permitted_users.insert(user_to_grant.to_string(), true);
281                              
282                              // Notify all users
283                              let connections: Vec<_> = state.user_connections.values()
284                                  .map(|w| Arc::clone(w))
285                                  .collect();
286                              
287                              drop(state);
288                              
289                              for conn in connections {
290                                  let mut w = conn.lock().await;
291                                  let _ = w.write_all(format!("GRANTED:{}\n", user_to_grant).as_bytes()).await;
292                              }
293                              
294                              println!("✅ {} granted speaking permission to {}", username, user_to_grant);
295                          }
296                      }
297                  } else if command.starts_with("DENY ") {
298                      let parts: Vec<&str> = command.splitn(2, ' ').collect();
299                      if parts.len() == 2 {
300                          let user_to_deny = parts[1];
301                          let state = moderation_state.lock().await;
302                          
303                          // Only moderator can deny
304                          if Some(&username) == state.moderator.as_ref() {
305                              // Notify all users
306                              let connections: Vec<_> = state.user_connections.values()
307                                  .map(|w| Arc::clone(w))
308                                  .collect();
309                              
310                              drop(state);
311                              
312                              for conn in connections {
313                                  let mut w = conn.lock().await;
314                                  let _ = w.write_all(format!("DENIED:{}\n", user_to_deny).as_bytes()).await;
315                              }
316                              
317                              println!("❌ {} denied speaking permission to {}", username, user_to_deny);
318                          }
319                      }
320                  }
321              },
322              Ok(_) => {
323                  // Client disconnected
324                  break;
325              }
326              Err(e) => {
327                  eprintln!("🔌 Command read error from {}: {}", username, e);
328                  break;
329              }
330          }
331      }
332      
333      // Remove user on disconnect
334      let mut state = moderation_state.lock().await;
335      state.user_connections.remove(&username);
336      state.permitted_users.remove(&username);
337      
338      // If moderator disconnects, reset moderation system
339      if Some(&username) == state.moderator.as_ref() {
340          state.moderator = None;
341          state.permitted_users.clear();
342          
343          // Notify remaining users that moderation has been reset
344          let connections: Vec<_> = state.user_connections.values()
345              .map(|w| Arc::clone(w))
346              .collect();
347          
348          drop(state);
349          
350          for conn in connections {
351              let mut w = conn.lock().await;
352              let _ = w.write_all("👑 Moderator disconnected. First to reconnect will become new moderator.\n".as_bytes()).await;
353          }
354          
355          println!("👑 Moderator '{}' disconnected - moderation system reset", username);
356      }
357      
358      println!("👋 User '{}' disconnected", username);
359      
360      Ok(())
361  }
362  
363  async fn run_moderation_server() -> Result<()> {
364      let listener = TcpListener::bind("0.0.0.0:8081").await?;
365      println!("👑 Moderation server listening on :8081");
366      let moderation_state: Arc<Mutex<ModerationState>> = Arc::new(Mutex::new(ModerationState::new()));
367      let connection_tracker: Arc<Mutex<ConnectionTracker>> = Arc::new(Mutex::new(ConnectionTracker::new()));
368  
369      loop {
370          let (socket, addr) = listener.accept().await?;
371          let moderation_state = Arc::clone(&moderation_state);
372          let connection_tracker = Arc::clone(&connection_tracker);
373          tokio::spawn(async move {
374              if let Err(e) = handle_moderation_client(socket, addr, moderation_state, connection_tracker).await {
375                  eprintln!("⚠️ Moderation client error: {}", e);
376              }
377          });
378      }
379  }
380  
381  #[tokio::main]
382  async fn main() -> Result<()> {
383      println!("🚀 Starting OTP Chat Server with Gentlemen's Agreement Moderation");
384      println!("📡 OTP relay server will run on port 8080 (cryptographically blind)");
385      println!("👑 Moderation server will run on port 8081");
386      println!("🔒 Server operates in cryptographically blind mode - no access to keys or plaintext");
387      println!("👨‍⚖️ First user to connect becomes moderator automatically");
388      println!("🙋 Other users must request permission to speak from the moderator");
389      println!("🛡️ Enhanced security: Input validation and rate limiting enabled\n");
390      
391      let otp_server = tokio::spawn(run_otp_server());
392      let moderation_server = tokio::spawn(run_moderation_server());
393      let _ = tokio::try_join!(otp_server, moderation_server)?;
394      Ok(())
395  }