/ crates / integ / tests / api_tests.rs
api_tests.rs
  1  use anyhow::bail;
  2  use std::env;
  3  use std::sync::{Arc, OnceLock};
  4  use std::time::Duration;
  5  
  6  use arroyo_openapi::types::{
  7      builder, ConnectionProfilePost, ConnectionSchema, ConnectionTablePost, Format, JsonFormat,
  8      MetricName, PipelinePatch, PipelinePost, SchemaDefinition, StopType, Udf, ValidateQueryPost,
  9      ValidateUdfPost,
 10  };
 11  use arroyo_openapi::Client;
 12  use rand::random;
 13  use rdkafka::admin::{AdminClient, AdminOptions, NewTopic};
 14  use rdkafka::{ClientConfig, ClientContext};
 15  use serde_json::json;
 16  use tracing::info;
 17  
 18  async fn wait_for_state(
 19      client: &Client,
 20      pipeline_id: &str,
 21      expected_state: &str,
 22  ) -> anyhow::Result<()> {
 23      let mut last_state = "None".to_string();
 24      while last_state != expected_state {
 25          let jobs = client
 26              .get_pipeline_jobs()
 27              .id(pipeline_id)
 28              .send()
 29              .await
 30              .unwrap();
 31          let job = jobs.data.first().unwrap();
 32  
 33          let state = job.state.clone();
 34          if last_state != state {
 35              info!("Job transitioned to {}", state);
 36              last_state = state;
 37          }
 38  
 39          if last_state == "Failed" {
 40              bail!("Job transitioned to failed");
 41          }
 42  
 43          tokio::time::sleep(Duration::from_millis(100)).await;
 44      }
 45  
 46      Ok(())
 47  }
 48  
 49  fn get_client() -> Arc<Client> {
 50      static CLIENT: OnceLock<Arc<Client>> = OnceLock::new();
 51      CLIENT
 52          .get_or_init(|| {
 53              let client = reqwest::ClientBuilder::new()
 54                  .timeout(Duration::from_secs(60))
 55                  .build()
 56                  .unwrap();
 57              Arc::new(Client::new_with_client(
 58                  &format!(
 59                      "{}/api",
 60                      env::var("API_ENDPOINT")
 61                          .unwrap_or_else(|_| "http://localhost:5115".to_string())
 62                  ),
 63                  client,
 64              ))
 65          })
 66          .clone()
 67  }
 68  
 69  async fn start_pipeline(run_id: u32, query: &str, udfs: &[&str]) -> anyhow::Result<String> {
 70      let pipeline_name = format!("pipeline_{}", run_id);
 71      info!("Creating pipeline {}", pipeline_name);
 72  
 73      let pipeline_id = get_client()
 74          .create_pipeline()
 75          .body(
 76              PipelinePost::builder()
 77                  .name(pipeline_name)
 78                  .parallelism(1)
 79                  .checkpoint_interval_micros(1_000_000)
 80                  .query(query)
 81                  .udfs(Some(
 82                      udfs.iter()
 83                          .map(|udf| Udf::builder().definition(*udf).try_into().unwrap())
 84                          .collect(),
 85                  )),
 86          )
 87          .send()
 88          .await?
 89          .into_inner()
 90          .id;
 91  
 92      info!("Created pipeline {}", pipeline_id);
 93      Ok(pipeline_id)
 94  }
 95  
 96  async fn start_and_monitor(
 97      run_id: u32,
 98      query: &str,
 99      udfs: &[&str],
100      checkpoints_to_wait: u32,
101  ) -> anyhow::Result<(String, String)> {
102      let api_client = get_client();
103  
104      println!("Starting pipeline");
105      let pipeline_id = start_pipeline(run_id, query, udfs)
106          .await
107          .expect("failed to start pipeline");
108  
109      // wait for job to enter running phase
110      println!("Waiting until running");
111      wait_for_state(&api_client, &pipeline_id, "Running")
112          .await
113          .unwrap();
114  
115      let jobs = api_client
116          .get_pipeline_jobs()
117          .id(&pipeline_id)
118          .send()
119          .await
120          .unwrap();
121      let job = jobs.data.first().unwrap();
122  
123      // wait for a checkpoint
124      println!("Waiting for {checkpoints_to_wait} successful checkpoints");
125      loop {
126          let checkpoints = api_client
127              .get_job_checkpoints()
128              .pipeline_id(&pipeline_id)
129              .job_id(&job.id)
130              .send()
131              .await
132              .unwrap()
133              .into_inner();
134  
135          if let Some(checkpoint) = checkpoints
136              .data
137              .iter()
138              .find(|c| c.epoch == checkpoints_to_wait as i64)
139          {
140              if checkpoint.finish_time.is_some() {
141                  // get details
142                  let details = api_client
143                      .get_checkpoint_details()
144                      .pipeline_id(&pipeline_id)
145                      .job_id(&job.id)
146                      .epoch(checkpoint.epoch)
147                      .send()
148                      .await
149                      .unwrap()
150                      .into_inner();
151  
152                  assert!(!details.data.is_empty());
153  
154                  return Ok((pipeline_id, job.id.clone()));
155              }
156          }
157  
158          tokio::time::sleep(Duration::from_millis(50)).await;
159      }
160  }
161  
162  async fn patch_and_wait(
163      pipeline_id: &str,
164      body: builder::PipelinePatch,
165      expected_state: &str,
166  ) -> anyhow::Result<()> {
167      println!("Patching with {:?}", body);
168      get_client()
169          .patch_pipeline()
170          .id(pipeline_id)
171          .body(body)
172          .send()
173          .await?;
174  
175      println!("Waiting for {}", expected_state);
176      wait_for_state(&get_client(), &pipeline_id, expected_state).await?;
177  
178      Ok(())
179  }
180  
181  #[tokio::test]
182  async fn basic_pipeline() {
183      let api_client = get_client();
184  
185      // create a source
186      println!("Creating source");
187      let run_id: u32 = random();
188      let source_name = format!("source_{}", run_id);
189  
190      let source_id = api_client
191          .create_connection_table()
192          .body(
193              ConnectionTablePost::builder()
194                  .config(json!({"event_rate": 10}))
195                  .connector("impulse")
196                  .name(source_name.clone()),
197          )
198          .send()
199          .await
200          .expect("failed to create connection table")
201          .into_inner()
202          .id;
203  
204      // create a pipeline
205      let query = format!(
206          r#"
207      select count(*) from {} where counter % 2 == 0
208      group by hop(interval '2 seconds', interval '10 seconds');
209      "#,
210          source_name
211      );
212  
213      // validate the pipeline
214      let valid = api_client
215          .validate_query()
216          .body(ValidateQueryPost::builder().query(&query).udfs(vec![]))
217          .send()
218          .await
219          .unwrap()
220          .into_inner();
221  
222      assert_eq!(valid.errors, Vec::<String>::new());
223      assert!(valid.graph.is_some());
224  
225      let (pipeline_id, job_id) = start_and_monitor(run_id, &query, &[], 10).await.unwrap();
226  
227      // get error messages
228      let errors = api_client
229          .get_job_errors()
230          .pipeline_id(&pipeline_id)
231          .job_id(&job_id)
232          .send()
233          .await
234          .unwrap()
235          .into_inner();
236      assert_eq!(errors.data.len(), 0);
237  
238      loop {
239          let metrics = api_client
240              .get_operator_metric_groups()
241              .pipeline_id(&pipeline_id)
242              .job_id(&job_id)
243              .send()
244              .await
245              .unwrap()
246              .into_inner();
247          if metrics.data.len() == valid.graph.as_ref().unwrap().nodes.len() {
248              if metrics
249                  .data
250                  .iter()
251                  .filter(|op| !op.operator_id.contains("sink"))
252                  .map(|op| {
253                      op.metric_groups
254                          .iter()
255                          .find(|t| t.name == MetricName::MessagesSent)
256                  })
257                  .all(|m| {
258                      m.map(|m| {
259                          m.subtasks[0].metrics.len() > 0
260                              && m.subtasks[0].metrics.iter().last().unwrap().value > 0.0
261                      })
262                      .unwrap_or(false)
263                  })
264              {
265                  break;
266              }
267          }
268          tokio::time::sleep(Duration::from_millis(500)).await;
269      }
270  
271      // stop job
272      patch_and_wait(
273          &pipeline_id,
274          PipelinePatch::builder().stop(StopType::Checkpoint),
275          "Stopped",
276      )
277      .await
278      .unwrap();
279  
280      // start job
281      patch_and_wait(
282          &pipeline_id,
283          PipelinePatch::builder().stop(StopType::None),
284          "Running",
285      )
286      .await
287      .unwrap();
288  
289      // rescale job
290      println!("Rescaling pipeline");
291      patch_and_wait(
292          &pipeline_id,
293          PipelinePatch::builder().parallelism(2),
294          "Running",
295      )
296      .await
297      .unwrap();
298  
299      for node in api_client
300          .get_pipeline()
301          .id(&pipeline_id)
302          .send()
303          .await
304          .unwrap()
305          .into_inner()
306          .graph
307          .nodes
308      {
309          assert_eq!(node.parallelism, 2);
310      }
311  
312      // restart job
313      println!("Restarting pipeline");
314      api_client
315          .restart_pipeline()
316          .id(&pipeline_id)
317          .send()
318          .await
319          .unwrap();
320  
321      wait_for_state(&api_client, &pipeline_id, "Running")
322          .await
323          .unwrap();
324  
325      // stop job
326      patch_and_wait(
327          &pipeline_id,
328          PipelinePatch::builder().stop(StopType::Immediate),
329          "Stopped",
330      )
331      .await
332      .unwrap();
333  
334      // delete pipeline
335      println!("Deleting pipeline");
336      api_client
337          .delete_pipeline()
338          .id(&pipeline_id)
339          .send()
340          .await
341          .unwrap();
342  
343      // delete source
344      println!("Deleting connection");
345      api_client
346          .delete_connection_table()
347          .id(&source_id)
348          .send()
349          .await
350          .unwrap();
351  }
352  
353  #[tokio::test]
354  async fn udfs() {
355      let udf = r#"
356  /*
357  [dependencies]
358  regex = "1"
359  */
360  
361  use arroyo_udf_plugin::udf;
362  use regex::Regex;
363  
364  #[udf]
365  fn my_double(x: i64) -> i64 {
366      x * 2
367  }"#;
368  
369      // validate UDF
370      let valid = get_client()
371          .validate_udf()
372          .body(ValidateUdfPost::builder().definition(udf))
373          .send()
374          .await
375          .unwrap()
376          .into_inner();
377  
378      assert_eq!(valid.errors, Vec::<String>::new());
379  
380      let query = r#"
381  create table impulse with (
382     connector = 'impulse',
383     event_rate = '10'
384  );
385  
386  select my_double(cast(counter as bigint)) from impulse;
387  "#;
388  
389      let run_id: u32 = random();
390  
391      let (pipeline_id, _job_id) = start_and_monitor(run_id, query, &[udf], 3).await.unwrap();
392  
393      // stop job
394      patch_and_wait(
395          &pipeline_id,
396          PipelinePatch::builder().stop(StopType::Checkpoint),
397          "Stopped",
398      )
399      .await
400      .unwrap();
401  
402      // delete pipeline
403      println!("Deleting pipeline");
404      get_client()
405          .delete_pipeline()
406          .id(&pipeline_id)
407          .send()
408          .await
409          .unwrap();
410  }
411  
412  fn create_kafka_admin() -> AdminClient<impl ClientContext> {
413      ClientConfig::new()
414          .set("bootstrap.servers", "localhost:9092")
415          .create()
416          .unwrap()
417  }
418  
419  async fn create_topic(client: &AdminClient<impl ClientContext>, topic: &str) {
420      client
421          .create_topics(
422              [&NewTopic::new(
423                  topic,
424                  1,
425                  rdkafka::admin::TopicReplication::Fixed(1),
426              )],
427              &AdminOptions::new(),
428          )
429          .await
430          .expect("deletion should have worked");
431  }
432  
433  async fn delete_topic(client: &AdminClient<impl ClientContext>, topic: &str) {
434      client
435          .delete_topics(&[topic], &AdminOptions::new())
436          .await
437          .expect("deletion should have worked");
438  }
439  
440  #[tokio::test]
441  async fn connection_table() {
442      let api_client = get_client();
443  
444      let connectors = api_client
445          .get_connectors()
446          .send()
447          .await
448          .unwrap()
449          .into_inner();
450  
451      assert!(connectors.data.iter().find(|c| c.name == "Kafka").is_some());
452  
453      let run_id: u32 = random();
454      let table_name = format!("kafka_table_{run_id}");
455      let kafka_admin = create_kafka_admin();
456  
457      let kafka_topic = format!("kafka_test_{run_id}");
458      create_topic(&kafka_admin, &kafka_topic).await;
459  
460      let schema = r#"
461  {
462      "type": "object",
463      "properties": {
464          "a": {
465              "type": "string"
466          },
467          "b": {
468              "type": "number"
469          },
470          "c": {
471              "type": "array",
472              "items": {
473                  "type": "string"
474              }
475          }
476      },
477      "required": ["a"]
478  }"#;
479  
480      let connection_schema = ConnectionSchema::builder()
481          .fields(vec![])
482          .format(Format::Json(JsonFormat::builder().try_into().unwrap()))
483          .definition(SchemaDefinition::JsonSchema(schema.to_string()));
484  
485      // create a kafka connection
486      let profile_post = ConnectionProfilePost::builder()
487          .name(format!("kafka_source_{}", run_id))
488          .connector("kafka")
489          .config(json!( {
490              "authentication": {},
491              "bootstrapServers": "localhost:9092",
492              "schemaRegistryEnum": {}
493          }));
494  
495      let valid = api_client
496          .test_connection_profile()
497          .body(profile_post.clone())
498          .send()
499          .await
500          .unwrap()
501          .into_inner();
502  
503      assert!(valid.done);
504      assert!(!valid.error);
505  
506      let profile = api_client
507          .create_connection_profile()
508          .body(profile_post)
509          .send()
510          .await
511          .unwrap()
512          .into_inner();
513  
514      api_client
515          .get_connection_profile_autocomplete()
516          .id(&profile.id)
517          .send()
518          .await
519          .unwrap()
520          .into_inner()
521          .values
522          .get("topic")
523          .unwrap()
524          .iter()
525          .find(|t| *t == &kafka_topic)
526          .expect("autocomplete did not return kafka topic");
527  
528      api_client
529          .test_schema()
530          .body(connection_schema.clone())
531          .send()
532          .await
533          .expect("valid schema");
534  
535      let connection_table = ConnectionTablePost::builder()
536          .name(table_name.clone())
537          .connector("kafka")
538          .schema(Some(connection_schema.try_into().unwrap()))
539          .config(json!({
540              "type": {
541                  "offset": "latest",
542                  "read_mode": "read_uncommitted"
543              },
544              "topic": kafka_topic
545          }))
546          .connection_profile_id(Some(profile.id.clone()));
547  
548      let mut connection_table = api_client
549          .create_connection_table()
550          .body(connection_table)
551          .send()
552          .await
553          .expect("failed to create table")
554          .into_inner();
555  
556      connection_table
557          .schema
558          .fields
559          .sort_by_key(|f| f.field_name.clone());
560  
561      assert_eq!(
562          serde_json::to_value(connection_table.schema.fields).unwrap(),
563          json!([
564              {
565                  "fieldName": "a",
566                  "fieldType": {
567                      "sqlName": "TEXT",
568                      "type": {
569                          "primitive": "String",
570                      },
571                  },
572                  "nullable": false
573              },
574              {
575                  "fieldName": "b",
576                  "fieldType": {
577                     "sqlName": "DOUBLE",
578                      "type": {
579                          "primitive": "F64",
580                      },
581                  },
582                 "nullable": true,
583              },
584              {
585                  "fieldName": "c",
586                  "fieldType": {
587                      "type": {
588                          "list": {
589                              "fieldName": "item",
590                              "fieldType": {
591                                  "sqlName": "TEXT",
592                                  "type": {
593                                      "primitive": "String",
594                                  }
595                              },
596                              "nullable": false
597                          },
598                      },
599                  },
600                  "nullable": true,
601              }
602          ])
603      );
604  
605      let (pipeline_id, _) = start_and_monitor(
606          run_id,
607          &format!("select * from {};", connection_table.name),
608          &[],
609          5,
610      )
611      .await
612      .unwrap();
613  
614      // stop job
615      patch_and_wait(
616          &pipeline_id,
617          PipelinePatch::builder().stop(StopType::Immediate),
618          "Stopped",
619      )
620      .await
621      .unwrap();
622  
623      // delete pipeline
624      println!("Deleting pipeline");
625      api_client
626          .delete_pipeline()
627          .id(&pipeline_id)
628          .send()
629          .await
630          .unwrap();
631  
632      // assert removal of pipeline
633      assert_eq!(
634          api_client
635              .get_pipeline()
636              .id(&pipeline_id)
637              .send()
638              .await
639              .unwrap_err()
640              .status()
641              .unwrap(),
642          reqwest::StatusCode::NOT_FOUND
643      );
644  
645      // delete source
646      println!("Deleting connection");
647      api_client
648          .delete_connection_table()
649          .id(&connection_table.id)
650          .send()
651          .await
652          .unwrap();
653  
654      // delete topic
655      delete_topic(&kafka_admin, &kafka_topic).await;
656  }