/ fedimint-server / src / net / queue.rs
queue.rs
  1  use std::collections::VecDeque;
  2  
  3  use serde::{Deserialize, Serialize};
  4  use tracing::{debug, trace};
  5  
  6  #[derive(Debug, Clone, Eq, PartialEq)]
  7  pub struct MessageQueue<M> {
  8      pub(super) queue: VecDeque<UniqueMessage<M>>,
  9      pub(super) next_id: MessageId,
 10  }
 11  
 12  #[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Ord, PartialOrd)]
 13  pub struct MessageId(pub u64);
 14  
 15  #[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, Ord, PartialOrd)]
 16  pub struct UniqueMessage<M> {
 17      pub id: MessageId,
 18      pub msg: M,
 19  }
 20  
 21  impl MessageId {
 22      pub fn increment(self) -> MessageId {
 23          MessageId(self.0 + 1)
 24      }
 25  }
 26  
 27  impl<M> Default for MessageQueue<M> {
 28      fn default() -> Self {
 29          MessageQueue {
 30              queue: Default::default(),
 31              next_id: MessageId(1),
 32          }
 33      }
 34  }
 35  
 36  impl<M> MessageQueue<M>
 37  where
 38      M: Clone,
 39  {
 40      pub fn push(&mut self, msg: M) -> UniqueMessage<M> {
 41          let id_msg = UniqueMessage {
 42              id: self.next_id,
 43              msg,
 44          };
 45  
 46          self.queue.push_back(id_msg.clone());
 47          self.next_id = self.next_id.increment();
 48  
 49          id_msg
 50      }
 51  
 52      pub fn ack(&mut self, msg_id: MessageId) {
 53          debug!("Received ACK for {:?}", msg_id);
 54          while self
 55              .queue
 56              .front()
 57              .map(|msg| msg.id <= msg_id)
 58              .unwrap_or(false)
 59          {
 60              let msg = self.queue.pop_front().expect("Checked in while head");
 61              trace!("Removing message {:?} from resend buffer", msg.id);
 62          }
 63      }
 64  
 65      pub fn iter(&self) -> impl Iterator<Item = &UniqueMessage<M>> {
 66          self.queue.iter()
 67      }
 68  }
 69  
 70  #[cfg(test)]
 71  mod tests {
 72      use crate::net::queue::{MessageId, MessageQueue};
 73  
 74      #[test]
 75      fn test_queue() {
 76          let mut queue = MessageQueue::default();
 77  
 78          for i in 0u64..10 {
 79              let umsg = queue.push(42 * i);
 80              assert_eq!(umsg.msg, 42 * i);
 81              assert_eq!(umsg.id.0, i + 1);
 82          }
 83  
 84          fn assert_contains(queue: &MessageQueue<u64>, iter: impl Iterator<Item = u64>) {
 85              let mut queue_iter = queue.iter();
 86  
 87              for i in iter {
 88                  let umsg = queue_iter.next().unwrap();
 89                  assert_eq!(umsg.msg, 42 * i);
 90                  assert_eq!(umsg.id.0, i + 1);
 91              }
 92  
 93              assert_eq!(queue_iter.next(), None);
 94          }
 95  
 96          assert_eq!(queue.iter().count(), 10);
 97          assert_contains(&queue, 0..10);
 98  
 99          queue.ack(MessageId(1));
100          assert_contains(&queue, 1..10);
101  
102          queue.ack(MessageId(4));
103          assert_contains(&queue, 4..10);
104  
105          queue.ack(MessageId(2)); // TODO: should that throw an error?
106          assert_contains(&queue, 4..10);
107      }
108  }