queue.rs
1 use std::collections::VecDeque; 2 3 use serde::{Deserialize, Serialize}; 4 use tracing::{debug, trace}; 5 6 #[derive(Debug, Clone, Eq, PartialEq)] 7 pub struct MessageQueue<M> { 8 pub(super) queue: VecDeque<UniqueMessage<M>>, 9 pub(super) next_id: MessageId, 10 } 11 12 #[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Ord, PartialOrd)] 13 pub struct MessageId(pub u64); 14 15 #[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Ord, PartialOrd)] 16 pub struct UniqueMessage<M> { 17 pub id: MessageId, 18 pub msg: M, 19 } 20 21 impl MessageId { 22 pub fn increment(self) -> MessageId { 23 MessageId(self.0 + 1) 24 } 25 } 26 27 impl<M> Default for MessageQueue<M> { 28 fn default() -> Self { 29 MessageQueue { 30 queue: Default::default(), 31 next_id: MessageId(1), 32 } 33 } 34 } 35 36 impl<M> MessageQueue<M> 37 where 38 M: Clone, 39 { 40 pub fn push(&mut self, msg: M) -> UniqueMessage<M> { 41 let id_msg = UniqueMessage { 42 id: self.next_id, 43 msg, 44 }; 45 46 self.queue.push_back(id_msg.clone()); 47 self.next_id = self.next_id.increment(); 48 49 id_msg 50 } 51 52 pub fn ack(&mut self, msg_id: MessageId) { 53 debug!("Received ACK for {:?}", msg_id); 54 while self 55 .queue 56 .front() 57 .map(|msg| msg.id <= msg_id) 58 .unwrap_or(false) 59 { 60 let msg = self.queue.pop_front().expect("Checked in while head"); 61 trace!("Removing message {:?} from resend buffer", msg.id); 62 } 63 } 64 65 pub fn iter(&self) -> impl Iterator<Item = &UniqueMessage<M>> { 66 self.queue.iter() 67 } 68 } 69 70 #[cfg(test)] 71 mod tests { 72 use crate::net::queue::{MessageId, MessageQueue}; 73 74 #[test] 75 fn test_queue() { 76 let mut queue = MessageQueue::default(); 77 78 for i in 0u64..10 { 79 let umsg = queue.push(42 * i); 80 assert_eq!(umsg.msg, 42 * i); 81 assert_eq!(umsg.id.0, i + 1); 82 } 83 84 fn assert_contains(queue: &MessageQueue<u64>, iter: impl Iterator<Item = u64>) { 85 let mut queue_iter = queue.iter(); 86 87 for i in iter { 88 let umsg = queue_iter.next().unwrap(); 89 assert_eq!(umsg.msg, 42 * i); 90 assert_eq!(umsg.id.0, i + 1); 91 } 92 93 assert_eq!(queue_iter.next(), None); 94 } 95 96 assert_eq!(queue.iter().count(), 10); 97 assert_contains(&queue, 0..10); 98 99 queue.ack(MessageId(1)); 100 assert_contains(&queue, 1..10); 101 102 queue.ack(MessageId(4)); 103 assert_contains(&queue, 4..10); 104 105 queue.ack(MessageId(2)); // TODO: should that throw an error? 106 assert_contains(&queue, 4..10); 107 } 108 }