smoke_tests.rs
1 use anyhow::{bail, Result}; 2 use arroyo_datastream::logical::{ 3 LogicalEdge, LogicalEdgeType, LogicalGraph, LogicalNode, LogicalProgram, OperatorName, 4 }; 5 use arroyo_df::{parse_and_get_arrow_program, ArroyoSchemaProvider, SqlConfig}; 6 use arroyo_state::parquet::ParquetBackend; 7 use petgraph::algo::has_path_connecting; 8 use petgraph::visit::EdgeRef; 9 use rstest::rstest; 10 use std::collections::{BTreeMap, HashMap, HashSet}; 11 use std::path::{Path, PathBuf}; 12 use std::sync::Arc; 13 use std::time::Duration; 14 use std::{env, time::SystemTime}; 15 use tokio::sync::mpsc::Receiver; 16 17 use crate::udfs::get_udfs; 18 use arroyo_rpc::grpc::rpc::{StopMode, TaskCheckpointCompletedReq, TaskCheckpointEventReq}; 19 use arroyo_rpc::{CompactionResult, ControlMessage, ControlResp}; 20 use arroyo_state::checkpoint_state::CheckpointState; 21 use arroyo_types::{to_micros, CheckpointBarrier}; 22 use arroyo_udf_host::LocalUdf; 23 use arroyo_worker::engine::{Engine, StreamConfig}; 24 use arroyo_worker::engine::{Program, RunningEngine}; 25 use petgraph::{Direction, Graph}; 26 use serde_json::Value; 27 use test_log::test as test_log; 28 use tokio::fs::read_to_string; 29 use tokio::sync::mpsc::error::TryRecvError; 30 use tracing::info; 31 32 #[test_log(rstest)] 33 fn for_each_file(#[files("src/test/queries/*.sql")] path: PathBuf) { 34 tokio::runtime::Builder::new_current_thread() 35 .enable_all() 36 .build() 37 .unwrap() 38 .block_on(async { 39 run_smoketest(&path).await; 40 }); 41 } 42 43 async fn run_smoketest(path: &Path) { 44 // read text at path 45 let test_name = path 46 .file_name() 47 .unwrap() 48 .to_str() 49 .unwrap() 50 .split('.') 51 .next() 52 .unwrap(); 53 let query = read_to_string(path).await.unwrap(); 54 let fail = query.starts_with("--fail"); 55 let error_message = query.starts_with("--fail=").then(|| { 56 query 57 .lines() 58 .next() 59 .unwrap() 60 .split_once('=') 61 .unwrap() 62 .1 63 .trim() 64 }); 65 match correctness_run_codegen(test_name, query.clone(), 20).await { 66 Ok(_) => {} 67 Err(err) => { 68 if fail { 69 if let Some(error_message) = error_message { 70 assert!( 71 err.to_string().contains(error_message), 72 "expected error message '{}' not found; instead got '{}'", 73 error_message, 74 err 75 ); 76 } 77 } else { 78 panic!("smoke test failed: {}", err); 79 } 80 } 81 } 82 } 83 84 struct SmokeTestContext<'a> { 85 job_id: Arc<String>, 86 engine: &'a RunningEngine, 87 control_rx: &'a mut Receiver<ControlResp>, 88 tasks_per_operator: HashMap<String, usize>, 89 } 90 91 async fn checkpoint(ctx: &mut SmokeTestContext<'_>, epoch: u32) { 92 let checkpoint_id = epoch as i64; 93 let mut checkpoint_state = CheckpointState::new( 94 ctx.job_id.clone(), 95 checkpoint_id.to_string(), 96 epoch, 97 0, 98 ctx.tasks_per_operator.clone(), 99 ); 100 101 // trigger a checkpoint, pass the messages to the CheckpointState 102 103 let barrier = CheckpointBarrier { 104 epoch, 105 min_epoch: 0, 106 timestamp: SystemTime::now(), 107 then_stop: false, 108 }; 109 110 for source in ctx.engine.source_controls() { 111 source 112 .send(ControlMessage::Checkpoint(barrier)) 113 .await 114 .unwrap(); 115 } 116 117 while !checkpoint_state.done() { 118 let c: ControlResp = ctx.control_rx.recv().await.unwrap(); 119 120 match c { 121 ControlResp::CheckpointEvent(c) => { 122 let req = TaskCheckpointEventReq { 123 worker_id: 1, 124 time: to_micros(c.time), 125 job_id: (*ctx.job_id).clone(), 126 operator_id: c.operator_id, 127 subtask_index: c.subtask_index, 128 epoch: c.checkpoint_epoch, 129 event_type: c.event_type as i32, 130 }; 131 checkpoint_state.checkpoint_event(req).unwrap(); 132 } 133 ControlResp::CheckpointCompleted(c) => { 134 let req = TaskCheckpointCompletedReq { 135 worker_id: 1, 136 time: c.subtask_metadata.finish_time, 137 job_id: (*ctx.job_id).clone(), 138 operator_id: c.operator_id, 139 epoch: c.checkpoint_epoch, 140 needs_commit: false, 141 metadata: Some(c.subtask_metadata), 142 }; 143 checkpoint_state.checkpoint_finished(req).await.unwrap(); 144 } 145 _ => {} 146 } 147 } 148 149 checkpoint_state.save_state().await.unwrap(); 150 151 info!("Smoke test checkpoint completed"); 152 } 153 154 async fn compact( 155 job_id: Arc<String>, 156 running_engine: &RunningEngine, 157 tasks_per_operator: HashMap<String, usize>, 158 epoch: u32, 159 ) { 160 let operator_controls = running_engine.operator_controls(); 161 for (operator, _) in tasks_per_operator { 162 if let Ok(compacted) = 163 ParquetBackend::compact_operator(job_id.clone(), operator.clone(), epoch).await 164 { 165 let operator_controls = operator_controls.get(&operator).unwrap(); 166 for s in operator_controls { 167 s.send(ControlMessage::LoadCompacted { 168 compacted: CompactionResult { 169 operator_id: operator.to_string(), 170 compacted_tables: compacted.clone(), 171 }, 172 }) 173 .await 174 .unwrap(); 175 } 176 } 177 } 178 } 179 180 async fn advance(engine: &RunningEngine, count: i32) { 181 // let the engine run for a bit, process some records 182 for source in engine.source_controls() { 183 for _ in 0..count { 184 let _ = source.send(ControlMessage::NoOp).await; 185 } 186 } 187 } 188 189 async fn run_until_finished(engine: &RunningEngine, control_rx: &mut Receiver<ControlResp>) { 190 while control_rx.try_recv().is_ok() 191 || control_rx 192 .try_recv() 193 .is_err_and(|e| e == TryRecvError::Empty) 194 { 195 advance(engine, 10).await; 196 tokio::time::sleep(Duration::from_millis(1)).await; 197 } 198 } 199 200 fn set_internal_parallelism(graph: &mut Graph<LogicalNode, LogicalEdge>, parallelism: usize) { 201 let watermark_nodes: HashSet<_> = graph 202 .node_indices() 203 .filter(|index| { 204 let operator_name = graph.node_weight(*index).unwrap().operator_name; 205 matches!(operator_name, OperatorName::ExpressionWatermark) 206 }) 207 .collect(); 208 let indices: Vec<_> = graph 209 .node_indices() 210 .filter( 211 |index| match graph.node_weight(*index).unwrap().operator_name { 212 OperatorName::ExpressionWatermark 213 | OperatorName::ConnectorSource 214 | OperatorName::ConnectorSink => false, 215 _ => { 216 for watermark_node in watermark_nodes.iter() { 217 if has_path_connecting(&graph.clone(), *watermark_node, *index, None) { 218 return true; 219 } 220 } 221 false 222 } 223 }, 224 ) 225 .collect(); 226 for node in indices { 227 graph.node_weight_mut(node).unwrap().parallelism = parallelism; 228 } 229 if parallelism > 1 { 230 let mut edges_to_make_shuffle = vec![]; 231 for node in graph.externals(Direction::Outgoing) { 232 for edge in graph.edges_directed(node, Direction::Incoming) { 233 edges_to_make_shuffle.push(edge.id()); 234 } 235 } 236 for node in graph.node_indices() { 237 if graph.node_weight(node).unwrap().operator_name == OperatorName::ExpressionWatermark { 238 for edge in graph.edges_directed(node, Direction::Outgoing) { 239 edges_to_make_shuffle.push(edge.id()); 240 } 241 } 242 } 243 for edge in edges_to_make_shuffle { 244 graph.edge_weight_mut(edge).unwrap().edge_type = LogicalEdgeType::Shuffle; 245 } 246 } 247 } 248 249 async fn run_and_checkpoint(job_id: Arc<String>, program: Program, checkpoint_interval: i32) { 250 let tasks_per_operator = program.tasks_per_operator(); 251 let engine = Engine::for_local(program, job_id.to_string()); 252 let (running_engine, mut control_rx) = engine 253 .start(StreamConfig { 254 restore_epoch: None, 255 }) 256 .await; 257 info!("Smoke test checkpointing enabled"); 258 env::set_var( 259 "ARROYO__CONTROLLER__COMPACTION__CHECKPOINTS_TO_COMPACT", 260 "2", 261 ); 262 263 let ctx = &mut SmokeTestContext { 264 job_id: job_id.clone(), 265 engine: &running_engine, 266 control_rx: &mut control_rx, 267 tasks_per_operator: tasks_per_operator.clone(), 268 }; 269 270 // trigger a couple checkpoints 271 advance(&running_engine, checkpoint_interval).await; 272 checkpoint(ctx, 1).await; 273 advance(&running_engine, checkpoint_interval).await; 274 checkpoint(ctx, 2).await; 275 advance(&running_engine, checkpoint_interval).await; 276 277 compact(job_id, &running_engine, tasks_per_operator.clone(), 2).await; 278 279 // trigger checkpoint 3, which will include the compacted files 280 advance(&running_engine, checkpoint_interval).await; 281 checkpoint(ctx, 3).await; 282 // shut down the engine 283 for source in running_engine.source_controls() { 284 source 285 .send(ControlMessage::Stop { 286 mode: StopMode::Graceful, 287 }) 288 .await 289 .unwrap(); 290 } 291 run_until_finished(&running_engine, &mut control_rx).await; 292 } 293 294 async fn finish_from_checkpoint(job_id: &str, program: Program) { 295 let engine = Engine::for_local(program, job_id.to_string()); 296 let (running_engine, mut control_rx) = engine 297 .start(StreamConfig { 298 restore_epoch: Some(3), 299 }) 300 .await; 301 302 info!("Restored engine, running until finished"); 303 run_until_finished(&running_engine, &mut control_rx).await; 304 } 305 306 async fn run_pipeline_and_assert_outputs( 307 job_id: &str, 308 mut graph: LogicalGraph, 309 checkpoint_interval: i32, 310 output_location: String, 311 golden_output_location: String, 312 udfs: &[LocalUdf], 313 ) { 314 // remove output_location before running the pipeline 315 if std::path::Path::new(&output_location).exists() { 316 std::fs::remove_file(&output_location).unwrap(); 317 } 318 319 let get_program = 320 |graph: &LogicalGraph| Program::local_from_logical(job_id.to_string(), graph, udfs); 321 322 run_completely( 323 job_id, 324 get_program(&graph), 325 output_location.clone(), 326 golden_output_location.clone(), 327 ) 328 .await; 329 330 set_internal_parallelism(&mut graph, 2); 331 332 run_and_checkpoint( 333 Arc::new(job_id.to_string()), 334 get_program(&graph), 335 checkpoint_interval, 336 ) 337 .await; 338 339 set_internal_parallelism(&mut graph, 3); 340 341 finish_from_checkpoint(job_id, get_program(&graph)).await; 342 343 check_output_files( 344 "resuming from checkpointing", 345 output_location, 346 golden_output_location, 347 ) 348 .await; 349 } 350 351 async fn run_completely( 352 job_id: &str, 353 program: Program, 354 output_location: String, 355 golden_output_location: String, 356 ) { 357 let engine = Engine::for_local(program, job_id.to_string()); 358 let (running_engine, mut control_rx) = engine 359 .start(StreamConfig { 360 restore_epoch: None, 361 }) 362 .await; 363 364 run_until_finished(&running_engine, &mut control_rx).await; 365 366 check_output_files( 367 "initial run", 368 output_location.clone(), 369 golden_output_location, 370 ) 371 .await; 372 if std::path::Path::new(&output_location).exists() { 373 std::fs::remove_file(&output_location).unwrap(); 374 } 375 } 376 377 // return the inner value and whether it is a retract 378 fn decode_debezium(value: &Value) -> Result<(Value, bool)> { 379 if !is_debezium(value) { 380 bail!("not a debezium record"); 381 } 382 let op = value.get("op").unwrap().as_str().unwrap(); 383 match op { 384 "c" => Ok((value.get("after").unwrap().clone(), false)), 385 "d" => Ok((value.get("before").unwrap().clone(), true)), 386 _ => bail!("unknown op {}", op), 387 } 388 } 389 390 fn is_debezium(value: &Value) -> bool { 391 let Some(op) = value.get("op") else { 392 return false; 393 }; 394 op.as_str().is_some() 395 } 396 397 fn check_debezium( 398 output_location: String, 399 golden_output_location: String, 400 output_lines: Vec<Value>, 401 golden_output_lines: Vec<Value>, 402 ) { 403 let output_deduped = dedup_debezium(output_lines); 404 let golden_output_deduped = dedup_debezium(golden_output_lines); 405 assert_eq!( 406 output_deduped, golden_output_deduped, 407 "failed to check debezium equality for\noutput: {}\ngolden: {}", 408 output_location, golden_output_location 409 ); 410 } 411 412 fn dedup_debezium(values: Vec<Value>) -> HashMap<String, i64> { 413 let mut deduped = HashMap::new(); 414 for value in &values { 415 let (row_data, value) = decode_debezium(value).unwrap(); 416 let row_data_str = roundtrip(&row_data); 417 let count = deduped.entry(row_data_str.clone()).or_insert(0); 418 if value { 419 *count -= 1; 420 } else { 421 *count += 1; 422 } 423 if *count == 0 { 424 deduped.remove(&row_data_str); 425 } 426 } 427 deduped 428 } 429 430 fn roundtrip(v: &Value) -> String { 431 // round trip string through a btreemap to get consistent key ordering 432 serde_json::to_string(&serde_json::from_value::<BTreeMap<String, Value>>(v.clone()).unwrap()) 433 .unwrap() 434 } 435 436 async fn check_output_files( 437 check_name: &str, 438 output_location: String, 439 golden_output_location: String, 440 ) { 441 let mut output_lines: Vec<Value> = read_to_string(output_location.clone()) 442 .await 443 .unwrap_or_else(|_| panic!("output file not found at {}", output_location)) 444 .lines() 445 .map(|s| serde_json::from_str(s).unwrap()) 446 .collect(); 447 448 let mut golden_output_lines: Vec<Value> = read_to_string(golden_output_location.clone()) 449 .await 450 .unwrap_or_else(|_| { 451 panic!( 452 "golden output file not found at {}, want to compare to {}", 453 golden_output_location, output_location 454 ) 455 }) 456 .lines() 457 .map(|s| serde_json::from_str(s).unwrap()) 458 .collect(); 459 if output_lines.len() != golden_output_lines.len() { 460 // might be updating, in which case lets see if we can cancel out rows 461 let Some(first_output) = output_lines.first() else { 462 panic!( 463 "failed at check {}, output has 0 lines, expect {} lines.\noutput: {}\ngolden: {}", 464 check_name, 465 golden_output_lines.len(), 466 output_location, 467 golden_output_location 468 ); 469 }; 470 if is_debezium(first_output) { 471 check_debezium( 472 output_location, 473 golden_output_location, 474 output_lines, 475 golden_output_lines, 476 ); 477 return; 478 } 479 480 panic!( 481 "failed at check {}, output has {} lines, expect {} lines.\noutput: {}\ngolden: {}", 482 check_name, 483 output_lines.len(), 484 golden_output_lines.len(), 485 output_location, 486 golden_output_location 487 ); 488 } 489 490 output_lines.sort_by_cached_key(roundtrip); 491 golden_output_lines.sort_by_cached_key(roundtrip); 492 output_lines 493 .into_iter() 494 .zip(golden_output_lines.into_iter()) 495 .enumerate() 496 .for_each(|(i, (output_line, golden_output_line))| { 497 assert_eq!( 498 output_line, golden_output_line, 499 "check {}: line {} of output and golden output differ\nactual:{}\nexpected:{})", 500 check_name, i, output_location, golden_output_location 501 ) 502 }); 503 } 504 505 pub async fn correctness_run_codegen( 506 test_name: impl Into<String>, 507 query: impl Into<String>, 508 checkpoint_interval: i32, 509 ) -> Result<()> { 510 let test_name = test_name.into(); 511 let parent_directory = std::env::current_dir() 512 .unwrap() 513 .to_string_lossy() 514 .to_string(); 515 516 // Depending on run location the directory might end with arroyo-sql-testing. 517 // If so, remove it. 518 let parent_directory = if parent_directory.ends_with("arroyo-sql-testing") { 519 parent_directory 520 .strip_suffix("arroyo-sql-testing") 521 .unwrap() 522 .to_string() 523 } else { 524 parent_directory 525 }; 526 527 // replace $input_file with the current directory and then inputs/query_name.json 528 let physical_input_dir = format!("{}/arroyo-sql-testing/inputs/", parent_directory,); 529 530 let query_string = query.into().replace("$input_dir", &physical_input_dir); 531 532 // replace $output_file with the current directory and then outputs/query_name.json 533 let physical_output = format!( 534 "{}/arroyo-sql-testing/outputs/{}.json", 535 parent_directory, test_name 536 ); 537 538 let query_string = query_string.replace("$output_path", &physical_output); 539 let golden_output_location = format!( 540 "{}/arroyo-sql-testing/golden_outputs/{}.json", 541 parent_directory, test_name 542 ); 543 544 let udfs = get_udfs(); 545 546 let logical_program = get_graph(query_string.clone(), &udfs).await?; 547 run_pipeline_and_assert_outputs( 548 &test_name, 549 logical_program.graph, 550 checkpoint_interval, 551 physical_output, 552 golden_output_location, 553 &udfs, 554 ) 555 .await; 556 Ok(()) 557 } 558 559 async fn get_graph(query_string: String, udfs: &[LocalUdf]) -> Result<LogicalProgram> { 560 let mut schema_provider = ArroyoSchemaProvider::new(); 561 for udf in udfs { 562 schema_provider 563 .add_rust_udf(udf.def, udf.config.name.as_str()) 564 .unwrap(); 565 } 566 567 // TODO: test with higher parallelism 568 let program = parse_and_get_arrow_program( 569 query_string, 570 schema_provider, 571 SqlConfig { 572 default_parallelism: 1, 573 }, 574 ) 575 .await? 576 .program; 577 Ok(program) 578 }