/ core / src / storage / replay.rs
replay.rs
  1  //! Replay attack protection.
  2  //!
  3  //! This module provides a cache of seen message IDs to prevent replay
  4  //! attacks. When a message is received, its ID is checked against the
  5  //! cache before processing.
  6  //!
  7  //! # How It Works
  8  //!
  9  //! 1. When a message is received, check `is_replay(message_id)`
 10  //! 2. If it returns `true`, reject the message (replay detected)
 11  //! 3. If it returns `false`, process the message and call `mark_seen(message_id)`
 12  //!
 13  //! # Cleanup
 14  //!
 15  //! Old entries are periodically cleaned up to prevent unbounded growth.
 16  //! The default retention period is 90 days, which should be longer than
 17  //! any message's maximum TTL.
 18  
 19  use chrono::{Duration, Utc};
 20  use rusqlite::params;
 21  
 22  use crate::error::Result;
 23  use crate::protocol::messages::MessageId;
 24  use crate::storage::Database;
 25  
 26  /// Default retention period for seen message IDs (in days).
 27  ///
 28  /// This should be longer than the maximum message TTL to ensure
 29  /// replayed messages are always detected.
 30  pub const DEFAULT_SEEN_RETENTION_DAYS: i64 = 90;
 31  
 32  impl Database {
 33      /// Check if a message ID has been seen before.
 34      ///
 35      /// # Arguments
 36      ///
 37      /// * `message_id` - The message ID to check
 38      ///
 39      /// # Returns
 40      ///
 41      /// `true` if the message ID has been seen (replay attack), `false` otherwise.
 42      ///
 43      /// # Example
 44      ///
 45      /// ```no_run
 46      /// use dead_drop_core::storage::Database;
 47      ///
 48      /// let db = Database::open_in_memory(b"passphrase")?;
 49      /// let message_id = [0u8; 16];
 50      ///
 51      /// if db.is_replay(&message_id)? {
 52      ///     // Reject the message - it's a replay
 53      /// } else {
 54      ///     // Process the message
 55      ///     db.mark_seen(&message_id)?;
 56      /// }
 57      /// # Ok::<(), dead_drop_core::error::DeadDropError>(())
 58      /// ```
 59      pub fn is_replay(&self, message_id: &MessageId) -> Result<bool> {
 60          let count: u32 = self.connection().query_row(
 61              "SELECT COUNT(*) FROM seen_messages WHERE message_id = ?",
 62              [message_id.as_slice()],
 63              |row| row.get(0),
 64          )?;
 65  
 66          Ok(count > 0)
 67      }
 68  
 69      /// Mark a message ID as seen.
 70      ///
 71      /// This should be called after successfully processing a message
 72      /// to prevent replay attacks.
 73      ///
 74      /// # Arguments
 75      ///
 76      /// * `message_id` - The message ID to mark as seen
 77      ///
 78      /// # Note
 79      ///
 80      /// If the message ID is already in the cache, this is a no-op
 81      /// (uses INSERT OR IGNORE).
 82      pub fn mark_seen(&self, message_id: &MessageId) -> Result<()> {
 83          let now = Utc::now().timestamp();
 84  
 85          self.connection().execute(
 86              "INSERT OR IGNORE INTO seen_messages (message_id, seen_at) VALUES (?, ?)",
 87              params![message_id.as_slice(), now],
 88          )?;
 89  
 90          Ok(())
 91      }
 92  
 93      /// Check if a message is a replay and mark it as seen atomically.
 94      ///
 95      /// This is a convenience method that combines `is_replay` and `mark_seen`
 96      /// into a single operation. It's useful when you want to check and mark
 97      /// in one step.
 98      ///
 99      /// # Returns
100      ///
101      /// `true` if this was a replay (message ID already seen), `false` if
102      /// the message is new (and has now been marked as seen).
103      pub fn check_and_mark_seen(&self, message_id: &MessageId) -> Result<bool> {
104          // Try to insert first
105          let now = Utc::now().timestamp();
106  
107          let rows = self.connection().execute(
108              "INSERT OR IGNORE INTO seen_messages (message_id, seen_at) VALUES (?, ?)",
109              params![message_id.as_slice(), now],
110          )?;
111  
112          // If no rows were inserted, the ID was already there (replay)
113          Ok(rows == 0)
114      }
115  
116      /// Clean up old seen message entries.
117      ///
118      /// Removes entries older than the specified number of days.
119      ///
120      /// # Arguments
121      ///
122      /// * `max_age_days` - Maximum age of entries to keep
123      ///
124      /// # Returns
125      ///
126      /// The number of entries deleted.
127      ///
128      /// # Example
129      ///
130      /// ```no_run
131      /// use dead_drop_core::storage::Database;
132      /// use dead_drop_core::storage::replay::DEFAULT_SEEN_RETENTION_DAYS;
133      ///
134      /// let db = Database::open_in_memory(b"passphrase")?;
135      /// let deleted = db.cleanup_seen(DEFAULT_SEEN_RETENTION_DAYS)?;
136      /// println!("Cleaned up {} old entries", deleted);
137      /// # Ok::<(), dead_drop_core::error::DeadDropError>(())
138      /// ```
139      pub fn cleanup_seen(&self, max_age_days: i64) -> Result<u32> {
140          let cutoff = (Utc::now() - Duration::days(max_age_days)).timestamp();
141  
142          let count = self.connection().execute(
143              "DELETE FROM seen_messages WHERE seen_at < ?",
144              [cutoff],
145          )?;
146  
147          Ok(count as u32)
148      }
149  
150      /// Get the count of seen message IDs in the cache.
151      pub fn count_seen(&self) -> Result<u32> {
152          let count: u32 = self.connection().query_row(
153              "SELECT COUNT(*) FROM seen_messages",
154              [],
155              |row| row.get(0),
156          )?;
157          Ok(count)
158      }
159  
160      /// Clear all seen message IDs.
161      ///
162      /// # Warning
163      ///
164      /// This will allow previously-seen messages to be replayed.
165      /// Only use this for testing or recovery purposes.
166      pub fn clear_seen(&self) -> Result<()> {
167          self.connection().execute("DELETE FROM seen_messages", [])?;
168          Ok(())
169      }
170  
171      /// Check multiple message IDs for replays in a single query.
172      ///
173      /// This is more efficient than calling `is_replay` multiple times.
174      ///
175      /// # Arguments
176      ///
177      /// * `message_ids` - Slice of message IDs to check
178      ///
179      /// # Returns
180      ///
181      /// Vector of booleans, where `true` indicates a replay.
182      pub fn check_replays(&self, message_ids: &[MessageId]) -> Result<Vec<bool>> {
183          let mut results = vec![false; message_ids.len()];
184  
185          // For each message ID, check if it exists
186          for (i, id) in message_ids.iter().enumerate() {
187              results[i] = self.is_replay(id)?;
188          }
189  
190          Ok(results)
191      }
192  
193      /// Mark multiple message IDs as seen in a single transaction.
194      ///
195      /// More efficient than calling `mark_seen` multiple times.
196      pub fn mark_seen_batch(&mut self, message_ids: &[MessageId]) -> Result<()> {
197          let tx = self.transaction()?;
198          let now = Utc::now().timestamp();
199  
200          for id in message_ids {
201              tx.execute(
202                  "INSERT OR IGNORE INTO seen_messages (message_id, seen_at) VALUES (?, ?)",
203                  params![id.as_slice(), now],
204              )?;
205          }
206  
207          tx.commit()?;
208          Ok(())
209      }
210  }
211  
212  #[cfg(test)]
213  mod tests {
214      use super::*;
215      use crate::protocol::messages::generate_message_id;
216  
217      fn test_db() -> Database {
218          Database::open_in_memory(b"test_passphrase").unwrap()
219      }
220  
221      #[test]
222      fn test_is_replay_not_seen() {
223          let db = test_db();
224          let message_id = generate_message_id();
225  
226          assert!(!db.is_replay(&message_id).unwrap());
227      }
228  
229      #[test]
230      fn test_is_replay_after_mark_seen() {
231          let db = test_db();
232          let message_id = generate_message_id();
233  
234          // First time: not a replay
235          assert!(!db.is_replay(&message_id).unwrap());
236  
237          // Mark as seen
238          db.mark_seen(&message_id).unwrap();
239  
240          // Now it should be a replay
241          assert!(db.is_replay(&message_id).unwrap());
242      }
243  
244      #[test]
245      fn test_mark_seen_idempotent() {
246          let db = test_db();
247          let message_id = generate_message_id();
248  
249          // Mark multiple times should not error
250          db.mark_seen(&message_id).unwrap();
251          db.mark_seen(&message_id).unwrap();
252          db.mark_seen(&message_id).unwrap();
253  
254          // Should still be counted once
255          assert_eq!(db.count_seen().unwrap(), 1);
256      }
257  
258      #[test]
259      fn test_check_and_mark_seen() {
260          let db = test_db();
261          let message_id = generate_message_id();
262  
263          // First call: not a replay, marks it
264          assert!(!db.check_and_mark_seen(&message_id).unwrap());
265  
266          // Second call: is a replay
267          assert!(db.check_and_mark_seen(&message_id).unwrap());
268      }
269  
270      #[test]
271      fn test_cleanup_seen() {
272          let db = test_db();
273  
274          // Add some entries
275          for _ in 0..5 {
276              let id = generate_message_id();
277              db.mark_seen(&id).unwrap();
278          }
279  
280          assert_eq!(db.count_seen().unwrap(), 5);
281  
282          // Cleanup with 0 days should remove all (since they're "old")
283          // Actually, entries are created with current timestamp, so 0 days won't remove them
284          // We need to test with negative days or manipulate timestamps
285  
286          // Clean up with very large max age - should remove nothing
287          let deleted = db.cleanup_seen(1000).unwrap();
288          assert_eq!(deleted, 0);
289          assert_eq!(db.count_seen().unwrap(), 5);
290      }
291  
292      #[test]
293      fn test_count_seen() {
294          let db = test_db();
295  
296          assert_eq!(db.count_seen().unwrap(), 0);
297  
298          // Add entries
299          for _ in 0..10 {
300              let id = generate_message_id();
301              db.mark_seen(&id).unwrap();
302          }
303  
304          assert_eq!(db.count_seen().unwrap(), 10);
305      }
306  
307      #[test]
308      fn test_clear_seen() {
309          let db = test_db();
310  
311          // Add entries
312          for _ in 0..5 {
313              let id = generate_message_id();
314              db.mark_seen(&id).unwrap();
315          }
316  
317          assert_eq!(db.count_seen().unwrap(), 5);
318  
319          // Clear
320          db.clear_seen().unwrap();
321  
322          assert_eq!(db.count_seen().unwrap(), 0);
323      }
324  
325      #[test]
326      fn test_check_replays() {
327          let db = test_db();
328  
329          let id1 = generate_message_id();
330          let id2 = generate_message_id();
331          let id3 = generate_message_id();
332  
333          // Mark id1 and id2 as seen
334          db.mark_seen(&id1).unwrap();
335          db.mark_seen(&id2).unwrap();
336  
337          let results = db.check_replays(&[id1, id2, id3]).unwrap();
338  
339          assert!(results[0]); // id1 is replay
340          assert!(results[1]); // id2 is replay
341          assert!(!results[2]); // id3 is not replay
342      }
343  
344      #[test]
345      fn test_mark_seen_batch() {
346          let mut db = test_db();
347  
348          let ids: Vec<MessageId> = (0..10).map(|_| generate_message_id()).collect();
349  
350          db.mark_seen_batch(&ids).unwrap();
351  
352          assert_eq!(db.count_seen().unwrap(), 10);
353  
354          // All should be replays now
355          for id in &ids {
356              assert!(db.is_replay(id).unwrap());
357          }
358      }
359  
360      #[test]
361      fn test_replay_persistence() {
362          use std::fs;
363          let temp_dir = tempfile::tempdir().unwrap();
364          let db_path = temp_dir.path().join("replay_test.db");
365          let db_path_str = db_path.to_str().unwrap();
366          let passphrase = b"replay_test";
367  
368          let message_id = generate_message_id();
369  
370          // Create and mark as seen
371          {
372              let db = Database::open(db_path_str, passphrase).unwrap();
373              db.mark_seen(&message_id).unwrap();
374          }
375  
376          // Reopen and verify
377          {
378              let db = Database::open(db_path_str, passphrase).unwrap();
379              assert!(db.is_replay(&message_id).unwrap());
380          }
381  
382          fs::remove_file(&db_path).ok();
383      }
384  
385      #[test]
386      fn test_different_messages_not_confused() {
387          let db = test_db();
388  
389          let id1 = generate_message_id();
390          let id2 = generate_message_id();
391  
392          // Mark only id1
393          db.mark_seen(&id1).unwrap();
394  
395          // id1 should be replay, id2 should not
396          assert!(db.is_replay(&id1).unwrap());
397          assert!(!db.is_replay(&id2).unwrap());
398      }
399  }