context.rs
1 use crate::{server_for_hash_array, RateLimiter}; 2 use arrow::array::{make_builder, Array, ArrayBuilder, PrimitiveArray, RecordBatch}; 3 use arrow::compute::{partition, sort_to_indices, take}; 4 use arrow::datatypes::{SchemaRef, UInt64Type}; 5 use arroyo_formats::de::ArrowDeserializer; 6 use arroyo_formats::should_flush; 7 use arroyo_metrics::{register_queue_gauge, QueueGauges, TaskCounters}; 8 use arroyo_rpc::config::config; 9 use arroyo_rpc::df::ArroyoSchema; 10 use arroyo_rpc::formats::{BadData, Format, Framing}; 11 use arroyo_rpc::grpc::rpc::{CheckpointMetadata, TableConfig, TaskCheckpointEventType}; 12 use arroyo_rpc::schema_resolver::SchemaResolver; 13 use arroyo_rpc::{get_hasher, CompactionResult, ControlMessage, ControlResp}; 14 use arroyo_state::tables::table_manager::TableManager; 15 use arroyo_state::{BackingStore, StateBackend}; 16 use arroyo_types::{ 17 from_micros, ArrowMessage, CheckpointBarrier, SourceError, TaskInfo, UserError, Watermark, 18 }; 19 use datafusion::common::hash_utils; 20 use rand::Rng; 21 use std::collections::HashMap; 22 use std::mem::size_of_val; 23 use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; 24 use std::sync::Arc; 25 use std::time::{Instant, SystemTime}; 26 use tokio::sync::mpsc::error::SendError; 27 use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender}; 28 use tokio::sync::Notify; 29 use tracing::warn; 30 31 pub type QueueItem = ArrowMessage; 32 33 pub struct WatermarkHolder { 34 // This is the last watermark with an actual value; this helps us keep track of the watermark we're at even 35 // if we're currently idle 36 last_present_watermark: Option<SystemTime>, 37 cur_watermark: Option<Watermark>, 38 watermarks: Vec<Option<Watermark>>, 39 } 40 41 impl WatermarkHolder { 42 pub fn new(watermarks: Vec<Option<Watermark>>) -> Self { 43 let mut s = Self { 44 last_present_watermark: None, 45 cur_watermark: None, 46 watermarks, 47 }; 48 s.update_watermark(); 49 50 s 51 } 52 53 pub fn watermark(&self) -> Option<Watermark> { 54 self.cur_watermark 55 } 56 57 pub fn last_present_watermark(&self) -> Option<SystemTime> { 58 self.last_present_watermark 59 } 60 61 fn update_watermark(&mut self) { 62 self.cur_watermark = 63 self.watermarks 64 .iter() 65 .try_fold(Watermark::Idle, |current, next| match (current, (*next)?) { 66 (Watermark::EventTime(cur), Watermark::EventTime(next)) => { 67 Some(Watermark::EventTime(cur.min(next))) 68 } 69 (Watermark::Idle, Watermark::EventTime(t)) 70 | (Watermark::EventTime(t), Watermark::Idle) => Some(Watermark::EventTime(t)), 71 (Watermark::Idle, Watermark::Idle) => Some(Watermark::Idle), 72 }); 73 74 if let Some(Watermark::EventTime(t)) = self.cur_watermark { 75 self.last_present_watermark = Some(t); 76 } 77 } 78 79 pub fn set(&mut self, idx: usize, watermark: Watermark) -> Option<Option<Watermark>> { 80 *(self.watermarks.get_mut(idx)?) = Some(watermark); 81 self.update_watermark(); 82 Some(self.cur_watermark) 83 } 84 } 85 86 /// A wrapper for an UnboundedSender<QueueItem> that bounds by the number of rows within 87 /// a batch rather than the number of batches 88 #[derive(Clone)] 89 pub struct BatchSender { 90 size: u32, 91 tx: UnboundedSender<QueueItem>, 92 queued_messages: Arc<AtomicU32>, 93 queued_bytes: Arc<AtomicU64>, 94 notify: Arc<Notify>, 95 } 96 97 #[inline] 98 fn message_count(item: &QueueItem, size: u32) -> u32 { 99 match item { 100 QueueItem::Data(d) => (d.num_rows() as u32).min(size), 101 QueueItem::Signal(_) => 1, 102 } 103 } 104 105 #[inline] 106 fn message_bytes(item: &QueueItem) -> u64 { 107 match item { 108 QueueItem::Data(d) => d.get_array_memory_size() as u64, 109 QueueItem::Signal(s) => size_of_val(s) as u64, 110 } 111 } 112 113 impl BatchSender { 114 pub async fn send(&self, item: QueueItem) -> Result<(), SendError<QueueItem>> { 115 // Ensure that every message is sendable, even if it's bigger than our max size 116 let count = message_count(&item, self.size); 117 loop { 118 if self.tx.is_closed() { 119 return Err(SendError(item)); 120 } 121 122 let cur = self.queued_messages.load(Ordering::Acquire); 123 if cur as usize + count as usize <= self.size as usize { 124 match self.queued_messages.compare_exchange( 125 cur, 126 cur + count, 127 Ordering::SeqCst, 128 Ordering::SeqCst, 129 ) { 130 Ok(_) => { 131 self.queued_bytes 132 .fetch_add(message_bytes(&item), Ordering::AcqRel); 133 return self.tx.send(item); 134 } 135 Err(_) => { 136 // try again 137 continue; 138 } 139 } 140 } else { 141 // not enough room in the queue, wait to be notified that the receiver has 142 // consumed 143 self.notify.notified().await; 144 } 145 } 146 } 147 148 pub fn capacity(&self) -> u32 { 149 self.size 150 .saturating_sub(self.queued_messages.load(Ordering::Relaxed)) 151 } 152 153 pub fn queued_bytes(&self) -> u64 { 154 self.queued_bytes.load(Ordering::Relaxed) 155 } 156 157 pub fn size(&self) -> u32 { 158 self.size 159 } 160 } 161 162 pub struct BatchReceiver { 163 size: u32, 164 rx: UnboundedReceiver<QueueItem>, 165 queued_messages: Arc<AtomicU32>, 166 queued_bytes: Arc<AtomicU64>, 167 notify: Arc<Notify>, 168 } 169 170 impl BatchReceiver { 171 pub async fn recv(&mut self) -> Option<QueueItem> { 172 let item = self.rx.recv().await; 173 if let Some(item) = &item { 174 let count = message_count(item, self.size); 175 self.queued_messages.fetch_sub(count, Ordering::SeqCst); 176 self.queued_bytes 177 .fetch_sub(message_bytes(item), Ordering::AcqRel); 178 self.notify.notify_waiters(); 179 } 180 item 181 } 182 } 183 184 pub fn batch_bounded(size: u32) -> (BatchSender, BatchReceiver) { 185 let (tx, rx) = unbounded_channel(); 186 let notify = Arc::new(Notify::new()); 187 let queued_messages = Arc::new(AtomicU32::new(0)); 188 let queued_bytes = Arc::new(AtomicU64::new(0)); 189 ( 190 BatchSender { 191 size, 192 tx, 193 queued_messages: queued_messages.clone(), 194 queued_bytes: queued_bytes.clone(), 195 notify: notify.clone(), 196 }, 197 BatchReceiver { 198 size, 199 rx, 200 notify, 201 queued_bytes, 202 queued_messages, 203 }, 204 ) 205 } 206 207 struct ContextBuffer { 208 buffer: Vec<Box<dyn ArrayBuilder>>, 209 created: Instant, 210 schema: SchemaRef, 211 } 212 213 impl ContextBuffer { 214 fn new(schema: SchemaRef) -> Self { 215 let buffer = schema 216 .fields 217 .iter() 218 .map(|f| make_builder(f.data_type(), 16)) 219 .collect(); 220 221 Self { 222 buffer, 223 created: Instant::now(), 224 schema, 225 } 226 } 227 228 pub fn size(&self) -> usize { 229 self.buffer[0].len() 230 } 231 232 pub fn should_flush(&self) -> bool { 233 should_flush(self.size(), self.created) 234 } 235 236 pub fn finish(self) -> RecordBatch { 237 RecordBatch::try_new( 238 self.schema, 239 self.buffer.into_iter().map(|mut a| a.finish()).collect(), 240 ) 241 .unwrap() 242 } 243 } 244 245 pub struct ArrowContext { 246 pub task_info: Arc<TaskInfo>, 247 pub control_rx: Receiver<ControlMessage>, 248 pub control_tx: Sender<ControlResp>, 249 pub error_reporter: ErrorReporter, 250 pub watermarks: WatermarkHolder, 251 pub in_schemas: Vec<ArroyoSchema>, 252 pub out_schema: Option<ArroyoSchema>, 253 pub collector: ArrowCollector, 254 buffer: Option<ContextBuffer>, 255 buffered_error: Option<UserError>, 256 error_rate_limiter: RateLimiter, 257 deserializer: Option<ArrowDeserializer>, 258 pub table_manager: TableManager, 259 } 260 261 #[derive(Clone)] 262 pub struct ErrorReporter { 263 pub tx: Sender<ControlResp>, 264 pub task_info: Arc<TaskInfo>, 265 } 266 267 impl ErrorReporter { 268 pub async fn report_error(&mut self, message: impl Into<String>, details: impl Into<String>) { 269 self.tx 270 .send(ControlResp::Error { 271 operator_id: self.task_info.operator_id.clone(), 272 task_index: self.task_info.task_index, 273 message: message.into(), 274 details: details.into(), 275 }) 276 .await 277 .unwrap(); 278 } 279 } 280 281 #[derive(Clone)] 282 pub struct ArrowCollector { 283 task_info: Arc<TaskInfo>, 284 out_schema: Option<ArroyoSchema>, 285 projection: Option<Vec<usize>>, 286 out_qs: Vec<Vec<BatchSender>>, 287 tx_queue_rem_gauges: QueueGauges, 288 tx_queue_size_gauges: QueueGauges, 289 tx_queue_bytes_gauges: QueueGauges, 290 } 291 292 fn repartition<'a>( 293 record: &'a RecordBatch, 294 keys: &'a Option<Vec<usize>>, 295 qs: usize, 296 ) -> impl Iterator<Item = (usize, RecordBatch)> + 'a { 297 let mut buf = vec![0; record.num_rows()]; 298 299 if let Some(keys) = keys { 300 let keys: Vec<_> = keys.iter().map(|i| record.column(*i).clone()).collect(); 301 302 hash_utils::create_hashes(&keys[..], &get_hasher(), &mut buf).unwrap(); 303 let buf_array = PrimitiveArray::from(buf); 304 305 let servers = server_for_hash_array(&buf_array, qs).unwrap(); 306 307 let indices = sort_to_indices(&servers, None, None).unwrap(); 308 let columns = record 309 .columns() 310 .iter() 311 .map(|c| take(c, &indices, None).unwrap()) 312 .collect(); 313 let sorted = RecordBatch::try_new(record.schema(), columns).unwrap(); 314 let sorted_keys = take(&servers, &indices, None).unwrap(); 315 316 let partition: arrow::compute::Partitions = 317 partition(vec![sorted_keys.clone()].as_slice()).unwrap(); 318 let typed_keys: &PrimitiveArray<UInt64Type> = sorted_keys.as_any().downcast_ref().unwrap(); 319 let result: Vec<_> = partition 320 .ranges() 321 .into_iter() 322 .map(|range| { 323 let server_batch = sorted.slice(range.start, range.end - range.start); 324 let server_id = typed_keys.value(range.start) as usize; 325 (server_id, server_batch) 326 }) 327 .collect(); 328 result.into_iter() 329 } else { 330 let range_size = record.num_rows() / qs + 1; 331 let rotation = rand::thread_rng().gen_range(0..qs); 332 let result: Vec<_> = (0..qs) 333 .filter_map(|i| { 334 let start = i * range_size; 335 let end = (i + 1) * range_size; 336 if start >= record.num_rows() { 337 None 338 } else { 339 let server_batch = record.slice(start, end.min(record.num_rows()) - start); 340 Some(((i + rotation) % qs, server_batch)) 341 } 342 }) 343 .collect(); 344 result.into_iter() 345 } 346 } 347 348 impl ArrowCollector { 349 pub async fn collect(&mut self, record: RecordBatch) { 350 TaskCounters::MessagesSent 351 .for_task(&self.task_info, |c| c.inc_by(record.num_rows() as u64)); 352 TaskCounters::BatchesSent.for_task(&self.task_info, |c| c.inc()); 353 TaskCounters::BytesSent.for_task(&self.task_info, |c| { 354 c.inc_by(record.get_array_memory_size() as u64) 355 }); 356 357 let out_schema = self.out_schema.as_ref().unwrap(); 358 359 let record = if let Some(projection) = &self.projection { 360 record.project(projection).unwrap_or_else(|e| { 361 panic!( 362 "failed to project for operator {}: {}", 363 self.task_info.operator_id, e 364 ) 365 }) 366 } else { 367 record 368 }; 369 370 let record = RecordBatch::try_new(out_schema.schema.clone(), record.columns().to_vec()) 371 .unwrap_or_else(|e| { 372 panic!( 373 "Data does not match expected schema for {}: {:?}. expected schema:\n{:#?}\n, actual schema:\n{:#?}", 374 self.task_info.operator_id, e, out_schema.schema, record.schema() 375 ); 376 }); 377 378 for (i, out_q) in self.out_qs.iter_mut().enumerate() { 379 let partitions = repartition(&record, &out_schema.key_indices, out_q.len()); 380 381 for (partition, batch) in partitions { 382 out_q[partition] 383 .send(ArrowMessage::Data(batch)) 384 .await 385 .unwrap(); 386 387 self.tx_queue_rem_gauges[i][partition] 388 .iter() 389 .for_each(|g| g.set(out_q[partition].capacity() as i64)); 390 391 self.tx_queue_size_gauges[i][partition] 392 .iter() 393 .for_each(|g| g.set(out_q[partition].size() as i64)); 394 395 self.tx_queue_bytes_gauges[i][partition] 396 .iter() 397 .for_each(|g| g.set(out_q[partition].queued_bytes() as i64)); 398 } 399 } 400 } 401 402 pub async fn broadcast(&mut self, message: ArrowMessage) { 403 for out_node in &self.out_qs { 404 for q in out_node { 405 q.send(message.clone()).await.unwrap_or_else(|e| { 406 panic!( 407 "failed to broadcast message <{:?}> for operator {}: {}", 408 message, self.task_info.operator_id, e 409 ) 410 }); 411 } 412 } 413 } 414 } 415 416 impl ArrowContext { 417 #[allow(clippy::too_many_arguments)] 418 pub async fn new( 419 task_info: TaskInfo, 420 restore_from: Option<CheckpointMetadata>, 421 control_rx: Receiver<ControlMessage>, 422 control_tx: Sender<ControlResp>, 423 input_partitions: usize, 424 in_schemas: Vec<ArroyoSchema>, 425 out_schema: Option<ArroyoSchema>, 426 projection: Option<Vec<usize>>, 427 out_qs: Vec<Vec<BatchSender>>, 428 tables: HashMap<String, TableConfig>, 429 ) -> Self { 430 let (watermark, metadata) = if let Some(metadata) = restore_from { 431 let (watermark, operator_metadata) = { 432 let metadata = StateBackend::load_operator_metadata( 433 &task_info.job_id, 434 &task_info.operator_id, 435 metadata.epoch, 436 ) 437 .await 438 .expect("lookup should succeed") 439 .expect("require metadata"); 440 ( 441 metadata 442 .operator_metadata 443 .as_ref() 444 .unwrap() 445 .min_watermark 446 .map(from_micros), 447 metadata, 448 ) 449 }; 450 451 (watermark, Some(operator_metadata)) 452 } else { 453 (None, None) 454 }; 455 456 let tx_queue_size_gauges = register_queue_gauge( 457 "arroyo_worker_tx_queue_size", 458 "Size of a tx queue", 459 &task_info, 460 &out_qs, 461 config().worker.queue_size as i64, 462 ); 463 464 let tx_queue_rem_gauges = register_queue_gauge( 465 "arroyo_worker_tx_queue_rem", 466 "Remaining space in a tx queue", 467 &task_info, 468 &out_qs, 469 config().worker.queue_size as i64, 470 ); 471 472 let tx_queue_bytes_gauges = register_queue_gauge( 473 "arroyo_worker_tx_bytes", 474 "Number of bytes queued in a tx queue", 475 &task_info, 476 &out_qs, 477 0, 478 ); 479 480 let task_info = Arc::new(task_info); 481 482 // initialize counters so that tasks that never produce data still report 0 483 for m in TaskCounters::variants() { 484 // just initialize it 485 m.for_task(&task_info, |_| {}); 486 } 487 488 let table_manager = 489 TableManager::new(task_info.clone(), tables, control_tx.clone(), metadata) 490 .await 491 .expect("should be able to create TableManager"); 492 493 Self { 494 task_info: task_info.clone(), 495 control_rx, 496 control_tx: control_tx.clone(), 497 watermarks: WatermarkHolder::new(vec![ 498 watermark.map(Watermark::EventTime); 499 input_partitions 500 ]), 501 in_schemas, 502 out_schema: out_schema.clone(), 503 collector: ArrowCollector { 504 task_info: task_info.clone(), 505 out_qs, 506 tx_queue_rem_gauges, 507 tx_queue_size_gauges, 508 tx_queue_bytes_gauges, 509 out_schema: out_schema.clone(), 510 projection, 511 }, 512 error_reporter: ErrorReporter { 513 tx: control_tx, 514 task_info, 515 }, 516 buffer: out_schema.map(|t| ContextBuffer::new(t.schema)), 517 error_rate_limiter: RateLimiter::new(), 518 deserializer: None, 519 buffered_error: None, 520 table_manager, 521 } 522 } 523 524 pub fn watermark(&self) -> Option<Watermark> { 525 self.watermarks.watermark() 526 } 527 528 pub fn last_present_watermark(&self) -> Option<SystemTime> { 529 self.watermarks.last_present_watermark() 530 } 531 532 pub async fn flush_buffer(&mut self) -> Result<(), UserError> { 533 if self.buffer.is_none() { 534 return Ok(()); 535 } 536 537 if self.buffer.as_ref().unwrap().size() > 0 { 538 let buffer = self.buffer.take().unwrap(); 539 let batch = buffer.finish(); 540 self.collector.collect(batch).await; 541 self.buffer = Some(ContextBuffer::new( 542 self.out_schema.as_ref().map(|t| t.schema.clone()).unwrap(), 543 )); 544 } 545 546 if let Some(deserializer) = self.deserializer.as_mut() { 547 if let Some(buffer) = deserializer.flush_buffer() { 548 match buffer { 549 Ok(batch) => { 550 self.collector.collect(batch).await; 551 } 552 Err(e) => { 553 self.collect_source_errors(vec![e]).await?; 554 } 555 } 556 } 557 } 558 559 if let Some(error) = self.buffered_error.take() { 560 return Err(error); 561 } 562 563 Ok(()) 564 } 565 566 pub async fn collect(&mut self, record: RecordBatch) { 567 self.collector.collect(record).await; 568 } 569 570 pub fn should_flush(&self) -> bool { 571 self.buffer 572 .as_ref() 573 .map(|b| b.should_flush()) 574 .unwrap_or(false) 575 || self 576 .deserializer 577 .as_ref() 578 .map(|d| d.should_flush()) 579 .unwrap_or(false) 580 } 581 582 pub async fn broadcast(&mut self, message: ArrowMessage) { 583 if let Err(e) = self.flush_buffer().await { 584 self.buffered_error.replace(e); 585 } 586 self.collector.broadcast(message).await; 587 } 588 589 pub async fn report_error(&mut self, message: impl Into<String>, details: impl Into<String>) { 590 self.error_reporter.report_error(message, details).await; 591 } 592 593 pub async fn report_user_error(&mut self, error: UserError) { 594 self.control_tx 595 .send(ControlResp::Error { 596 operator_id: self.task_info.operator_id.clone(), 597 task_index: self.task_info.task_index, 598 message: error.name, 599 details: error.details, 600 }) 601 .await 602 .unwrap(); 603 } 604 605 pub async fn send_checkpoint_event( 606 &mut self, 607 barrier: CheckpointBarrier, 608 event_type: TaskCheckpointEventType, 609 ) { 610 // These messages are received by the engine control thread, 611 // which then sends a TaskCheckpointEventReq to the controller. 612 self.control_tx 613 .send(ControlResp::CheckpointEvent(arroyo_rpc::CheckpointEvent { 614 checkpoint_epoch: barrier.epoch, 615 operator_id: self.task_info.operator_id.clone(), 616 subtask_index: self.task_info.task_index as u32, 617 time: SystemTime::now(), 618 event_type, 619 })) 620 .await 621 .unwrap(); 622 } 623 624 pub async fn load_compacted(&mut self, compaction: CompactionResult) { 625 //TODO: support compaction in the table manager 626 self.table_manager 627 .load_compacted(compaction) 628 .await 629 .expect("should be able to load compacted"); 630 } 631 632 pub fn initialize_deserializer( 633 &mut self, 634 format: Format, 635 framing: Option<Framing>, 636 bad_data: Option<BadData>, 637 ) { 638 if self.deserializer.is_some() { 639 panic!("Deserialize already initialized"); 640 } 641 642 self.deserializer = Some(ArrowDeserializer::new( 643 format, 644 self.out_schema.as_ref().expect("no out schema").clone(), 645 framing, 646 bad_data.unwrap_or_default(), 647 )); 648 } 649 650 pub fn initialize_deserializer_with_resolver( 651 &mut self, 652 format: Format, 653 framing: Option<Framing>, 654 bad_data: Option<BadData>, 655 schema_resolver: Arc<dyn SchemaResolver + Sync>, 656 ) { 657 self.deserializer = Some(ArrowDeserializer::with_schema_resolver( 658 format, 659 framing, 660 self.out_schema.as_ref().expect("no out schema").clone(), 661 bad_data.unwrap_or_default(), 662 schema_resolver, 663 )); 664 } 665 666 pub async fn deserialize_slice( 667 &mut self, 668 msg: &[u8], 669 time: SystemTime, 670 ) -> Result<(), UserError> { 671 let deserializer = self 672 .deserializer 673 .as_mut() 674 .expect("deserializer not initialized!"); 675 let errors = deserializer 676 .deserialize_slice( 677 &mut self.buffer.as_mut().expect("no out schema").buffer, 678 msg, 679 time, 680 ) 681 .await; 682 self.collect_source_errors(errors).await?; 683 684 Ok(()) 685 } 686 687 /// Handling errors and rate limiting error reporting. 688 /// Considers the `bad_data` option to determine whether to drop or fail on bad data. 689 async fn collect_source_errors(&mut self, errors: Vec<SourceError>) -> Result<(), UserError> { 690 let bad_data = self 691 .deserializer 692 .as_ref() 693 .expect("deserializer not initialized") 694 .bad_data(); 695 for error in errors { 696 match error { 697 SourceError::BadData { details } => match bad_data { 698 BadData::Drop {} => { 699 self.error_rate_limiter 700 .rate_limit(|| async { 701 warn!("Dropping invalid data: {}", details.clone()); 702 self.control_tx 703 .send(ControlResp::Error { 704 operator_id: self.task_info.operator_id.clone(), 705 task_index: self.task_info.task_index, 706 message: "Dropping invalid data".to_string(), 707 details, 708 }) 709 .await 710 .unwrap(); 711 }) 712 .await; 713 TaskCounters::DeserializationErrors.for_task(&self.task_info, |c| c.inc()) 714 } 715 BadData::Fail {} => { 716 return Err(UserError::new("Deserialization error", details)); 717 } 718 }, 719 SourceError::Other { name, details } => { 720 return Err(UserError::new(name, details)); 721 } 722 } 723 } 724 725 Ok(()) 726 } 727 } 728 729 #[cfg(test)] 730 mod tests { 731 use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray, UInt64Array}; 732 use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; 733 use arroyo_types::to_nanos; 734 use std::time::Duration; 735 736 use super::*; 737 738 #[test] 739 fn test_watermark_holder() { 740 let t1 = SystemTime::UNIX_EPOCH; 741 let t2 = t1 + Duration::from_secs(1); 742 let t3 = t2 + Duration::from_secs(1); 743 744 let mut w = WatermarkHolder::new(vec![None, None, None]); 745 746 assert!(w.watermark().is_none()); 747 748 w.set(0, Watermark::EventTime(t1)); 749 w.set(1, Watermark::EventTime(t2)); 750 751 assert!(w.watermark().is_none()); 752 753 w.set(2, Watermark::EventTime(t3)); 754 755 assert_eq!(w.watermark(), Some(Watermark::EventTime(t1))); 756 757 w.set(0, Watermark::Idle); 758 assert_eq!(w.watermark(), Some(Watermark::EventTime(t2))); 759 760 w.set(1, Watermark::Idle); 761 w.set(2, Watermark::Idle); 762 assert_eq!(w.watermark(), Some(Watermark::Idle)); 763 } 764 765 #[tokio::test] 766 async fn test_shuffles() { 767 let timestamp = SystemTime::now(); 768 769 let data = vec![0, 101, 0, 101, 0, 101, 0, 0]; 770 771 let columns: Vec<ArrayRef> = vec![ 772 Arc::new(UInt64Array::from(data.clone())), 773 Arc::new(TimestampNanosecondArray::from( 774 data.iter() 775 .map(|_| to_nanos(timestamp) as i64) 776 .collect::<Vec<_>>(), 777 )), 778 ]; 779 780 let schema = Arc::new(Schema::new(vec![ 781 Field::new("key", DataType::UInt64, false), 782 Field::new( 783 "time", 784 DataType::Timestamp(TimeUnit::Nanosecond, None), 785 false, 786 ), 787 ])); 788 789 let (tx1, mut rx1) = batch_bounded(8); 790 let (tx2, mut rx2) = batch_bounded(8); 791 792 let record = RecordBatch::try_new(schema.clone(), columns).unwrap(); 793 794 let task_info = Arc::new(TaskInfo { 795 job_id: "test-job".to_string(), 796 operator_name: "test-operator".to_string(), 797 operator_id: "test-operator-1".to_string(), 798 task_index: 0, 799 parallelism: 1, 800 key_range: 0..=1, 801 }); 802 803 let out_qs = vec![vec![tx1, tx2]]; 804 805 let tx_queue_size_gauges = register_queue_gauge( 806 "arroyo_worker_tx_queue_size", 807 "Size of a tx queue", 808 &task_info, 809 &out_qs, 810 0, 811 ); 812 813 let tx_queue_rem_gauges = register_queue_gauge( 814 "arroyo_worker_tx_queue_rem", 815 "Remaining space in a tx queue", 816 &task_info, 817 &out_qs, 818 0, 819 ); 820 821 let tx_queue_bytes_gauges = register_queue_gauge( 822 "arroyo_worker_tx_bytes", 823 "Number of bytes queued in a tx queue", 824 &task_info, 825 &out_qs, 826 0, 827 ); 828 829 let mut collector = ArrowCollector { 830 task_info, 831 out_schema: Some(ArroyoSchema::new_keyed(schema, 1, vec![0])), 832 projection: None, 833 out_qs, 834 tx_queue_rem_gauges, 835 tx_queue_size_gauges, 836 tx_queue_bytes_gauges, 837 }; 838 839 collector.collect(record).await; 840 841 drop(collector); 842 843 // pull all messages out of the two queues 844 let mut q1 = vec![]; 845 while let Some(m) = rx1.recv().await { 846 q1.push(m); 847 } 848 849 let mut q2 = vec![]; 850 while let Some(m) = rx2.recv().await { 851 q2.push(m); 852 } 853 854 let v1 = &q1[0]; 855 for v in &q1[1..] { 856 assert_eq!(v1, v); 857 } 858 859 let v2 = &q2[0]; 860 for v in &q2[1..] { 861 assert_eq!(v2, v); 862 } 863 } 864 865 #[tokio::test] 866 async fn test_batch_queues() { 867 let (tx, mut rx) = batch_bounded(8); 868 let msg = RecordBatch::try_new( 869 Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)])), 870 vec![Arc::new(Int64Array::from(vec![1, 2, 3, 4]))], 871 ) 872 .unwrap(); 873 874 tx.send(ArrowMessage::Data(msg.clone())).await.unwrap(); 875 tx.send(ArrowMessage::Data(msg.clone())).await.unwrap(); 876 877 assert_eq!(tx.capacity(), 0); 878 879 rx.recv().await.unwrap(); 880 rx.recv().await.unwrap(); 881 882 assert_eq!(tx.capacity(), 8); 883 } 884 }