mod.rs
1 use std::str::FromStr; 2 use std::sync::Arc; 3 use std::{ 4 collections::HashMap, 5 time::{Duration, Instant, SystemTime}, 6 }; 7 8 use crate::types::public::StopMode as SqlStopMode; 9 use anyhow::bail; 10 use arroyo_rpc::grpc::rpc::{ 11 worker_grpc_client::WorkerGrpcClient, CheckpointReq, CommitReq, JobFinishedReq, LabelPair, 12 LoadCompactedDataReq, MetricsReq, StopExecutionReq, StopMode, TaskCheckpointEventType, 13 }; 14 use arroyo_state::{BackingStore, StateBackend}; 15 use arroyo_types::{to_micros, WorkerId}; 16 use cornucopia_async::DatabaseSource; 17 use rand::{thread_rng, Rng}; 18 19 use time::OffsetDateTime; 20 21 use crate::job_controller::job_metrics::{get_metric_name, JobMetrics}; 22 use crate::types::public::CheckpointState as DbCheckpointState; 23 use crate::{queries::controller_queries, JobConfig, JobMessage, RunningMessage}; 24 use arroyo_datastream::logical::LogicalProgram; 25 use arroyo_rpc::api_types::metrics::MetricName; 26 use arroyo_rpc::config::config; 27 use arroyo_rpc::notify_db; 28 use arroyo_rpc::public_ids::{generate_id, IdTypes}; 29 use arroyo_state::checkpoint_state::CheckpointState; 30 use arroyo_state::committing_state::CommittingState; 31 use arroyo_state::parquet::ParquetBackend; 32 use tokio::{sync::mpsc::Receiver, task::JoinHandle}; 33 use tonic::{transport::Channel, Request}; 34 use tracing::{debug, error, info, warn}; 35 36 use self::checkpointer::CheckpointingOrCommittingState; 37 38 mod checkpointer; 39 pub mod job_metrics; 40 41 const CHECKPOINTS_TO_KEEP: u32 = 4; 42 const CHECKPOINT_ROWS_TO_KEEP: u32 = 100; 43 const COMPACT_EVERY: u32 = 2; 44 45 #[derive(Debug, PartialEq, Eq)] 46 pub enum WorkerState { 47 Running, 48 Stopped, 49 } 50 51 #[allow(unused)] 52 pub struct WorkerStatus { 53 id: WorkerId, 54 connect: WorkerGrpcClient<Channel>, 55 last_heartbeat: Instant, 56 state: WorkerState, 57 } 58 59 impl WorkerStatus { 60 fn heartbeat_timeout(&self) -> bool { 61 self.last_heartbeat.elapsed() > *config().pipeline.worker_heartbeat_timeout 62 } 63 } 64 65 #[derive(Debug, PartialEq, Eq)] 66 pub enum TaskState { 67 Running, 68 Finished, 69 Failed(String), 70 } 71 72 #[derive(Debug)] 73 pub struct TaskStatus { 74 state: TaskState, 75 } 76 77 // Stores a model of the current state of a running job to use in the state machine 78 #[derive(Debug, PartialEq, Eq)] 79 pub enum JobState { 80 Running, 81 Stopped, 82 } 83 84 pub struct RunningJobModel { 85 job_id: Arc<String>, 86 state: JobState, 87 program: Arc<LogicalProgram>, 88 checkpoint_state: Option<CheckpointingOrCommittingState>, 89 epoch: u32, 90 min_epoch: u32, 91 last_checkpoint: Instant, 92 workers: HashMap<WorkerId, WorkerStatus>, 93 tasks: HashMap<(String, u32), TaskStatus>, 94 operator_parallelism: HashMap<String, usize>, 95 metrics: JobMetrics, 96 metric_update_task: Option<JoinHandle<()>>, 97 last_updated_metrics: Instant, 98 } 99 100 impl std::fmt::Debug for RunningJobModel { 101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 102 f.debug_struct("RunningJobModel") 103 .field("job_id", &self.job_id) 104 .field("state", &self.state) 105 .field("checkpointing", &self.checkpoint_state.is_some()) 106 .field("epoch", &self.epoch) 107 .field("min_epoch", &self.min_epoch) 108 .field("last_checkpoint", &self.last_checkpoint) 109 .finish() 110 } 111 } 112 113 impl RunningJobModel { 114 pub async fn update_db( 115 checkpoint_state: &CheckpointState, 116 db: &DatabaseSource, 117 ) -> anyhow::Result<()> { 118 let c = db.client().await?; 119 120 controller_queries::execute_update_checkpoint( 121 &c, 122 &serde_json::to_value(&checkpoint_state.operator_details).unwrap(), 123 &None, 124 &DbCheckpointState::inprogress, 125 &checkpoint_state.checkpoint_id(), 126 ) 127 .await?; 128 129 Ok(()) 130 } 131 132 pub async fn update_checkpoint_in_db( 133 checkpoint_state: &CheckpointState, 134 db: &DatabaseSource, 135 db_checkpoint_state: DbCheckpointState, 136 ) -> anyhow::Result<()> { 137 let c = db.client().await?; 138 let finish_time = if db_checkpoint_state == DbCheckpointState::ready { 139 Some(SystemTime::now().into()) 140 } else { 141 None 142 }; 143 let operator_state = serde_json::to_value(&checkpoint_state.operator_details).unwrap(); 144 controller_queries::execute_update_checkpoint( 145 &c, 146 &operator_state, 147 &finish_time, 148 &db_checkpoint_state, 149 &checkpoint_state.checkpoint_id(), 150 ) 151 .await?; 152 153 Ok(()) 154 } 155 156 pub async fn finish_committing(checkpoint_id: &str, db: &DatabaseSource) -> anyhow::Result<()> { 157 info!("finishing committing"); 158 let finish_time = SystemTime::now(); 159 160 let c = db.client().await?; 161 controller_queries::execute_commit_checkpoint(&c, &finish_time.into(), &checkpoint_id) 162 .await?; 163 164 Ok(()) 165 } 166 167 pub async fn handle_message( 168 &mut self, 169 msg: RunningMessage, 170 db: &DatabaseSource, 171 ) -> anyhow::Result<()> { 172 match msg { 173 RunningMessage::TaskCheckpointEvent(c) => { 174 if let Some(checkpoint_state) = &mut self.checkpoint_state { 175 if c.epoch != self.epoch { 176 warn!( 177 message = "Received checkpoint event for wrong epoch", 178 epoch = c.epoch, 179 expected = self.epoch, 180 job_id = *self.job_id, 181 ); 182 } else { 183 match checkpoint_state { 184 CheckpointingOrCommittingState::Checkpointing(checkpoint_state) => { 185 checkpoint_state.checkpoint_event(c)?; 186 Self::update_db(checkpoint_state, db).await? 187 } 188 CheckpointingOrCommittingState::Committing(committing_state) => { 189 if matches!(c.event_type(), TaskCheckpointEventType::FinishedCommit) 190 { 191 committing_state 192 .subtask_committed(c.operator_id.clone(), c.subtask_index); 193 self.compact_state().await?; 194 } else { 195 warn!("unexpected checkpoint event type {:?}", c.event_type()) 196 } 197 } 198 }; 199 } 200 } else { 201 debug!( 202 message = "Received checkpoint event but not checkpointing", 203 job_id = *self.job_id, 204 event = format!("{:?}", c) 205 ) 206 } 207 } 208 RunningMessage::TaskCheckpointFinished(c) => { 209 if let Some(checkpoint_state) = &mut self.checkpoint_state { 210 if c.epoch != self.epoch { 211 warn!( 212 message = "Received checkpoint finished for wrong epoch", 213 epoch = c.epoch, 214 expected = self.epoch, 215 job_id = *self.job_id, 216 ); 217 } else { 218 let CheckpointingOrCommittingState::Checkpointing(checkpoint_state) = 219 checkpoint_state 220 else { 221 bail!("Received checkpoint finished but not checkpointing"); 222 }; 223 checkpoint_state.checkpoint_finished(c).await?; 224 Self::update_db(checkpoint_state, db).await?; 225 } 226 } else { 227 warn!( 228 message = "Received checkpoint finished but not checkpointing", 229 job_id = *self.job_id 230 ) 231 } 232 } 233 RunningMessage::TaskFinished { 234 worker_id: _, 235 time: _, 236 operator_id, 237 subtask_index, 238 } => { 239 let key = (operator_id, subtask_index); 240 if let Some(status) = self.tasks.get_mut(&key) { 241 status.state = TaskState::Finished; 242 } else { 243 warn!( 244 message = "Received task finished for unknown task", 245 job_id = *self.job_id, 246 operator_id = key.0, 247 subtask_index 248 ); 249 } 250 } 251 RunningMessage::TaskFailed { 252 operator_id, 253 subtask_index, 254 reason, 255 .. 256 } => { 257 let key = (operator_id, subtask_index); 258 if let Some(status) = self.tasks.get_mut(&key) { 259 status.state = TaskState::Failed(reason); 260 } else { 261 warn!( 262 message = "Received task failed message for unknown task", 263 job_id = *self.job_id, 264 operator_id = key.0, 265 subtask_index, 266 reason, 267 ); 268 } 269 } 270 RunningMessage::WorkerHeartbeat { worker_id, time } => { 271 if let Some(worker) = self.workers.get_mut(&worker_id) { 272 worker.last_heartbeat = time; 273 } else { 274 warn!( 275 message = "Received heartbeat for unknown worker", 276 job_id = *self.job_id, 277 worker_id = worker_id.0 278 ); 279 } 280 } 281 RunningMessage::WorkerFinished { worker_id } => { 282 if let Some(worker) = self.workers.get_mut(&worker_id) { 283 worker.state = WorkerState::Stopped; 284 } else { 285 warn!( 286 message = "Received finish message for unknown worker", 287 job_id = *self.job_id, 288 worker_id = worker_id.0 289 ); 290 } 291 } 292 } 293 294 if self.state == JobState::Running 295 && self.all_tasks_finished() 296 && self.checkpoint_state.is_none() 297 { 298 for w in &mut self.workers.values_mut() { 299 if let Err(e) = w.connect.job_finished(JobFinishedReq {}).await { 300 warn!( 301 message = "Failed to connect to work to send job finish", 302 job_id = *self.job_id, 303 worker_id = w.id.0, 304 error = format!("{:?}", e), 305 ) 306 } 307 } 308 self.state = JobState::Stopped; 309 } 310 311 Ok(()) 312 } 313 314 pub async fn start_checkpoint( 315 &mut self, 316 organization_id: &str, 317 db: &DatabaseSource, 318 then_stop: bool, 319 ) -> anyhow::Result<()> { 320 self.epoch += 1; 321 322 info!( 323 message = "Starting checkpointing", 324 job_id = *self.job_id, 325 epoch = self.epoch, 326 then_stop 327 ); 328 329 // TODO: maybe parallelize 330 for worker in self.workers.values_mut() { 331 worker 332 .connect 333 .checkpoint(Request::new(CheckpointReq { 334 epoch: self.epoch, 335 timestamp: to_micros(SystemTime::now()), 336 min_epoch: self.min_epoch, 337 then_stop, 338 is_commit: false, 339 })) 340 .await?; 341 } 342 343 let checkpoint_id = generate_id(IdTypes::Checkpoint); 344 345 let c = db.client().await?; 346 controller_queries::execute_create_checkpoint( 347 &c, 348 &checkpoint_id, 349 &organization_id, 350 &*self.job_id, 351 &StateBackend::name().to_string(), 352 &(self.epoch as i32), 353 &(self.min_epoch as i32), 354 &OffsetDateTime::now_utc(), 355 ) 356 .await?; 357 358 let state = CheckpointState::new( 359 self.job_id.clone(), 360 checkpoint_id, 361 self.epoch, 362 self.min_epoch, 363 self.program.tasks_per_operator(), 364 ); 365 366 self.checkpoint_state = Some(CheckpointingOrCommittingState::Checkpointing(state)); 367 368 Ok(()) 369 } 370 371 async fn compact_state(&mut self) -> anyhow::Result<()> { 372 if !config().pipeline.compaction.enabled { 373 debug!("Compaction is disabled, skipping compaction"); 374 return Ok(()); 375 } 376 377 info!( 378 message = "Compacting state", 379 job_id = *self.job_id, 380 epoch = self.epoch, 381 ); 382 383 let mut worker_clients: Vec<WorkerGrpcClient<Channel>> = 384 self.workers.values().map(|w| w.connect.clone()).collect(); 385 for operator_id in self.operator_parallelism.keys() { 386 let compacted_tables = ParquetBackend::compact_operator( 387 // compact the operator's state and notify the workers to load the new files 388 self.job_id.clone(), 389 operator_id.clone(), 390 self.epoch, 391 ) 392 .await?; 393 394 if compacted_tables.is_empty() { 395 continue; 396 } 397 398 // TODO: these should be put on separate tokio tasks. 399 for worker_client in &mut worker_clients { 400 worker_client 401 .load_compacted_data(LoadCompactedDataReq { 402 operator_id: operator_id.clone(), 403 compacted_metadata: compacted_tables.clone(), 404 }) 405 .await?; 406 } 407 } 408 409 info!( 410 message = "Finished compaction", 411 job_id = *self.job_id, 412 epoch = self.epoch, 413 ); 414 Ok(()) 415 } 416 417 pub async fn finish_checkpoint_if_done(&mut self, db: &DatabaseSource) -> anyhow::Result<()> { 418 if self.checkpoint_state.as_ref().unwrap().done() { 419 let state = self.checkpoint_state.take().unwrap(); 420 match state { 421 CheckpointingOrCommittingState::Checkpointing(checkpointing) => { 422 checkpointing.save_state().await?; 423 424 let committing_state = checkpointing.committing_state(); 425 let duration = checkpointing 426 .start_time() 427 .elapsed() 428 .unwrap_or(Duration::ZERO) 429 .as_secs_f32(); 430 // shortcut if committing is unnecessary 431 if committing_state.done() { 432 Self::update_checkpoint_in_db(&checkpointing, db, DbCheckpointState::ready) 433 .await?; 434 self.last_checkpoint = Instant::now(); 435 self.checkpoint_state = None; 436 self.compact_state().await?; 437 438 info!( 439 message = "Finished checkpointing", 440 job_id = *self.job_id, 441 epoch = self.epoch, 442 duration 443 ); 444 // trigger a DB backup now that we're done checkpointing 445 notify_db(); 446 } else { 447 Self::update_checkpoint_in_db( 448 &checkpointing, 449 db, 450 DbCheckpointState::committing, 451 ) 452 .await?; 453 let committing_data = committing_state.committing_data(); 454 self.checkpoint_state = 455 Some(CheckpointingOrCommittingState::Committing(committing_state)); 456 info!( 457 message = "Committing checkpoint", 458 job_id = *self.job_id, 459 epoch = self.epoch, 460 ); 461 for worker in self.workers.values_mut() { 462 worker 463 .connect 464 .commit(Request::new(CommitReq { 465 epoch: self.epoch, 466 committing_data: committing_data.clone(), 467 })) 468 .await?; 469 } 470 } 471 } 472 CheckpointingOrCommittingState::Committing(committing) => { 473 Self::finish_committing(committing.checkpoint_id(), db).await?; 474 self.last_checkpoint = Instant::now(); 475 self.checkpoint_state = None; 476 info!( 477 message = "Finished committing checkpointing", 478 job_id = *self.job_id, 479 epoch = self.epoch, 480 ); 481 // trigger a DB backup now that we're done checkpointing 482 notify_db(); 483 } 484 } 485 } 486 Ok(()) 487 } 488 489 pub fn cleanup_needed(&self) -> Option<u32> { 490 if self.epoch - self.min_epoch > CHECKPOINTS_TO_KEEP && self.epoch % COMPACT_EVERY == 0 { 491 Some(self.epoch - CHECKPOINTS_TO_KEEP) 492 } else { 493 None 494 } 495 } 496 497 pub fn failed(&self) -> bool { 498 for (worker, status) in &self.workers { 499 if status.heartbeat_timeout() { 500 error!( 501 message = "worker failed to heartbeat", 502 job_id = *self.job_id, 503 worker_id = worker.0 504 ); 505 return true; 506 } 507 } 508 509 for ((operator_id, subtask), status) in &self.tasks { 510 if let TaskState::Failed(reason) = &status.state { 511 error!( 512 message = "task failed", 513 job_id = *self.job_id, 514 operator_id, 515 subtask, 516 reason, 517 ); 518 return true; 519 } 520 } 521 522 false 523 } 524 525 pub fn any_finished_sources(&self) -> bool { 526 let source_tasks = self.program.sources(); 527 528 self.tasks.iter().any(|((operator, _), t)| { 529 source_tasks.contains(operator.as_str()) && t.state == TaskState::Finished 530 }) 531 } 532 533 pub fn all_tasks_finished(&self) -> bool { 534 self.tasks 535 .iter() 536 .all(|(_, t)| t.state == TaskState::Finished) 537 } 538 } 539 540 pub struct JobController { 541 db: DatabaseSource, 542 config: JobConfig, 543 model: RunningJobModel, 544 cleanup_task: Option<JoinHandle<anyhow::Result<u32>>>, 545 } 546 547 impl std::fmt::Debug for JobController { 548 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 549 f.debug_struct("JobController") 550 .field("config", &self.config) 551 .field("model", &self.model) 552 .field("cleaning", &self.cleanup_task.is_some()) 553 .finish() 554 } 555 } 556 557 pub enum ControllerProgress { 558 Continue, 559 Finishing, 560 } 561 562 impl JobController { 563 pub fn new( 564 db: DatabaseSource, 565 config: JobConfig, 566 program: Arc<LogicalProgram>, 567 epoch: u32, 568 min_epoch: u32, 569 worker_connects: HashMap<WorkerId, WorkerGrpcClient<Channel>>, 570 commit_state: Option<CommittingState>, 571 metrics: JobMetrics, 572 ) -> Self { 573 Self { 574 db, 575 model: RunningJobModel { 576 job_id: config.id.clone(), 577 state: JobState::Running, 578 checkpoint_state: commit_state.map(CheckpointingOrCommittingState::Committing), 579 epoch, 580 min_epoch, 581 // delay the initial checkpoint by a random amount so that on controller restart, 582 // checkpoint times are staggered across jobs 583 last_checkpoint: Instant::now() 584 + Duration::from_millis( 585 thread_rng().gen_range(0..config.checkpoint_interval.as_millis() as u64), 586 ), 587 workers: worker_connects 588 .into_iter() 589 .map(|(id, connect)| { 590 ( 591 id, 592 WorkerStatus { 593 id, 594 connect, 595 last_heartbeat: Instant::now(), 596 state: WorkerState::Running, 597 }, 598 ) 599 }) 600 .collect(), 601 tasks: program 602 .graph 603 .node_weights() 604 .flat_map(|node| { 605 (0..node.parallelism).map(|idx| { 606 ( 607 (node.operator_id.clone(), idx as u32), 608 TaskStatus { 609 state: TaskState::Running, 610 }, 611 ) 612 }) 613 }) 614 .collect(), 615 operator_parallelism: program.tasks_per_operator(), 616 metrics, 617 metric_update_task: None, 618 last_updated_metrics: Instant::now(), 619 program, 620 }, 621 config, 622 cleanup_task: None, 623 } 624 } 625 626 pub fn update_config(&mut self, config: JobConfig) { 627 self.config = config; 628 } 629 630 pub async fn handle_message(&mut self, msg: RunningMessage) -> anyhow::Result<()> { 631 self.model.handle_message(msg, &self.db).await 632 } 633 634 async fn update_metrics(&mut self) { 635 if self.model.metric_update_task.is_some() 636 && !self 637 .model 638 .metric_update_task 639 .as_ref() 640 .unwrap() 641 .is_finished() 642 { 643 return; 644 } 645 646 let job_metrics = self.model.metrics.clone(); 647 let workers: Vec<_> = self 648 .model 649 .workers 650 .iter() 651 .filter(|(_, w)| w.state == WorkerState::Running) 652 .map(|(id, w)| (*id, w.connect.clone())) 653 .collect(); 654 let program = self.model.program.clone(); 655 656 self.model.metric_update_task = Some(tokio::spawn(async move { 657 let mut metrics: HashMap<(u32, u32), HashMap<MetricName, u64>> = HashMap::new(); 658 659 for (id, mut connect) in workers { 660 let Ok(e) = connect.get_metrics(MetricsReq {}).await else { 661 warn!("Failed to collect metrics from worker {:?}", id); 662 return; 663 }; 664 665 fn find_label<'a>(labels: &'a [LabelPair], name: &'static str) -> Option<&'a str> { 666 Some( 667 labels 668 .iter() 669 .find(|t| t.name.as_ref().map(|t| t == name).unwrap_or(false))? 670 .value 671 .as_ref()? 672 .as_str(), 673 ) 674 } 675 676 e.into_inner() 677 .metrics 678 .into_iter() 679 .filter_map(|f| Some((get_metric_name(&f.name?)?, f.metric))) 680 .flat_map(|(metric, values)| { 681 let program = program.clone(); 682 values.into_iter().filter_map(move |m| { 683 let subtask_idx = 684 u32::from_str(find_label(&m.label, "subtask_idx")?).ok()?; 685 let operator_idx = 686 program.operator_index(find_label(&m.label, "operator_id")?)?; 687 let value = m 688 .counter 689 .map(|c| c.value) 690 .or_else(|| m.gauge.map(|g| g.value))?? 691 as u64; 692 Some(((operator_idx, subtask_idx), (metric, value))) 693 }) 694 }) 695 .for_each(|(subtask_idx, (metric, value))| { 696 metrics 697 .entry(subtask_idx) 698 .or_default() 699 .insert(metric, value); 700 }); 701 } 702 703 for ((operator_idx, subtask_idx), values) in metrics { 704 job_metrics.update(operator_idx, subtask_idx, &values).await; 705 } 706 })); 707 } 708 709 pub async fn progress(&mut self) -> anyhow::Result<ControllerProgress> { 710 // have any of our workers failed? 711 if self.model.failed() { 712 bail!("worker failed"); 713 } 714 715 // have any of our tasks finished? 716 if self.model.any_finished_sources() { 717 return Ok(ControllerProgress::Finishing); 718 } 719 720 // check on compaction 721 if self.cleanup_task.is_some() && self.cleanup_task.as_ref().unwrap().is_finished() { 722 let task = self.cleanup_task.take().unwrap(); 723 724 match task.await { 725 Ok(Ok(min_epoch)) => { 726 info!( 727 message = "setting new min epoch", 728 min_epoch, 729 job_id = *self.config.id 730 ); 731 self.model.min_epoch = min_epoch; 732 } 733 Ok(Err(e)) => { 734 error!( 735 message = "cleanup failed", 736 job_id = *self.config.id, 737 error = format!("{:?}", e) 738 ); 739 740 // wait a bit before trying again 741 tokio::time::sleep(Duration::from_millis(100)).await; 742 } 743 Err(e) => { 744 error!( 745 message = "cleanup panicked", 746 job_id = *self.config.id, 747 error = format!("{:?}", e) 748 ); 749 750 // wait a bit before trying again 751 tokio::time::sleep(Duration::from_millis(100)).await; 752 } 753 } 754 } 755 756 if let Some(new_epoch) = self.model.cleanup_needed() { 757 if self.cleanup_task.is_none() && self.model.checkpoint_state.is_none() { 758 self.cleanup_task = Some(self.start_cleanup(new_epoch)); 759 } 760 } 761 762 // check on checkpointing 763 if self.model.checkpoint_state.is_some() { 764 self.model.finish_checkpoint_if_done(&self.db).await?; 765 } else if self.model.last_checkpoint.elapsed() > self.config.checkpoint_interval 766 && self.cleanup_task.is_none() 767 { 768 // or do we need to start checkpointing? 769 self.checkpoint(false).await?; 770 } 771 772 // update metrics 773 if self.model.last_updated_metrics.elapsed() > job_metrics::COLLECTION_RATE { 774 self.update_metrics().await; 775 self.model.last_updated_metrics = Instant::now(); 776 } 777 778 Ok(ControllerProgress::Continue) 779 } 780 781 pub async fn stop_job(&mut self, stop_mode: StopMode) -> anyhow::Result<()> { 782 for c in self.model.workers.values_mut() { 783 c.connect 784 .stop_execution(StopExecutionReq { 785 stop_mode: stop_mode as i32, 786 }) 787 .await?; 788 } 789 790 Ok(()) 791 } 792 793 pub async fn checkpoint(&mut self, then_stop: bool) -> anyhow::Result<bool> { 794 if self.model.checkpoint_state.is_none() { 795 self.model 796 .start_checkpoint(&self.config.organization_id, &self.db, then_stop) 797 .await?; 798 Ok(true) 799 } else { 800 Ok(false) 801 } 802 } 803 804 pub fn finished(&self) -> bool { 805 self.model.all_tasks_finished() 806 } 807 808 pub async fn checkpoint_finished(&mut self) -> anyhow::Result<bool> { 809 if self.model.checkpoint_state.is_some() { 810 self.model.finish_checkpoint_if_done(&self.db).await?; 811 } 812 Ok(self.model.checkpoint_state.is_none()) 813 } 814 815 pub async fn send_commit_messages(&mut self) -> anyhow::Result<()> { 816 let Some(CheckpointingOrCommittingState::Committing(committing)) = 817 &self.model.checkpoint_state 818 else { 819 bail!("should be committing") 820 }; 821 for worker in self.model.workers.values_mut() { 822 worker 823 .connect 824 .commit(CommitReq { 825 epoch: self.model.epoch, 826 committing_data: committing.committing_data(), 827 }) 828 .await?; 829 } 830 Ok(()) 831 } 832 833 pub async fn wait_for_finish(&mut self, rx: &mut Receiver<JobMessage>) -> anyhow::Result<()> { 834 loop { 835 if self.model.all_tasks_finished() { 836 return Ok(()); 837 } 838 839 match rx 840 .recv() 841 .await 842 .ok_or_else(|| anyhow::anyhow!("channel closed while receiving"))? 843 { 844 JobMessage::RunningMessage(msg) => { 845 self.model.handle_message(msg, &self.db).await?; 846 } 847 JobMessage::ConfigUpdate(c) => { 848 if c.stop_mode == SqlStopMode::immediate { 849 info!( 850 message = "stopping job immediately", 851 job_id = *self.config.id 852 ); 853 self.stop_job(StopMode::Immediate).await?; 854 } 855 } 856 _ => { 857 // ignore other messages 858 } 859 } 860 } 861 } 862 863 pub fn operator_parallelism(&self, op: &str) -> Option<usize> { 864 self.model.operator_parallelism.get(op).cloned() 865 } 866 867 fn start_cleanup(&mut self, new_min: u32) -> JoinHandle<anyhow::Result<u32>> { 868 let min_epoch = self.model.min_epoch.max(1); 869 let job_id = self.config.id.clone(); 870 let db = self.db.clone(); 871 872 info!( 873 message = "Starting cleaning", 874 job_id = *job_id, 875 min_epoch, 876 new_min 877 ); 878 let start = Instant::now(); 879 let cur_epoch = self.model.epoch; 880 881 tokio::spawn(async move { 882 let checkpoint = StateBackend::load_checkpoint_metadata(&job_id, cur_epoch).await?; 883 884 controller_queries::execute_mark_compacting( 885 &db.client().await?, 886 &*job_id, 887 &(min_epoch as i32), 888 &(new_min as i32), 889 ) 890 .await?; 891 892 StateBackend::cleanup_checkpoint(checkpoint, min_epoch, new_min).await?; 893 894 controller_queries::execute_mark_checkpoints_compacted( 895 &db.client().await?, 896 &*job_id, 897 &(new_min as i32), 898 ) 899 .await?; 900 901 if let Some(epoch_to_filter_before) = min_epoch.checked_sub(CHECKPOINT_ROWS_TO_KEEP) { 902 controller_queries::execute_drop_old_checkpoint_rows( 903 &db.client().await?, 904 &*job_id, 905 &(epoch_to_filter_before as i32), 906 ) 907 .await?; 908 } 909 910 info!( 911 message = "Finished cleaning", 912 job_id = *job_id, 913 min_epoch, 914 new_min, 915 duration = start.elapsed().as_secs_f32() 916 ); 917 918 Ok(new_min) 919 }) 920 } 921 }