exchange.rs
1 //! Message exchange state machine for Dead Drop protocol. 2 //! 3 //! This module implements the bidirectional message exchange that occurs 4 //! after a successful handshake. Both peers exchange message counts, then 5 //! transmit their queued messages and acknowledgements. 6 //! 7 //! # Exchange Flow 8 //! 9 //! ```text 10 //! Initiator Responder 11 //! | | 12 //! |────── MSG_COUNT(n=2) ───────────────>| (1) Exchange counts 13 //! |<───── MSG_COUNT(n=1) ────────────────| 14 //! | | 15 //! |────── MSG_DATA(0/2) ────────────────>| (2) Send messages 16 //! |────── MSG_DATA(1/2) ────────────────>| 17 //! |<───── MSG_DATA(0/1) ─────────────────| 18 //! | | 19 //! |<───── MSG_ACK(0) ────────────────────| (3) Acknowledgements 20 //! |<───── MSG_ACK(1) ────────────────────| 21 //! |────── MSG_ACK(0) ───────────────────>| 22 //! | | 23 //! |────── SESSION_DONE ─────────────────>| (4) Completion 24 //! |<───── SESSION_DONE ──────────────────| 25 //! | | 26 //! ``` 27 //! 28 //! # State Machine 29 //! 30 //! | State | Description | Next States | 31 //! |-------|-------------|-------------| 32 //! | `AwaitingCounts` | Waiting to exchange message counts | `Exchanging` | 33 //! | `Exchanging` | Sending/receiving messages and acks | `Finalizing` | 34 //! | `Finalizing` | Waiting for final acks and done signals | `Complete` | 35 //! | `Complete` | Exchange finished successfully | - | 36 //! | `Failed` | Exchange failed | - | 37 //! 38 //! # Reliability 39 //! 40 //! The exchange protocol handles: 41 //! - Message ordering (via index/total fields) 42 //! - Delivery confirmation (via ACK/NACK) 43 //! - Partial failures (NACK with error code) 44 //! - Session termination (SESSION_DONE from both sides) 45 //! 46 //! # Example 47 //! 48 //! ```ignore 49 //! use dead_drop_core::protocol::exchange::Exchange; 50 //! use dead_drop_core::crypto::noise::NoiseTransport; 51 //! 52 //! // After handshake completes... 53 //! let mut exchange = Exchange::new(transport, contact_id, outbound_messages); 54 //! 55 //! // Drive the exchange 56 //! loop { 57 //! // Send any pending data 58 //! if let Some(data) = exchange.get_next_to_send()? { 59 //! ble_send(&data); 60 //! } 61 //! 62 //! // Process incoming data 63 //! if let Some(incoming) = ble_receive() { 64 //! exchange.process_received(&incoming)?; 65 //! } 66 //! 67 //! if exchange.is_complete() { 68 //! break; 69 //! } 70 //! } 71 //! 72 //! // Get results 73 //! let result = exchange.finalize()?; 74 //! ``` 75 76 use crate::crypto::noise::NoiseTransport; 77 use crate::error::{DeadDropError, Result}; 78 use crate::protocol::messages::{ContactId, EncryptedMessage, MessageId}; 79 use crate::protocol::wire::{ 80 deserialize, serialize, ErrorCode, MessageType, MsgAck, MsgCount, MsgData, 81 MsgNack, SessionDone, 82 }; 83 84 // ============================================================================= 85 // CONSTANTS 86 // ============================================================================= 87 88 /// Maximum number of retries for sending a message. 89 pub const MAX_SEND_RETRIES: u8 = 3; 90 91 /// Timeout for waiting for acknowledgement (in exchange cycles). 92 pub const ACK_TIMEOUT_CYCLES: u8 = 10; 93 94 // ============================================================================= 95 // EXCHANGE STATE 96 // ============================================================================= 97 98 /// State of the message exchange. 99 #[derive(Debug, Clone, PartialEq, Eq)] 100 pub enum ExchangeState { 101 /// Waiting to send our message count. 102 SendingCount, 103 104 /// Waiting to receive peer's message count. 105 AwaitingCount, 106 107 /// Both counts exchanged, now exchanging messages. 108 Exchanging { 109 /// Number of messages we're sending. 110 our_total: u8, 111 /// Number of messages we're receiving. 112 their_total: u8, 113 /// Next message index to send. 114 send_index: u8, 115 /// Next message index we expect to receive. 116 receive_index: u8, 117 }, 118 119 /// All messages sent/received, waiting for final acks. 120 Finalizing { 121 /// Number of acks we're still waiting for. 122 pending_acks: u8, 123 /// Whether we've sent SESSION_DONE. 124 sent_done: bool, 125 /// Whether we've received SESSION_DONE. 126 received_done: bool, 127 }, 128 129 /// Exchange completed successfully. 130 Complete, 131 132 /// Exchange failed. 133 Failed(String), 134 } 135 136 impl ExchangeState { 137 /// Check if the exchange is complete. 138 pub fn is_complete(&self) -> bool { 139 matches!(self, ExchangeState::Complete) 140 } 141 142 /// Check if the exchange has failed. 143 pub fn is_failed(&self) -> bool { 144 matches!(self, ExchangeState::Failed(_)) 145 } 146 147 /// Check if the exchange is still in progress. 148 pub fn is_in_progress(&self) -> bool { 149 !self.is_complete() && !self.is_failed() 150 } 151 } 152 153 // ============================================================================= 154 // MESSAGE TRACKING 155 // ============================================================================= 156 157 /// Tracking information for an outbound message. 158 #[derive(Debug, Clone)] 159 struct OutboundTracker { 160 /// The encrypted message data. 161 message: EncryptedMessage, 162 /// Whether this message has been sent. 163 sent: bool, 164 /// Whether we received an ACK for this message. 165 acknowledged: bool, 166 /// Error code if we received a NACK. 167 error: Option<ErrorCode>, 168 } 169 170 /// Tracking information for an inbound message. 171 #[derive(Debug, Clone)] 172 struct InboundTracker { 173 /// The received encrypted message (if received). 174 message: Option<EncryptedMessage>, 175 /// Whether we sent an ACK/NACK. 176 responded: bool, 177 /// Error code if we sent a NACK. 178 error: Option<ErrorCode>, 179 } 180 181 // ============================================================================= 182 // EXCHANGE RESULT 183 // ============================================================================= 184 185 /// Result of a completed message exchange. 186 #[derive(Debug)] 187 pub struct ExchangeResult { 188 /// IDs of messages that were successfully delivered (acknowledged). 189 pub delivered_ids: Vec<MessageId>, 190 /// IDs of messages that failed to deliver (with error codes). 191 pub failed_ids: Vec<(MessageId, ErrorCode)>, 192 /// Messages that were successfully received. 193 pub received_messages: Vec<EncryptedMessage>, 194 /// Messages that failed to process (with error codes). 195 pub failed_receives: Vec<(u8, ErrorCode)>, 196 /// Total messages we attempted to send. 197 pub total_sent: u8, 198 /// Total messages the peer attempted to send. 199 pub total_received: u8, 200 } 201 202 // ============================================================================= 203 // EXCHANGE 204 // ============================================================================= 205 206 /// Message exchange state machine. 207 /// 208 /// Manages the bidirectional message exchange after handshake completion. 209 /// Handles message count exchange, data transfer, acknowledgements, and 210 /// session termination. 211 pub struct Exchange { 212 /// The secure transport for encryption. 213 transport: NoiseTransport, 214 /// The contact we're exchanging with. 215 contact_id: ContactId, 216 /// Current exchange state. 217 state: ExchangeState, 218 /// Our outbound messages. 219 outbound: Vec<OutboundTracker>, 220 /// Their inbound messages (slots). 221 inbound: Vec<InboundTracker>, 222 /// Whether we've sent our count. 223 sent_count: bool, 224 /// Whether we've received their count. 225 received_count: bool, 226 /// Queue of wire messages to send. 227 send_queue: Vec<Vec<u8>>, 228 } 229 230 impl Exchange { 231 /// Create a new exchange session. 232 /// 233 /// # Arguments 234 /// 235 /// * `transport` - The Noise transport from completed handshake 236 /// * `contact_id` - The contact we're exchanging with 237 /// * `outbound_messages` - Messages queued to send to this contact 238 /// 239 /// # Example 240 /// 241 /// ```ignore 242 /// let exchange = Exchange::new(transport, contact_id, vec![encrypted_msg]); 243 /// ``` 244 pub fn new( 245 transport: NoiseTransport, 246 contact_id: ContactId, 247 outbound_messages: Vec<EncryptedMessage>, 248 ) -> Self { 249 let outbound = outbound_messages 250 .into_iter() 251 .map(|msg| OutboundTracker { 252 message: msg, 253 sent: false, 254 acknowledged: false, 255 error: None, 256 }) 257 .collect(); 258 259 Self { 260 transport, 261 contact_id, 262 state: ExchangeState::SendingCount, 263 outbound, 264 inbound: Vec::new(), 265 sent_count: false, 266 received_count: false, 267 send_queue: Vec::new(), 268 } 269 } 270 271 /// Get the current exchange state. 272 pub fn state(&self) -> &ExchangeState { 273 &self.state 274 } 275 276 /// Get the contact ID. 277 pub fn contact_id(&self) -> &ContactId { 278 &self.contact_id 279 } 280 281 /// Check if the exchange is complete. 282 pub fn is_complete(&self) -> bool { 283 self.state.is_complete() 284 } 285 286 /// Check if the exchange has failed. 287 pub fn is_failed(&self) -> bool { 288 self.state.is_failed() 289 } 290 291 /// Get the next message to send (if any). 292 /// 293 /// Returns encrypted wire data ready for BLE transmission. 294 /// Call this in a loop until it returns `None`. 295 /// 296 /// # Returns 297 /// 298 /// `Some(data)` if there's data to send, `None` otherwise. 299 /// 300 /// # Errors 301 /// 302 /// - `Encryption` if transport encryption fails 303 pub fn get_next_to_send(&mut self) -> Result<Option<Vec<u8>>> { 304 // First, check if we have queued data 305 if !self.send_queue.is_empty() { 306 let data = self.send_queue.remove(0); 307 return Ok(Some(self.encrypt_and_send(&data)?)); 308 } 309 310 // Generate new data based on state 311 match &self.state { 312 ExchangeState::SendingCount => { 313 if !self.sent_count { 314 let count = MsgCount::new(self.outbound.len() as u16); 315 let wire = serialize(MessageType::MsgCount, &count)?; 316 self.sent_count = true; 317 318 // Transition state 319 if self.received_count { 320 self.transition_to_exchanging(); 321 } else { 322 self.state = ExchangeState::AwaitingCount; 323 } 324 325 return Ok(Some(self.encrypt_and_send(&wire)?)); 326 } 327 } 328 329 ExchangeState::AwaitingCount => { 330 // Just waiting, nothing to send 331 } 332 333 ExchangeState::Exchanging { 334 our_total, 335 their_total, 336 send_index, 337 receive_index, 338 } => { 339 let our_total = *our_total; 340 let their_total = *their_total; 341 let send_index = *send_index; 342 let receive_index = *receive_index; 343 344 // Send next message if we have one 345 if send_index < our_total { 346 if let Some(tracker) = self.outbound.get_mut(send_index as usize) { 347 if !tracker.sent { 348 let payload = tracker.message.to_bytes()?; 349 let msg_data = MsgData::new(send_index, our_total, payload); 350 let wire = serialize(MessageType::MsgData, &msg_data)?; 351 tracker.sent = true; 352 353 // Update state 354 self.state = ExchangeState::Exchanging { 355 our_total, 356 their_total, 357 send_index: send_index + 1, 358 receive_index, 359 }; 360 361 return Ok(Some(self.encrypt_and_send(&wire)?)); 362 } 363 } 364 } 365 366 // Check if we should transition to finalizing 367 self.check_transition_to_finalizing(); 368 } 369 370 ExchangeState::Finalizing { 371 pending_acks, 372 sent_done, 373 received_done, 374 } => { 375 let pending_acks = *pending_acks; 376 let sent_done = *sent_done; 377 let received_done = *received_done; 378 379 // Send SESSION_DONE if we haven't 380 if !sent_done && pending_acks == 0 { 381 let sent_count = self.outbound.iter().filter(|t| t.sent).count() as u8; 382 let received_count = 383 self.inbound.iter().filter(|t| t.message.is_some()).count() as u8; 384 385 let done = SessionDone::new(sent_count, received_count); 386 let wire = serialize(MessageType::SessionDone, &done)?; 387 388 self.state = ExchangeState::Finalizing { 389 pending_acks, 390 sent_done: true, 391 received_done, 392 }; 393 394 // Check for completion 395 if received_done { 396 self.state = ExchangeState::Complete; 397 } 398 399 return Ok(Some(self.encrypt_and_send(&wire)?)); 400 } 401 } 402 403 ExchangeState::Complete | ExchangeState::Failed(_) => { 404 // Nothing to send 405 } 406 } 407 408 Ok(None) 409 } 410 411 /// Process received data from the peer. 412 /// 413 /// Handles incoming wire messages and updates the exchange state. 414 /// 415 /// # Arguments 416 /// 417 /// * `encrypted_data` - Encrypted wire data received from BLE 418 /// 419 /// # Errors 420 /// 421 /// - `Decryption` if transport decryption fails 422 /// - `InvalidFormat` if the message cannot be parsed 423 pub fn process_received(&mut self, encrypted_data: &[u8]) -> Result<()> { 424 // Decrypt 425 let wire_data = self.transport.decrypt(encrypted_data)?; 426 427 // Parse message type 428 let (msg_type, payload) = deserialize(&wire_data)?; 429 430 match msg_type { 431 MessageType::MsgCount => self.handle_msg_count(payload)?, 432 MessageType::MsgData => self.handle_msg_data(payload)?, 433 MessageType::MsgAck => self.handle_msg_ack(payload)?, 434 MessageType::MsgNack => self.handle_msg_nack(payload)?, 435 MessageType::SessionDone => self.handle_session_done(payload)?, 436 MessageType::Ping => self.handle_ping(payload)?, 437 MessageType::Pong => { /* Ignore pong responses */ } 438 } 439 440 Ok(()) 441 } 442 443 /// Finalize the exchange and get results. 444 /// 445 /// Should be called after `is_complete()` returns true. 446 /// 447 /// # Returns 448 /// 449 /// An `ExchangeResult` containing delivery status and received messages. 450 /// 451 /// # Errors 452 /// 453 /// - `InvalidStateTransition` if the exchange is not complete 454 pub fn finalize(self) -> Result<ExchangeResult> { 455 let (result, _transport) = self.finalize_with_transport()?; 456 Ok(result) 457 } 458 459 /// Finalize the exchange and return both results and the transport. 460 /// 461 /// Use this when you want to continue using the encrypted channel 462 /// (e.g., for a gossip exchange phase after the direct exchange). 463 /// 464 /// # Returns 465 /// 466 /// A tuple of `(ExchangeResult, NoiseTransport)`. 467 /// 468 /// # Errors 469 /// 470 /// - `InvalidStateTransition` if the exchange is not complete 471 pub fn finalize_with_transport(self) -> Result<(ExchangeResult, NoiseTransport)> { 472 if !self.is_complete() { 473 return Err(DeadDropError::InvalidStateTransition { 474 from: format!("{:?}", self.state), 475 to: "Complete".to_string(), 476 }); 477 } 478 479 let mut delivered_ids = Vec::new(); 480 let mut failed_ids = Vec::new(); 481 482 for tracker in &self.outbound { 483 if tracker.acknowledged { 484 delivered_ids.push(tracker.message.message_id); 485 } else if let Some(error) = tracker.error { 486 failed_ids.push((tracker.message.message_id, error)); 487 } 488 } 489 490 let mut received_messages = Vec::new(); 491 let mut failed_receives = Vec::new(); 492 493 for (index, tracker) in self.inbound.iter().enumerate() { 494 if let Some(msg) = &tracker.message { 495 received_messages.push(msg.clone()); 496 } else if let Some(error) = tracker.error { 497 failed_receives.push((index as u8, error)); 498 } 499 } 500 501 let result = ExchangeResult { 502 delivered_ids, 503 failed_ids, 504 received_messages, 505 failed_receives, 506 total_sent: self.outbound.len() as u8, 507 total_received: self.inbound.len() as u8, 508 }; 509 510 Ok((result, self.transport)) 511 } 512 513 // ========================================================================= 514 // Internal: Encryption 515 // ========================================================================= 516 517 /// Encrypt data for transmission. 518 fn encrypt_and_send(&mut self, data: &[u8]) -> Result<Vec<u8>> { 519 self.transport.encrypt(data) 520 } 521 522 // ========================================================================= 523 // Internal: Message Handlers 524 // ========================================================================= 525 526 /// Handle MSG_COUNT message. 527 fn handle_msg_count(&mut self, payload: &[u8]) -> Result<()> { 528 let count: MsgCount = bincode::deserialize(payload) 529 .map_err(|e| DeadDropError::Deserialization(e.to_string()))?; 530 531 count.validate()?; 532 533 self.received_count = true; 534 535 // Initialize inbound tracking 536 self.inbound = (0..count.count) 537 .map(|_| InboundTracker { 538 message: None, 539 responded: false, 540 error: None, 541 }) 542 .collect(); 543 544 // Transition state 545 if self.sent_count { 546 self.transition_to_exchanging(); 547 } else { 548 self.state = ExchangeState::SendingCount; 549 } 550 551 Ok(()) 552 } 553 554 /// Handle MSG_DATA message. 555 fn handle_msg_data(&mut self, payload: &[u8]) -> Result<()> { 556 let msg_data: MsgData = bincode::deserialize(payload) 557 .map_err(|e| DeadDropError::Deserialization(e.to_string()))?; 558 559 msg_data.validate()?; 560 561 let index = msg_data.index as usize; 562 563 // Ensure we have a slot for this message 564 if index >= self.inbound.len() { 565 // Queue NACK 566 let nack = MsgNack::new(msg_data.index, ErrorCode::InvalidFormat); 567 let wire = serialize(MessageType::MsgNack, &nack)?; 568 self.send_queue.push(wire); 569 return Ok(()); 570 } 571 572 // Try to parse the encrypted message 573 match EncryptedMessage::from_bytes(&msg_data.payload) { 574 Ok(encrypted) => { 575 self.inbound[index].message = Some(encrypted); 576 self.inbound[index].responded = true; 577 578 // Queue ACK 579 let ack = MsgAck::success(msg_data.index); 580 let wire = serialize(MessageType::MsgAck, &ack)?; 581 self.send_queue.push(wire); 582 } 583 Err(e) => { 584 self.inbound[index].error = Some(ErrorCode::InvalidFormat); 585 self.inbound[index].responded = true; 586 587 // Queue NACK 588 let nack = MsgNack::new(msg_data.index, ErrorCode::InvalidFormat); 589 let wire = serialize(MessageType::MsgNack, &nack)?; 590 self.send_queue.push(wire); 591 592 // Log error but continue 593 debug_log!("Failed to parse message {}: {}", index, e); 594 } 595 } 596 597 // Update exchange state 598 if let ExchangeState::Exchanging { 599 our_total, 600 their_total, 601 send_index, 602 receive_index, 603 } = self.state 604 { 605 let new_receive_index = receive_index.max(msg_data.index + 1); 606 self.state = ExchangeState::Exchanging { 607 our_total, 608 their_total, 609 send_index, 610 receive_index: new_receive_index, 611 }; 612 } 613 614 self.check_transition_to_finalizing(); 615 616 Ok(()) 617 } 618 619 /// Handle MSG_ACK message. 620 fn handle_msg_ack(&mut self, payload: &[u8]) -> Result<()> { 621 let ack: MsgAck = bincode::deserialize(payload) 622 .map_err(|e| DeadDropError::Deserialization(e.to_string()))?; 623 624 let index = ack.index as usize; 625 626 if index < self.outbound.len() { 627 self.outbound[index].acknowledged = true; 628 } 629 630 // Update finalizing state 631 if let ExchangeState::Finalizing { 632 pending_acks, 633 sent_done, 634 received_done, 635 } = self.state 636 { 637 let new_pending = pending_acks.saturating_sub(1); 638 self.state = ExchangeState::Finalizing { 639 pending_acks: new_pending, 640 sent_done, 641 received_done, 642 }; 643 } 644 645 self.check_transition_to_finalizing(); 646 647 Ok(()) 648 } 649 650 /// Handle MSG_NACK message. 651 fn handle_msg_nack(&mut self, payload: &[u8]) -> Result<()> { 652 let nack: MsgNack = bincode::deserialize(payload) 653 .map_err(|e| DeadDropError::Deserialization(e.to_string()))?; 654 655 let index = nack.index as usize; 656 657 if index < self.outbound.len() { 658 self.outbound[index].error = ErrorCode::from_byte(nack.error_code); 659 } 660 661 // Update finalizing state (NACK counts as response) 662 if let ExchangeState::Finalizing { 663 pending_acks, 664 sent_done, 665 received_done, 666 } = self.state 667 { 668 let new_pending = pending_acks.saturating_sub(1); 669 self.state = ExchangeState::Finalizing { 670 pending_acks: new_pending, 671 sent_done, 672 received_done, 673 }; 674 } 675 676 self.check_transition_to_finalizing(); 677 678 Ok(()) 679 } 680 681 /// Handle SESSION_DONE message. 682 fn handle_session_done(&mut self, payload: &[u8]) -> Result<()> { 683 let _done: SessionDone = bincode::deserialize(payload) 684 .map_err(|e| DeadDropError::Deserialization(e.to_string()))?; 685 686 // Mark that peer is done 687 match &self.state { 688 ExchangeState::Finalizing { 689 pending_acks, 690 sent_done, 691 .. 692 } => { 693 let pending_acks = *pending_acks; 694 let sent_done = *sent_done; 695 696 self.state = ExchangeState::Finalizing { 697 pending_acks, 698 sent_done, 699 received_done: true, 700 }; 701 702 // Check for completion 703 if sent_done { 704 self.state = ExchangeState::Complete; 705 } 706 } 707 ExchangeState::Exchanging { .. } => { 708 // Peer finished early, transition to finalizing 709 self.state = ExchangeState::Finalizing { 710 pending_acks: self.count_pending_acks(), 711 sent_done: false, 712 received_done: true, 713 }; 714 } 715 _ => {} 716 } 717 718 Ok(()) 719 } 720 721 /// Handle PING message. 722 fn handle_ping(&mut self, payload: &[u8]) -> Result<()> { 723 use crate::protocol::wire::{Ping, Pong}; 724 725 let ping: Ping = bincode::deserialize(payload) 726 .map_err(|e| DeadDropError::Deserialization(e.to_string()))?; 727 728 let pong = Pong::from_ping(&ping); 729 let wire = serialize(MessageType::Pong, &pong)?; 730 self.send_queue.push(wire); 731 732 Ok(()) 733 } 734 735 // ========================================================================= 736 // Internal: State Transitions 737 // ========================================================================= 738 739 /// Transition to exchanging state. 740 fn transition_to_exchanging(&mut self) { 741 self.state = ExchangeState::Exchanging { 742 our_total: self.outbound.len() as u8, 743 their_total: self.inbound.len() as u8, 744 send_index: 0, 745 receive_index: 0, 746 }; 747 748 // If no messages to exchange, go straight to finalizing 749 if self.outbound.is_empty() && self.inbound.is_empty() { 750 self.state = ExchangeState::Finalizing { 751 pending_acks: 0, 752 sent_done: false, 753 received_done: false, 754 }; 755 } 756 } 757 758 /// Check if we should transition to finalizing state. 759 fn check_transition_to_finalizing(&mut self) { 760 if let ExchangeState::Exchanging { 761 our_total, 762 their_total, 763 send_index, 764 .. 765 } = self.state 766 { 767 // Check if all messages sent 768 let all_sent = send_index >= our_total; 769 770 // Check if all messages received 771 let all_received = self 772 .inbound 773 .iter() 774 .all(|t| t.message.is_some() || t.error.is_some()); 775 776 if all_sent && (all_received || their_total == 0) { 777 self.state = ExchangeState::Finalizing { 778 pending_acks: self.count_pending_acks(), 779 sent_done: false, 780 received_done: false, 781 }; 782 } 783 } 784 } 785 786 /// Count pending acknowledgements. 787 fn count_pending_acks(&self) -> u8 { 788 self.outbound 789 .iter() 790 .filter(|t| t.sent && !t.acknowledged && t.error.is_none()) 791 .count() as u8 792 } 793 } 794 795 impl std::fmt::Debug for Exchange { 796 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 797 f.debug_struct("Exchange") 798 .field("contact_id", &hex::encode(&self.contact_id[..8])) 799 .field("state", &self.state) 800 .field("outbound_count", &self.outbound.len()) 801 .field("inbound_count", &self.inbound.len()) 802 .field("sent_count", &self.sent_count) 803 .field("received_count", &self.received_count) 804 .finish() 805 } 806 } 807 808 #[cfg(test)] 809 mod tests { 810 use super::*; 811 use crate::crypto::keys::{ExchangeKeyPair, IdentityKeyPair}; 812 use crate::crypto::noise::perform_handshake; 813 use crate::protocol::messages::{encrypt_message, PlaintextMessage}; 814 815 /// Create test encrypted message. 816 fn create_test_message(text: &str) -> EncryptedMessage { 817 let sender = IdentityKeyPair::generate(); 818 let recipient = ExchangeKeyPair::generate(); 819 let plaintext = PlaintextMessage::text(text); 820 encrypt_message(&plaintext, &sender, &recipient.public_bytes()).unwrap() 821 } 822 823 /// Create connected transports for testing. 824 fn create_transports() -> (NoiseTransport, NoiseTransport) { 825 let alice = ExchangeKeyPair::generate(); 826 let bob = ExchangeKeyPair::generate(); 827 perform_handshake(&alice, &bob).unwrap() 828 } 829 830 // ==================== Creation Tests ==================== 831 832 #[test] 833 fn test_exchange_creation() { 834 let (transport, _) = create_transports(); 835 let contact_id = [0u8; 16]; 836 let messages = vec![create_test_message("Hello")]; 837 838 let exchange = Exchange::new(transport, contact_id, messages); 839 840 assert_eq!(*exchange.state(), ExchangeState::SendingCount); 841 assert!(!exchange.is_complete()); 842 assert!(!exchange.is_failed()); 843 } 844 845 #[test] 846 fn test_exchange_creation_no_messages() { 847 let (transport, _) = create_transports(); 848 let contact_id = [0u8; 16]; 849 850 let exchange = Exchange::new(transport, contact_id, vec![]); 851 852 assert_eq!(*exchange.state(), ExchangeState::SendingCount); 853 } 854 855 // ==================== Count Exchange Tests ==================== 856 857 #[test] 858 fn test_count_exchange() { 859 let (alice_transport, bob_transport) = create_transports(); 860 let contact_id = [0u8; 16]; 861 862 let alice_msgs = vec![create_test_message("From Alice")]; 863 let bob_msgs = vec![create_test_message("From Bob 1"), create_test_message("From Bob 2")]; 864 865 let mut alice = Exchange::new(alice_transport, contact_id, alice_msgs); 866 let mut bob = Exchange::new(bob_transport, contact_id, bob_msgs); 867 868 // Alice sends count 869 let alice_count_data = alice.get_next_to_send().unwrap().unwrap(); 870 871 // Bob sends count 872 let bob_count_data = bob.get_next_to_send().unwrap().unwrap(); 873 874 // Exchange counts 875 alice.process_received(&bob_count_data).unwrap(); 876 bob.process_received(&alice_count_data).unwrap(); 877 878 // Both should be in exchanging state 879 assert!(matches!(alice.state(), ExchangeState::Exchanging { .. })); 880 assert!(matches!(bob.state(), ExchangeState::Exchanging { .. })); 881 } 882 883 // ==================== Full Exchange Tests ==================== 884 885 #[test] 886 fn test_full_exchange_one_message_each() { 887 let (alice_transport, bob_transport) = create_transports(); 888 let contact_id = [0u8; 16]; 889 890 let alice_msgs = vec![create_test_message("Hello Bob")]; 891 let bob_msgs = vec![create_test_message("Hello Alice")]; 892 893 let mut alice = Exchange::new(alice_transport, contact_id, alice_msgs); 894 let mut bob = Exchange::new(bob_transport, contact_id, bob_msgs); 895 896 // Run the exchange 897 let mut iterations = 0; 898 while (!alice.is_complete() || !bob.is_complete()) && iterations < 100 { 899 // Alice sends 900 while let Some(data) = alice.get_next_to_send().unwrap() { 901 bob.process_received(&data).unwrap(); 902 } 903 904 // Bob sends 905 while let Some(data) = bob.get_next_to_send().unwrap() { 906 alice.process_received(&data).unwrap(); 907 } 908 909 iterations += 1; 910 } 911 912 assert!(alice.is_complete(), "Alice not complete after {} iterations", iterations); 913 assert!(bob.is_complete(), "Bob not complete after {} iterations", iterations); 914 915 // Check results 916 let alice_result = alice.finalize().unwrap(); 917 let bob_result = bob.finalize().unwrap(); 918 919 assert_eq!(alice_result.delivered_ids.len(), 1); 920 assert_eq!(alice_result.received_messages.len(), 1); 921 922 assert_eq!(bob_result.delivered_ids.len(), 1); 923 assert_eq!(bob_result.received_messages.len(), 1); 924 } 925 926 #[test] 927 fn test_full_exchange_no_messages() { 928 let (alice_transport, bob_transport) = create_transports(); 929 let contact_id = [0u8; 16]; 930 931 let mut alice = Exchange::new(alice_transport, contact_id, vec![]); 932 let mut bob = Exchange::new(bob_transport, contact_id, vec![]); 933 934 // Run the exchange 935 let mut iterations = 0; 936 while (!alice.is_complete() || !bob.is_complete()) && iterations < 100 { 937 while let Some(data) = alice.get_next_to_send().unwrap() { 938 bob.process_received(&data).unwrap(); 939 } 940 941 while let Some(data) = bob.get_next_to_send().unwrap() { 942 alice.process_received(&data).unwrap(); 943 } 944 945 iterations += 1; 946 } 947 948 assert!(alice.is_complete()); 949 assert!(bob.is_complete()); 950 951 let alice_result = alice.finalize().unwrap(); 952 assert!(alice_result.delivered_ids.is_empty()); 953 assert!(alice_result.received_messages.is_empty()); 954 } 955 956 #[test] 957 fn test_full_exchange_asymmetric() { 958 let (alice_transport, bob_transport) = create_transports(); 959 let contact_id = [0u8; 16]; 960 961 // Alice has 3 messages, Bob has 0 962 let alice_msgs = vec![ 963 create_test_message("Msg 1"), 964 create_test_message("Msg 2"), 965 create_test_message("Msg 3"), 966 ]; 967 968 let mut alice = Exchange::new(alice_transport, contact_id, alice_msgs); 969 let mut bob = Exchange::new(bob_transport, contact_id, vec![]); 970 971 // Run the exchange 972 let mut iterations = 0; 973 while (!alice.is_complete() || !bob.is_complete()) && iterations < 100 { 974 while let Some(data) = alice.get_next_to_send().unwrap() { 975 bob.process_received(&data).unwrap(); 976 } 977 978 while let Some(data) = bob.get_next_to_send().unwrap() { 979 alice.process_received(&data).unwrap(); 980 } 981 982 iterations += 1; 983 } 984 985 assert!(alice.is_complete()); 986 assert!(bob.is_complete()); 987 988 let alice_result = alice.finalize().unwrap(); 989 assert_eq!(alice_result.delivered_ids.len(), 3); 990 assert!(alice_result.received_messages.is_empty()); 991 992 let bob_result = bob.finalize().unwrap(); 993 assert!(bob_result.delivered_ids.is_empty()); 994 assert_eq!(bob_result.received_messages.len(), 3); 995 } 996 997 // ==================== State Tests ==================== 998 999 #[test] 1000 fn test_exchange_state_is_complete() { 1001 assert!(ExchangeState::Complete.is_complete()); 1002 assert!(!ExchangeState::SendingCount.is_complete()); 1003 assert!(!ExchangeState::Failed("error".to_string()).is_complete()); 1004 } 1005 1006 #[test] 1007 fn test_exchange_state_is_failed() { 1008 assert!(ExchangeState::Failed("error".to_string()).is_failed()); 1009 assert!(!ExchangeState::Complete.is_failed()); 1010 assert!(!ExchangeState::SendingCount.is_failed()); 1011 } 1012 1013 #[test] 1014 fn test_exchange_state_is_in_progress() { 1015 assert!(ExchangeState::SendingCount.is_in_progress()); 1016 assert!(ExchangeState::AwaitingCount.is_in_progress()); 1017 assert!(!ExchangeState::Complete.is_in_progress()); 1018 assert!(!ExchangeState::Failed("error".to_string()).is_in_progress()); 1019 } 1020 1021 // ==================== Error Cases ==================== 1022 1023 #[test] 1024 fn test_finalize_not_complete() { 1025 let (transport, _) = create_transports(); 1026 let contact_id = [0u8; 16]; 1027 1028 let exchange = Exchange::new(transport, contact_id, vec![]); 1029 1030 let result = exchange.finalize(); 1031 assert!(result.is_err()); 1032 } 1033 1034 // ==================== Debug Tests ==================== 1035 1036 #[test] 1037 fn test_exchange_debug() { 1038 let (transport, _) = create_transports(); 1039 let contact_id = [0u8; 16]; 1040 1041 let exchange = Exchange::new(transport, contact_id, vec![]); 1042 1043 let debug_str = format!("{:?}", exchange); 1044 assert!(debug_str.contains("Exchange")); 1045 assert!(debug_str.contains("state")); 1046 } 1047 }