/ crates / arroyo-sql-testing / src / smoke_tests.rs
smoke_tests.rs
  1  use anyhow::{bail, Result};
  2  use arroyo_datastream::logical::{
  3      LogicalEdge, LogicalEdgeType, LogicalGraph, LogicalNode, LogicalProgram, OperatorName,
  4  };
  5  use arroyo_df::{parse_and_get_arrow_program, ArroyoSchemaProvider, SqlConfig};
  6  use arroyo_state::parquet::ParquetBackend;
  7  use petgraph::algo::has_path_connecting;
  8  use petgraph::visit::EdgeRef;
  9  use rstest::rstest;
 10  use std::collections::{BTreeMap, HashMap, HashSet};
 11  use std::path::{Path, PathBuf};
 12  use std::sync::Arc;
 13  use std::time::Duration;
 14  use std::{env, time::SystemTime};
 15  use tokio::sync::mpsc::Receiver;
 16  
 17  use crate::udfs::get_udfs;
 18  use arroyo_rpc::grpc::rpc::{StopMode, TaskCheckpointCompletedReq, TaskCheckpointEventReq};
 19  use arroyo_rpc::{CompactionResult, ControlMessage, ControlResp};
 20  use arroyo_state::checkpoint_state::CheckpointState;
 21  use arroyo_types::{to_micros, CheckpointBarrier};
 22  use arroyo_udf_host::LocalUdf;
 23  use arroyo_worker::engine::{Engine, StreamConfig};
 24  use arroyo_worker::engine::{Program, RunningEngine};
 25  use petgraph::{Direction, Graph};
 26  use serde_json::Value;
 27  use test_log::test as test_log;
 28  use tokio::fs::read_to_string;
 29  use tokio::sync::mpsc::error::TryRecvError;
 30  use tracing::info;
 31  
 32  #[test_log(rstest)]
 33  fn for_each_file(#[files("src/test/queries/*.sql")] path: PathBuf) {
 34      tokio::runtime::Builder::new_current_thread()
 35          .enable_all()
 36          .build()
 37          .unwrap()
 38          .block_on(async {
 39              run_smoketest(&path).await;
 40          });
 41  }
 42  
 43  async fn run_smoketest(path: &Path) {
 44      // read text at path
 45      let test_name = path
 46          .file_name()
 47          .unwrap()
 48          .to_str()
 49          .unwrap()
 50          .split('.')
 51          .next()
 52          .unwrap();
 53      let query = read_to_string(path).await.unwrap();
 54      let fail = query.starts_with("--fail");
 55      let error_message = query.starts_with("--fail=").then(|| {
 56          query
 57              .lines()
 58              .next()
 59              .unwrap()
 60              .split_once('=')
 61              .unwrap()
 62              .1
 63              .trim()
 64      });
 65      match correctness_run_codegen(test_name, query.clone(), 20).await {
 66          Ok(_) => {}
 67          Err(err) => {
 68              if fail {
 69                  if let Some(error_message) = error_message {
 70                      assert!(
 71                          err.to_string().contains(error_message),
 72                          "expected error message '{}' not found; instead got '{}'",
 73                          error_message,
 74                          err
 75                      );
 76                  }
 77              } else {
 78                  panic!("smoke test failed: {}", err);
 79              }
 80          }
 81      }
 82  }
 83  
 84  struct SmokeTestContext<'a> {
 85      job_id: Arc<String>,
 86      engine: &'a RunningEngine,
 87      control_rx: &'a mut Receiver<ControlResp>,
 88      tasks_per_operator: HashMap<String, usize>,
 89  }
 90  
 91  async fn checkpoint(ctx: &mut SmokeTestContext<'_>, epoch: u32) {
 92      let checkpoint_id = epoch as i64;
 93      let mut checkpoint_state = CheckpointState::new(
 94          ctx.job_id.clone(),
 95          checkpoint_id.to_string(),
 96          epoch,
 97          0,
 98          ctx.tasks_per_operator.clone(),
 99      );
100  
101      // trigger a checkpoint, pass the messages to the CheckpointState
102  
103      let barrier = CheckpointBarrier {
104          epoch,
105          min_epoch: 0,
106          timestamp: SystemTime::now(),
107          then_stop: false,
108      };
109  
110      for source in ctx.engine.source_controls() {
111          source
112              .send(ControlMessage::Checkpoint(barrier))
113              .await
114              .unwrap();
115      }
116  
117      while !checkpoint_state.done() {
118          let c: ControlResp = ctx.control_rx.recv().await.unwrap();
119  
120          match c {
121              ControlResp::CheckpointEvent(c) => {
122                  let req = TaskCheckpointEventReq {
123                      worker_id: 1,
124                      time: to_micros(c.time),
125                      job_id: (*ctx.job_id).clone(),
126                      operator_id: c.operator_id,
127                      subtask_index: c.subtask_index,
128                      epoch: c.checkpoint_epoch,
129                      event_type: c.event_type as i32,
130                  };
131                  checkpoint_state.checkpoint_event(req).unwrap();
132              }
133              ControlResp::CheckpointCompleted(c) => {
134                  let req = TaskCheckpointCompletedReq {
135                      worker_id: 1,
136                      time: c.subtask_metadata.finish_time,
137                      job_id: (*ctx.job_id).clone(),
138                      operator_id: c.operator_id,
139                      epoch: c.checkpoint_epoch,
140                      needs_commit: false,
141                      metadata: Some(c.subtask_metadata),
142                  };
143                  checkpoint_state.checkpoint_finished(req).await.unwrap();
144              }
145              _ => {}
146          }
147      }
148  
149      checkpoint_state.save_state().await.unwrap();
150  
151      info!("Smoke test checkpoint completed");
152  }
153  
154  async fn compact(
155      job_id: Arc<String>,
156      running_engine: &RunningEngine,
157      tasks_per_operator: HashMap<String, usize>,
158      epoch: u32,
159  ) {
160      let operator_controls = running_engine.operator_controls();
161      for (operator, _) in tasks_per_operator {
162          if let Ok(compacted) =
163              ParquetBackend::compact_operator(job_id.clone(), operator.clone(), epoch).await
164          {
165              let operator_controls = operator_controls.get(&operator).unwrap();
166              for s in operator_controls {
167                  s.send(ControlMessage::LoadCompacted {
168                      compacted: CompactionResult {
169                          operator_id: operator.to_string(),
170                          compacted_tables: compacted.clone(),
171                      },
172                  })
173                  .await
174                  .unwrap();
175              }
176          }
177      }
178  }
179  
180  async fn advance(engine: &RunningEngine, count: i32) {
181      // let the engine run for a bit, process some records
182      for source in engine.source_controls() {
183          for _ in 0..count {
184              let _ = source.send(ControlMessage::NoOp).await;
185          }
186      }
187  }
188  
189  async fn run_until_finished(engine: &RunningEngine, control_rx: &mut Receiver<ControlResp>) {
190      while control_rx.try_recv().is_ok()
191          || control_rx
192              .try_recv()
193              .is_err_and(|e| e == TryRecvError::Empty)
194      {
195          advance(engine, 10).await;
196          tokio::time::sleep(Duration::from_millis(1)).await;
197      }
198  }
199  
200  fn set_internal_parallelism(graph: &mut Graph<LogicalNode, LogicalEdge>, parallelism: usize) {
201      let watermark_nodes: HashSet<_> = graph
202          .node_indices()
203          .filter(|index| {
204              let operator_name = graph.node_weight(*index).unwrap().operator_name;
205              matches!(operator_name, OperatorName::ExpressionWatermark)
206          })
207          .collect();
208      let indices: Vec<_> = graph
209          .node_indices()
210          .filter(
211              |index| match graph.node_weight(*index).unwrap().operator_name {
212                  OperatorName::ExpressionWatermark
213                  | OperatorName::ConnectorSource
214                  | OperatorName::ConnectorSink => false,
215                  _ => {
216                      for watermark_node in watermark_nodes.iter() {
217                          if has_path_connecting(&graph.clone(), *watermark_node, *index, None) {
218                              return true;
219                          }
220                      }
221                      false
222                  }
223              },
224          )
225          .collect();
226      for node in indices {
227          graph.node_weight_mut(node).unwrap().parallelism = parallelism;
228      }
229      if parallelism > 1 {
230          let mut edges_to_make_shuffle = vec![];
231          for node in graph.externals(Direction::Outgoing) {
232              for edge in graph.edges_directed(node, Direction::Incoming) {
233                  edges_to_make_shuffle.push(edge.id());
234              }
235          }
236          for node in graph.node_indices() {
237              if graph.node_weight(node).unwrap().operator_name == OperatorName::ExpressionWatermark {
238                  for edge in graph.edges_directed(node, Direction::Outgoing) {
239                      edges_to_make_shuffle.push(edge.id());
240                  }
241              }
242          }
243          for edge in edges_to_make_shuffle {
244              graph.edge_weight_mut(edge).unwrap().edge_type = LogicalEdgeType::Shuffle;
245          }
246      }
247  }
248  
249  async fn run_and_checkpoint(job_id: Arc<String>, program: Program, checkpoint_interval: i32) {
250      let tasks_per_operator = program.tasks_per_operator();
251      let engine = Engine::for_local(program, job_id.to_string());
252      let (running_engine, mut control_rx) = engine
253          .start(StreamConfig {
254              restore_epoch: None,
255          })
256          .await;
257      info!("Smoke test checkpointing enabled");
258      env::set_var(
259          "ARROYO__CONTROLLER__COMPACTION__CHECKPOINTS_TO_COMPACT",
260          "2",
261      );
262  
263      let ctx = &mut SmokeTestContext {
264          job_id: job_id.clone(),
265          engine: &running_engine,
266          control_rx: &mut control_rx,
267          tasks_per_operator: tasks_per_operator.clone(),
268      };
269  
270      // trigger a couple checkpoints
271      advance(&running_engine, checkpoint_interval).await;
272      checkpoint(ctx, 1).await;
273      advance(&running_engine, checkpoint_interval).await;
274      checkpoint(ctx, 2).await;
275      advance(&running_engine, checkpoint_interval).await;
276  
277      compact(job_id, &running_engine, tasks_per_operator.clone(), 2).await;
278  
279      // trigger checkpoint 3, which will include the compacted files
280      advance(&running_engine, checkpoint_interval).await;
281      checkpoint(ctx, 3).await;
282      // shut down the engine
283      for source in running_engine.source_controls() {
284          source
285              .send(ControlMessage::Stop {
286                  mode: StopMode::Graceful,
287              })
288              .await
289              .unwrap();
290      }
291      run_until_finished(&running_engine, &mut control_rx).await;
292  }
293  
294  async fn finish_from_checkpoint(job_id: &str, program: Program) {
295      let engine = Engine::for_local(program, job_id.to_string());
296      let (running_engine, mut control_rx) = engine
297          .start(StreamConfig {
298              restore_epoch: Some(3),
299          })
300          .await;
301  
302      info!("Restored engine, running until finished");
303      run_until_finished(&running_engine, &mut control_rx).await;
304  }
305  
306  async fn run_pipeline_and_assert_outputs(
307      job_id: &str,
308      mut graph: LogicalGraph,
309      checkpoint_interval: i32,
310      output_location: String,
311      golden_output_location: String,
312      udfs: &[LocalUdf],
313  ) {
314      // remove output_location before running the pipeline
315      if std::path::Path::new(&output_location).exists() {
316          std::fs::remove_file(&output_location).unwrap();
317      }
318  
319      let get_program =
320          |graph: &LogicalGraph| Program::local_from_logical(job_id.to_string(), graph, udfs);
321  
322      run_completely(
323          job_id,
324          get_program(&graph),
325          output_location.clone(),
326          golden_output_location.clone(),
327      )
328      .await;
329  
330      set_internal_parallelism(&mut graph, 2);
331  
332      run_and_checkpoint(
333          Arc::new(job_id.to_string()),
334          get_program(&graph),
335          checkpoint_interval,
336      )
337      .await;
338  
339      set_internal_parallelism(&mut graph, 3);
340  
341      finish_from_checkpoint(job_id, get_program(&graph)).await;
342  
343      check_output_files(
344          "resuming from checkpointing",
345          output_location,
346          golden_output_location,
347      )
348      .await;
349  }
350  
351  async fn run_completely(
352      job_id: &str,
353      program: Program,
354      output_location: String,
355      golden_output_location: String,
356  ) {
357      let engine = Engine::for_local(program, job_id.to_string());
358      let (running_engine, mut control_rx) = engine
359          .start(StreamConfig {
360              restore_epoch: None,
361          })
362          .await;
363  
364      run_until_finished(&running_engine, &mut control_rx).await;
365  
366      check_output_files(
367          "initial run",
368          output_location.clone(),
369          golden_output_location,
370      )
371      .await;
372      if std::path::Path::new(&output_location).exists() {
373          std::fs::remove_file(&output_location).unwrap();
374      }
375  }
376  
377  // return the inner value and whether it is a retract
378  fn decode_debezium(value: &Value) -> Result<(Value, bool)> {
379      if !is_debezium(value) {
380          bail!("not a debezium record");
381      }
382      let op = value.get("op").unwrap().as_str().unwrap();
383      match op {
384          "c" => Ok((value.get("after").unwrap().clone(), false)),
385          "d" => Ok((value.get("before").unwrap().clone(), true)),
386          _ => bail!("unknown op {}", op),
387      }
388  }
389  
390  fn is_debezium(value: &Value) -> bool {
391      let Some(op) = value.get("op") else {
392          return false;
393      };
394      op.as_str().is_some()
395  }
396  
397  fn check_debezium(
398      output_location: String,
399      golden_output_location: String,
400      output_lines: Vec<Value>,
401      golden_output_lines: Vec<Value>,
402  ) {
403      let output_deduped = dedup_debezium(output_lines);
404      let golden_output_deduped = dedup_debezium(golden_output_lines);
405      assert_eq!(
406          output_deduped, golden_output_deduped,
407          "failed to check debezium equality for\noutput: {}\ngolden: {}",
408          output_location, golden_output_location
409      );
410  }
411  
412  fn dedup_debezium(values: Vec<Value>) -> HashMap<String, i64> {
413      let mut deduped = HashMap::new();
414      for value in &values {
415          let (row_data, value) = decode_debezium(value).unwrap();
416          let row_data_str = roundtrip(&row_data);
417          let count = deduped.entry(row_data_str.clone()).or_insert(0);
418          if value {
419              *count -= 1;
420          } else {
421              *count += 1;
422          }
423          if *count == 0 {
424              deduped.remove(&row_data_str);
425          }
426      }
427      deduped
428  }
429  
430  fn roundtrip(v: &Value) -> String {
431      // round trip string through a btreemap to get consistent key ordering
432      serde_json::to_string(&serde_json::from_value::<BTreeMap<String, Value>>(v.clone()).unwrap())
433          .unwrap()
434  }
435  
436  async fn check_output_files(
437      check_name: &str,
438      output_location: String,
439      golden_output_location: String,
440  ) {
441      let mut output_lines: Vec<Value> = read_to_string(output_location.clone())
442          .await
443          .unwrap_or_else(|_| panic!("output file not found at {}", output_location))
444          .lines()
445          .map(|s| serde_json::from_str(s).unwrap())
446          .collect();
447  
448      let mut golden_output_lines: Vec<Value> = read_to_string(golden_output_location.clone())
449          .await
450          .unwrap_or_else(|_| {
451              panic!(
452                  "golden output file not found at {}, want to compare to {}",
453                  golden_output_location, output_location
454              )
455          })
456          .lines()
457          .map(|s| serde_json::from_str(s).unwrap())
458          .collect();
459      if output_lines.len() != golden_output_lines.len() {
460          // might be updating, in which case lets see if we can cancel out rows
461          let Some(first_output) = output_lines.first() else {
462              panic!(
463                  "failed at check {}, output has 0 lines, expect {} lines.\noutput: {}\ngolden: {}",
464                  check_name,
465                  golden_output_lines.len(),
466                  output_location,
467                  golden_output_location
468              );
469          };
470          if is_debezium(first_output) {
471              check_debezium(
472                  output_location,
473                  golden_output_location,
474                  output_lines,
475                  golden_output_lines,
476              );
477              return;
478          }
479  
480          panic!(
481              "failed at check {}, output has {} lines, expect {} lines.\noutput: {}\ngolden: {}",
482              check_name,
483              output_lines.len(),
484              golden_output_lines.len(),
485              output_location,
486              golden_output_location
487          );
488      }
489  
490      output_lines.sort_by_cached_key(roundtrip);
491      golden_output_lines.sort_by_cached_key(roundtrip);
492      output_lines
493          .into_iter()
494          .zip(golden_output_lines.into_iter())
495          .enumerate()
496          .for_each(|(i, (output_line, golden_output_line))| {
497              assert_eq!(
498                  output_line, golden_output_line,
499                  "check {}: line {} of output and golden output differ\nactual:{}\nexpected:{})",
500                  check_name, i, output_location, golden_output_location
501              )
502          });
503  }
504  
505  pub async fn correctness_run_codegen(
506      test_name: impl Into<String>,
507      query: impl Into<String>,
508      checkpoint_interval: i32,
509  ) -> Result<()> {
510      let test_name = test_name.into();
511      let parent_directory = std::env::current_dir()
512          .unwrap()
513          .to_string_lossy()
514          .to_string();
515  
516      // Depending on run location the directory might end with arroyo-sql-testing.
517      // If so, remove it.
518      let parent_directory = if parent_directory.ends_with("arroyo-sql-testing") {
519          parent_directory
520              .strip_suffix("arroyo-sql-testing")
521              .unwrap()
522              .to_string()
523      } else {
524          parent_directory
525      };
526  
527      // replace $input_file with the current directory and then inputs/query_name.json
528      let physical_input_dir = format!("{}/arroyo-sql-testing/inputs/", parent_directory,);
529  
530      let query_string = query.into().replace("$input_dir", &physical_input_dir);
531  
532      // replace $output_file with the current directory and then outputs/query_name.json
533      let physical_output = format!(
534          "{}/arroyo-sql-testing/outputs/{}.json",
535          parent_directory, test_name
536      );
537  
538      let query_string = query_string.replace("$output_path", &physical_output);
539      let golden_output_location = format!(
540          "{}/arroyo-sql-testing/golden_outputs/{}.json",
541          parent_directory, test_name
542      );
543  
544      let udfs = get_udfs();
545  
546      let logical_program = get_graph(query_string.clone(), &udfs).await?;
547      run_pipeline_and_assert_outputs(
548          &test_name,
549          logical_program.graph,
550          checkpoint_interval,
551          physical_output,
552          golden_output_location,
553          &udfs,
554      )
555      .await;
556      Ok(())
557  }
558  
559  async fn get_graph(query_string: String, udfs: &[LocalUdf]) -> Result<LogicalProgram> {
560      let mut schema_provider = ArroyoSchemaProvider::new();
561      for udf in udfs {
562          schema_provider
563              .add_rust_udf(udf.def, udf.config.name.as_str())
564              .unwrap();
565      }
566  
567      // TODO: test with higher parallelism
568      let program = parse_and_get_arrow_program(
569          query_string,
570          schema_provider,
571          SqlConfig {
572              default_parallelism: 1,
573          },
574      )
575      .await?
576      .program;
577      Ok(program)
578  }