/ server / src / encryption.rs
encryption.rs
  1  use chacha20poly1305::{
  2      aead::{Aead, KeyInit, OsRng},
  3      ChaCha20Poly1305, Key, Nonce,
  4  };
  5  use rand::RngCore;
  6  use sha2::{Digest, Sha256};
  7  use zeroize::Zeroize;
  8  
  9  /// Derive encryption key from room_id using SHA-256 (server-side encryption)
 10  pub fn derive_key(room_id: &str, salt: &[u8]) -> [u8; 32] {
 11      let mut hasher = Sha256::new();
 12      hasher.update(room_id.as_bytes());
 13      hasher.update(salt);
 14      let result = hasher.finalize();
 15      let mut key = [0u8; 32];
 16      key.copy_from_slice(&result);
 17      key
 18  }
 19  
 20  // Encrypt the message using ChaCha20Poly1305
 21  pub fn encrypt_message(plain_text: &str, room_id: &str) -> Result<String, &'static str> {
 22      let mut salt = [0u8; 16];
 23      OsRng.fill_bytes(&mut salt);
 24  
 25      let mut key = derive_key(room_id, &salt);
 26      let cipher = ChaCha20Poly1305::new(&Key::from_slice(&key));
 27  
 28      let mut nonce_bytes = [0u8; 12];
 29      OsRng.fill_bytes(&mut nonce_bytes);
 30      let nonce = Nonce::from_slice(&nonce_bytes);
 31  
 32      let encrypted_data = cipher
 33          .encrypt(nonce, plain_text.as_bytes())
 34          .map_err(|_| "Encryption error")?;
 35  
 36      key.zeroize();
 37  
 38      Ok(format!(
 39          "{}:{}:{}",
 40          hex::encode(salt),
 41          hex::encode(nonce_bytes),
 42          hex::encode(encrypted_data)
 43      ))
 44  }
 45  
 46  // Decrypt the message using ChaCha20Poly1305
 47  pub fn decrypt_message(encrypted_text: &str, room_id: &str) -> Result<String, &'static str> {
 48      let parts: Vec<&str> = encrypted_text.split(':').collect();
 49      if parts.len() != 3 {
 50          return Err("Invalid encrypted message format");
 51      }
 52  
 53      let salt = hex::decode(parts[0]).map_err(|_| "Decryption error")?;
 54      let nonce_bytes = hex::decode(parts[1]).map_err(|_| "Decryption error")?;
 55      let encrypted_data = hex::decode(parts[2]).map_err(|_| "Decryption error")?;
 56  
 57      let mut key = derive_key(room_id, &salt);
 58      let cipher = ChaCha20Poly1305::new(&Key::from_slice(&key));
 59  
 60      let nonce = Nonce::from_slice(&nonce_bytes);
 61  
 62      let decrypted_data = cipher
 63          .decrypt(nonce, encrypted_data.as_ref())
 64          .map_err(|_| "Decryption error")?;
 65  
 66      key.zeroize();
 67  
 68      String::from_utf8(decrypted_data).map_err(|_| "Decryption error")
 69  }
 70  
 71  pub fn is_message_encrypted(message: &str) -> bool {
 72      // Define markers for both types of blocks
 73      const ENCRYPTED_BEGIN_MARKER: &str = "-----BEGIN ENCRYPTED MESSAGE-----";
 74      const ENCRYPTED_END_MARKER: &str = "-----END ENCRYPTED MESSAGE-----";
 75      const DILITHIUM_PUBLIC_KEY_PREFIX: &str = "DILITHIUM_PUBLIC_KEY:";
 76      const EDDSA_PUBLIC_KEY_PREFIX: &str = "EDDSA_PUBLIC_KEY:";
 77      const ECDH_KEY_EXCHANGE_PREFIX: &str = "ECDH_PUBLIC_KEY:";
 78      const KYBER_KEY_EXCHANGE_PREFIX: &str = "KYBER_PUBLIC_KEY:";
 79  
 80      // Check for key exchange prefixes and handle them separately
 81      if message.starts_with(DILITHIUM_PUBLIC_KEY_PREFIX)
 82          || message.starts_with(EDDSA_PUBLIC_KEY_PREFIX)
 83          || message.starts_with(ECDH_KEY_EXCHANGE_PREFIX)
 84          || message.starts_with(KYBER_KEY_EXCHANGE_PREFIX)
 85      {
 86          // Allow key exchange messages and return true
 87          return true;
 88      }
 89  
 90      // Determine which markers are present for PGP encryption or key block
 91      let begin_marker = if message.contains(ENCRYPTED_BEGIN_MARKER) { // Check for encrypted message
 92          ENCRYPTED_BEGIN_MARKER
 93      } else {
 94          println!("Missing or unrecognized begin marker.");
 95          return false;
 96      };
 97  
 98      let end_marker = if message.contains(ENCRYPTED_END_MARKER) { // Check for encrypted message
 99          ENCRYPTED_END_MARKER
100      } else {
101          println!("Missing or unrecognized end marker.");
102          return false;
103      };
104  
105      // Locate the markers
106      let begin_marker_pos = message.find(begin_marker);
107      let end_marker_pos = message.find(end_marker);
108  
109      if let (Some(begin), Some(end)) = (begin_marker_pos, end_marker_pos) {
110          if begin < end {
111              return true;
112          } else {
113              println!("Markers out of order.");
114          }
115      } else {
116          println!("Missing markers.");
117      }
118  
119      false
120  }