/ core / src / protocol / wire.rs
wire.rs
   1  //! Wire format for Dead Drop protocol messages.
   2  //!
   3  //! This module defines the binary wire format used for communication between
   4  //! devices over BLE. All messages are prefixed with a single byte indicating
   5  //! the message type, followed by bincode-encoded payload data.
   6  //!
   7  //! # Message Types
   8  //!
   9  //! | Type | Code | Direction | Purpose |
  10  //! |------|------|-----------|---------|
  11  //! | [`MsgCount`] | `0x01` | Both | Announce number of messages to send |
  12  //! | [`MsgData`] | `0x02` | Both | Encrypted message payload |
  13  //! | [`MsgAck`] | `0x03` | Both | Acknowledge receipt of message |
  14  //! | [`MsgNack`] | `0x04` | Both | Negative acknowledgement with error |
  15  //! | [`SessionDone`] | `0x05` | Both | Signal exchange completion |
  16  //! | [`Ping`] | `0x06` | Both | Keep-alive / connectivity check |
  17  //! | [`Pong`] | `0x07` | Both | Response to ping |
  18  //!
  19  //! # Wire Format
  20  //!
  21  //! ```text
  22  //! +--------+-------------------+
  23  //! | 1 byte |   N bytes         |
  24  //! | type   |   bincode payload |
  25  //! +--------+-------------------+
  26  //! ```
  27  //!
  28  //! # Protocol Version
  29  //!
  30  //! The current protocol version is 1. Version negotiation happens during
  31  //! the message count exchange phase.
  32  //!
  33  //! # Size Limits
  34  //!
  35  //! All wire messages are limited to [`MAX_WIRE_MESSAGE_SIZE`] bytes to prevent
  36  //! memory exhaustion attacks and ensure compatibility with BLE MTU constraints.
  37  //!
  38  //! # Example
  39  //!
  40  //! ```
  41  //! use dead_drop_core::protocol::wire::{
  42  //!     MessageType, MsgCount, serialize, deserialize,
  43  //! };
  44  //!
  45  //! // Create a message count announcement
  46  //! let msg = MsgCount {
  47  //!     version: 1,
  48  //!     count: 5,
  49  //! };
  50  //!
  51  //! // Serialize for transmission
  52  //! let bytes = serialize(MessageType::MsgCount, &msg).unwrap();
  53  //!
  54  //! // Deserialize on receipt
  55  //! let (msg_type, payload) = deserialize(&bytes).unwrap();
  56  //! assert_eq!(msg_type, MessageType::MsgCount);
  57  //! ```
  58  
  59  use serde::{Deserialize, Serialize};
  60  
  61  use crate::error::{DeadDropError, Result};
  62  
  63  // =============================================================================
  64  // CONSTANTS
  65  // =============================================================================
  66  
  67  /// Current protocol version.
  68  ///
  69  /// This version number is included in the initial message count exchange
  70  /// to allow peers to detect incompatible protocol versions.
  71  pub const PROTOCOL_VERSION: u8 = 1;
  72  
  73  /// Maximum size of a wire message in bytes.
  74  ///
  75  /// This limit ensures messages fit within BLE constraints and prevents
  76  /// memory exhaustion attacks. Set to 64KB which is generous for the
  77  /// expected use case of text messages and small documents.
  78  pub const MAX_WIRE_MESSAGE_SIZE: usize = 65536;
  79  
  80  /// Maximum number of messages that can be exchanged in a single session.
  81  ///
  82  /// This prevents resource exhaustion from a peer claiming to have an
  83  /// unreasonable number of messages.
  84  pub const MAX_MESSAGE_COUNT: u16 = 1000;
  85  
  86  // =============================================================================
  87  // MESSAGE TYPES
  88  // =============================================================================
  89  
  90  /// Wire message type identifier.
  91  ///
  92  /// Each message on the wire begins with a single byte identifying its type.
  93  /// This allows the receiver to determine how to parse the remaining payload.
  94  ///
  95  /// # Wire Encoding
  96  ///
  97  /// The discriminant values are explicitly defined to ensure stable wire format
  98  /// across compiler versions and platforms.
  99  #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 100  #[repr(u8)]
 101  pub enum MessageType {
 102      /// Message count announcement.
 103      ///
 104      /// Sent at the start of an exchange to inform the peer how many
 105      /// messages will be transmitted.
 106      MsgCount = 0x01,
 107  
 108      /// Encrypted message data.
 109      ///
 110      /// Contains an encrypted message payload along with index information
 111      /// for ordering and acknowledgement.
 112      MsgData = 0x02,
 113  
 114      /// Positive acknowledgement.
 115      ///
 116      /// Confirms successful receipt and processing of a message.
 117      MsgAck = 0x03,
 118  
 119      /// Negative acknowledgement.
 120      ///
 121      /// Indicates a problem processing a received message, with an
 122      /// error code explaining the failure.
 123      MsgNack = 0x04,
 124  
 125      /// Session complete signal.
 126      ///
 127      /// Sent when a peer has finished sending all messages and received
 128      /// all expected acknowledgements.
 129      SessionDone = 0x05,
 130  
 131      /// Keep-alive ping.
 132      ///
 133      /// Used to verify the connection is still active during long
 134      /// operations or idle periods.
 135      Ping = 0x06,
 136  
 137      /// Ping response.
 138      ///
 139      /// Sent in response to a Ping to confirm connectivity.
 140      Pong = 0x07,
 141  }
 142  
 143  impl MessageType {
 144      /// Convert a byte to a MessageType.
 145      ///
 146      /// # Errors
 147      ///
 148      /// Returns `InvalidFormat` if the byte does not correspond to a known
 149      /// message type.
 150      pub fn from_byte(byte: u8) -> Result<Self> {
 151          match byte {
 152              0x01 => Ok(MessageType::MsgCount),
 153              0x02 => Ok(MessageType::MsgData),
 154              0x03 => Ok(MessageType::MsgAck),
 155              0x04 => Ok(MessageType::MsgNack),
 156              0x05 => Ok(MessageType::SessionDone),
 157              0x06 => Ok(MessageType::Ping),
 158              0x07 => Ok(MessageType::Pong),
 159              _ => Err(DeadDropError::InvalidFormat(format!(
 160                  "Unknown message type: 0x{:02X}",
 161                  byte
 162              ))),
 163          }
 164      }
 165  
 166      /// Convert MessageType to its wire byte representation.
 167      pub fn to_byte(self) -> u8 {
 168          self as u8
 169      }
 170  }
 171  
 172  // =============================================================================
 173  // MESSAGE STRUCTURES
 174  // =============================================================================
 175  
 176  /// Message count announcement.
 177  ///
 178  /// Sent at the beginning of a message exchange to inform the peer how many
 179  /// messages will be transmitted. Both peers send this message after the
 180  /// handshake completes.
 181  ///
 182  /// # Fields
 183  ///
 184  /// - `version`: Protocol version for compatibility checking
 185  /// - `count`: Number of messages this peer will send (0 to [`MAX_MESSAGE_COUNT`])
 186  ///
 187  /// # Example
 188  ///
 189  /// ```
 190  /// use dead_drop_core::protocol::wire::MsgCount;
 191  ///
 192  /// let count = MsgCount {
 193  ///     version: 1,
 194  ///     count: 3,
 195  /// };
 196  /// ```
 197  #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 198  pub struct MsgCount {
 199      /// Protocol version (must match [`PROTOCOL_VERSION`]).
 200      pub version: u8,
 201      /// Number of messages to send (0 = no messages, just receiving).
 202      pub count: u16,
 203  }
 204  
 205  impl MsgCount {
 206      /// Create a new MsgCount with the current protocol version.
 207      ///
 208      /// # Arguments
 209      ///
 210      /// * `count` - Number of messages to announce
 211      ///
 212      /// # Example
 213      ///
 214      /// ```
 215      /// use dead_drop_core::protocol::wire::MsgCount;
 216      ///
 217      /// let msg = MsgCount::new(5);
 218      /// assert_eq!(msg.version, 1);
 219      /// assert_eq!(msg.count, 5);
 220      /// ```
 221      pub fn new(count: u16) -> Self {
 222          Self {
 223              version: PROTOCOL_VERSION,
 224              count,
 225          }
 226      }
 227  
 228      /// Validate the message count announcement.
 229      ///
 230      /// # Errors
 231      ///
 232      /// - `InvalidFormat` if version doesn't match current protocol version
 233      /// - `InvalidFormat` if count exceeds [`MAX_MESSAGE_COUNT`]
 234      pub fn validate(&self) -> Result<()> {
 235          if self.version != PROTOCOL_VERSION {
 236              return Err(DeadDropError::InvalidFormat(format!(
 237                  "Protocol version mismatch: expected {}, got {}",
 238                  PROTOCOL_VERSION, self.version
 239              )));
 240          }
 241          if self.count > MAX_MESSAGE_COUNT {
 242              return Err(DeadDropError::InvalidFormat(format!(
 243                  "Message count {} exceeds maximum {}",
 244                  self.count, MAX_MESSAGE_COUNT
 245              )));
 246          }
 247          Ok(())
 248      }
 249  }
 250  
 251  /// Encrypted message data payload.
 252  ///
 253  /// Contains an encrypted message along with metadata for ordering and
 254  /// reassembly. Messages are sent in order by index.
 255  ///
 256  /// # Fields
 257  ///
 258  /// - `index`: Zero-based position of this message (0 to total-1)
 259  /// - `total`: Total number of messages being sent
 260  /// - `payload`: The encrypted message bytes
 261  ///
 262  /// # Size Limit
 263  ///
 264  /// The payload must not exceed `MAX_WIRE_MESSAGE_SIZE - overhead` bytes.
 265  /// In practice, encrypted messages should be under 32KB.
 266  #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 267  pub struct MsgData {
 268      /// Zero-based index of this message in the sequence.
 269      pub index: u8,
 270      /// Total number of messages in this exchange.
 271      pub total: u8,
 272      /// Encrypted message payload.
 273      pub payload: Vec<u8>,
 274  }
 275  
 276  impl MsgData {
 277      /// Create a new MsgData.
 278      ///
 279      /// # Arguments
 280      ///
 281      /// * `index` - Position of this message (0-based)
 282      /// * `total` - Total messages in exchange
 283      /// * `payload` - Encrypted message bytes
 284      pub fn new(index: u8, total: u8, payload: Vec<u8>) -> Self {
 285          Self {
 286              index,
 287              total,
 288              payload,
 289          }
 290      }
 291  
 292      /// Validate the message data.
 293      ///
 294      /// # Errors
 295      ///
 296      /// - `InvalidFormat` if index >= total
 297      /// - `InvalidFormat` if total is 0
 298      /// - `MessageTooLarge` if payload exceeds size limit
 299      pub fn validate(&self) -> Result<()> {
 300          if self.total == 0 {
 301              return Err(DeadDropError::InvalidFormat(
 302                  "Message total cannot be zero".to_string(),
 303              ));
 304          }
 305          if self.index >= self.total {
 306              return Err(DeadDropError::InvalidFormat(format!(
 307                  "Message index {} out of range (total: {})",
 308                  self.index, self.total
 309              )));
 310          }
 311          // Reserve some bytes for wire overhead
 312          let max_payload = MAX_WIRE_MESSAGE_SIZE - 16;
 313          if self.payload.len() > max_payload {
 314              return Err(DeadDropError::MessageTooLarge {
 315                  size: self.payload.len(),
 316                  max: max_payload,
 317              });
 318          }
 319          Ok(())
 320      }
 321  }
 322  
 323  /// Positive acknowledgement for a received message.
 324  ///
 325  /// Sent after successfully receiving and processing a message to confirm
 326  /// delivery. The sender can mark the message as delivered upon receiving
 327  /// this acknowledgement.
 328  ///
 329  /// # Fields
 330  ///
 331  /// - `index`: The index of the message being acknowledged
 332  /// - `status`: Status code (currently always 0 for success)
 333  #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 334  pub struct MsgAck {
 335      /// Index of the acknowledged message.
 336      pub index: u8,
 337      /// Status code (0 = success).
 338      pub status: u8,
 339  }
 340  
 341  impl MsgAck {
 342      /// Create a successful acknowledgement.
 343      ///
 344      /// # Arguments
 345      ///
 346      /// * `index` - Index of the message being acknowledged
 347      pub fn success(index: u8) -> Self {
 348          Self { index, status: 0 }
 349      }
 350  }
 351  
 352  /// Negative acknowledgement with error information.
 353  ///
 354  /// Sent when a message could not be processed successfully. The error
 355  /// code indicates the reason for failure.
 356  ///
 357  /// # Security Note
 358  ///
 359  /// Error codes are intentionally limited to avoid leaking information
 360  /// that could aid attackers. For example, we don't distinguish between
 361  /// "wrong key" and "tampered message".
 362  #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 363  pub struct MsgNack {
 364      /// Index of the rejected message.
 365      pub index: u8,
 366      /// Error code indicating the failure reason.
 367      pub error_code: u8,
 368  }
 369  
 370  impl MsgNack {
 371      /// Create a negative acknowledgement.
 372      ///
 373      /// # Arguments
 374      ///
 375      /// * `index` - Index of the rejected message
 376      /// * `error_code` - Reason for rejection (see [`ErrorCode`])
 377      pub fn new(index: u8, error_code: ErrorCode) -> Self {
 378          Self {
 379              index,
 380              error_code: error_code as u8,
 381          }
 382      }
 383  
 384      /// Get the error code as an enum.
 385      ///
 386      /// Returns `None` if the error code is not recognized.
 387      pub fn error(&self) -> Option<ErrorCode> {
 388          ErrorCode::from_byte(self.error_code)
 389      }
 390  }
 391  
 392  /// Session completion signal.
 393  ///
 394  /// Sent when a peer has finished sending all messages and received
 395  /// acknowledgements for all of them (or given up on unacknowledged ones).
 396  ///
 397  /// # Fields
 398  ///
 399  /// - `sent_count`: Number of messages successfully sent
 400  /// - `received_count`: Number of messages successfully received
 401  #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 402  pub struct SessionDone {
 403      /// Number of messages sent by this peer.
 404      pub sent_count: u8,
 405      /// Number of messages received by this peer.
 406      pub received_count: u8,
 407  }
 408  
 409  impl SessionDone {
 410      /// Create a session done message.
 411      pub fn new(sent_count: u8, received_count: u8) -> Self {
 412          Self {
 413              sent_count,
 414              received_count,
 415          }
 416      }
 417  }
 418  
 419  /// Keep-alive ping message.
 420  ///
 421  /// Contains a random nonce that must be echoed in the Pong response.
 422  /// This allows detection of connection issues and prevents replay attacks.
 423  #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 424  pub struct Ping {
 425      /// Random nonce to be echoed in Pong.
 426      pub nonce: u32,
 427  }
 428  
 429  impl Ping {
 430      /// Create a new ping with a random nonce.
 431      pub fn new() -> Self {
 432          use rand::RngCore;
 433          Self {
 434              nonce: rand::rngs::OsRng.next_u32(),
 435          }
 436      }
 437  
 438      /// Create a ping with a specific nonce (for testing).
 439      pub fn with_nonce(nonce: u32) -> Self {
 440          Self { nonce }
 441      }
 442  }
 443  
 444  impl Default for Ping {
 445      fn default() -> Self {
 446          Self::new()
 447      }
 448  }
 449  
 450  /// Keep-alive pong response.
 451  ///
 452  /// Echoes the nonce from the corresponding Ping message.
 453  #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 454  pub struct Pong {
 455      /// Echoed nonce from the Ping.
 456      pub nonce: u32,
 457  }
 458  
 459  impl Pong {
 460      /// Create a pong response to a ping.
 461      pub fn from_ping(ping: &Ping) -> Self {
 462          Self { nonce: ping.nonce }
 463      }
 464  
 465      /// Create a pong with a specific nonce.
 466      pub fn new(nonce: u32) -> Self {
 467          Self { nonce }
 468      }
 469  }
 470  
 471  // =============================================================================
 472  // ERROR CODES
 473  // =============================================================================
 474  
 475  /// Error codes for negative acknowledgements.
 476  ///
 477  /// These codes indicate why a message was rejected. They are intentionally
 478  /// limited to prevent information leakage.
 479  ///
 480  /// # Security Considerations
 481  ///
 482  /// - Cryptographic errors are grouped to prevent oracle attacks
 483  /// - Internal errors are not exposed to the peer
 484  /// - Error codes should be logged but details should not be sent
 485  #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 486  #[repr(u8)]
 487  pub enum ErrorCode {
 488      /// No error (should not appear in NACK).
 489      Ok = 0x00,
 490  
 491      /// Cryptographic verification failed.
 492      ///
 493      /// This covers signature failures, decryption failures, and
 494      /// authentication tag mismatches. We don't distinguish between
 495      /// them to prevent oracle attacks.
 496      CryptoError = 0x01,
 497  
 498      /// Message ID was already seen (replay detected).
 499      ReplayDetected = 0x02,
 500  
 501      /// Message timestamp is too old.
 502      Expired = 0x03,
 503  
 504      /// Sender is not in contact list.
 505      UnknownSender = 0x04,
 506  
 507      /// Storage is full, cannot accept more messages.
 508      StorageFull = 0x05,
 509  
 510      /// Message format is invalid.
 511      InvalidFormat = 0x06,
 512  
 513      /// Internal error (generic, for unexpected failures).
 514      InternalError = 0xFF,
 515  }
 516  
 517  impl ErrorCode {
 518      /// Convert a byte to an ErrorCode.
 519      ///
 520      /// Returns `None` for unrecognized codes.
 521      pub fn from_byte(byte: u8) -> Option<Self> {
 522          match byte {
 523              0x00 => Some(ErrorCode::Ok),
 524              0x01 => Some(ErrorCode::CryptoError),
 525              0x02 => Some(ErrorCode::ReplayDetected),
 526              0x03 => Some(ErrorCode::Expired),
 527              0x04 => Some(ErrorCode::UnknownSender),
 528              0x05 => Some(ErrorCode::StorageFull),
 529              0x06 => Some(ErrorCode::InvalidFormat),
 530              0xFF => Some(ErrorCode::InternalError),
 531              _ => None,
 532          }
 533      }
 534  
 535      /// Convert ErrorCode to its wire byte representation.
 536      pub fn to_byte(self) -> u8 {
 537          self as u8
 538      }
 539  }
 540  
 541  impl From<&DeadDropError> for ErrorCode {
 542      /// Convert a DeadDropError to the appropriate ErrorCode for NACK.
 543      ///
 544      /// This mapping intentionally groups errors to prevent information leakage.
 545      fn from(err: &DeadDropError) -> Self {
 546          match err {
 547              // Crypto errors are grouped
 548              DeadDropError::Decryption(_)
 549              | DeadDropError::InvalidSignature
 550              | DeadDropError::InvalidKey(_) => ErrorCode::CryptoError,
 551  
 552              DeadDropError::ReplayDetected => ErrorCode::ReplayDetected,
 553              DeadDropError::MessageExpired => ErrorCode::Expired,
 554              DeadDropError::UnknownContact => ErrorCode::UnknownSender,
 555              DeadDropError::InvalidFormat(_) | DeadDropError::Deserialization(_) => {
 556                  ErrorCode::InvalidFormat
 557              }
 558  
 559              // Everything else maps to internal error
 560              _ => ErrorCode::InternalError,
 561          }
 562      }
 563  }
 564  
 565  // =============================================================================
 566  // SERIALIZATION
 567  // =============================================================================
 568  
 569  /// Serialize a message for wire transmission.
 570  ///
 571  /// Produces a byte vector with the message type prefix followed by the
 572  /// bincode-encoded payload.
 573  ///
 574  /// # Arguments
 575  ///
 576  /// * `msg_type` - The type of message being serialized
 577  /// * `msg` - The message payload to serialize
 578  ///
 579  /// # Returns
 580  ///
 581  /// A byte vector ready for transmission.
 582  ///
 583  /// # Errors
 584  ///
 585  /// Returns `Serialization` if bincode encoding fails.
 586  ///
 587  /// # Example
 588  ///
 589  /// ```
 590  /// use dead_drop_core::protocol::wire::{MessageType, MsgCount, serialize};
 591  ///
 592  /// let msg = MsgCount::new(3);
 593  /// let bytes = serialize(MessageType::MsgCount, &msg).unwrap();
 594  ///
 595  /// // First byte is message type
 596  /// assert_eq!(bytes[0], 0x01);
 597  /// ```
 598  pub fn serialize<T: Serialize>(msg_type: MessageType, msg: &T) -> Result<Vec<u8>> {
 599      let payload = bincode::serialize(msg)?;
 600  
 601      // Check total size
 602      let total_size = 1 + payload.len();
 603      if total_size > MAX_WIRE_MESSAGE_SIZE {
 604          return Err(DeadDropError::MessageTooLarge {
 605              size: total_size,
 606              max: MAX_WIRE_MESSAGE_SIZE,
 607          });
 608      }
 609  
 610      let mut result = Vec::with_capacity(total_size);
 611      result.push(msg_type.to_byte());
 612      result.extend(payload);
 613  
 614      Ok(result)
 615  }
 616  
 617  /// Deserialize a wire message.
 618  ///
 619  /// Parses the message type prefix and returns it along with the raw
 620  /// payload bytes for type-specific deserialization.
 621  ///
 622  /// # Arguments
 623  ///
 624  /// * `bytes` - The raw wire message bytes
 625  ///
 626  /// # Returns
 627  ///
 628  /// A tuple of (MessageType, payload_bytes).
 629  ///
 630  /// # Errors
 631  ///
 632  /// - `InvalidFormat` if the message is empty or has unknown type
 633  ///
 634  /// # Example
 635  ///
 636  /// ```
 637  /// use dead_drop_core::protocol::wire::{MessageType, MsgCount, serialize, deserialize};
 638  ///
 639  /// let msg = MsgCount::new(3);
 640  /// let bytes = serialize(MessageType::MsgCount, &msg).unwrap();
 641  ///
 642  /// let (msg_type, payload) = deserialize(&bytes).unwrap();
 643  /// assert_eq!(msg_type, MessageType::MsgCount);
 644  /// ```
 645  pub fn deserialize(bytes: &[u8]) -> Result<(MessageType, &[u8])> {
 646      if bytes.is_empty() {
 647          return Err(DeadDropError::InvalidFormat(
 648              "Empty message".to_string(),
 649          ));
 650      }
 651  
 652      if bytes.len() > MAX_WIRE_MESSAGE_SIZE {
 653          return Err(DeadDropError::MessageTooLarge {
 654              size: bytes.len(),
 655              max: MAX_WIRE_MESSAGE_SIZE,
 656          });
 657      }
 658  
 659      let msg_type = MessageType::from_byte(bytes[0])?;
 660      let payload = &bytes[1..];
 661  
 662      Ok((msg_type, payload))
 663  }
 664  
 665  /// Deserialize a typed message from wire bytes.
 666  ///
 667  /// Convenience function that combines [`deserialize`] with bincode decoding
 668  /// of the payload.
 669  ///
 670  /// # Type Parameters
 671  ///
 672  /// * `T` - The expected message type (must match the wire message type)
 673  ///
 674  /// # Arguments
 675  ///
 676  /// * `bytes` - The raw wire message bytes
 677  ///
 678  /// # Returns
 679  ///
 680  /// A tuple of (MessageType, decoded_message).
 681  ///
 682  /// # Errors
 683  ///
 684  /// - `InvalidFormat` for type/format errors
 685  /// - `Deserialization` if bincode decoding fails
 686  ///
 687  /// # Example
 688  ///
 689  /// ```
 690  /// use dead_drop_core::protocol::wire::{MessageType, MsgCount, serialize, deserialize_typed};
 691  ///
 692  /// let original = MsgCount::new(5);
 693  /// let bytes = serialize(MessageType::MsgCount, &original).unwrap();
 694  ///
 695  /// let (msg_type, decoded): (MessageType, MsgCount) = deserialize_typed(&bytes).unwrap();
 696  /// assert_eq!(decoded.count, 5);
 697  /// ```
 698  pub fn deserialize_typed<'a, T: Deserialize<'a>>(bytes: &'a [u8]) -> Result<(MessageType, T)> {
 699      let (msg_type, payload) = deserialize(bytes)?;
 700      let msg: T = bincode::deserialize(payload).map_err(|e| {
 701          DeadDropError::Deserialization(format!("Failed to decode payload: {}", e))
 702      })?;
 703      Ok((msg_type, msg))
 704  }
 705  
 706  /// Serialize a Ping message.
 707  pub fn serialize_ping(ping: &Ping) -> Result<Vec<u8>> {
 708      serialize(MessageType::Ping, ping)
 709  }
 710  
 711  /// Serialize a Pong message.
 712  pub fn serialize_pong(pong: &Pong) -> Result<Vec<u8>> {
 713      serialize(MessageType::Pong, pong)
 714  }
 715  
 716  /// Serialize a SessionDone message.
 717  pub fn serialize_session_done(done: &SessionDone) -> Result<Vec<u8>> {
 718      serialize(MessageType::SessionDone, done)
 719  }
 720  
 721  #[cfg(test)]
 722  mod tests {
 723      use super::*;
 724  
 725      // ==================== MessageType Tests ====================
 726  
 727      #[test]
 728      fn test_message_type_from_byte_valid() {
 729          assert_eq!(MessageType::from_byte(0x01).unwrap(), MessageType::MsgCount);
 730          assert_eq!(MessageType::from_byte(0x02).unwrap(), MessageType::MsgData);
 731          assert_eq!(MessageType::from_byte(0x03).unwrap(), MessageType::MsgAck);
 732          assert_eq!(MessageType::from_byte(0x04).unwrap(), MessageType::MsgNack);
 733          assert_eq!(
 734              MessageType::from_byte(0x05).unwrap(),
 735              MessageType::SessionDone
 736          );
 737          assert_eq!(MessageType::from_byte(0x06).unwrap(), MessageType::Ping);
 738          assert_eq!(MessageType::from_byte(0x07).unwrap(), MessageType::Pong);
 739      }
 740  
 741      #[test]
 742      fn test_message_type_from_byte_invalid() {
 743          assert!(MessageType::from_byte(0x00).is_err());
 744          assert!(MessageType::from_byte(0x08).is_err());
 745          assert!(MessageType::from_byte(0xFF).is_err());
 746      }
 747  
 748      #[test]
 749      fn test_message_type_round_trip() {
 750          for msg_type in [
 751              MessageType::MsgCount,
 752              MessageType::MsgData,
 753              MessageType::MsgAck,
 754              MessageType::MsgNack,
 755              MessageType::SessionDone,
 756              MessageType::Ping,
 757              MessageType::Pong,
 758          ] {
 759              let byte = msg_type.to_byte();
 760              let recovered = MessageType::from_byte(byte).unwrap();
 761              assert_eq!(recovered, msg_type);
 762          }
 763      }
 764  
 765      // ==================== MsgCount Tests ====================
 766  
 767      #[test]
 768      fn test_msg_count_new() {
 769          let msg = MsgCount::new(10);
 770          assert_eq!(msg.version, PROTOCOL_VERSION);
 771          assert_eq!(msg.count, 10);
 772      }
 773  
 774      #[test]
 775      fn test_msg_count_validate_success() {
 776          let msg = MsgCount::new(100);
 777          assert!(msg.validate().is_ok());
 778      }
 779  
 780      #[test]
 781      fn test_msg_count_validate_wrong_version() {
 782          let msg = MsgCount {
 783              version: 99,
 784              count: 5,
 785          };
 786          assert!(msg.validate().is_err());
 787      }
 788  
 789      #[test]
 790      fn test_msg_count_validate_count_too_large() {
 791          let msg = MsgCount {
 792              version: PROTOCOL_VERSION,
 793              count: MAX_MESSAGE_COUNT + 1,
 794          };
 795          assert!(msg.validate().is_err());
 796      }
 797  
 798      #[test]
 799      fn test_msg_count_serialize_deserialize() {
 800          let original = MsgCount::new(42);
 801          let bytes = serialize(MessageType::MsgCount, &original).unwrap();
 802  
 803          let (msg_type, decoded): (_, MsgCount) = deserialize_typed(&bytes).unwrap();
 804  
 805          assert_eq!(msg_type, MessageType::MsgCount);
 806          assert_eq!(decoded.version, original.version);
 807          assert_eq!(decoded.count, original.count);
 808      }
 809  
 810      // ==================== MsgData Tests ====================
 811  
 812      #[test]
 813      fn test_msg_data_new() {
 814          let payload = vec![1, 2, 3, 4, 5];
 815          let msg = MsgData::new(0, 3, payload.clone());
 816  
 817          assert_eq!(msg.index, 0);
 818          assert_eq!(msg.total, 3);
 819          assert_eq!(msg.payload, payload);
 820      }
 821  
 822      #[test]
 823      fn test_msg_data_validate_success() {
 824          let msg = MsgData::new(0, 1, vec![1, 2, 3]);
 825          assert!(msg.validate().is_ok());
 826  
 827          let msg = MsgData::new(2, 5, vec![1, 2, 3]);
 828          assert!(msg.validate().is_ok());
 829      }
 830  
 831      #[test]
 832      fn test_msg_data_validate_zero_total() {
 833          let msg = MsgData::new(0, 0, vec![1, 2, 3]);
 834          assert!(msg.validate().is_err());
 835      }
 836  
 837      #[test]
 838      fn test_msg_data_validate_index_out_of_range() {
 839          let msg = MsgData::new(5, 3, vec![1, 2, 3]);
 840          assert!(msg.validate().is_err());
 841  
 842          let msg = MsgData::new(3, 3, vec![1, 2, 3]);
 843          assert!(msg.validate().is_err());
 844      }
 845  
 846      #[test]
 847      fn test_msg_data_validate_payload_too_large() {
 848          let large_payload = vec![0u8; MAX_WIRE_MESSAGE_SIZE];
 849          let msg = MsgData::new(0, 1, large_payload);
 850          assert!(msg.validate().is_err());
 851      }
 852  
 853      #[test]
 854      fn test_msg_data_serialize_deserialize() {
 855          let payload = vec![0xDE, 0xAD, 0xBE, 0xEF];
 856          let original = MsgData::new(2, 5, payload);
 857          let bytes = serialize(MessageType::MsgData, &original).unwrap();
 858  
 859          let (msg_type, decoded): (_, MsgData) = deserialize_typed(&bytes).unwrap();
 860  
 861          assert_eq!(msg_type, MessageType::MsgData);
 862          assert_eq!(decoded.index, original.index);
 863          assert_eq!(decoded.total, original.total);
 864          assert_eq!(decoded.payload, original.payload);
 865      }
 866  
 867      // ==================== MsgAck Tests ====================
 868  
 869      #[test]
 870      fn test_msg_ack_success() {
 871          let ack = MsgAck::success(3);
 872          assert_eq!(ack.index, 3);
 873          assert_eq!(ack.status, 0);
 874      }
 875  
 876      #[test]
 877      fn test_msg_ack_serialize_deserialize() {
 878          let original = MsgAck::success(7);
 879          let bytes = serialize(MessageType::MsgAck, &original).unwrap();
 880  
 881          let (msg_type, decoded): (_, MsgAck) = deserialize_typed(&bytes).unwrap();
 882  
 883          assert_eq!(msg_type, MessageType::MsgAck);
 884          assert_eq!(decoded.index, original.index);
 885          assert_eq!(decoded.status, original.status);
 886      }
 887  
 888      // ==================== MsgNack Tests ====================
 889  
 890      #[test]
 891      fn test_msg_nack_new() {
 892          let nack = MsgNack::new(5, ErrorCode::ReplayDetected);
 893          assert_eq!(nack.index, 5);
 894          assert_eq!(nack.error_code, ErrorCode::ReplayDetected as u8);
 895      }
 896  
 897      #[test]
 898      fn test_msg_nack_error() {
 899          let nack = MsgNack::new(0, ErrorCode::CryptoError);
 900          assert_eq!(nack.error(), Some(ErrorCode::CryptoError));
 901  
 902          let nack = MsgNack {
 903              index: 0,
 904              error_code: 0x99, // Unknown
 905          };
 906          assert_eq!(nack.error(), None);
 907      }
 908  
 909      #[test]
 910      fn test_msg_nack_serialize_deserialize() {
 911          let original = MsgNack::new(2, ErrorCode::Expired);
 912          let bytes = serialize(MessageType::MsgNack, &original).unwrap();
 913  
 914          let (msg_type, decoded): (_, MsgNack) = deserialize_typed(&bytes).unwrap();
 915  
 916          assert_eq!(msg_type, MessageType::MsgNack);
 917          assert_eq!(decoded.index, original.index);
 918          assert_eq!(decoded.error_code, original.error_code);
 919      }
 920  
 921      // ==================== SessionDone Tests ====================
 922  
 923      #[test]
 924      fn test_session_done_new() {
 925          let done = SessionDone::new(5, 3);
 926          assert_eq!(done.sent_count, 5);
 927          assert_eq!(done.received_count, 3);
 928      }
 929  
 930      #[test]
 931      fn test_session_done_serialize_deserialize() {
 932          let original = SessionDone::new(10, 8);
 933          let bytes = serialize(MessageType::SessionDone, &original).unwrap();
 934  
 935          let (msg_type, decoded): (_, SessionDone) = deserialize_typed(&bytes).unwrap();
 936  
 937          assert_eq!(msg_type, MessageType::SessionDone);
 938          assert_eq!(decoded.sent_count, original.sent_count);
 939          assert_eq!(decoded.received_count, original.received_count);
 940      }
 941  
 942      // ==================== Ping/Pong Tests ====================
 943  
 944      #[test]
 945      fn test_ping_new() {
 946          let ping1 = Ping::new();
 947          let ping2 = Ping::new();
 948          // Random nonces should be different (extremely high probability)
 949          assert_ne!(ping1.nonce, ping2.nonce);
 950      }
 951  
 952      #[test]
 953      fn test_ping_with_nonce() {
 954          let ping = Ping::with_nonce(12345);
 955          assert_eq!(ping.nonce, 12345);
 956      }
 957  
 958      #[test]
 959      fn test_pong_from_ping() {
 960          let ping = Ping::with_nonce(0xDEADBEEF);
 961          let pong = Pong::from_ping(&ping);
 962          assert_eq!(pong.nonce, ping.nonce);
 963      }
 964  
 965      #[test]
 966      fn test_ping_pong_serialize_deserialize() {
 967          let ping = Ping::with_nonce(42);
 968          let bytes = serialize_ping(&ping).unwrap();
 969  
 970          let (msg_type, decoded): (_, Ping) = deserialize_typed(&bytes).unwrap();
 971  
 972          assert_eq!(msg_type, MessageType::Ping);
 973          assert_eq!(decoded.nonce, ping.nonce);
 974  
 975          let pong = Pong::from_ping(&ping);
 976          let bytes = serialize_pong(&pong).unwrap();
 977  
 978          let (msg_type, decoded): (_, Pong) = deserialize_typed(&bytes).unwrap();
 979  
 980          assert_eq!(msg_type, MessageType::Pong);
 981          assert_eq!(decoded.nonce, pong.nonce);
 982      }
 983  
 984      // ==================== ErrorCode Tests ====================
 985  
 986      #[test]
 987      fn test_error_code_from_byte() {
 988          assert_eq!(ErrorCode::from_byte(0x00), Some(ErrorCode::Ok));
 989          assert_eq!(ErrorCode::from_byte(0x01), Some(ErrorCode::CryptoError));
 990          assert_eq!(ErrorCode::from_byte(0x02), Some(ErrorCode::ReplayDetected));
 991          assert_eq!(ErrorCode::from_byte(0x03), Some(ErrorCode::Expired));
 992          assert_eq!(ErrorCode::from_byte(0x04), Some(ErrorCode::UnknownSender));
 993          assert_eq!(ErrorCode::from_byte(0x05), Some(ErrorCode::StorageFull));
 994          assert_eq!(ErrorCode::from_byte(0x06), Some(ErrorCode::InvalidFormat));
 995          assert_eq!(ErrorCode::from_byte(0xFF), Some(ErrorCode::InternalError));
 996          assert_eq!(ErrorCode::from_byte(0x99), None);
 997      }
 998  
 999      #[test]
1000      fn test_error_code_round_trip() {
1001          for code in [
1002              ErrorCode::Ok,
1003              ErrorCode::CryptoError,
1004              ErrorCode::ReplayDetected,
1005              ErrorCode::Expired,
1006              ErrorCode::UnknownSender,
1007              ErrorCode::StorageFull,
1008              ErrorCode::InvalidFormat,
1009              ErrorCode::InternalError,
1010          ] {
1011              let byte = code.to_byte();
1012              let recovered = ErrorCode::from_byte(byte).unwrap();
1013              assert_eq!(recovered, code);
1014          }
1015      }
1016  
1017      #[test]
1018      fn test_error_code_from_dead_drop_error() {
1019          assert_eq!(
1020              ErrorCode::from(&DeadDropError::Decryption("test".to_string())),
1021              ErrorCode::CryptoError
1022          );
1023          assert_eq!(
1024              ErrorCode::from(&DeadDropError::InvalidSignature),
1025              ErrorCode::CryptoError
1026          );
1027          assert_eq!(
1028              ErrorCode::from(&DeadDropError::ReplayDetected),
1029              ErrorCode::ReplayDetected
1030          );
1031          assert_eq!(
1032              ErrorCode::from(&DeadDropError::MessageExpired),
1033              ErrorCode::Expired
1034          );
1035          assert_eq!(
1036              ErrorCode::from(&DeadDropError::UnknownContact),
1037              ErrorCode::UnknownSender
1038          );
1039      }
1040  
1041      // ==================== Serialization Tests ====================
1042  
1043      #[test]
1044      fn test_deserialize_empty() {
1045          let result = deserialize(&[]);
1046          assert!(result.is_err());
1047      }
1048  
1049      #[test]
1050      fn test_deserialize_unknown_type() {
1051          let bytes = [0xFF, 0x01, 0x02, 0x03];
1052          let result = deserialize(&bytes);
1053          assert!(result.is_err());
1054      }
1055  
1056      #[test]
1057      fn test_serialize_message_too_large() {
1058          let large_payload = vec![0u8; MAX_WIRE_MESSAGE_SIZE + 1];
1059          let msg = MsgData::new(0, 1, large_payload);
1060          let result = serialize(MessageType::MsgData, &msg);
1061          assert!(result.is_err());
1062      }
1063  
1064      #[test]
1065      fn test_deserialize_message_too_large() {
1066          let large_bytes = vec![0x02; MAX_WIRE_MESSAGE_SIZE + 1];
1067          let result = deserialize(&large_bytes);
1068          assert!(result.is_err());
1069      }
1070  
1071      // ==================== Integration Tests ====================
1072  
1073      #[test]
1074      fn test_complete_message_exchange_flow() {
1075          // Simulate a message exchange protocol flow
1076  
1077          // 1. Both peers send message counts
1078          let alice_count = MsgCount::new(2);
1079          let bob_count = MsgCount::new(1);
1080  
1081          let alice_count_bytes = serialize(MessageType::MsgCount, &alice_count).unwrap();
1082          let bob_count_bytes = serialize(MessageType::MsgCount, &bob_count).unwrap();
1083  
1084          // Verify message type detection
1085          let (t, _) = deserialize(&alice_count_bytes).unwrap();
1086          assert_eq!(t, MessageType::MsgCount);
1087  
1088          // 2. Exchange messages
1089          let alice_msg_0 = MsgData::new(0, 2, vec![1, 2, 3]);
1090          let alice_msg_1 = MsgData::new(1, 2, vec![4, 5, 6]);
1091          let bob_msg_0 = MsgData::new(0, 1, vec![7, 8, 9]);
1092  
1093          // 3. Send acknowledgements
1094          let ack_0 = MsgAck::success(0);
1095          let ack_1 = MsgAck::success(1);
1096  
1097          // 4. Session done
1098          let alice_done = SessionDone::new(2, 1);
1099          let bob_done = SessionDone::new(1, 2);
1100  
1101          // Verify all serialize correctly
1102          assert!(serialize(MessageType::MsgData, &alice_msg_0).is_ok());
1103          assert!(serialize(MessageType::MsgData, &alice_msg_1).is_ok());
1104          assert!(serialize(MessageType::MsgData, &bob_msg_0).is_ok());
1105          assert!(serialize(MessageType::MsgAck, &ack_0).is_ok());
1106          assert!(serialize(MessageType::MsgAck, &ack_1).is_ok());
1107          assert!(serialize(MessageType::SessionDone, &alice_done).is_ok());
1108          assert!(serialize(MessageType::SessionDone, &bob_done).is_ok());
1109      }
1110  
1111      #[test]
1112      fn test_ping_pong_flow() {
1113          // Alice sends ping
1114          let ping = Ping::new();
1115          let ping_bytes = serialize_ping(&ping).unwrap();
1116  
1117          // Bob receives and responds
1118          let (msg_type, payload) = deserialize(&ping_bytes).unwrap();
1119          assert_eq!(msg_type, MessageType::Ping);
1120  
1121          let received_ping: Ping = bincode::deserialize(payload).unwrap();
1122          let pong = Pong::from_ping(&received_ping);
1123          let pong_bytes = serialize_pong(&pong).unwrap();
1124  
1125          // Alice receives pong
1126          let (msg_type, _): (_, Pong) = deserialize_typed(&pong_bytes).unwrap();
1127          assert_eq!(msg_type, MessageType::Pong);
1128      }
1129  }