wire.rs
1 //! Wire format for Dead Drop protocol messages. 2 //! 3 //! This module defines the binary wire format used for communication between 4 //! devices over BLE. All messages are prefixed with a single byte indicating 5 //! the message type, followed by bincode-encoded payload data. 6 //! 7 //! # Message Types 8 //! 9 //! | Type | Code | Direction | Purpose | 10 //! |------|------|-----------|---------| 11 //! | [`MsgCount`] | `0x01` | Both | Announce number of messages to send | 12 //! | [`MsgData`] | `0x02` | Both | Encrypted message payload | 13 //! | [`MsgAck`] | `0x03` | Both | Acknowledge receipt of message | 14 //! | [`MsgNack`] | `0x04` | Both | Negative acknowledgement with error | 15 //! | [`SessionDone`] | `0x05` | Both | Signal exchange completion | 16 //! | [`Ping`] | `0x06` | Both | Keep-alive / connectivity check | 17 //! | [`Pong`] | `0x07` | Both | Response to ping | 18 //! 19 //! # Wire Format 20 //! 21 //! ```text 22 //! +--------+-------------------+ 23 //! | 1 byte | N bytes | 24 //! | type | bincode payload | 25 //! +--------+-------------------+ 26 //! ``` 27 //! 28 //! # Protocol Version 29 //! 30 //! The current protocol version is 1. Version negotiation happens during 31 //! the message count exchange phase. 32 //! 33 //! # Size Limits 34 //! 35 //! All wire messages are limited to [`MAX_WIRE_MESSAGE_SIZE`] bytes to prevent 36 //! memory exhaustion attacks and ensure compatibility with BLE MTU constraints. 37 //! 38 //! # Example 39 //! 40 //! ``` 41 //! use dead_drop_core::protocol::wire::{ 42 //! MessageType, MsgCount, serialize, deserialize, 43 //! }; 44 //! 45 //! // Create a message count announcement 46 //! let msg = MsgCount { 47 //! version: 1, 48 //! count: 5, 49 //! }; 50 //! 51 //! // Serialize for transmission 52 //! let bytes = serialize(MessageType::MsgCount, &msg).unwrap(); 53 //! 54 //! // Deserialize on receipt 55 //! let (msg_type, payload) = deserialize(&bytes).unwrap(); 56 //! assert_eq!(msg_type, MessageType::MsgCount); 57 //! ``` 58 59 use serde::{Deserialize, Serialize}; 60 61 use crate::error::{DeadDropError, Result}; 62 63 // ============================================================================= 64 // CONSTANTS 65 // ============================================================================= 66 67 /// Current protocol version. 68 /// 69 /// This version number is included in the initial message count exchange 70 /// to allow peers to detect incompatible protocol versions. 71 pub const PROTOCOL_VERSION: u8 = 1; 72 73 /// Maximum size of a wire message in bytes. 74 /// 75 /// This limit ensures messages fit within BLE constraints and prevents 76 /// memory exhaustion attacks. Set to 64KB which is generous for the 77 /// expected use case of text messages and small documents. 78 pub const MAX_WIRE_MESSAGE_SIZE: usize = 65536; 79 80 /// Maximum number of messages that can be exchanged in a single session. 81 /// 82 /// This prevents resource exhaustion from a peer claiming to have an 83 /// unreasonable number of messages. 84 pub const MAX_MESSAGE_COUNT: u16 = 1000; 85 86 // ============================================================================= 87 // MESSAGE TYPES 88 // ============================================================================= 89 90 /// Wire message type identifier. 91 /// 92 /// Each message on the wire begins with a single byte identifying its type. 93 /// This allows the receiver to determine how to parse the remaining payload. 94 /// 95 /// # Wire Encoding 96 /// 97 /// The discriminant values are explicitly defined to ensure stable wire format 98 /// across compiler versions and platforms. 99 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 100 #[repr(u8)] 101 pub enum MessageType { 102 /// Message count announcement. 103 /// 104 /// Sent at the start of an exchange to inform the peer how many 105 /// messages will be transmitted. 106 MsgCount = 0x01, 107 108 /// Encrypted message data. 109 /// 110 /// Contains an encrypted message payload along with index information 111 /// for ordering and acknowledgement. 112 MsgData = 0x02, 113 114 /// Positive acknowledgement. 115 /// 116 /// Confirms successful receipt and processing of a message. 117 MsgAck = 0x03, 118 119 /// Negative acknowledgement. 120 /// 121 /// Indicates a problem processing a received message, with an 122 /// error code explaining the failure. 123 MsgNack = 0x04, 124 125 /// Session complete signal. 126 /// 127 /// Sent when a peer has finished sending all messages and received 128 /// all expected acknowledgements. 129 SessionDone = 0x05, 130 131 /// Keep-alive ping. 132 /// 133 /// Used to verify the connection is still active during long 134 /// operations or idle periods. 135 Ping = 0x06, 136 137 /// Ping response. 138 /// 139 /// Sent in response to a Ping to confirm connectivity. 140 Pong = 0x07, 141 } 142 143 impl MessageType { 144 /// Convert a byte to a MessageType. 145 /// 146 /// # Errors 147 /// 148 /// Returns `InvalidFormat` if the byte does not correspond to a known 149 /// message type. 150 pub fn from_byte(byte: u8) -> Result<Self> { 151 match byte { 152 0x01 => Ok(MessageType::MsgCount), 153 0x02 => Ok(MessageType::MsgData), 154 0x03 => Ok(MessageType::MsgAck), 155 0x04 => Ok(MessageType::MsgNack), 156 0x05 => Ok(MessageType::SessionDone), 157 0x06 => Ok(MessageType::Ping), 158 0x07 => Ok(MessageType::Pong), 159 _ => Err(DeadDropError::InvalidFormat(format!( 160 "Unknown message type: 0x{:02X}", 161 byte 162 ))), 163 } 164 } 165 166 /// Convert MessageType to its wire byte representation. 167 pub fn to_byte(self) -> u8 { 168 self as u8 169 } 170 } 171 172 // ============================================================================= 173 // MESSAGE STRUCTURES 174 // ============================================================================= 175 176 /// Message count announcement. 177 /// 178 /// Sent at the beginning of a message exchange to inform the peer how many 179 /// messages will be transmitted. Both peers send this message after the 180 /// handshake completes. 181 /// 182 /// # Fields 183 /// 184 /// - `version`: Protocol version for compatibility checking 185 /// - `count`: Number of messages this peer will send (0 to [`MAX_MESSAGE_COUNT`]) 186 /// 187 /// # Example 188 /// 189 /// ``` 190 /// use dead_drop_core::protocol::wire::MsgCount; 191 /// 192 /// let count = MsgCount { 193 /// version: 1, 194 /// count: 3, 195 /// }; 196 /// ``` 197 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] 198 pub struct MsgCount { 199 /// Protocol version (must match [`PROTOCOL_VERSION`]). 200 pub version: u8, 201 /// Number of messages to send (0 = no messages, just receiving). 202 pub count: u16, 203 } 204 205 impl MsgCount { 206 /// Create a new MsgCount with the current protocol version. 207 /// 208 /// # Arguments 209 /// 210 /// * `count` - Number of messages to announce 211 /// 212 /// # Example 213 /// 214 /// ``` 215 /// use dead_drop_core::protocol::wire::MsgCount; 216 /// 217 /// let msg = MsgCount::new(5); 218 /// assert_eq!(msg.version, 1); 219 /// assert_eq!(msg.count, 5); 220 /// ``` 221 pub fn new(count: u16) -> Self { 222 Self { 223 version: PROTOCOL_VERSION, 224 count, 225 } 226 } 227 228 /// Validate the message count announcement. 229 /// 230 /// # Errors 231 /// 232 /// - `InvalidFormat` if version doesn't match current protocol version 233 /// - `InvalidFormat` if count exceeds [`MAX_MESSAGE_COUNT`] 234 pub fn validate(&self) -> Result<()> { 235 if self.version != PROTOCOL_VERSION { 236 return Err(DeadDropError::InvalidFormat(format!( 237 "Protocol version mismatch: expected {}, got {}", 238 PROTOCOL_VERSION, self.version 239 ))); 240 } 241 if self.count > MAX_MESSAGE_COUNT { 242 return Err(DeadDropError::InvalidFormat(format!( 243 "Message count {} exceeds maximum {}", 244 self.count, MAX_MESSAGE_COUNT 245 ))); 246 } 247 Ok(()) 248 } 249 } 250 251 /// Encrypted message data payload. 252 /// 253 /// Contains an encrypted message along with metadata for ordering and 254 /// reassembly. Messages are sent in order by index. 255 /// 256 /// # Fields 257 /// 258 /// - `index`: Zero-based position of this message (0 to total-1) 259 /// - `total`: Total number of messages being sent 260 /// - `payload`: The encrypted message bytes 261 /// 262 /// # Size Limit 263 /// 264 /// The payload must not exceed `MAX_WIRE_MESSAGE_SIZE - overhead` bytes. 265 /// In practice, encrypted messages should be under 32KB. 266 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] 267 pub struct MsgData { 268 /// Zero-based index of this message in the sequence. 269 pub index: u8, 270 /// Total number of messages in this exchange. 271 pub total: u8, 272 /// Encrypted message payload. 273 pub payload: Vec<u8>, 274 } 275 276 impl MsgData { 277 /// Create a new MsgData. 278 /// 279 /// # Arguments 280 /// 281 /// * `index` - Position of this message (0-based) 282 /// * `total` - Total messages in exchange 283 /// * `payload` - Encrypted message bytes 284 pub fn new(index: u8, total: u8, payload: Vec<u8>) -> Self { 285 Self { 286 index, 287 total, 288 payload, 289 } 290 } 291 292 /// Validate the message data. 293 /// 294 /// # Errors 295 /// 296 /// - `InvalidFormat` if index >= total 297 /// - `InvalidFormat` if total is 0 298 /// - `MessageTooLarge` if payload exceeds size limit 299 pub fn validate(&self) -> Result<()> { 300 if self.total == 0 { 301 return Err(DeadDropError::InvalidFormat( 302 "Message total cannot be zero".to_string(), 303 )); 304 } 305 if self.index >= self.total { 306 return Err(DeadDropError::InvalidFormat(format!( 307 "Message index {} out of range (total: {})", 308 self.index, self.total 309 ))); 310 } 311 // Reserve some bytes for wire overhead 312 let max_payload = MAX_WIRE_MESSAGE_SIZE - 16; 313 if self.payload.len() > max_payload { 314 return Err(DeadDropError::MessageTooLarge { 315 size: self.payload.len(), 316 max: max_payload, 317 }); 318 } 319 Ok(()) 320 } 321 } 322 323 /// Positive acknowledgement for a received message. 324 /// 325 /// Sent after successfully receiving and processing a message to confirm 326 /// delivery. The sender can mark the message as delivered upon receiving 327 /// this acknowledgement. 328 /// 329 /// # Fields 330 /// 331 /// - `index`: The index of the message being acknowledged 332 /// - `status`: Status code (currently always 0 for success) 333 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] 334 pub struct MsgAck { 335 /// Index of the acknowledged message. 336 pub index: u8, 337 /// Status code (0 = success). 338 pub status: u8, 339 } 340 341 impl MsgAck { 342 /// Create a successful acknowledgement. 343 /// 344 /// # Arguments 345 /// 346 /// * `index` - Index of the message being acknowledged 347 pub fn success(index: u8) -> Self { 348 Self { index, status: 0 } 349 } 350 } 351 352 /// Negative acknowledgement with error information. 353 /// 354 /// Sent when a message could not be processed successfully. The error 355 /// code indicates the reason for failure. 356 /// 357 /// # Security Note 358 /// 359 /// Error codes are intentionally limited to avoid leaking information 360 /// that could aid attackers. For example, we don't distinguish between 361 /// "wrong key" and "tampered message". 362 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] 363 pub struct MsgNack { 364 /// Index of the rejected message. 365 pub index: u8, 366 /// Error code indicating the failure reason. 367 pub error_code: u8, 368 } 369 370 impl MsgNack { 371 /// Create a negative acknowledgement. 372 /// 373 /// # Arguments 374 /// 375 /// * `index` - Index of the rejected message 376 /// * `error_code` - Reason for rejection (see [`ErrorCode`]) 377 pub fn new(index: u8, error_code: ErrorCode) -> Self { 378 Self { 379 index, 380 error_code: error_code as u8, 381 } 382 } 383 384 /// Get the error code as an enum. 385 /// 386 /// Returns `None` if the error code is not recognized. 387 pub fn error(&self) -> Option<ErrorCode> { 388 ErrorCode::from_byte(self.error_code) 389 } 390 } 391 392 /// Session completion signal. 393 /// 394 /// Sent when a peer has finished sending all messages and received 395 /// acknowledgements for all of them (or given up on unacknowledged ones). 396 /// 397 /// # Fields 398 /// 399 /// - `sent_count`: Number of messages successfully sent 400 /// - `received_count`: Number of messages successfully received 401 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] 402 pub struct SessionDone { 403 /// Number of messages sent by this peer. 404 pub sent_count: u8, 405 /// Number of messages received by this peer. 406 pub received_count: u8, 407 } 408 409 impl SessionDone { 410 /// Create a session done message. 411 pub fn new(sent_count: u8, received_count: u8) -> Self { 412 Self { 413 sent_count, 414 received_count, 415 } 416 } 417 } 418 419 /// Keep-alive ping message. 420 /// 421 /// Contains a random nonce that must be echoed in the Pong response. 422 /// This allows detection of connection issues and prevents replay attacks. 423 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] 424 pub struct Ping { 425 /// Random nonce to be echoed in Pong. 426 pub nonce: u32, 427 } 428 429 impl Ping { 430 /// Create a new ping with a random nonce. 431 pub fn new() -> Self { 432 use rand::RngCore; 433 Self { 434 nonce: rand::rngs::OsRng.next_u32(), 435 } 436 } 437 438 /// Create a ping with a specific nonce (for testing). 439 pub fn with_nonce(nonce: u32) -> Self { 440 Self { nonce } 441 } 442 } 443 444 impl Default for Ping { 445 fn default() -> Self { 446 Self::new() 447 } 448 } 449 450 /// Keep-alive pong response. 451 /// 452 /// Echoes the nonce from the corresponding Ping message. 453 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] 454 pub struct Pong { 455 /// Echoed nonce from the Ping. 456 pub nonce: u32, 457 } 458 459 impl Pong { 460 /// Create a pong response to a ping. 461 pub fn from_ping(ping: &Ping) -> Self { 462 Self { nonce: ping.nonce } 463 } 464 465 /// Create a pong with a specific nonce. 466 pub fn new(nonce: u32) -> Self { 467 Self { nonce } 468 } 469 } 470 471 // ============================================================================= 472 // ERROR CODES 473 // ============================================================================= 474 475 /// Error codes for negative acknowledgements. 476 /// 477 /// These codes indicate why a message was rejected. They are intentionally 478 /// limited to prevent information leakage. 479 /// 480 /// # Security Considerations 481 /// 482 /// - Cryptographic errors are grouped to prevent oracle attacks 483 /// - Internal errors are not exposed to the peer 484 /// - Error codes should be logged but details should not be sent 485 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 486 #[repr(u8)] 487 pub enum ErrorCode { 488 /// No error (should not appear in NACK). 489 Ok = 0x00, 490 491 /// Cryptographic verification failed. 492 /// 493 /// This covers signature failures, decryption failures, and 494 /// authentication tag mismatches. We don't distinguish between 495 /// them to prevent oracle attacks. 496 CryptoError = 0x01, 497 498 /// Message ID was already seen (replay detected). 499 ReplayDetected = 0x02, 500 501 /// Message timestamp is too old. 502 Expired = 0x03, 503 504 /// Sender is not in contact list. 505 UnknownSender = 0x04, 506 507 /// Storage is full, cannot accept more messages. 508 StorageFull = 0x05, 509 510 /// Message format is invalid. 511 InvalidFormat = 0x06, 512 513 /// Internal error (generic, for unexpected failures). 514 InternalError = 0xFF, 515 } 516 517 impl ErrorCode { 518 /// Convert a byte to an ErrorCode. 519 /// 520 /// Returns `None` for unrecognized codes. 521 pub fn from_byte(byte: u8) -> Option<Self> { 522 match byte { 523 0x00 => Some(ErrorCode::Ok), 524 0x01 => Some(ErrorCode::CryptoError), 525 0x02 => Some(ErrorCode::ReplayDetected), 526 0x03 => Some(ErrorCode::Expired), 527 0x04 => Some(ErrorCode::UnknownSender), 528 0x05 => Some(ErrorCode::StorageFull), 529 0x06 => Some(ErrorCode::InvalidFormat), 530 0xFF => Some(ErrorCode::InternalError), 531 _ => None, 532 } 533 } 534 535 /// Convert ErrorCode to its wire byte representation. 536 pub fn to_byte(self) -> u8 { 537 self as u8 538 } 539 } 540 541 impl From<&DeadDropError> for ErrorCode { 542 /// Convert a DeadDropError to the appropriate ErrorCode for NACK. 543 /// 544 /// This mapping intentionally groups errors to prevent information leakage. 545 fn from(err: &DeadDropError) -> Self { 546 match err { 547 // Crypto errors are grouped 548 DeadDropError::Decryption(_) 549 | DeadDropError::InvalidSignature 550 | DeadDropError::InvalidKey(_) => ErrorCode::CryptoError, 551 552 DeadDropError::ReplayDetected => ErrorCode::ReplayDetected, 553 DeadDropError::MessageExpired => ErrorCode::Expired, 554 DeadDropError::UnknownContact => ErrorCode::UnknownSender, 555 DeadDropError::InvalidFormat(_) | DeadDropError::Deserialization(_) => { 556 ErrorCode::InvalidFormat 557 } 558 559 // Everything else maps to internal error 560 _ => ErrorCode::InternalError, 561 } 562 } 563 } 564 565 // ============================================================================= 566 // SERIALIZATION 567 // ============================================================================= 568 569 /// Serialize a message for wire transmission. 570 /// 571 /// Produces a byte vector with the message type prefix followed by the 572 /// bincode-encoded payload. 573 /// 574 /// # Arguments 575 /// 576 /// * `msg_type` - The type of message being serialized 577 /// * `msg` - The message payload to serialize 578 /// 579 /// # Returns 580 /// 581 /// A byte vector ready for transmission. 582 /// 583 /// # Errors 584 /// 585 /// Returns `Serialization` if bincode encoding fails. 586 /// 587 /// # Example 588 /// 589 /// ``` 590 /// use dead_drop_core::protocol::wire::{MessageType, MsgCount, serialize}; 591 /// 592 /// let msg = MsgCount::new(3); 593 /// let bytes = serialize(MessageType::MsgCount, &msg).unwrap(); 594 /// 595 /// // First byte is message type 596 /// assert_eq!(bytes[0], 0x01); 597 /// ``` 598 pub fn serialize<T: Serialize>(msg_type: MessageType, msg: &T) -> Result<Vec<u8>> { 599 let payload = bincode::serialize(msg)?; 600 601 // Check total size 602 let total_size = 1 + payload.len(); 603 if total_size > MAX_WIRE_MESSAGE_SIZE { 604 return Err(DeadDropError::MessageTooLarge { 605 size: total_size, 606 max: MAX_WIRE_MESSAGE_SIZE, 607 }); 608 } 609 610 let mut result = Vec::with_capacity(total_size); 611 result.push(msg_type.to_byte()); 612 result.extend(payload); 613 614 Ok(result) 615 } 616 617 /// Deserialize a wire message. 618 /// 619 /// Parses the message type prefix and returns it along with the raw 620 /// payload bytes for type-specific deserialization. 621 /// 622 /// # Arguments 623 /// 624 /// * `bytes` - The raw wire message bytes 625 /// 626 /// # Returns 627 /// 628 /// A tuple of (MessageType, payload_bytes). 629 /// 630 /// # Errors 631 /// 632 /// - `InvalidFormat` if the message is empty or has unknown type 633 /// 634 /// # Example 635 /// 636 /// ``` 637 /// use dead_drop_core::protocol::wire::{MessageType, MsgCount, serialize, deserialize}; 638 /// 639 /// let msg = MsgCount::new(3); 640 /// let bytes = serialize(MessageType::MsgCount, &msg).unwrap(); 641 /// 642 /// let (msg_type, payload) = deserialize(&bytes).unwrap(); 643 /// assert_eq!(msg_type, MessageType::MsgCount); 644 /// ``` 645 pub fn deserialize(bytes: &[u8]) -> Result<(MessageType, &[u8])> { 646 if bytes.is_empty() { 647 return Err(DeadDropError::InvalidFormat( 648 "Empty message".to_string(), 649 )); 650 } 651 652 if bytes.len() > MAX_WIRE_MESSAGE_SIZE { 653 return Err(DeadDropError::MessageTooLarge { 654 size: bytes.len(), 655 max: MAX_WIRE_MESSAGE_SIZE, 656 }); 657 } 658 659 let msg_type = MessageType::from_byte(bytes[0])?; 660 let payload = &bytes[1..]; 661 662 Ok((msg_type, payload)) 663 } 664 665 /// Deserialize a typed message from wire bytes. 666 /// 667 /// Convenience function that combines [`deserialize`] with bincode decoding 668 /// of the payload. 669 /// 670 /// # Type Parameters 671 /// 672 /// * `T` - The expected message type (must match the wire message type) 673 /// 674 /// # Arguments 675 /// 676 /// * `bytes` - The raw wire message bytes 677 /// 678 /// # Returns 679 /// 680 /// A tuple of (MessageType, decoded_message). 681 /// 682 /// # Errors 683 /// 684 /// - `InvalidFormat` for type/format errors 685 /// - `Deserialization` if bincode decoding fails 686 /// 687 /// # Example 688 /// 689 /// ``` 690 /// use dead_drop_core::protocol::wire::{MessageType, MsgCount, serialize, deserialize_typed}; 691 /// 692 /// let original = MsgCount::new(5); 693 /// let bytes = serialize(MessageType::MsgCount, &original).unwrap(); 694 /// 695 /// let (msg_type, decoded): (MessageType, MsgCount) = deserialize_typed(&bytes).unwrap(); 696 /// assert_eq!(decoded.count, 5); 697 /// ``` 698 pub fn deserialize_typed<'a, T: Deserialize<'a>>(bytes: &'a [u8]) -> Result<(MessageType, T)> { 699 let (msg_type, payload) = deserialize(bytes)?; 700 let msg: T = bincode::deserialize(payload).map_err(|e| { 701 DeadDropError::Deserialization(format!("Failed to decode payload: {}", e)) 702 })?; 703 Ok((msg_type, msg)) 704 } 705 706 /// Serialize a Ping message. 707 pub fn serialize_ping(ping: &Ping) -> Result<Vec<u8>> { 708 serialize(MessageType::Ping, ping) 709 } 710 711 /// Serialize a Pong message. 712 pub fn serialize_pong(pong: &Pong) -> Result<Vec<u8>> { 713 serialize(MessageType::Pong, pong) 714 } 715 716 /// Serialize a SessionDone message. 717 pub fn serialize_session_done(done: &SessionDone) -> Result<Vec<u8>> { 718 serialize(MessageType::SessionDone, done) 719 } 720 721 #[cfg(test)] 722 mod tests { 723 use super::*; 724 725 // ==================== MessageType Tests ==================== 726 727 #[test] 728 fn test_message_type_from_byte_valid() { 729 assert_eq!(MessageType::from_byte(0x01).unwrap(), MessageType::MsgCount); 730 assert_eq!(MessageType::from_byte(0x02).unwrap(), MessageType::MsgData); 731 assert_eq!(MessageType::from_byte(0x03).unwrap(), MessageType::MsgAck); 732 assert_eq!(MessageType::from_byte(0x04).unwrap(), MessageType::MsgNack); 733 assert_eq!( 734 MessageType::from_byte(0x05).unwrap(), 735 MessageType::SessionDone 736 ); 737 assert_eq!(MessageType::from_byte(0x06).unwrap(), MessageType::Ping); 738 assert_eq!(MessageType::from_byte(0x07).unwrap(), MessageType::Pong); 739 } 740 741 #[test] 742 fn test_message_type_from_byte_invalid() { 743 assert!(MessageType::from_byte(0x00).is_err()); 744 assert!(MessageType::from_byte(0x08).is_err()); 745 assert!(MessageType::from_byte(0xFF).is_err()); 746 } 747 748 #[test] 749 fn test_message_type_round_trip() { 750 for msg_type in [ 751 MessageType::MsgCount, 752 MessageType::MsgData, 753 MessageType::MsgAck, 754 MessageType::MsgNack, 755 MessageType::SessionDone, 756 MessageType::Ping, 757 MessageType::Pong, 758 ] { 759 let byte = msg_type.to_byte(); 760 let recovered = MessageType::from_byte(byte).unwrap(); 761 assert_eq!(recovered, msg_type); 762 } 763 } 764 765 // ==================== MsgCount Tests ==================== 766 767 #[test] 768 fn test_msg_count_new() { 769 let msg = MsgCount::new(10); 770 assert_eq!(msg.version, PROTOCOL_VERSION); 771 assert_eq!(msg.count, 10); 772 } 773 774 #[test] 775 fn test_msg_count_validate_success() { 776 let msg = MsgCount::new(100); 777 assert!(msg.validate().is_ok()); 778 } 779 780 #[test] 781 fn test_msg_count_validate_wrong_version() { 782 let msg = MsgCount { 783 version: 99, 784 count: 5, 785 }; 786 assert!(msg.validate().is_err()); 787 } 788 789 #[test] 790 fn test_msg_count_validate_count_too_large() { 791 let msg = MsgCount { 792 version: PROTOCOL_VERSION, 793 count: MAX_MESSAGE_COUNT + 1, 794 }; 795 assert!(msg.validate().is_err()); 796 } 797 798 #[test] 799 fn test_msg_count_serialize_deserialize() { 800 let original = MsgCount::new(42); 801 let bytes = serialize(MessageType::MsgCount, &original).unwrap(); 802 803 let (msg_type, decoded): (_, MsgCount) = deserialize_typed(&bytes).unwrap(); 804 805 assert_eq!(msg_type, MessageType::MsgCount); 806 assert_eq!(decoded.version, original.version); 807 assert_eq!(decoded.count, original.count); 808 } 809 810 // ==================== MsgData Tests ==================== 811 812 #[test] 813 fn test_msg_data_new() { 814 let payload = vec![1, 2, 3, 4, 5]; 815 let msg = MsgData::new(0, 3, payload.clone()); 816 817 assert_eq!(msg.index, 0); 818 assert_eq!(msg.total, 3); 819 assert_eq!(msg.payload, payload); 820 } 821 822 #[test] 823 fn test_msg_data_validate_success() { 824 let msg = MsgData::new(0, 1, vec![1, 2, 3]); 825 assert!(msg.validate().is_ok()); 826 827 let msg = MsgData::new(2, 5, vec![1, 2, 3]); 828 assert!(msg.validate().is_ok()); 829 } 830 831 #[test] 832 fn test_msg_data_validate_zero_total() { 833 let msg = MsgData::new(0, 0, vec![1, 2, 3]); 834 assert!(msg.validate().is_err()); 835 } 836 837 #[test] 838 fn test_msg_data_validate_index_out_of_range() { 839 let msg = MsgData::new(5, 3, vec![1, 2, 3]); 840 assert!(msg.validate().is_err()); 841 842 let msg = MsgData::new(3, 3, vec![1, 2, 3]); 843 assert!(msg.validate().is_err()); 844 } 845 846 #[test] 847 fn test_msg_data_validate_payload_too_large() { 848 let large_payload = vec![0u8; MAX_WIRE_MESSAGE_SIZE]; 849 let msg = MsgData::new(0, 1, large_payload); 850 assert!(msg.validate().is_err()); 851 } 852 853 #[test] 854 fn test_msg_data_serialize_deserialize() { 855 let payload = vec![0xDE, 0xAD, 0xBE, 0xEF]; 856 let original = MsgData::new(2, 5, payload); 857 let bytes = serialize(MessageType::MsgData, &original).unwrap(); 858 859 let (msg_type, decoded): (_, MsgData) = deserialize_typed(&bytes).unwrap(); 860 861 assert_eq!(msg_type, MessageType::MsgData); 862 assert_eq!(decoded.index, original.index); 863 assert_eq!(decoded.total, original.total); 864 assert_eq!(decoded.payload, original.payload); 865 } 866 867 // ==================== MsgAck Tests ==================== 868 869 #[test] 870 fn test_msg_ack_success() { 871 let ack = MsgAck::success(3); 872 assert_eq!(ack.index, 3); 873 assert_eq!(ack.status, 0); 874 } 875 876 #[test] 877 fn test_msg_ack_serialize_deserialize() { 878 let original = MsgAck::success(7); 879 let bytes = serialize(MessageType::MsgAck, &original).unwrap(); 880 881 let (msg_type, decoded): (_, MsgAck) = deserialize_typed(&bytes).unwrap(); 882 883 assert_eq!(msg_type, MessageType::MsgAck); 884 assert_eq!(decoded.index, original.index); 885 assert_eq!(decoded.status, original.status); 886 } 887 888 // ==================== MsgNack Tests ==================== 889 890 #[test] 891 fn test_msg_nack_new() { 892 let nack = MsgNack::new(5, ErrorCode::ReplayDetected); 893 assert_eq!(nack.index, 5); 894 assert_eq!(nack.error_code, ErrorCode::ReplayDetected as u8); 895 } 896 897 #[test] 898 fn test_msg_nack_error() { 899 let nack = MsgNack::new(0, ErrorCode::CryptoError); 900 assert_eq!(nack.error(), Some(ErrorCode::CryptoError)); 901 902 let nack = MsgNack { 903 index: 0, 904 error_code: 0x99, // Unknown 905 }; 906 assert_eq!(nack.error(), None); 907 } 908 909 #[test] 910 fn test_msg_nack_serialize_deserialize() { 911 let original = MsgNack::new(2, ErrorCode::Expired); 912 let bytes = serialize(MessageType::MsgNack, &original).unwrap(); 913 914 let (msg_type, decoded): (_, MsgNack) = deserialize_typed(&bytes).unwrap(); 915 916 assert_eq!(msg_type, MessageType::MsgNack); 917 assert_eq!(decoded.index, original.index); 918 assert_eq!(decoded.error_code, original.error_code); 919 } 920 921 // ==================== SessionDone Tests ==================== 922 923 #[test] 924 fn test_session_done_new() { 925 let done = SessionDone::new(5, 3); 926 assert_eq!(done.sent_count, 5); 927 assert_eq!(done.received_count, 3); 928 } 929 930 #[test] 931 fn test_session_done_serialize_deserialize() { 932 let original = SessionDone::new(10, 8); 933 let bytes = serialize(MessageType::SessionDone, &original).unwrap(); 934 935 let (msg_type, decoded): (_, SessionDone) = deserialize_typed(&bytes).unwrap(); 936 937 assert_eq!(msg_type, MessageType::SessionDone); 938 assert_eq!(decoded.sent_count, original.sent_count); 939 assert_eq!(decoded.received_count, original.received_count); 940 } 941 942 // ==================== Ping/Pong Tests ==================== 943 944 #[test] 945 fn test_ping_new() { 946 let ping1 = Ping::new(); 947 let ping2 = Ping::new(); 948 // Random nonces should be different (extremely high probability) 949 assert_ne!(ping1.nonce, ping2.nonce); 950 } 951 952 #[test] 953 fn test_ping_with_nonce() { 954 let ping = Ping::with_nonce(12345); 955 assert_eq!(ping.nonce, 12345); 956 } 957 958 #[test] 959 fn test_pong_from_ping() { 960 let ping = Ping::with_nonce(0xDEADBEEF); 961 let pong = Pong::from_ping(&ping); 962 assert_eq!(pong.nonce, ping.nonce); 963 } 964 965 #[test] 966 fn test_ping_pong_serialize_deserialize() { 967 let ping = Ping::with_nonce(42); 968 let bytes = serialize_ping(&ping).unwrap(); 969 970 let (msg_type, decoded): (_, Ping) = deserialize_typed(&bytes).unwrap(); 971 972 assert_eq!(msg_type, MessageType::Ping); 973 assert_eq!(decoded.nonce, ping.nonce); 974 975 let pong = Pong::from_ping(&ping); 976 let bytes = serialize_pong(&pong).unwrap(); 977 978 let (msg_type, decoded): (_, Pong) = deserialize_typed(&bytes).unwrap(); 979 980 assert_eq!(msg_type, MessageType::Pong); 981 assert_eq!(decoded.nonce, pong.nonce); 982 } 983 984 // ==================== ErrorCode Tests ==================== 985 986 #[test] 987 fn test_error_code_from_byte() { 988 assert_eq!(ErrorCode::from_byte(0x00), Some(ErrorCode::Ok)); 989 assert_eq!(ErrorCode::from_byte(0x01), Some(ErrorCode::CryptoError)); 990 assert_eq!(ErrorCode::from_byte(0x02), Some(ErrorCode::ReplayDetected)); 991 assert_eq!(ErrorCode::from_byte(0x03), Some(ErrorCode::Expired)); 992 assert_eq!(ErrorCode::from_byte(0x04), Some(ErrorCode::UnknownSender)); 993 assert_eq!(ErrorCode::from_byte(0x05), Some(ErrorCode::StorageFull)); 994 assert_eq!(ErrorCode::from_byte(0x06), Some(ErrorCode::InvalidFormat)); 995 assert_eq!(ErrorCode::from_byte(0xFF), Some(ErrorCode::InternalError)); 996 assert_eq!(ErrorCode::from_byte(0x99), None); 997 } 998 999 #[test] 1000 fn test_error_code_round_trip() { 1001 for code in [ 1002 ErrorCode::Ok, 1003 ErrorCode::CryptoError, 1004 ErrorCode::ReplayDetected, 1005 ErrorCode::Expired, 1006 ErrorCode::UnknownSender, 1007 ErrorCode::StorageFull, 1008 ErrorCode::InvalidFormat, 1009 ErrorCode::InternalError, 1010 ] { 1011 let byte = code.to_byte(); 1012 let recovered = ErrorCode::from_byte(byte).unwrap(); 1013 assert_eq!(recovered, code); 1014 } 1015 } 1016 1017 #[test] 1018 fn test_error_code_from_dead_drop_error() { 1019 assert_eq!( 1020 ErrorCode::from(&DeadDropError::Decryption("test".to_string())), 1021 ErrorCode::CryptoError 1022 ); 1023 assert_eq!( 1024 ErrorCode::from(&DeadDropError::InvalidSignature), 1025 ErrorCode::CryptoError 1026 ); 1027 assert_eq!( 1028 ErrorCode::from(&DeadDropError::ReplayDetected), 1029 ErrorCode::ReplayDetected 1030 ); 1031 assert_eq!( 1032 ErrorCode::from(&DeadDropError::MessageExpired), 1033 ErrorCode::Expired 1034 ); 1035 assert_eq!( 1036 ErrorCode::from(&DeadDropError::UnknownContact), 1037 ErrorCode::UnknownSender 1038 ); 1039 } 1040 1041 // ==================== Serialization Tests ==================== 1042 1043 #[test] 1044 fn test_deserialize_empty() { 1045 let result = deserialize(&[]); 1046 assert!(result.is_err()); 1047 } 1048 1049 #[test] 1050 fn test_deserialize_unknown_type() { 1051 let bytes = [0xFF, 0x01, 0x02, 0x03]; 1052 let result = deserialize(&bytes); 1053 assert!(result.is_err()); 1054 } 1055 1056 #[test] 1057 fn test_serialize_message_too_large() { 1058 let large_payload = vec![0u8; MAX_WIRE_MESSAGE_SIZE + 1]; 1059 let msg = MsgData::new(0, 1, large_payload); 1060 let result = serialize(MessageType::MsgData, &msg); 1061 assert!(result.is_err()); 1062 } 1063 1064 #[test] 1065 fn test_deserialize_message_too_large() { 1066 let large_bytes = vec![0x02; MAX_WIRE_MESSAGE_SIZE + 1]; 1067 let result = deserialize(&large_bytes); 1068 assert!(result.is_err()); 1069 } 1070 1071 // ==================== Integration Tests ==================== 1072 1073 #[test] 1074 fn test_complete_message_exchange_flow() { 1075 // Simulate a message exchange protocol flow 1076 1077 // 1. Both peers send message counts 1078 let alice_count = MsgCount::new(2); 1079 let bob_count = MsgCount::new(1); 1080 1081 let alice_count_bytes = serialize(MessageType::MsgCount, &alice_count).unwrap(); 1082 let bob_count_bytes = serialize(MessageType::MsgCount, &bob_count).unwrap(); 1083 1084 // Verify message type detection 1085 let (t, _) = deserialize(&alice_count_bytes).unwrap(); 1086 assert_eq!(t, MessageType::MsgCount); 1087 1088 // 2. Exchange messages 1089 let alice_msg_0 = MsgData::new(0, 2, vec![1, 2, 3]); 1090 let alice_msg_1 = MsgData::new(1, 2, vec![4, 5, 6]); 1091 let bob_msg_0 = MsgData::new(0, 1, vec![7, 8, 9]); 1092 1093 // 3. Send acknowledgements 1094 let ack_0 = MsgAck::success(0); 1095 let ack_1 = MsgAck::success(1); 1096 1097 // 4. Session done 1098 let alice_done = SessionDone::new(2, 1); 1099 let bob_done = SessionDone::new(1, 2); 1100 1101 // Verify all serialize correctly 1102 assert!(serialize(MessageType::MsgData, &alice_msg_0).is_ok()); 1103 assert!(serialize(MessageType::MsgData, &alice_msg_1).is_ok()); 1104 assert!(serialize(MessageType::MsgData, &bob_msg_0).is_ok()); 1105 assert!(serialize(MessageType::MsgAck, &ack_0).is_ok()); 1106 assert!(serialize(MessageType::MsgAck, &ack_1).is_ok()); 1107 assert!(serialize(MessageType::SessionDone, &alice_done).is_ok()); 1108 assert!(serialize(MessageType::SessionDone, &bob_done).is_ok()); 1109 } 1110 1111 #[test] 1112 fn test_ping_pong_flow() { 1113 // Alice sends ping 1114 let ping = Ping::new(); 1115 let ping_bytes = serialize_ping(&ping).unwrap(); 1116 1117 // Bob receives and responds 1118 let (msg_type, payload) = deserialize(&ping_bytes).unwrap(); 1119 assert_eq!(msg_type, MessageType::Ping); 1120 1121 let received_ping: Ping = bincode::deserialize(payload).unwrap(); 1122 let pong = Pong::from_ping(&received_ping); 1123 let pong_bytes = serialize_pong(&pong).unwrap(); 1124 1125 // Alice receives pong 1126 let (msg_type, _): (_, Pong) = deserialize_typed(&pong_bytes).unwrap(); 1127 assert_eq!(msg_type, MessageType::Pong); 1128 } 1129 }