/ test-framework / src / lib.rs
lib.rs
  1  // Copyright (C) 2019-2025 ADnet Contributors
  2  // This file is part of the ADL library.
  3  
  4  // The ADL library is free software: you can redistribute it and/or modify
  5  // it under the terms of the GNU General Public License as published by
  6  // the Free Software Foundation, either version 3 of the License, or
  7  // (at your option) any later version.
  8  
  9  // The ADL library is distributed in the hope that it will be useful,
 10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
 11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 12  // GNU General Public License for more details.
 13  
 14  // You should have received a copy of the GNU General Public License
 15  // along with the ADL library. If not, see <https://www.gnu.org/licenses/>.
 16  
 17  //! This is a simple test framework for the ADL compiler.
 18  
 19  #[cfg(not(feature = "no_parallel"))]
 20  use rayon::prelude::*;
 21  
 22  use std::{fs, path::PathBuf};
 23  use walkdir::WalkDir;
 24  
 25  enum TestFailure {
 26      Panicked(String),
 27      Mismatch { got: String, expected: String },
 28  }
 29  
 30  /// Pulls tests from `category`, running them through the `runner` and
 31  /// comparing them against expectations in previous runs.
 32  ///
 33  /// The tests are `.adl` files in `tests/{category}`, and the
 34  /// runner receives the contents of each of them as a `&str`,
 35  /// returning a `String` result. A test is considered to have failed
 36  /// if it panics or if results differ from the previous run.
 37  ///
 38  ///
 39  /// If no corresponding `.out` file is found in `expecations/{category}`,
 40  /// or if the environment variable `REWRITE_EXPECTATIONS` is set, no
 41  /// comparison to a previous result is done and the result of the current
 42  /// run is written to the file.
 43  pub fn run_tests(category: &str, runner: fn(&str) -> String) {
 44      // This ensures error output doesn't try to display colors.
 45      unsafe {
 46          // SAFETY: Safety issues around `set_var` are surprisingly complicated.
 47          // For now, I think marking tests as `serial` may be sufficient to
 48          // address this, and in the future we'll try to think of an alternative for
 49          // error output.
 50          std::env::set_var("NOCOLOR", "x");
 51      }
 52  
 53      let base_tests_dir: PathBuf = [env!("CARGO_MANIFEST_DIR"), "..", "tests"].iter().collect();
 54  
 55      let base_tests_dir = base_tests_dir.canonicalize().unwrap();
 56      let tests_dir = base_tests_dir.join("tests").join(category);
 57      let expectations_dir = base_tests_dir.join("expectations").join(category);
 58  
 59      let filter_string = std::env::var("TEST_FILTER").unwrap_or_default();
 60      let rewrite_expectations = std::env::var("REWRITE_EXPECTATIONS").is_ok();
 61  
 62      struct TestResult {
 63          failure: Option<TestFailure>,
 64          name: PathBuf,
 65          wrote: bool,
 66      }
 67  
 68      let paths: Vec<PathBuf> = WalkDir::new(&tests_dir)
 69          .into_iter()
 70          .flatten()
 71          .filter_map(|entry| {
 72              let path = entry.path();
 73  
 74              if path.to_str().is_none() {
 75                  panic!("Path not unicode: {}.", path.display());
 76              };
 77  
 78              let path_str = path.to_str().unwrap();
 79  
 80              if !path_str.contains(&filter_string) || !path_str.ends_with(".adl") {
 81                  return None;
 82              }
 83  
 84              Some(path.into())
 85          })
 86          .collect();
 87  
 88      let run_test = |path: &PathBuf| -> TestResult {
 89          let contents =
 90              fs::read_to_string(path).unwrap_or_else(|e| panic!("Failed to read file {}: {e}.", path.display()));
 91          let result_output = std::panic::catch_unwind(|| runner(&contents));
 92          if let Err(payload) = result_output {
 93              let s1 = payload.downcast_ref::<&str>().map(|s| s.to_string());
 94              let s2 = payload.downcast_ref::<String>().cloned();
 95              let s = s1.or(s2).unwrap_or_else(|| "Unknown panic payload".to_string());
 96  
 97              return TestResult { failure: Some(TestFailure::Panicked(s)), name: path.clone(), wrote: false };
 98          }
 99          let output = result_output.unwrap();
100  
101          let mut expectation_path: PathBuf = expectations_dir.join(path.strip_prefix(&tests_dir).unwrap());
102          expectation_path.set_extension("out");
103  
104          // It may not be ideal to the the IO below in parallel, but I'm thinking it likely won't matter.
105          if rewrite_expectations || !expectation_path.exists() {
106              fs::write(&expectation_path, &output)
107                  .unwrap_or_else(|e| panic!("Failed to write file {}: {e}.", expectation_path.display()));
108              TestResult { failure: None, name: path.clone(), wrote: true }
109          } else {
110              let expected = fs::read_to_string(&expectation_path)
111                  .unwrap_or_else(|e| panic!("Failed to read file {}: {e}.", expectation_path.display()));
112              if output == expected {
113                  TestResult { failure: None, name: path.clone(), wrote: false }
114              } else {
115                  TestResult {
116                      failure: Some(TestFailure::Mismatch { got: output, expected }),
117                      name: path.clone(),
118                      wrote: false,
119                  }
120              }
121          }
122      };
123  
124      #[cfg(feature = "no_parallel")]
125      let results: Vec<TestResult> = paths.iter().map(run_test).collect();
126  
127      #[cfg(not(feature = "no_parallel"))]
128      let results: Vec<TestResult> = paths.par_iter().map(run_test).collect();
129  
130      println!("Ran {} tests.", results.len());
131  
132      let failure_count = results.iter().filter(|test_result| test_result.failure.is_some()).count();
133  
134      if failure_count != 0 {
135          eprintln!("{failure_count}/{} tests failed.", results.len());
136      }
137  
138      let writes = results.iter().filter(|test_result| test_result.wrote).count();
139  
140      for test_result in results.iter() {
141          if let Some(test_failure) = &test_result.failure {
142              eprintln!("FAILURE: {}:", test_result.name.display());
143              match test_failure {
144                  TestFailure::Panicked(s) => eprintln!("Rust panic:\n{s}"),
145                  TestFailure::Mismatch { got, expected } => {
146                      eprintln!("\ngot:\n{got}\nexpected:\n{expected}\n")
147                  }
148              }
149          }
150      }
151  
152      if writes != 0 {
153          println!("Wrote {}/{} expectation files for tests:", writes, results.len());
154      }
155  
156      for test_result in results.iter() {
157          if test_result.wrote {
158              println!("{}", test_result.name.display());
159          }
160      }
161  
162      assert!(failure_count == 0);
163  }