/ node / src / message_stream.rs
message_stream.rs
  1  //! Stream messages with backfilling.
  2  //!
  3  //! This object is used to create a continuous stream that
  4  //! seamlessly switches between sending historical messages to
  5  //! sending live messages.
  6  //! It can also start streaming from a sequence number not yet
  7  //! produced: in this case it waits until the live stream reaches
  8  //! the target sequence number.
  9  //!
 10  //! ```txt
 11  //!         ┌───────────┐
 12  //!         │           │
 13  //! Live────►           ├───► Backfilled
 14  //!         │           │     Stream
 15  //!         └─────▲─────┘
 16  //!               │
 17  //!               │
 18  //!            Message
 19  //!            Storage
 20  //! ```
 21  
 22  use std::{collections::VecDeque, marker::PhantomData, pin::Pin, task::Poll, time::Duration};
 23  
 24  use apibara_core::stream::{MessageData, RawMessageData, Sequence, StreamMessage};
 25  use futures::Stream;
 26  use pin_project::pin_project;
 27  use tokio::time::Instant;
 28  use tokio_util::sync::CancellationToken;
 29  use tracing::debug;
 30  
 31  use crate::message_storage::MessageStorage;
 32  
 33  pub type LiveStreamItem<M> = std::result::Result<StreamMessage<M>, Box<dyn std::error::Error>>;
 34  
 35  #[pin_project]
 36  pub struct BackfilledMessageStream<M, S, L>
 37  where
 38      M: MessageData,
 39      S: MessageStorage<M>,
 40      L: Stream<Item = LiveStreamItem<M>>,
 41  {
 42      storage: S,
 43      #[pin]
 44      live: L,
 45      // state is in its own struct to play nicely with pin.
 46      state: State<M>,
 47      pending_interval: Option<Duration>,
 48      ct: CancellationToken,
 49      _phantom: PhantomData<M>,
 50  }
 51  
 52  #[derive(Debug, thiserror::Error)]
 53  pub enum BackfilledMessageStreamError {
 54      #[error("invalid live message sequence number")]
 55      InvalidLiveSequence { expected: u64, actual: u64 },
 56      #[error("message with sequence {sequence} not found")]
 57      MessageNotFound { sequence: u64 },
 58      #[error("error retrieving data from message storage")]
 59      Storage(Box<dyn std::error::Error>),
 60      #[error("error retrieving data from live stream")]
 61      LiveStream(Box<dyn std::error::Error>),
 62  }
 63  
 64  pub type Result<T> = std::result::Result<T, BackfilledMessageStreamError>;
 65  
 66  /// Deadline for sending pending messages.
 67  #[derive(Debug)]
 68  enum PendingDeadline {
 69      /// Don't send any pending messages.
 70      None,
 71      /// Send pending message immediately, ignoring any deadline.
 72      /// Used to send a pending messages immediately after a new data message.
 73      Immediately,
 74      /// Send after instant.
 75      Deadline(Instant),
 76  }
 77  
 78  #[derive(Debug)]
 79  struct State<M: MessageData> {
 80      current: Sequence,
 81      latest: Sequence,
 82      buffer: VecDeque<(Sequence, RawMessageData<M>)>,
 83      pending_deadline: PendingDeadline,
 84      pending_interval: Option<Duration>,
 85  }
 86  
 87  impl<M, S, L> BackfilledMessageStream<M, S, L>
 88  where
 89      M: MessageData,
 90      S: MessageStorage<M>,
 91      L: Stream<Item = LiveStreamItem<M>>,
 92  {
 93      /// Creates a new `MessageStreamer`.
 94      ///
 95      /// Start streaming from the `current` message (inclusive), using `latest` as
 96      /// hint about the most recently stored message.
 97      /// Messages that are not `live` are streamed from the `storage`.
 98      pub fn new(
 99          current: Sequence,
100          latest: Sequence,
101          storage: S,
102          live: L,
103          pending_interval: Option<Duration>,
104          ct: CancellationToken,
105      ) -> Self {
106          BackfilledMessageStream {
107              storage,
108              live,
109              state: State::new(current, latest, pending_interval),
110              pending_interval,
111              ct,
112              _phantom: PhantomData,
113          }
114      }
115  }
116  
117  impl<M, S, L> Stream for BackfilledMessageStream<M, S, L>
118  where
119      M: MessageData,
120      S: MessageStorage<M>,
121      L: Stream<Item = LiveStreamItem<M>>,
122  {
123      type Item = Result<StreamMessage<M>>;
124  
125      fn poll_next(
126          self: Pin<&mut Self>,
127          cx: &mut std::task::Context<'_>,
128      ) -> Poll<Option<Self::Item>> {
129          // always check cancellation
130          if self.ct.is_cancelled() {
131              return Poll::Ready(None);
132          }
133  
134          // when receiving a `StreamMessage::Data` message the stream can perform
135          // three possible actions, depending on the stream state:
136          //
137          // current < latest:
138          //   live stream: used to keep latest updated
139          //   storage: used to read backfilled data and send it to stream
140          // current == latest:
141          //   live stream: used to keep state updated and send data to stream
142          //   storage: not used
143          // current > latest:
144          //   live stream: used to keep track of state, but data is not sent
145          //   storage: not used
146          //
147          // when receiving a `StreamMessage::Invalidate` message the stream
148          // can perform:
149          //
150          // current < invalidate:
151          //   update latest from invalidate
152          // current >= invalidate:
153          //   update current and latest from invalidate
154          //   send invalidate message to stream
155  
156          let current = self.state.current;
157          let latest = self.state.latest;
158  
159          let mut this = self.project();
160  
161          let live_message = {
162              match Pin::new(&mut this.live).poll_next(cx) {
163                  Poll::Pending => {
164                      // return pending and wake when live stream is ready
165                      if current > latest {
166                          return Poll::Pending;
167                      }
168                      None
169                  }
170                  Poll::Ready(None) => {
171                      // live stream closed, try to keep sending backfilled messages
172                      // but if it's live simply close this stream too.
173                      if current >= latest {
174                          return Poll::Ready(None);
175                      }
176                      None
177                  }
178                  Poll::Ready(Some(message)) => Some(message),
179              }
180          };
181  
182          if let Some(message) = live_message {
183              match message {
184                  Err(err) => {
185                      return Poll::Ready(Some(Err(BackfilledMessageStreamError::LiveStream(err))))
186                  }
187                  Ok(StreamMessage::Invalidate { sequence }) => {
188                      debug!(sequence = ?sequence, "live invalidate");
189                      // clear buffer just in case
190                      this.state.clear_buffer();
191                      this.state.update_latest(sequence);
192  
193                      // all messages after `sequence` (inclusive) are now invalidated.
194                      // forward invalidate message
195                      if current >= sequence {
196                          debug!(sequence = ?sequence, "send live invalidate");
197                          this.state.update_current(sequence);
198                          let message = StreamMessage::Invalidate { sequence };
199                          return Poll::Ready(Some(Ok(message)));
200                      }
201  
202                      // buffer data was invalidated and new state updated
203                      // let's stop here and start again
204                      cx.waker().wake_by_ref();
205                      return Poll::Pending;
206                  }
207                  Ok(StreamMessage::Data { sequence, data }) => {
208                      debug!(sequence = ?sequence, "live data");
209                      this.state.update_latest(sequence);
210  
211                      // just send the message to the stream if it's the current one
212                      if current == sequence {
213                          debug!(sequence = ?sequence, "send live data");
214                          this.state.update_current(sequence);
215                          this.state.increment_current();
216                          this.state.reset_pending_deadline_to_immediately();
217                          let message = StreamMessage::Data { sequence, data };
218                          return Poll::Ready(Some(Ok(message)));
219                      }
220  
221                      // no point in adding messages that won't be sent
222                      if current < sequence {
223                          // add message to buffer
224                          this.state.add_live_message(sequence, data);
225                      }
226                  }
227                  Ok(StreamMessage::Pending { sequence, data }) => {
228                      debug!(sequence = ?sequence, "live pending");
229                      if sequence == current && this.state.is_pending_deadline_exceeded() {
230                          debug!(sequence = ?sequence, "send live pending");
231                          let message = StreamMessage::Pending { sequence, data };
232                          this.state.reset_pending_deadline();
233                          return Poll::Ready(Some(Ok(message)));
234                      }
235                  }
236              }
237          }
238  
239          // stream is not interested in any messages we can send, so just
240          // restart and wait for more live data
241          if current > latest {
242              cx.waker().wake_by_ref();
243              return Poll::Pending;
244          }
245  
246          // prioritize sending from buffer
247          if this.state.buffer_has_sequence(this.state.current()) {
248              match this.state.pop_buffer() {
249                  None => {
250                      let sequence = this.state.current().as_u64();
251                      return Poll::Ready(Some(Err(BackfilledMessageStreamError::MessageNotFound {
252                          sequence,
253                      })));
254                  }
255                  Some((sequence, data)) => {
256                      this.state.increment_current();
257                      let message = StreamMessage::Data { sequence, data };
258                      return Poll::Ready(Some(Ok(message)));
259                  }
260              }
261          }
262  
263          // as last resort, send backfilled messages from storage
264          match this.storage.get(this.state.current()) {
265              Err(err) => {
266                  let err = BackfilledMessageStreamError::Storage(Box::new(err));
267                  Poll::Ready(Some(Err(err)))
268              }
269              Ok(None) => {
270                  let sequence = this.state.current().as_u64();
271                  Poll::Ready(Some(Err(BackfilledMessageStreamError::MessageNotFound {
272                      sequence,
273                  })))
274              }
275              Ok(Some(message)) => {
276                  let sequence = *this.state.current();
277                  this.state.increment_current();
278                  let message = StreamMessage::Data {
279                      sequence,
280                      data: message,
281                  };
282                  Poll::Ready(Some(Ok(message)))
283              }
284          }
285      }
286  }
287  
288  impl<M: MessageData> State<M> {
289      fn new(current: Sequence, latest: Sequence, pending_interval: Option<Duration>) -> Self {
290          let pending_deadline = pending_interval
291              .map(|_| PendingDeadline::Immediately)
292              .unwrap_or(PendingDeadline::None);
293          State {
294              current,
295              latest,
296              buffer: VecDeque::default(),
297              pending_interval,
298              pending_deadline,
299          }
300      }
301  
302      fn current(&self) -> &Sequence {
303          &self.current
304      }
305  
306      fn increment_current(&mut self) {
307          self.current = Sequence::from_u64(self.current.as_u64() + 1);
308      }
309  
310      fn update_latest(&mut self, sequence: Sequence) {
311          self.latest = sequence;
312      }
313  
314      fn update_current(&mut self, sequence: Sequence) {
315          self.current = sequence;
316      }
317  
318      fn add_live_message(&mut self, sequence: Sequence, message: RawMessageData<M>) {
319          self.buffer.push_back((sequence, message));
320  
321          // trim buffer size to always be ~50 elements
322          while self.buffer.len() > 50 {
323              self.buffer.pop_front();
324          }
325      }
326  
327      fn clear_buffer(&mut self) {
328          self.buffer.clear();
329      }
330  
331      fn buffer_has_sequence(&self, sequence: &Sequence) -> bool {
332          match self.buffer.front() {
333              None => false,
334              Some((seq, _)) => seq <= sequence,
335          }
336      }
337  
338      fn pop_buffer(&mut self) -> Option<(Sequence, RawMessageData<M>)> {
339          self.buffer.pop_front()
340      }
341  
342      fn reset_pending_deadline_to_immediately(&mut self) {
343          self.pending_deadline = PendingDeadline::Immediately;
344      }
345  
346      fn reset_pending_deadline(&mut self) {
347          self.pending_deadline = self
348              .pending_interval
349              .map(|i| PendingDeadline::Deadline(Instant::now() + i))
350              .unwrap_or(PendingDeadline::None);
351      }
352  
353      fn is_pending_deadline_exceeded(&self) -> bool {
354          match self.pending_deadline {
355              PendingDeadline::None => false,
356              PendingDeadline::Immediately => true,
357              PendingDeadline::Deadline(deadline) => deadline <= Instant::now(),
358          }
359      }
360  }
361  
362  #[cfg(test)]
363  mod tests {
364      use std::{
365          collections::HashMap,
366          sync::{Arc, Mutex},
367      };
368  
369      use apibara_core::stream::{RawMessageData, Sequence, StreamMessage};
370      use futures::StreamExt;
371      use prost::Message;
372      use tokio::sync::mpsc;
373      use tokio_stream::wrappers::ReceiverStream;
374      use tokio_util::sync::CancellationToken;
375  
376      use crate::message_storage::MessageStorage;
377  
378      use super::BackfilledMessageStream;
379  
380      #[derive(Clone, prost::Message)]
381      pub struct TestMessage {
382          #[prost(uint64, tag = "1")]
383          pub sequence: u64,
384      }
385  
386      impl TestMessage {
387          pub fn new(sequence: u64) -> TestMessage {
388              TestMessage { sequence }
389          }
390  
391          pub fn new_raw(sequence: u64) -> RawMessageData<TestMessage> {
392              let data = Self::new(sequence).encode_to_vec();
393              RawMessageData::from_vec(data)
394          }
395      }
396  
397      #[derive(Debug, Default)]
398      pub struct TestMessageStorage {
399          messages: HashMap<Sequence, RawMessageData<TestMessage>>,
400      }
401  
402      #[derive(Debug, thiserror::Error)]
403      pub enum TestMessageStorageError {}
404  
405      impl TestMessageStorage {
406          pub fn insert(&mut self, sequence: &Sequence, message: &TestMessage) {
407              let message = RawMessageData::from_vec(message.encode_to_vec());
408              self.insert_raw(sequence, message);
409          }
410  
411          pub fn insert_raw(&mut self, sequence: &Sequence, message: RawMessageData<TestMessage>) {
412              self.messages.insert(*sequence, message);
413          }
414      }
415  
416      impl MessageStorage<TestMessage> for Arc<Mutex<TestMessageStorage>> {
417          type Error = TestMessageStorageError;
418  
419          fn get(
420              &self,
421              sequence: &Sequence,
422          ) -> Result<Option<RawMessageData<TestMessage>>, Self::Error> {
423              Ok(self.lock().unwrap().messages.get(sequence).cloned())
424          }
425      }
426  
427      #[tokio::test]
428      pub async fn test_transition_between_backfilled_and_live() {
429          let storage = Arc::new(Mutex::new(TestMessageStorage::default()));
430  
431          for sequence in 0..10 {
432              let message = TestMessage::new(sequence);
433              storage
434                  .lock()
435                  .unwrap()
436                  .insert(&Sequence::from_u64(sequence), &message);
437          }
438  
439          let (live_tx, live_rx) = mpsc::channel(256);
440          let live_stream = ReceiverStream::new(live_rx);
441          let ct = CancellationToken::new();
442  
443          let mut stream = BackfilledMessageStream::new(
444              Sequence::from_u64(0),
445              Sequence::from_u64(9),
446              storage.clone(),
447              live_stream,
448              None,
449              ct,
450          );
451  
452          live_tx
453              .send(Ok(StreamMessage::new_data(
454                  Sequence::from_u64(10),
455                  TestMessage::new_raw(10),
456              )))
457              .await
458              .unwrap();
459  
460          // first 10 messages come from storage
461          for sequence in 0..10 {
462              let message = stream.next().await.unwrap().unwrap();
463              assert_eq!(message.sequence().as_u64(), sequence);
464          }
465  
466          // 11th message from live stream
467          let message = stream.next().await.unwrap().unwrap();
468          assert_eq!(message.sequence().as_u64(), 10);
469  
470          // simulate node adding messages to storage (for persistence) while
471          // publishing to live stream
472          for sequence in 11..100 {
473              let message = TestMessage::new_raw(sequence);
474              let sequence = Sequence::from_u64(sequence);
475              storage
476                  .lock()
477                  .unwrap()
478                  .insert_raw(&sequence, message.clone());
479              let message = StreamMessage::new_data(sequence, message);
480              live_tx.send(Ok(message)).await.unwrap();
481          }
482  
483          for sequence in 11..100 {
484              let message = stream.next().await.unwrap().unwrap();
485              assert_eq!(message.sequence().as_u64(), sequence);
486          }
487      }
488  
489      #[tokio::test]
490      pub async fn test_start_at_future_sequence() {
491          let storage = Arc::new(Mutex::new(TestMessageStorage::default()));
492  
493          let (live_tx, live_rx) = mpsc::channel(256);
494          let live_stream = ReceiverStream::new(live_rx);
495          let ct = CancellationToken::new();
496  
497          let mut stream = BackfilledMessageStream::new(
498              Sequence::from_u64(15),
499              Sequence::from_u64(9),
500              storage.clone(),
501              live_stream,
502              None,
503              ct,
504          );
505  
506          for sequence in 10..20 {
507              let message = TestMessage::new_raw(sequence);
508              let sequence = Sequence::from_u64(sequence);
509              let message = StreamMessage::new_data(sequence, message);
510              live_tx.send(Ok(message)).await.unwrap();
511          }
512  
513          for sequence in 15..20 {
514              let message = stream.next().await.unwrap().unwrap();
515              assert_eq!(message.sequence().as_u64(), sequence);
516          }
517      }
518  
519      #[tokio::test]
520      pub async fn test_invalidate_data_after_current() {
521          let storage = Arc::new(Mutex::new(TestMessageStorage::default()));
522  
523          let (live_tx, live_rx) = mpsc::channel(256);
524          let live_stream = ReceiverStream::new(live_rx);
525          let ct = CancellationToken::new();
526  
527          let mut stream = BackfilledMessageStream::new(
528              Sequence::from_u64(0),
529              Sequence::from_u64(9),
530              storage.clone(),
531              live_stream,
532              None,
533              ct,
534          );
535  
536          // add some messages to storage
537          for sequence in 0..5 {
538              let message = TestMessage::new(sequence);
539              let sequence = Sequence::from_u64(sequence);
540              storage.lock().unwrap().insert(&sequence, &message);
541          }
542  
543          // live stream some messages
544          for sequence in 5..10 {
545              let message = TestMessage::new_raw(sequence);
546              let sequence = Sequence::from_u64(sequence);
547              storage
548                  .lock()
549                  .unwrap()
550                  .insert_raw(&sequence, message.clone());
551              let message = StreamMessage::new_data(sequence, message);
552              live_tx.send(Ok(message)).await.unwrap();
553          }
554  
555          // invalidate all messages with sequence >= 8
556          let sequence = Sequence::from_u64(8);
557          let message = StreamMessage::new_invalidate(sequence);
558          live_tx.send(Ok(message)).await.unwrap();
559  
560          // then send some more messages
561          for sequence in 8..12 {
562              let message = TestMessage::new_raw(sequence);
563              let sequence = Sequence::from_u64(sequence);
564              storage
565                  .lock()
566                  .unwrap()
567                  .insert_raw(&sequence, message.clone());
568              let message = StreamMessage::new_data(sequence, message);
569              live_tx.send(Ok(message)).await.unwrap();
570          }
571  
572          // notice there is no invalidate message because it happened for a
573          // message sequence that was never streamed
574          for sequence in 0..12 {
575              let message = stream.next().await.unwrap().unwrap();
576              assert_eq!(message.sequence().as_u64(), sequence);
577              assert!(message.is_data());
578          }
579      }
580  
581      #[tokio::test]
582      pub async fn test_invalidate_before_current() {
583          let storage = Arc::new(Mutex::new(TestMessageStorage::default()));
584  
585          let (live_tx, live_rx) = mpsc::channel(256);
586          let live_stream = ReceiverStream::new(live_rx);
587          let ct = CancellationToken::new();
588  
589          let mut stream = BackfilledMessageStream::new(
590              Sequence::from_u64(0),
591              Sequence::from_u64(9),
592              storage.clone(),
593              live_stream,
594              None,
595              ct,
596          );
597  
598          // add some messages to storage
599          for sequence in 0..5 {
600              let message = TestMessage::new(sequence);
601              let sequence = Sequence::from_u64(sequence);
602              storage.lock().unwrap().insert(&sequence, &message);
603          }
604  
605          // live stream some messages
606          for sequence in 5..10 {
607              let message = TestMessage::new_raw(sequence);
608              let sequence = Sequence::from_u64(sequence);
609              storage
610                  .lock()
611                  .unwrap()
612                  .insert_raw(&sequence, message.clone());
613              let message = StreamMessage::new_data(sequence, message);
614              live_tx.send(Ok(message)).await.unwrap();
615          }
616  
617          // now stream messages up to 9
618          for sequence in 0..10 {
619              let message = stream.next().await.unwrap().unwrap();
620              assert_eq!(message.sequence().as_u64(), sequence);
621              assert!(message.is_data());
622          }
623  
624          // invalidate all messages with sequence >= 8
625          let sequence = Sequence::from_u64(8);
626          let message = StreamMessage::new_invalidate(sequence);
627          live_tx.send(Ok(message)).await.unwrap();
628  
629          // then send some more messages
630          for sequence in 8..12 {
631              let message = TestMessage::new_raw(sequence);
632              let sequence = Sequence::from_u64(sequence);
633              storage
634                  .lock()
635                  .unwrap()
636                  .insert_raw(&sequence, message.clone());
637              let message = StreamMessage::new_data(sequence, message);
638              live_tx.send(Ok(message)).await.unwrap();
639          }
640  
641          // received invalidate message
642          let message = stream.next().await.unwrap().unwrap();
643          assert_eq!(message.sequence().as_u64(), 8);
644          assert!(message.is_invalidate());
645  
646          // resume messages
647          for sequence in 8..12 {
648              let message = stream.next().await.unwrap().unwrap();
649              assert_eq!(message.sequence().as_u64(), sequence);
650              assert!(message.is_data());
651          }
652      }
653  }