mod.rs
1 use anyhow::bail; 2 use arroyo_datastream::logical::LogicalProgram; 3 use arroyo_rpc::config::config; 4 use arroyo_rpc::grpc::rpc::node_grpc_client::NodeGrpcClient; 5 use arroyo_rpc::grpc::rpc::{ 6 HeartbeatNodeReq, RegisterNodeReq, StartWorkerReq, StopWorkerReq, StopWorkerStatus, 7 WorkerFinishedReq, 8 }; 9 use arroyo_types::{NodeId, WorkerId, JOB_ID_ENV, RUN_ID_ENV}; 10 use lazy_static::lazy_static; 11 use prometheus::{register_gauge, Gauge}; 12 use std::collections::HashMap; 13 use std::env::current_exe; 14 use std::path::PathBuf; 15 use std::str::FromStr; 16 use std::sync::atomic::{AtomicU64, Ordering}; 17 use std::sync::Arc; 18 use std::time::{Duration, Instant}; 19 use tokio::process::Command; 20 use tokio::sync::{oneshot, Mutex}; 21 use tonic::{Request, Status}; 22 use tracing::{info, warn}; 23 24 pub mod embedded; 25 pub mod kubernetes; 26 27 lazy_static! { 28 static ref FREE_SLOTS: Gauge = 29 register_gauge!("arroyo_controller_free_slots", "number of free task slots").unwrap(); 30 static ref REGISTERED_SLOTS: Gauge = register_gauge!( 31 "arroyo_controller_registered_slots", 32 "total number of registered task slots" 33 ) 34 .unwrap(); 35 static ref REGISTERED_NODES: Gauge = register_gauge!( 36 "arroyo_controller_registered_nodes", 37 "total number of registered nodes" 38 ) 39 .unwrap(); 40 } 41 42 #[async_trait::async_trait] 43 pub trait Scheduler: Send + Sync { 44 async fn start_workers( 45 &self, 46 start_pipeline_req: StartPipelineReq, 47 ) -> Result<(), SchedulerError>; 48 49 async fn register_node(&self, req: RegisterNodeReq); 50 async fn heartbeat_node(&self, req: HeartbeatNodeReq) -> Result<(), Status>; 51 async fn worker_finished(&self, req: WorkerFinishedReq); 52 async fn stop_workers( 53 &self, 54 job_id: &str, 55 run_id: Option<i64>, 56 force: bool, 57 ) -> anyhow::Result<()>; 58 async fn workers_for_job( 59 &self, 60 job_id: &str, 61 run_id: Option<i64>, 62 ) -> anyhow::Result<Vec<WorkerId>>; 63 } 64 65 pub struct ProcessWorker { 66 job_id: Arc<String>, 67 run_id: i64, 68 shutdown_tx: oneshot::Sender<()>, 69 } 70 71 /// This Scheduler starts new processes to run the worker nodes 72 pub struct ProcessScheduler { 73 workers: Arc<Mutex<HashMap<WorkerId, ProcessWorker>>>, 74 worker_counter: AtomicU64, 75 } 76 77 impl ProcessScheduler { 78 pub fn new() -> Self { 79 Self { 80 workers: Arc::new(Mutex::new(HashMap::new())), 81 worker_counter: AtomicU64::new(100), 82 } 83 } 84 } 85 86 pub struct StartPipelineReq { 87 pub name: String, 88 pub program: LogicalProgram, 89 pub wasm_path: String, 90 pub job_id: Arc<String>, 91 pub hash: String, 92 pub run_id: i64, 93 pub slots: usize, 94 pub env_vars: HashMap<String, String>, 95 } 96 97 #[async_trait::async_trait] 98 impl Scheduler for ProcessScheduler { 99 async fn start_workers( 100 &self, 101 start_pipeline_req: StartPipelineReq, 102 ) -> Result<(), SchedulerError> { 103 let workers = (start_pipeline_req.slots as f32 104 / config().process_scheduler.slots_per_process as f32) 105 .ceil() as usize; 106 107 let mut slots_scheduled = 0; 108 109 let base_path = PathBuf::from_str(&format!( 110 "/tmp/arroyo-process/{}", 111 start_pipeline_req.job_id 112 )) 113 .unwrap(); 114 115 for _ in 0..workers { 116 let path = base_path.clone(); 117 118 let slots_here = (start_pipeline_req.slots - slots_scheduled) 119 .min(config().process_scheduler.slots_per_process as usize); 120 121 let worker_id = self.worker_counter.fetch_add(1, Ordering::SeqCst); 122 123 let (tx, rx) = oneshot::channel(); 124 125 { 126 let mut workers = self.workers.lock().await; 127 workers.insert( 128 WorkerId(worker_id), 129 ProcessWorker { 130 job_id: start_pipeline_req.job_id.clone(), 131 run_id: start_pipeline_req.run_id, 132 shutdown_tx: tx, 133 }, 134 ); 135 } 136 137 slots_scheduled += slots_here; 138 let job_id = start_pipeline_req.job_id.clone(); 139 let workers = self.workers.clone(); 140 let env_map = start_pipeline_req.env_vars.clone(); 141 142 tokio::spawn(async move { 143 let mut command = 144 Command::new(current_exe().expect("Could not get path of worker binary")); 145 146 for (env, value) in env_map { 147 command.env(env, value); 148 } 149 let mut child = command 150 .arg("worker") 151 .env("ARROYO__ADMIN__HTTP_PORT", "0") 152 .env("ARROYO__WORKER__TASK_SLOTS", format!("{}", slots_here)) 153 .env("ARROYO__WORKER__ID", format!("{}", worker_id)) // start at 100 to make same length 154 .env( 155 "ARROYO__CONTROLLER_ENDPOINT", 156 config().controller_endpoint(), 157 ) 158 .env("UNDER_PROCESS_SCHEDULER", "true") 159 .env(JOB_ID_ENV, &*job_id) 160 .env(RUN_ID_ENV, format!("{}", start_pipeline_req.run_id)) 161 .kill_on_drop(true) 162 .spawn() 163 .unwrap(); 164 165 tokio::select! { 166 status = child.wait() => { 167 info!("Child ({:?}) exited with status {:?}", path, status); 168 } 169 _ = rx => { 170 info!(message = "Killing child", worker_id = worker_id, job_id = *job_id); 171 child.kill().await.unwrap(); 172 } 173 } 174 175 let mut state = workers.lock().await; 176 state.remove(&WorkerId(worker_id)); 177 }); 178 } 179 180 Ok(()) 181 } 182 183 async fn register_node(&self, _: RegisterNodeReq) {} 184 async fn heartbeat_node(&self, _: HeartbeatNodeReq) -> Result<(), Status> { 185 Ok(()) 186 } 187 async fn worker_finished(&self, _: WorkerFinishedReq) {} 188 189 async fn workers_for_job( 190 &self, 191 job_id: &str, 192 run_id: Option<i64>, 193 ) -> anyhow::Result<Vec<WorkerId>> { 194 Ok(self 195 .workers 196 .lock() 197 .await 198 .iter() 199 .filter(|(_, w)| { 200 *w.job_id == job_id && (run_id.is_none() || w.run_id == run_id.unwrap()) 201 }) 202 .map(|(k, _)| *k) 203 .collect()) 204 } 205 206 async fn stop_workers( 207 &self, 208 job_id: &str, 209 run_id: Option<i64>, 210 _force: bool, 211 ) -> anyhow::Result<()> { 212 for worker_id in self.workers_for_job(job_id, run_id).await? { 213 let worker = { 214 let mut state = self.workers.lock().await; 215 let Some(worker) = state.remove(&worker_id) else { 216 return Ok(()); 217 }; 218 worker 219 }; 220 221 let _ = worker.shutdown_tx.send(()); 222 } 223 224 Ok(()) 225 } 226 } 227 228 #[derive(Debug, Clone)] 229 struct NodeStatus { 230 id: NodeId, 231 free_slots: usize, 232 scheduled_slots: HashMap<WorkerId, usize>, 233 addr: String, 234 last_heartbeat: Instant, 235 } 236 237 impl NodeStatus { 238 fn new(id: NodeId, slots: usize, addr: String) -> NodeStatus { 239 FREE_SLOTS.add(slots as f64); 240 REGISTERED_SLOTS.add(slots as f64); 241 242 NodeStatus { 243 id, 244 free_slots: slots, 245 scheduled_slots: HashMap::new(), 246 addr, 247 last_heartbeat: Instant::now(), 248 } 249 } 250 251 fn take_slots(&mut self, worker: WorkerId, slots: usize) { 252 if let Some(v) = self.free_slots.checked_sub(slots) { 253 FREE_SLOTS.sub(slots as f64); 254 self.free_slots = v; 255 self.scheduled_slots.insert(worker, slots); 256 } else { 257 panic!( 258 "Attempted to schedule more slots than are available on node {} ({} < {})", 259 self.addr, self.free_slots, slots 260 ); 261 } 262 } 263 264 fn release_slots(&mut self, worker_id: WorkerId, slots: usize) { 265 if let Some(freed) = self.scheduled_slots.remove(&worker_id) { 266 assert_eq!(freed, slots, 267 "Controller and node disagree about how many slots are scheduled for worker {:?} ({} != {})", 268 worker_id, freed, slots); 269 270 self.free_slots += slots; 271 272 FREE_SLOTS.add(slots as f64); 273 } else { 274 warn!( 275 "Received release request for unknown worker {:?}", 276 worker_id 277 ); 278 } 279 } 280 } 281 282 #[derive(Clone)] 283 struct NodeWorker { 284 job_id: Arc<String>, 285 node_id: NodeId, 286 run_id: i64, 287 running: bool, 288 } 289 290 #[derive(Default)] 291 pub struct NodeSchedulerState { 292 nodes: HashMap<NodeId, NodeStatus>, 293 workers: HashMap<WorkerId, NodeWorker>, 294 } 295 296 impl NodeSchedulerState { 297 fn expire_nodes(&mut self, expiration_time: Instant) { 298 let expired_nodes: Vec<_> = self 299 .nodes 300 .iter() 301 .filter_map(|(node_id, status)| { 302 if status.last_heartbeat >= expiration_time { 303 None 304 } else { 305 Some(*node_id) 306 } 307 }) 308 .collect(); 309 for node_id in expired_nodes { 310 warn!("expiring node {:?} from scheduler state", node_id); 311 self.nodes.remove(&node_id); 312 } 313 } 314 } 315 316 pub struct NodeScheduler { 317 state: Arc<Mutex<NodeSchedulerState>>, 318 } 319 320 pub enum SchedulerError { 321 NotEnoughSlots { slots_needed: usize }, 322 Other(String), 323 } 324 325 impl NodeScheduler { 326 pub fn new() -> Self { 327 Self { 328 state: Arc::new(Mutex::new(NodeSchedulerState::default())), 329 } 330 } 331 332 async fn stop_worker( 333 &self, 334 job_id: &str, 335 worker_id: WorkerId, 336 force: bool, 337 ) -> anyhow::Result<Option<WorkerId>> { 338 let state = self.state.lock().await; 339 340 let Some(worker) = state.workers.get(&worker_id) else { 341 // assume it's already finished 342 return Ok(Some(worker_id)); 343 }; 344 345 let Some(node) = state.nodes.get(&worker.node_id) else { 346 warn!( 347 message = "node not found for stop worker", 348 node_id = worker.node_id.0 349 ); 350 return Ok(Some(worker_id)); 351 }; 352 353 let worker = worker.clone(); 354 let node = node.clone(); 355 drop(state); 356 357 info!( 358 message = "stopping worker", 359 job_id = *worker.job_id, 360 node_id = worker.node_id.0, 361 node_addr = node.addr, 362 worker_id = worker_id.0 363 ); 364 365 let Ok(mut client) = NodeGrpcClient::connect(format!("http://{}", node.addr)).await else { 366 warn!("Failed to connect to worker to stop; this likely means it is dead"); 367 return Ok(Some(worker_id)); 368 }; 369 370 let Ok(resp) = client 371 .stop_worker(Request::new(StopWorkerReq { 372 job_id: job_id.to_string(), 373 worker_id: worker_id.0, 374 force, 375 })) 376 .await 377 else { 378 warn!("Failed to connect to worker to stop; this likely means it is dead"); 379 return Ok(Some(worker_id)); 380 }; 381 382 match (resp.get_ref().status(), force) { 383 (StopWorkerStatus::NotFound, false) => { 384 bail!("couldn't find worker, will only continue if force") 385 } 386 (StopWorkerStatus::StopFailed, _) => bail!("tried to kill and couldn't"), 387 _ => Ok(None), 388 } 389 } 390 } 391 392 #[async_trait::async_trait] 393 impl Scheduler for NodeScheduler { 394 async fn register_node(&self, req: RegisterNodeReq) { 395 let mut state = self.state.lock().await; 396 if let std::collections::hash_map::Entry::Vacant(e) = state.nodes.entry(NodeId(req.node_id)) 397 { 398 e.insert(NodeStatus::new( 399 NodeId(req.node_id), 400 req.task_slots as usize, 401 req.addr, 402 )); 403 } 404 } 405 406 async fn heartbeat_node(&self, req: HeartbeatNodeReq) -> Result<(), Status> { 407 let mut state = self.state.lock().await; 408 if let Some(node) = state.nodes.get_mut(&NodeId(req.node_id)) { 409 node.last_heartbeat = Instant::now(); 410 Ok(()) 411 } else { 412 warn!( 413 "Received heartbeat for unregistered node {}, failing request", 414 req.node_id 415 ); 416 Err(Status::not_found(format!( 417 "node {} not in scheduler's collection of nodes", 418 req.node_id 419 ))) 420 } 421 } 422 423 async fn worker_finished(&self, req: WorkerFinishedReq) { 424 let mut state = self.state.lock().await; 425 let worker_id = WorkerId(req.worker_id); 426 427 if let Some(node) = state.nodes.get_mut(&NodeId(req.node_id)) { 428 node.release_slots(worker_id, req.slots as usize); 429 } else { 430 warn!( 431 "Got worker finished message for unknown node {}", 432 req.node_id 433 ); 434 } 435 436 if state.workers.remove(&worker_id).is_none() { 437 warn!( 438 "Got worker finished message for unknown worker {}", 439 worker_id.0 440 ); 441 } 442 } 443 444 async fn workers_for_job( 445 &self, 446 job_id: &str, 447 run_id: Option<i64>, 448 ) -> anyhow::Result<Vec<WorkerId>> { 449 let state = self.state.lock().await; 450 Ok(state 451 .workers 452 .iter() 453 .filter(|(_, v)| { 454 *v.job_id == job_id 455 && v.running 456 && (run_id.is_none() || v.run_id == run_id.unwrap()) 457 }) 458 .map(|(w, _)| *w) 459 .collect()) 460 } 461 462 #[allow(unreachable_code, unused)] 463 async fn start_workers( 464 &self, 465 start_pipeline_req: StartPipelineReq, 466 ) -> Result<(), SchedulerError> { 467 // TODO: make this locking more fine-grained 468 let mut state = self.state.lock().await; 469 470 state.expire_nodes(Instant::now() - Duration::from_secs(30)); 471 472 let free_slots = state.nodes.values().map(|n| n.free_slots).sum::<usize>(); 473 let slots = start_pipeline_req.slots; 474 if slots > free_slots { 475 return Err(SchedulerError::NotEnoughSlots { 476 slots_needed: slots - free_slots, 477 }); 478 } 479 480 let mut to_schedule = slots; 481 let mut slots_assigned = vec![]; 482 while to_schedule > 0 { 483 // find the node with the most free slots and fill it 484 let node = { 485 if let Some(status) = state 486 .nodes 487 .values() 488 .filter(|n| { 489 n.free_slots > 0 && n.last_heartbeat.elapsed() < Duration::from_secs(30) 490 }) 491 .max_by_key(|n| n.free_slots) 492 .cloned() 493 { 494 status 495 } else { 496 unreachable!(); 497 } 498 }; 499 500 let slots_for_this_one = node.free_slots.min(to_schedule); 501 info!( 502 "Scheduling {} slots on node {}", 503 slots_for_this_one, node.addr 504 ); 505 506 let mut client = NodeGrpcClient::connect(format!("http://{}", node.addr)) 507 .await 508 // TODO: handle this issue more gracefully by moving trying other nodes 509 .map_err(|e| { 510 // release back slots already scheduled. 511 slots_assigned 512 .iter() 513 .for_each(|(node_id, worker_id, slots)| { 514 state 515 .nodes 516 .get_mut(node_id) 517 .unwrap() 518 .release_slots(*worker_id, *slots); 519 }); 520 SchedulerError::Other(format!( 521 "Failed to connect to node {}: {:?}", 522 node.addr, e 523 )) 524 })?; 525 526 let req = StartWorkerReq { 527 name: start_pipeline_req.name.clone(), 528 job_id: (*start_pipeline_req.job_id).clone(), 529 slots: slots_for_this_one as u64, 530 node_id: node.id.0, 531 run_id: start_pipeline_req.run_id as u64, 532 env_vars: start_pipeline_req.env_vars.clone(), 533 }; 534 535 let res = client 536 .start_worker(Request::new(req)) 537 .await 538 .map_err(|e| { 539 // release back slots already scheduled. 540 slots_assigned 541 .iter() 542 .for_each(|(node_id, worker_id, slots)| { 543 state 544 .nodes 545 .get_mut(node_id) 546 .unwrap() 547 .release_slots(*worker_id, *slots); 548 }); 549 SchedulerError::Other(format!( 550 "Failed to start worker on node {}: {:?}", 551 node.addr, e 552 )) 553 })? 554 .into_inner(); 555 556 state 557 .nodes 558 .get_mut(&node.id) 559 .unwrap() 560 .take_slots(WorkerId(res.worker_id), slots_for_this_one); 561 562 state.workers.insert( 563 WorkerId(res.worker_id), 564 NodeWorker { 565 job_id: start_pipeline_req.job_id.clone(), 566 run_id: start_pipeline_req.run_id, 567 node_id: node.id, 568 running: true, 569 }, 570 ); 571 572 slots_assigned.push((node.id, WorkerId(res.worker_id), slots_for_this_one)); 573 574 to_schedule -= slots_for_this_one; 575 } 576 Ok(()) 577 } 578 579 async fn stop_workers( 580 &self, 581 job_id: &str, 582 run_id: Option<i64>, 583 force: bool, 584 ) -> anyhow::Result<()> { 585 // iterate through all of the workers from workers_for_job and stop them in parallel 586 let workers = self.workers_for_job(job_id, run_id).await?; 587 let mut futures = vec![]; 588 for worker_id in workers { 589 futures.push(self.stop_worker(job_id, worker_id, force)); 590 } 591 592 for f in futures { 593 match f.await? { 594 Some(worker_id) => { 595 let mut state = self.state.lock().await; 596 if let Some(worker) = state.workers.get_mut(&worker_id) { 597 worker.running = false; 598 } 599 } 600 None => { 601 bail!("Failed to stop worker"); 602 } 603 } 604 } 605 606 Ok(()) 607 } 608 }