replay.rs
1 //! Replay attack protection. 2 //! 3 //! This module provides a cache of seen message IDs to prevent replay 4 //! attacks. When a message is received, its ID is checked against the 5 //! cache before processing. 6 //! 7 //! # How It Works 8 //! 9 //! 1. When a message is received, check `is_replay(message_id)` 10 //! 2. If it returns `true`, reject the message (replay detected) 11 //! 3. If it returns `false`, process the message and call `mark_seen(message_id)` 12 //! 13 //! # Cleanup 14 //! 15 //! Old entries are periodically cleaned up to prevent unbounded growth. 16 //! The default retention period is 90 days, which should be longer than 17 //! any message's maximum TTL. 18 19 use chrono::{Duration, Utc}; 20 use rusqlite::params; 21 22 use crate::error::Result; 23 use crate::protocol::messages::MessageId; 24 use crate::storage::Database; 25 26 /// Default retention period for seen message IDs (in days). 27 /// 28 /// This should be longer than the maximum message TTL to ensure 29 /// replayed messages are always detected. 30 pub const DEFAULT_SEEN_RETENTION_DAYS: i64 = 90; 31 32 impl Database { 33 /// Check if a message ID has been seen before. 34 /// 35 /// # Arguments 36 /// 37 /// * `message_id` - The message ID to check 38 /// 39 /// # Returns 40 /// 41 /// `true` if the message ID has been seen (replay attack), `false` otherwise. 42 /// 43 /// # Example 44 /// 45 /// ```no_run 46 /// use dead_drop_core::storage::Database; 47 /// 48 /// let db = Database::open_in_memory(b"passphrase")?; 49 /// let message_id = [0u8; 16]; 50 /// 51 /// if db.is_replay(&message_id)? { 52 /// // Reject the message - it's a replay 53 /// } else { 54 /// // Process the message 55 /// db.mark_seen(&message_id)?; 56 /// } 57 /// # Ok::<(), dead_drop_core::error::DeadDropError>(()) 58 /// ``` 59 pub fn is_replay(&self, message_id: &MessageId) -> Result<bool> { 60 let count: u32 = self.connection().query_row( 61 "SELECT COUNT(*) FROM seen_messages WHERE message_id = ?", 62 [message_id.as_slice()], 63 |row| row.get(0), 64 )?; 65 66 Ok(count > 0) 67 } 68 69 /// Mark a message ID as seen. 70 /// 71 /// This should be called after successfully processing a message 72 /// to prevent replay attacks. 73 /// 74 /// # Arguments 75 /// 76 /// * `message_id` - The message ID to mark as seen 77 /// 78 /// # Note 79 /// 80 /// If the message ID is already in the cache, this is a no-op 81 /// (uses INSERT OR IGNORE). 82 pub fn mark_seen(&self, message_id: &MessageId) -> Result<()> { 83 let now = Utc::now().timestamp(); 84 85 self.connection().execute( 86 "INSERT OR IGNORE INTO seen_messages (message_id, seen_at) VALUES (?, ?)", 87 params![message_id.as_slice(), now], 88 )?; 89 90 Ok(()) 91 } 92 93 /// Check if a message is a replay and mark it as seen atomically. 94 /// 95 /// This is a convenience method that combines `is_replay` and `mark_seen` 96 /// into a single operation. It's useful when you want to check and mark 97 /// in one step. 98 /// 99 /// # Returns 100 /// 101 /// `true` if this was a replay (message ID already seen), `false` if 102 /// the message is new (and has now been marked as seen). 103 pub fn check_and_mark_seen(&self, message_id: &MessageId) -> Result<bool> { 104 // Try to insert first 105 let now = Utc::now().timestamp(); 106 107 let rows = self.connection().execute( 108 "INSERT OR IGNORE INTO seen_messages (message_id, seen_at) VALUES (?, ?)", 109 params![message_id.as_slice(), now], 110 )?; 111 112 // If no rows were inserted, the ID was already there (replay) 113 Ok(rows == 0) 114 } 115 116 /// Clean up old seen message entries. 117 /// 118 /// Removes entries older than the specified number of days. 119 /// 120 /// # Arguments 121 /// 122 /// * `max_age_days` - Maximum age of entries to keep 123 /// 124 /// # Returns 125 /// 126 /// The number of entries deleted. 127 /// 128 /// # Example 129 /// 130 /// ```no_run 131 /// use dead_drop_core::storage::Database; 132 /// use dead_drop_core::storage::replay::DEFAULT_SEEN_RETENTION_DAYS; 133 /// 134 /// let db = Database::open_in_memory(b"passphrase")?; 135 /// let deleted = db.cleanup_seen(DEFAULT_SEEN_RETENTION_DAYS)?; 136 /// println!("Cleaned up {} old entries", deleted); 137 /// # Ok::<(), dead_drop_core::error::DeadDropError>(()) 138 /// ``` 139 pub fn cleanup_seen(&self, max_age_days: i64) -> Result<u32> { 140 let cutoff = (Utc::now() - Duration::days(max_age_days)).timestamp(); 141 142 let count = self.connection().execute( 143 "DELETE FROM seen_messages WHERE seen_at < ?", 144 [cutoff], 145 )?; 146 147 Ok(count as u32) 148 } 149 150 /// Get the count of seen message IDs in the cache. 151 pub fn count_seen(&self) -> Result<u32> { 152 let count: u32 = self.connection().query_row( 153 "SELECT COUNT(*) FROM seen_messages", 154 [], 155 |row| row.get(0), 156 )?; 157 Ok(count) 158 } 159 160 /// Clear all seen message IDs. 161 /// 162 /// # Warning 163 /// 164 /// This will allow previously-seen messages to be replayed. 165 /// Only use this for testing or recovery purposes. 166 pub fn clear_seen(&self) -> Result<()> { 167 self.connection().execute("DELETE FROM seen_messages", [])?; 168 Ok(()) 169 } 170 171 /// Check multiple message IDs for replays in a single query. 172 /// 173 /// This is more efficient than calling `is_replay` multiple times. 174 /// 175 /// # Arguments 176 /// 177 /// * `message_ids` - Slice of message IDs to check 178 /// 179 /// # Returns 180 /// 181 /// Vector of booleans, where `true` indicates a replay. 182 pub fn check_replays(&self, message_ids: &[MessageId]) -> Result<Vec<bool>> { 183 let mut results = vec![false; message_ids.len()]; 184 185 // For each message ID, check if it exists 186 for (i, id) in message_ids.iter().enumerate() { 187 results[i] = self.is_replay(id)?; 188 } 189 190 Ok(results) 191 } 192 193 /// Mark multiple message IDs as seen in a single transaction. 194 /// 195 /// More efficient than calling `mark_seen` multiple times. 196 pub fn mark_seen_batch(&mut self, message_ids: &[MessageId]) -> Result<()> { 197 let tx = self.transaction()?; 198 let now = Utc::now().timestamp(); 199 200 for id in message_ids { 201 tx.execute( 202 "INSERT OR IGNORE INTO seen_messages (message_id, seen_at) VALUES (?, ?)", 203 params![id.as_slice(), now], 204 )?; 205 } 206 207 tx.commit()?; 208 Ok(()) 209 } 210 } 211 212 #[cfg(test)] 213 mod tests { 214 use super::*; 215 use crate::protocol::messages::generate_message_id; 216 217 fn test_db() -> Database { 218 Database::open_in_memory(b"test_passphrase").unwrap() 219 } 220 221 #[test] 222 fn test_is_replay_not_seen() { 223 let db = test_db(); 224 let message_id = generate_message_id(); 225 226 assert!(!db.is_replay(&message_id).unwrap()); 227 } 228 229 #[test] 230 fn test_is_replay_after_mark_seen() { 231 let db = test_db(); 232 let message_id = generate_message_id(); 233 234 // First time: not a replay 235 assert!(!db.is_replay(&message_id).unwrap()); 236 237 // Mark as seen 238 db.mark_seen(&message_id).unwrap(); 239 240 // Now it should be a replay 241 assert!(db.is_replay(&message_id).unwrap()); 242 } 243 244 #[test] 245 fn test_mark_seen_idempotent() { 246 let db = test_db(); 247 let message_id = generate_message_id(); 248 249 // Mark multiple times should not error 250 db.mark_seen(&message_id).unwrap(); 251 db.mark_seen(&message_id).unwrap(); 252 db.mark_seen(&message_id).unwrap(); 253 254 // Should still be counted once 255 assert_eq!(db.count_seen().unwrap(), 1); 256 } 257 258 #[test] 259 fn test_check_and_mark_seen() { 260 let db = test_db(); 261 let message_id = generate_message_id(); 262 263 // First call: not a replay, marks it 264 assert!(!db.check_and_mark_seen(&message_id).unwrap()); 265 266 // Second call: is a replay 267 assert!(db.check_and_mark_seen(&message_id).unwrap()); 268 } 269 270 #[test] 271 fn test_cleanup_seen() { 272 let db = test_db(); 273 274 // Add some entries 275 for _ in 0..5 { 276 let id = generate_message_id(); 277 db.mark_seen(&id).unwrap(); 278 } 279 280 assert_eq!(db.count_seen().unwrap(), 5); 281 282 // Cleanup with 0 days should remove all (since they're "old") 283 // Actually, entries are created with current timestamp, so 0 days won't remove them 284 // We need to test with negative days or manipulate timestamps 285 286 // Clean up with very large max age - should remove nothing 287 let deleted = db.cleanup_seen(1000).unwrap(); 288 assert_eq!(deleted, 0); 289 assert_eq!(db.count_seen().unwrap(), 5); 290 } 291 292 #[test] 293 fn test_count_seen() { 294 let db = test_db(); 295 296 assert_eq!(db.count_seen().unwrap(), 0); 297 298 // Add entries 299 for _ in 0..10 { 300 let id = generate_message_id(); 301 db.mark_seen(&id).unwrap(); 302 } 303 304 assert_eq!(db.count_seen().unwrap(), 10); 305 } 306 307 #[test] 308 fn test_clear_seen() { 309 let db = test_db(); 310 311 // Add entries 312 for _ in 0..5 { 313 let id = generate_message_id(); 314 db.mark_seen(&id).unwrap(); 315 } 316 317 assert_eq!(db.count_seen().unwrap(), 5); 318 319 // Clear 320 db.clear_seen().unwrap(); 321 322 assert_eq!(db.count_seen().unwrap(), 0); 323 } 324 325 #[test] 326 fn test_check_replays() { 327 let db = test_db(); 328 329 let id1 = generate_message_id(); 330 let id2 = generate_message_id(); 331 let id3 = generate_message_id(); 332 333 // Mark id1 and id2 as seen 334 db.mark_seen(&id1).unwrap(); 335 db.mark_seen(&id2).unwrap(); 336 337 let results = db.check_replays(&[id1, id2, id3]).unwrap(); 338 339 assert!(results[0]); // id1 is replay 340 assert!(results[1]); // id2 is replay 341 assert!(!results[2]); // id3 is not replay 342 } 343 344 #[test] 345 fn test_mark_seen_batch() { 346 let mut db = test_db(); 347 348 let ids: Vec<MessageId> = (0..10).map(|_| generate_message_id()).collect(); 349 350 db.mark_seen_batch(&ids).unwrap(); 351 352 assert_eq!(db.count_seen().unwrap(), 10); 353 354 // All should be replays now 355 for id in &ids { 356 assert!(db.is_replay(id).unwrap()); 357 } 358 } 359 360 #[test] 361 fn test_replay_persistence() { 362 use std::fs; 363 let temp_dir = tempfile::tempdir().unwrap(); 364 let db_path = temp_dir.path().join("replay_test.db"); 365 let db_path_str = db_path.to_str().unwrap(); 366 let passphrase = b"replay_test"; 367 368 let message_id = generate_message_id(); 369 370 // Create and mark as seen 371 { 372 let db = Database::open(db_path_str, passphrase).unwrap(); 373 db.mark_seen(&message_id).unwrap(); 374 } 375 376 // Reopen and verify 377 { 378 let db = Database::open(db_path_str, passphrase).unwrap(); 379 assert!(db.is_replay(&message_id).unwrap()); 380 } 381 382 fs::remove_file(&db_path).ok(); 383 } 384 385 #[test] 386 fn test_different_messages_not_confused() { 387 let db = test_db(); 388 389 let id1 = generate_message_id(); 390 let id2 = generate_message_id(); 391 392 // Mark only id1 393 db.mark_seen(&id1).unwrap(); 394 395 // id1 should be replay, id2 should not 396 assert!(db.is_replay(&id1).unwrap()); 397 assert!(!db.is_replay(&id2).unwrap()); 398 } 399 }