/ crates / arroyo-operator / src / context.rs
context.rs
  1  use crate::{server_for_hash_array, RateLimiter};
  2  use arrow::array::{make_builder, Array, ArrayBuilder, PrimitiveArray, RecordBatch};
  3  use arrow::compute::{partition, sort_to_indices, take};
  4  use arrow::datatypes::{SchemaRef, UInt64Type};
  5  use arroyo_formats::de::ArrowDeserializer;
  6  use arroyo_formats::should_flush;
  7  use arroyo_metrics::{register_queue_gauge, QueueGauges, TaskCounters};
  8  use arroyo_rpc::config::config;
  9  use arroyo_rpc::df::ArroyoSchema;
 10  use arroyo_rpc::formats::{BadData, Format, Framing};
 11  use arroyo_rpc::grpc::rpc::{CheckpointMetadata, TableConfig, TaskCheckpointEventType};
 12  use arroyo_rpc::schema_resolver::SchemaResolver;
 13  use arroyo_rpc::{get_hasher, CompactionResult, ControlMessage, ControlResp};
 14  use arroyo_state::tables::table_manager::TableManager;
 15  use arroyo_state::{BackingStore, StateBackend};
 16  use arroyo_types::{
 17      from_micros, ArrowMessage, CheckpointBarrier, SourceError, TaskInfo, UserError, Watermark,
 18  };
 19  use datafusion::common::hash_utils;
 20  use rand::Rng;
 21  use std::collections::HashMap;
 22  use std::mem::size_of_val;
 23  use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
 24  use std::sync::Arc;
 25  use std::time::{Instant, SystemTime};
 26  use tokio::sync::mpsc::error::SendError;
 27  use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender};
 28  use tokio::sync::Notify;
 29  use tracing::warn;
 30  
 31  pub type QueueItem = ArrowMessage;
 32  
 33  pub struct WatermarkHolder {
 34      // This is the last watermark with an actual value; this helps us keep track of the watermark we're at even
 35      // if we're currently idle
 36      last_present_watermark: Option<SystemTime>,
 37      cur_watermark: Option<Watermark>,
 38      watermarks: Vec<Option<Watermark>>,
 39  }
 40  
 41  impl WatermarkHolder {
 42      pub fn new(watermarks: Vec<Option<Watermark>>) -> Self {
 43          let mut s = Self {
 44              last_present_watermark: None,
 45              cur_watermark: None,
 46              watermarks,
 47          };
 48          s.update_watermark();
 49  
 50          s
 51      }
 52  
 53      pub fn watermark(&self) -> Option<Watermark> {
 54          self.cur_watermark
 55      }
 56  
 57      pub fn last_present_watermark(&self) -> Option<SystemTime> {
 58          self.last_present_watermark
 59      }
 60  
 61      fn update_watermark(&mut self) {
 62          self.cur_watermark =
 63              self.watermarks
 64                  .iter()
 65                  .try_fold(Watermark::Idle, |current, next| match (current, (*next)?) {
 66                      (Watermark::EventTime(cur), Watermark::EventTime(next)) => {
 67                          Some(Watermark::EventTime(cur.min(next)))
 68                      }
 69                      (Watermark::Idle, Watermark::EventTime(t))
 70                      | (Watermark::EventTime(t), Watermark::Idle) => Some(Watermark::EventTime(t)),
 71                      (Watermark::Idle, Watermark::Idle) => Some(Watermark::Idle),
 72                  });
 73  
 74          if let Some(Watermark::EventTime(t)) = self.cur_watermark {
 75              self.last_present_watermark = Some(t);
 76          }
 77      }
 78  
 79      pub fn set(&mut self, idx: usize, watermark: Watermark) -> Option<Option<Watermark>> {
 80          *(self.watermarks.get_mut(idx)?) = Some(watermark);
 81          self.update_watermark();
 82          Some(self.cur_watermark)
 83      }
 84  }
 85  
 86  /// A wrapper for an UnboundedSender<QueueItem> that bounds by the number of rows within
 87  /// a batch rather than the number of batches
 88  #[derive(Clone)]
 89  pub struct BatchSender {
 90      size: u32,
 91      tx: UnboundedSender<QueueItem>,
 92      queued_messages: Arc<AtomicU32>,
 93      queued_bytes: Arc<AtomicU64>,
 94      notify: Arc<Notify>,
 95  }
 96  
 97  #[inline]
 98  fn message_count(item: &QueueItem, size: u32) -> u32 {
 99      match item {
100          QueueItem::Data(d) => (d.num_rows() as u32).min(size),
101          QueueItem::Signal(_) => 1,
102      }
103  }
104  
105  #[inline]
106  fn message_bytes(item: &QueueItem) -> u64 {
107      match item {
108          QueueItem::Data(d) => d.get_array_memory_size() as u64,
109          QueueItem::Signal(s) => size_of_val(s) as u64,
110      }
111  }
112  
113  impl BatchSender {
114      pub async fn send(&self, item: QueueItem) -> Result<(), SendError<QueueItem>> {
115          // Ensure that every message is sendable, even if it's bigger than our max size
116          let count = message_count(&item, self.size);
117          loop {
118              if self.tx.is_closed() {
119                  return Err(SendError(item));
120              }
121  
122              let cur = self.queued_messages.load(Ordering::Acquire);
123              if cur as usize + count as usize <= self.size as usize {
124                  match self.queued_messages.compare_exchange(
125                      cur,
126                      cur + count,
127                      Ordering::SeqCst,
128                      Ordering::SeqCst,
129                  ) {
130                      Ok(_) => {
131                          self.queued_bytes
132                              .fetch_add(message_bytes(&item), Ordering::AcqRel);
133                          return self.tx.send(item);
134                      }
135                      Err(_) => {
136                          // try again
137                          continue;
138                      }
139                  }
140              } else {
141                  // not enough room in the queue, wait to be notified that the receiver has
142                  // consumed
143                  self.notify.notified().await;
144              }
145          }
146      }
147  
148      pub fn capacity(&self) -> u32 {
149          self.size
150              .saturating_sub(self.queued_messages.load(Ordering::Relaxed))
151      }
152  
153      pub fn queued_bytes(&self) -> u64 {
154          self.queued_bytes.load(Ordering::Relaxed)
155      }
156  
157      pub fn size(&self) -> u32 {
158          self.size
159      }
160  }
161  
162  pub struct BatchReceiver {
163      size: u32,
164      rx: UnboundedReceiver<QueueItem>,
165      queued_messages: Arc<AtomicU32>,
166      queued_bytes: Arc<AtomicU64>,
167      notify: Arc<Notify>,
168  }
169  
170  impl BatchReceiver {
171      pub async fn recv(&mut self) -> Option<QueueItem> {
172          let item = self.rx.recv().await;
173          if let Some(item) = &item {
174              let count = message_count(item, self.size);
175              self.queued_messages.fetch_sub(count, Ordering::SeqCst);
176              self.queued_bytes
177                  .fetch_sub(message_bytes(item), Ordering::AcqRel);
178              self.notify.notify_waiters();
179          }
180          item
181      }
182  }
183  
184  pub fn batch_bounded(size: u32) -> (BatchSender, BatchReceiver) {
185      let (tx, rx) = unbounded_channel();
186      let notify = Arc::new(Notify::new());
187      let queued_messages = Arc::new(AtomicU32::new(0));
188      let queued_bytes = Arc::new(AtomicU64::new(0));
189      (
190          BatchSender {
191              size,
192              tx,
193              queued_messages: queued_messages.clone(),
194              queued_bytes: queued_bytes.clone(),
195              notify: notify.clone(),
196          },
197          BatchReceiver {
198              size,
199              rx,
200              notify,
201              queued_bytes,
202              queued_messages,
203          },
204      )
205  }
206  
207  struct ContextBuffer {
208      buffer: Vec<Box<dyn ArrayBuilder>>,
209      created: Instant,
210      schema: SchemaRef,
211  }
212  
213  impl ContextBuffer {
214      fn new(schema: SchemaRef) -> Self {
215          let buffer = schema
216              .fields
217              .iter()
218              .map(|f| make_builder(f.data_type(), 16))
219              .collect();
220  
221          Self {
222              buffer,
223              created: Instant::now(),
224              schema,
225          }
226      }
227  
228      pub fn size(&self) -> usize {
229          self.buffer[0].len()
230      }
231  
232      pub fn should_flush(&self) -> bool {
233          should_flush(self.size(), self.created)
234      }
235  
236      pub fn finish(self) -> RecordBatch {
237          RecordBatch::try_new(
238              self.schema,
239              self.buffer.into_iter().map(|mut a| a.finish()).collect(),
240          )
241          .unwrap()
242      }
243  }
244  
245  pub struct ArrowContext {
246      pub task_info: Arc<TaskInfo>,
247      pub control_rx: Receiver<ControlMessage>,
248      pub control_tx: Sender<ControlResp>,
249      pub error_reporter: ErrorReporter,
250      pub watermarks: WatermarkHolder,
251      pub in_schemas: Vec<ArroyoSchema>,
252      pub out_schema: Option<ArroyoSchema>,
253      pub collector: ArrowCollector,
254      buffer: Option<ContextBuffer>,
255      buffered_error: Option<UserError>,
256      error_rate_limiter: RateLimiter,
257      deserializer: Option<ArrowDeserializer>,
258      pub table_manager: TableManager,
259  }
260  
261  #[derive(Clone)]
262  pub struct ErrorReporter {
263      pub tx: Sender<ControlResp>,
264      pub task_info: Arc<TaskInfo>,
265  }
266  
267  impl ErrorReporter {
268      pub async fn report_error(&mut self, message: impl Into<String>, details: impl Into<String>) {
269          self.tx
270              .send(ControlResp::Error {
271                  operator_id: self.task_info.operator_id.clone(),
272                  task_index: self.task_info.task_index,
273                  message: message.into(),
274                  details: details.into(),
275              })
276              .await
277              .unwrap();
278      }
279  }
280  
281  #[derive(Clone)]
282  pub struct ArrowCollector {
283      task_info: Arc<TaskInfo>,
284      out_schema: Option<ArroyoSchema>,
285      projection: Option<Vec<usize>>,
286      out_qs: Vec<Vec<BatchSender>>,
287      tx_queue_rem_gauges: QueueGauges,
288      tx_queue_size_gauges: QueueGauges,
289      tx_queue_bytes_gauges: QueueGauges,
290  }
291  
292  fn repartition<'a>(
293      record: &'a RecordBatch,
294      keys: &'a Option<Vec<usize>>,
295      qs: usize,
296  ) -> impl Iterator<Item = (usize, RecordBatch)> + 'a {
297      let mut buf = vec![0; record.num_rows()];
298  
299      if let Some(keys) = keys {
300          let keys: Vec<_> = keys.iter().map(|i| record.column(*i).clone()).collect();
301  
302          hash_utils::create_hashes(&keys[..], &get_hasher(), &mut buf).unwrap();
303          let buf_array = PrimitiveArray::from(buf);
304  
305          let servers = server_for_hash_array(&buf_array, qs).unwrap();
306  
307          let indices = sort_to_indices(&servers, None, None).unwrap();
308          let columns = record
309              .columns()
310              .iter()
311              .map(|c| take(c, &indices, None).unwrap())
312              .collect();
313          let sorted = RecordBatch::try_new(record.schema(), columns).unwrap();
314          let sorted_keys = take(&servers, &indices, None).unwrap();
315  
316          let partition: arrow::compute::Partitions =
317              partition(vec![sorted_keys.clone()].as_slice()).unwrap();
318          let typed_keys: &PrimitiveArray<UInt64Type> = sorted_keys.as_any().downcast_ref().unwrap();
319          let result: Vec<_> = partition
320              .ranges()
321              .into_iter()
322              .map(|range| {
323                  let server_batch = sorted.slice(range.start, range.end - range.start);
324                  let server_id = typed_keys.value(range.start) as usize;
325                  (server_id, server_batch)
326              })
327              .collect();
328          result.into_iter()
329      } else {
330          let range_size = record.num_rows() / qs + 1;
331          let rotation = rand::thread_rng().gen_range(0..qs);
332          let result: Vec<_> = (0..qs)
333              .filter_map(|i| {
334                  let start = i * range_size;
335                  let end = (i + 1) * range_size;
336                  if start >= record.num_rows() {
337                      None
338                  } else {
339                      let server_batch = record.slice(start, end.min(record.num_rows()) - start);
340                      Some(((i + rotation) % qs, server_batch))
341                  }
342              })
343              .collect();
344          result.into_iter()
345      }
346  }
347  
348  impl ArrowCollector {
349      pub async fn collect(&mut self, record: RecordBatch) {
350          TaskCounters::MessagesSent
351              .for_task(&self.task_info, |c| c.inc_by(record.num_rows() as u64));
352          TaskCounters::BatchesSent.for_task(&self.task_info, |c| c.inc());
353          TaskCounters::BytesSent.for_task(&self.task_info, |c| {
354              c.inc_by(record.get_array_memory_size() as u64)
355          });
356  
357          let out_schema = self.out_schema.as_ref().unwrap();
358  
359          let record = if let Some(projection) = &self.projection {
360              record.project(projection).unwrap_or_else(|e| {
361                  panic!(
362                      "failed to project for operator {}: {}",
363                      self.task_info.operator_id, e
364                  )
365              })
366          } else {
367              record
368          };
369  
370          let record = RecordBatch::try_new(out_schema.schema.clone(), record.columns().to_vec())
371              .unwrap_or_else(|e| {
372                  panic!(
373                      "Data does not match expected schema for {}: {:?}. expected schema:\n{:#?}\n, actual schema:\n{:#?}",
374                      self.task_info.operator_id, e, out_schema.schema, record.schema()
375                  );
376              });
377  
378          for (i, out_q) in self.out_qs.iter_mut().enumerate() {
379              let partitions = repartition(&record, &out_schema.key_indices, out_q.len());
380  
381              for (partition, batch) in partitions {
382                  out_q[partition]
383                      .send(ArrowMessage::Data(batch))
384                      .await
385                      .unwrap();
386  
387                  self.tx_queue_rem_gauges[i][partition]
388                      .iter()
389                      .for_each(|g| g.set(out_q[partition].capacity() as i64));
390  
391                  self.tx_queue_size_gauges[i][partition]
392                      .iter()
393                      .for_each(|g| g.set(out_q[partition].size() as i64));
394  
395                  self.tx_queue_bytes_gauges[i][partition]
396                      .iter()
397                      .for_each(|g| g.set(out_q[partition].queued_bytes() as i64));
398              }
399          }
400      }
401  
402      pub async fn broadcast(&mut self, message: ArrowMessage) {
403          for out_node in &self.out_qs {
404              for q in out_node {
405                  q.send(message.clone()).await.unwrap_or_else(|e| {
406                      panic!(
407                          "failed to broadcast message <{:?}> for operator {}: {}",
408                          message, self.task_info.operator_id, e
409                      )
410                  });
411              }
412          }
413      }
414  }
415  
416  impl ArrowContext {
417      #[allow(clippy::too_many_arguments)]
418      pub async fn new(
419          task_info: TaskInfo,
420          restore_from: Option<CheckpointMetadata>,
421          control_rx: Receiver<ControlMessage>,
422          control_tx: Sender<ControlResp>,
423          input_partitions: usize,
424          in_schemas: Vec<ArroyoSchema>,
425          out_schema: Option<ArroyoSchema>,
426          projection: Option<Vec<usize>>,
427          out_qs: Vec<Vec<BatchSender>>,
428          tables: HashMap<String, TableConfig>,
429      ) -> Self {
430          let (watermark, metadata) = if let Some(metadata) = restore_from {
431              let (watermark, operator_metadata) = {
432                  let metadata = StateBackend::load_operator_metadata(
433                      &task_info.job_id,
434                      &task_info.operator_id,
435                      metadata.epoch,
436                  )
437                  .await
438                  .expect("lookup should succeed")
439                  .expect("require metadata");
440                  (
441                      metadata
442                          .operator_metadata
443                          .as_ref()
444                          .unwrap()
445                          .min_watermark
446                          .map(from_micros),
447                      metadata,
448                  )
449              };
450  
451              (watermark, Some(operator_metadata))
452          } else {
453              (None, None)
454          };
455  
456          let tx_queue_size_gauges = register_queue_gauge(
457              "arroyo_worker_tx_queue_size",
458              "Size of a tx queue",
459              &task_info,
460              &out_qs,
461              config().worker.queue_size as i64,
462          );
463  
464          let tx_queue_rem_gauges = register_queue_gauge(
465              "arroyo_worker_tx_queue_rem",
466              "Remaining space in a tx queue",
467              &task_info,
468              &out_qs,
469              config().worker.queue_size as i64,
470          );
471  
472          let tx_queue_bytes_gauges = register_queue_gauge(
473              "arroyo_worker_tx_bytes",
474              "Number of bytes queued in a tx queue",
475              &task_info,
476              &out_qs,
477              0,
478          );
479  
480          let task_info = Arc::new(task_info);
481  
482          // initialize counters so that tasks that never produce data still report 0
483          for m in TaskCounters::variants() {
484              // just initialize it
485              m.for_task(&task_info, |_| {});
486          }
487  
488          let table_manager =
489              TableManager::new(task_info.clone(), tables, control_tx.clone(), metadata)
490                  .await
491                  .expect("should be able to create TableManager");
492  
493          Self {
494              task_info: task_info.clone(),
495              control_rx,
496              control_tx: control_tx.clone(),
497              watermarks: WatermarkHolder::new(vec![
498                  watermark.map(Watermark::EventTime);
499                  input_partitions
500              ]),
501              in_schemas,
502              out_schema: out_schema.clone(),
503              collector: ArrowCollector {
504                  task_info: task_info.clone(),
505                  out_qs,
506                  tx_queue_rem_gauges,
507                  tx_queue_size_gauges,
508                  tx_queue_bytes_gauges,
509                  out_schema: out_schema.clone(),
510                  projection,
511              },
512              error_reporter: ErrorReporter {
513                  tx: control_tx,
514                  task_info,
515              },
516              buffer: out_schema.map(|t| ContextBuffer::new(t.schema)),
517              error_rate_limiter: RateLimiter::new(),
518              deserializer: None,
519              buffered_error: None,
520              table_manager,
521          }
522      }
523  
524      pub fn watermark(&self) -> Option<Watermark> {
525          self.watermarks.watermark()
526      }
527  
528      pub fn last_present_watermark(&self) -> Option<SystemTime> {
529          self.watermarks.last_present_watermark()
530      }
531  
532      pub async fn flush_buffer(&mut self) -> Result<(), UserError> {
533          if self.buffer.is_none() {
534              return Ok(());
535          }
536  
537          if self.buffer.as_ref().unwrap().size() > 0 {
538              let buffer = self.buffer.take().unwrap();
539              let batch = buffer.finish();
540              self.collector.collect(batch).await;
541              self.buffer = Some(ContextBuffer::new(
542                  self.out_schema.as_ref().map(|t| t.schema.clone()).unwrap(),
543              ));
544          }
545  
546          if let Some(deserializer) = self.deserializer.as_mut() {
547              if let Some(buffer) = deserializer.flush_buffer() {
548                  match buffer {
549                      Ok(batch) => {
550                          self.collector.collect(batch).await;
551                      }
552                      Err(e) => {
553                          self.collect_source_errors(vec![e]).await?;
554                      }
555                  }
556              }
557          }
558  
559          if let Some(error) = self.buffered_error.take() {
560              return Err(error);
561          }
562  
563          Ok(())
564      }
565  
566      pub async fn collect(&mut self, record: RecordBatch) {
567          self.collector.collect(record).await;
568      }
569  
570      pub fn should_flush(&self) -> bool {
571          self.buffer
572              .as_ref()
573              .map(|b| b.should_flush())
574              .unwrap_or(false)
575              || self
576                  .deserializer
577                  .as_ref()
578                  .map(|d| d.should_flush())
579                  .unwrap_or(false)
580      }
581  
582      pub async fn broadcast(&mut self, message: ArrowMessage) {
583          if let Err(e) = self.flush_buffer().await {
584              self.buffered_error.replace(e);
585          }
586          self.collector.broadcast(message).await;
587      }
588  
589      pub async fn report_error(&mut self, message: impl Into<String>, details: impl Into<String>) {
590          self.error_reporter.report_error(message, details).await;
591      }
592  
593      pub async fn report_user_error(&mut self, error: UserError) {
594          self.control_tx
595              .send(ControlResp::Error {
596                  operator_id: self.task_info.operator_id.clone(),
597                  task_index: self.task_info.task_index,
598                  message: error.name,
599                  details: error.details,
600              })
601              .await
602              .unwrap();
603      }
604  
605      pub async fn send_checkpoint_event(
606          &mut self,
607          barrier: CheckpointBarrier,
608          event_type: TaskCheckpointEventType,
609      ) {
610          // These messages are received by the engine control thread,
611          // which then sends a TaskCheckpointEventReq to the controller.
612          self.control_tx
613              .send(ControlResp::CheckpointEvent(arroyo_rpc::CheckpointEvent {
614                  checkpoint_epoch: barrier.epoch,
615                  operator_id: self.task_info.operator_id.clone(),
616                  subtask_index: self.task_info.task_index as u32,
617                  time: SystemTime::now(),
618                  event_type,
619              }))
620              .await
621              .unwrap();
622      }
623  
624      pub async fn load_compacted(&mut self, compaction: CompactionResult) {
625          //TODO: support compaction in the table manager
626          self.table_manager
627              .load_compacted(compaction)
628              .await
629              .expect("should be able to load compacted");
630      }
631  
632      pub fn initialize_deserializer(
633          &mut self,
634          format: Format,
635          framing: Option<Framing>,
636          bad_data: Option<BadData>,
637      ) {
638          if self.deserializer.is_some() {
639              panic!("Deserialize already initialized");
640          }
641  
642          self.deserializer = Some(ArrowDeserializer::new(
643              format,
644              self.out_schema.as_ref().expect("no out schema").clone(),
645              framing,
646              bad_data.unwrap_or_default(),
647          ));
648      }
649  
650      pub fn initialize_deserializer_with_resolver(
651          &mut self,
652          format: Format,
653          framing: Option<Framing>,
654          bad_data: Option<BadData>,
655          schema_resolver: Arc<dyn SchemaResolver + Sync>,
656      ) {
657          self.deserializer = Some(ArrowDeserializer::with_schema_resolver(
658              format,
659              framing,
660              self.out_schema.as_ref().expect("no out schema").clone(),
661              bad_data.unwrap_or_default(),
662              schema_resolver,
663          ));
664      }
665  
666      pub async fn deserialize_slice(
667          &mut self,
668          msg: &[u8],
669          time: SystemTime,
670      ) -> Result<(), UserError> {
671          let deserializer = self
672              .deserializer
673              .as_mut()
674              .expect("deserializer not initialized!");
675          let errors = deserializer
676              .deserialize_slice(
677                  &mut self.buffer.as_mut().expect("no out schema").buffer,
678                  msg,
679                  time,
680              )
681              .await;
682          self.collect_source_errors(errors).await?;
683  
684          Ok(())
685      }
686  
687      /// Handling errors and rate limiting error reporting.
688      /// Considers the `bad_data` option to determine whether to drop or fail on bad data.
689      async fn collect_source_errors(&mut self, errors: Vec<SourceError>) -> Result<(), UserError> {
690          let bad_data = self
691              .deserializer
692              .as_ref()
693              .expect("deserializer not initialized")
694              .bad_data();
695          for error in errors {
696              match error {
697                  SourceError::BadData { details } => match bad_data {
698                      BadData::Drop {} => {
699                          self.error_rate_limiter
700                              .rate_limit(|| async {
701                                  warn!("Dropping invalid data: {}", details.clone());
702                                  self.control_tx
703                                      .send(ControlResp::Error {
704                                          operator_id: self.task_info.operator_id.clone(),
705                                          task_index: self.task_info.task_index,
706                                          message: "Dropping invalid data".to_string(),
707                                          details,
708                                      })
709                                      .await
710                                      .unwrap();
711                              })
712                              .await;
713                          TaskCounters::DeserializationErrors.for_task(&self.task_info, |c| c.inc())
714                      }
715                      BadData::Fail {} => {
716                          return Err(UserError::new("Deserialization error", details));
717                      }
718                  },
719                  SourceError::Other { name, details } => {
720                      return Err(UserError::new(name, details));
721                  }
722              }
723          }
724  
725          Ok(())
726      }
727  }
728  
729  #[cfg(test)]
730  mod tests {
731      use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray, UInt64Array};
732      use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
733      use arroyo_types::to_nanos;
734      use std::time::Duration;
735  
736      use super::*;
737  
738      #[test]
739      fn test_watermark_holder() {
740          let t1 = SystemTime::UNIX_EPOCH;
741          let t2 = t1 + Duration::from_secs(1);
742          let t3 = t2 + Duration::from_secs(1);
743  
744          let mut w = WatermarkHolder::new(vec![None, None, None]);
745  
746          assert!(w.watermark().is_none());
747  
748          w.set(0, Watermark::EventTime(t1));
749          w.set(1, Watermark::EventTime(t2));
750  
751          assert!(w.watermark().is_none());
752  
753          w.set(2, Watermark::EventTime(t3));
754  
755          assert_eq!(w.watermark(), Some(Watermark::EventTime(t1)));
756  
757          w.set(0, Watermark::Idle);
758          assert_eq!(w.watermark(), Some(Watermark::EventTime(t2)));
759  
760          w.set(1, Watermark::Idle);
761          w.set(2, Watermark::Idle);
762          assert_eq!(w.watermark(), Some(Watermark::Idle));
763      }
764  
765      #[tokio::test]
766      async fn test_shuffles() {
767          let timestamp = SystemTime::now();
768  
769          let data = vec![0, 101, 0, 101, 0, 101, 0, 0];
770  
771          let columns: Vec<ArrayRef> = vec![
772              Arc::new(UInt64Array::from(data.clone())),
773              Arc::new(TimestampNanosecondArray::from(
774                  data.iter()
775                      .map(|_| to_nanos(timestamp) as i64)
776                      .collect::<Vec<_>>(),
777              )),
778          ];
779  
780          let schema = Arc::new(Schema::new(vec![
781              Field::new("key", DataType::UInt64, false),
782              Field::new(
783                  "time",
784                  DataType::Timestamp(TimeUnit::Nanosecond, None),
785                  false,
786              ),
787          ]));
788  
789          let (tx1, mut rx1) = batch_bounded(8);
790          let (tx2, mut rx2) = batch_bounded(8);
791  
792          let record = RecordBatch::try_new(schema.clone(), columns).unwrap();
793  
794          let task_info = Arc::new(TaskInfo {
795              job_id: "test-job".to_string(),
796              operator_name: "test-operator".to_string(),
797              operator_id: "test-operator-1".to_string(),
798              task_index: 0,
799              parallelism: 1,
800              key_range: 0..=1,
801          });
802  
803          let out_qs = vec![vec![tx1, tx2]];
804  
805          let tx_queue_size_gauges = register_queue_gauge(
806              "arroyo_worker_tx_queue_size",
807              "Size of a tx queue",
808              &task_info,
809              &out_qs,
810              0,
811          );
812  
813          let tx_queue_rem_gauges = register_queue_gauge(
814              "arroyo_worker_tx_queue_rem",
815              "Remaining space in a tx queue",
816              &task_info,
817              &out_qs,
818              0,
819          );
820  
821          let tx_queue_bytes_gauges = register_queue_gauge(
822              "arroyo_worker_tx_bytes",
823              "Number of bytes queued in a tx queue",
824              &task_info,
825              &out_qs,
826              0,
827          );
828  
829          let mut collector = ArrowCollector {
830              task_info,
831              out_schema: Some(ArroyoSchema::new_keyed(schema, 1, vec![0])),
832              projection: None,
833              out_qs,
834              tx_queue_rem_gauges,
835              tx_queue_size_gauges,
836              tx_queue_bytes_gauges,
837          };
838  
839          collector.collect(record).await;
840  
841          drop(collector);
842  
843          // pull all messages out of the two queues
844          let mut q1 = vec![];
845          while let Some(m) = rx1.recv().await {
846              q1.push(m);
847          }
848  
849          let mut q2 = vec![];
850          while let Some(m) = rx2.recv().await {
851              q2.push(m);
852          }
853  
854          let v1 = &q1[0];
855          for v in &q1[1..] {
856              assert_eq!(v1, v);
857          }
858  
859          let v2 = &q2[0];
860          for v in &q2[1..] {
861              assert_eq!(v2, v);
862          }
863      }
864  
865      #[tokio::test]
866      async fn test_batch_queues() {
867          let (tx, mut rx) = batch_bounded(8);
868          let msg = RecordBatch::try_new(
869              Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)])),
870              vec![Arc::new(Int64Array::from(vec![1, 2, 3, 4]))],
871          )
872          .unwrap();
873  
874          tx.send(ArrowMessage::Data(msg.clone())).await.unwrap();
875          tx.send(ArrowMessage::Data(msg.clone())).await.unwrap();
876  
877          assert_eq!(tx.capacity(), 0);
878  
879          rx.recv().await.unwrap();
880          rx.recv().await.unwrap();
881  
882          assert_eq!(tx.capacity(), 8);
883      }
884  }