/ circuit / environment / src / helpers / updatable_count.rs
updatable_count.rs
  1  // Copyright (c) 2019-2025 Alpha-Delta Network Inc.
  2  // This file is part of the deltavm library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  use crate::{Constant, Constraints, Measurement, Private, Public};
 17  
 18  use core::fmt::Debug;
 19  use std::{
 20      cmp::Ordering,
 21      collections::{BTreeSet, HashMap},
 22      env,
 23      fmt::Display,
 24      fs,
 25      ops::Range,
 26      path::{Path, PathBuf},
 27      sync::{LazyLock, Mutex, OnceLock},
 28  };
 29  
 30  static FILES: LazyLock<Mutex<HashMap<&'static str, FileUpdates>>> = LazyLock::new(Default::default);
 31  static WORKSPACE_ROOT: OnceLock<PathBuf> = OnceLock::new();
 32  
 33  /// To update the arguments to `count_is!`, run cargo test with the `UPDATE_COUNT` flag set to the name of the file containing the macro invocation.
 34  /// e.g. `UPDATE_COUNT=boolean cargo test
 35  /// See <https://github.com/ProvableHQ/deltavm/pull/1688> for more details.
 36  #[macro_export]
 37  macro_rules! count_is {
 38      ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
 39          $crate::UpdatableCount {
 40              constant: $crate::Measurement::Exact($num_constants),
 41              public: $crate::Measurement::Exact($num_public),
 42              private: $crate::Measurement::Exact($num_private),
 43              constraints: $crate::Measurement::Exact($num_constraints),
 44              file: file!(),
 45              line: line!(),
 46              column: column!(),
 47          }
 48      };
 49  }
 50  
 51  /// To update the arguments to `count_less_than!`, run cargo test with the `UPDATE_COUNT` flag set to the name of the file containing the macro invocation.
 52  /// e.g. `UPDATE_COUNT=boolean cargo test
 53  /// See <https://github.com/ProvableHQ/deltavm/pull/1688> for more details.
 54  #[macro_export]
 55  macro_rules! count_less_than {
 56      ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
 57          $crate::UpdatableCount {
 58              constant: $crate::Measurement::UpperBound($num_constants),
 59              public: $crate::Measurement::UpperBound($num_public),
 60              private: $crate::Measurement::UpperBound($num_private),
 61              constraints: $crate::Measurement::UpperBound($num_constraints),
 62              file: file!(),
 63              line: line!(),
 64              column: column!(),
 65          }
 66      };
 67  }
 68  
 69  /// A helper struct for tracking the number of constants, public inputs, private inputs, and constraints.
 70  /// Warning: Do not construct this struct directly. Instead, use the `count_is!` and `count_less_than!` macros.
 71  #[derive(Copy, Clone, Debug)]
 72  pub struct UpdatableCount {
 73      pub constant: Constant,
 74      pub public: Public,
 75      pub private: Private,
 76      pub constraints: Constraints,
 77      #[doc(hidden)]
 78      pub file: &'static str,
 79      #[doc(hidden)]
 80      pub line: u32,
 81      #[doc(hidden)]
 82      pub column: u32,
 83  }
 84  
 85  impl Display for UpdatableCount {
 86      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 87          write!(
 88              f,
 89              "Constants: {}, Public: {}, Private: {}, Constraints: {}",
 90              self.constant, self.public, self.private, self.constraints
 91          )
 92      }
 93  }
 94  
 95  impl UpdatableCount {
 96      /// Returns `true` if the values matches the `Measurement`s in `UpdatableCount`.
 97      ///
 98      /// For an `Exact` metric, `value` must be equal to the exact value defined by the metric.
 99      /// For a `Range` metric, `value` must be satisfy lower bound and the upper bound.
100      /// For an `UpperBound` metric, `value` must be satisfy the upper bound.
101      pub fn matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) -> bool {
102          self.constant.matches(num_constants)
103              && self.public.matches(num_public)
104              && self.private.matches(num_private)
105              && self.constraints.matches(num_constraints)
106      }
107  
108      /// If all values match, do nothing.
109      /// If all values metrics do not match:
110      ///    - If the update condition is satisfied, then update the macro invocation that constructed this `UpdatableCount`.
111      ///    - Otherwise, panic.
112      pub fn assert_matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) {
113          if !self.matches(num_constants, num_public, num_private, num_constraints) {
114              let mut files = FILES.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
115              match env::var("UPDATE_COUNT") {
116                  // If `UPDATE_COUNT` is set and the `query_string` matches the file containing the macro invocation
117                  // that constructed this `UpdatableCount`, then update the macro invocation.
118                  Ok(query_string) if self.file.contains(&query_string) => {
119                      files.entry(self.file).or_insert_with(|| FileUpdates::new(self)).update_count(
120                          self,
121                          num_constants,
122                          num_public,
123                          num_private,
124                          num_constraints,
125                      );
126                  }
127                  // Otherwise, error.
128                  _ => {
129                      println!(
130                          "\n
131  \x1b[1m\x1b[91merror\x1b[97m: Count does not match\x1b[0m
132     \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
133  \x1b[1mExpected\x1b[0m:
134  ----
135  {}
136  ----
137  \x1b[1mActual\x1b[0m:
138  ----
139  Constants: {}, Public: {}, Private: {}, Constraints: {}
140  ----
141  ",
142                          self.file,
143                          self.line,
144                          self.column,
145                          self,
146                          num_constants,
147                          num_public,
148                          num_private,
149                          num_constraints,
150                      );
151                      // Use resume_unwind instead of panic!() to prevent a backtrace, which is unnecessary noise.
152                      std::panic::resume_unwind(Box::new(()));
153                  }
154              }
155          }
156      }
157  
158      /// Given a string containing the contents of a file, `locate` returns a range delimiting the arguments
159      /// to the macro invocation that constructed this `UpdatableCount`.
160      /// The beginning of the range corresponds to the opening parenthesis of the macro invocation.
161      /// The end of the range corresponds to the closing parenthesis of the macro invocation.
162      /// ```ignore
163      ///              count_is!(0, 1, 2, 3)
164      /// ```                   ^          ^
165      ///           starting_index     ending_index
166      ///
167      /// Note: This function must always invoked with the file contents of the same file as the macro invocation.
168      fn locate(&self, file: &str) -> Range<usize> {
169          // `line_start` is the absolute byte offset from the beginning of the file to the beginning of the current line.
170          let mut line_start = 0;
171          let mut starting_index = None;
172          let mut ending_index = None;
173          for (i, line) in LinesWithEnds::from(file).enumerate() {
174              if i == self.line as usize - 1 {
175                  // Seek past the exclamation point, then skip any whitespace and the macro delimiter to get to the opening parentheses.
176                  let mut argument_character_indices = line.char_indices().skip((self.column - 1).try_into().unwrap())
177                      .skip_while(|&(_, c)| c != '!') // Skip up to the exclamation point.
178                      .skip(1) // Skip `!`.
179                      .skip_while(|(_, c)| c.is_whitespace()); // Skip any whitespace.
180  
181                  // Set `starting_index` to the absolute position of the opening parenthesis in `file`.
182                  starting_index = Some(
183                      line_start
184                          + argument_character_indices
185                              .next()
186                              .expect("Could not find the beginning of the macro invocation.")
187                              .0,
188                  );
189              }
190  
191              if starting_index.is_some() {
192                  // At this point, we have found the opening parentheses, so we continue to skip all characters until the closing parentheses.
193                  match line.char_indices().find(|&(_, c)| c == ')') {
194                      None => (), // Do nothing. This means that the closing parentheses was not found on the same line as the opening parentheses.
195                      Some((offset, _)) => {
196                          // Note that the `+ 1` is to account for the fact that `std::ops::Range` is exclusive on the upper bound.
197                          ending_index = Some(line_start + offset + 1);
198                          break;
199                      }
200                  }
201              }
202              line_start += line.len();
203          }
204  
205          Range {
206              start: starting_index.expect("Could not find the beginning of the macro invocation."),
207              end: ending_index.expect("Could not find the ending of the macro invocation."),
208          }
209      }
210  
211      /// Computes the difference between the number of constants, public, private, and constraints of `self` and those of `other`.
212      pub fn difference_between(&self, other: &Self) -> (i64, i64, i64, i64) {
213          let difference = |self_measurement, other_measurement| match (self_measurement, other_measurement) {
214              (Measurement::Exact(self_value), Measurement::Exact(other_value))
215              | (Measurement::UpperBound(self_value), Measurement::UpperBound(other_value)) => {
216                  // Note: This assumes that the number of constants, public, private, and constraints do not exceed `i64::MAX`.
217                  (self_value as i64) - (other_value as i64)
218              }
219              _ => panic!(
220                  "Cannot compute difference for `Measurement::Range` or if both measurements are of different types."
221              ),
222          };
223          (
224              difference(self.constant, other.constant),
225              difference(self.public, other.public),
226              difference(self.private, other.private),
227              difference(self.constraints, other.constraints),
228          )
229      }
230  
231      /// Initializes an `UpdatableCount` without a specified location.
232      /// This is only used to store intermediate counts as the source file is updated.
233      fn dummy(constant: Constant, public: Public, private: Private, constraints: Constraints) -> Self {
234          Self {
235              constant,
236              public,
237              private,
238              constraints,
239              file: Default::default(),
240              line: Default::default(),
241              column: Default::default(),
242          }
243      }
244  
245      /// Returns a string that is intended to replace the arguments to `count_is` or `count_less_than` in the source file.
246      fn as_argument_string(&self) -> String {
247          let generate_arg = |measurement| match measurement {
248              Measurement::Exact(value) => value,
249              Measurement::UpperBound(bound) => bound,
250              Measurement::Range(..) => panic!(
251                  "Cannot create an argument string from an `UpdatableCount` that contains a `Measurement::Range`."
252              ),
253          };
254          format!(
255              "({}, {}, {}, {})",
256              generate_arg(self.constant),
257              generate_arg(self.public),
258              generate_arg(self.private),
259              generate_arg(self.constraints)
260          )
261      }
262  }
263  
264  /// This struct is used to track updates to the `UpdatableCount`s in a file.
265  /// It is used to ensure that the updates are written to the appropriate location in the file as the file is modified.
266  /// This design avoids having to re-read the source file in the event that it has been modified.
267  struct FileUpdates {
268      absolute_path: PathBuf,
269      original_text: String,
270      modified_text: String,
271      /// An ordered set of `Update`s.
272      /// `Update`s are ordered by their starting location.
273      /// We assume that all `Updates` are made to disjoint ranges in the original file.
274      /// This assumption is valid since invocations of `count_is` and `count_less_than` cannot be nested.
275      updates: BTreeSet<Update>,
276  }
277  
278  impl FileUpdates {
279      /// Initializes a new instance of `FileUpdates`.
280      /// This function will read the contents of the file at the specified path and store it in the `original_text` field.
281      /// This function will also initialize the `updates` field to an empty vector.
282      fn new(count: &UpdatableCount) -> Self {
283          let path = Path::new(count.file);
284          let absolute_path = match path.is_absolute() {
285              true => path.to_owned(),
286              false => {
287                  // Append `path` to the workspace root.
288                  WORKSPACE_ROOT
289                      .get_or_init(|| {
290                          // Heuristic, see https://github.com/rust-lang/cargo/issues/3946
291                          Path::new(&env!("CARGO_MANIFEST_DIR"))
292                              .ancestors()
293                              .filter(|it| it.join("Cargo.toml").exists())
294                              .last()
295                              .unwrap()
296                              .to_path_buf()
297                      })
298                      .join(path)
299              }
300          };
301          let original_text = fs::read_to_string(&absolute_path).unwrap();
302          let modified_text = original_text.clone();
303          let updates = Default::default();
304          Self { absolute_path, original_text, modified_text, updates }
305      }
306  
307      /// This function will update the `modified_text` field with the new text that is being inserted.
308      /// The resulting `modified_text` is written to the file at the specified path.
309      /// This implementation allows us to avoid re-reading the source file in the case where multiple updates
310      /// are being made to the same location in the source code.
311      fn update_count(
312          &mut self,
313          count: &UpdatableCount,
314          num_constants: u64,
315          num_public: u64,
316          num_private: u64,
317          num_constraints: u64,
318      ) {
319          // Get the location of arguments in the macro invocation.
320          let range = count.locate(&self.original_text);
321  
322          let mut new_range = range.clone();
323          let mut update_with_same_start = None;
324  
325          // Shift the range to account for changes made to the original file.
326          // Note that the `Update`s in self.updates are ordered by their starting location.
327          for previous_update in &self.updates {
328              let amount_deleted = previous_update.end - previous_update.start;
329              let amount_inserted = previous_update.argument_string.len();
330  
331              match previous_update.start.cmp(&range.start) {
332                  // If an update was made in a location preceding the range in the original file, we need to shift the range by the length of the text that was changed.
333                  Ordering::Less => {
334                      new_range.start = new_range.start - amount_deleted + amount_inserted;
335                      new_range.end = new_range.end - amount_deleted + amount_inserted;
336                  }
337                  // If an update was made at the same location as the range in the original file, we need to shift the end of the range by the amount of text that was changed.
338                  Ordering::Equal => {
339                      new_range.end = new_range.end - amount_deleted + amount_inserted;
340                      update_with_same_start = Some(previous_update);
341                  }
342                  // We do not need to shift the range if an update was made in a location following the range in the original file.
343                  Ordering::Greater => {
344                      break;
345                  }
346              }
347          }
348  
349          // If the original `UpdatableCount` has been modified, then check if the counts satisfy the most recent `UpdatableCount`.
350          // If so, then there is no need to write to update the file and we can return early.
351          if let Some(update) = update_with_same_start {
352              if update.count.matches(num_constants, num_public, num_private, num_constraints) {
353                  return;
354              }
355          }
356  
357          // Construct the new update.
358          let new_update = match update_with_same_start {
359              None => Update::new(&range, count, num_constants, num_public, num_private, num_constraints),
360              Some(update) => Update::new(&range, &update.count, num_constants, num_public, num_private, num_constraints),
361          };
362  
363          // Apply the update at the adjusted location.
364          self.modified_text.replace_range(new_range, &new_update.argument_string);
365  
366          // Print the difference between the original and updated counts.
367          let difference = new_update.count.difference_between(count);
368          println!(
369              "\n
370  \x1b[1m\x1b[33mwarning\x1b[97m: Updated count\x1b[0m
371     \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
372  \x1b[1mOriginal count\x1b[0m:
373  ----
374  {}
375  ----
376  \x1b[1mUpdated count\x1b[0m:
377  ----
378  {}
379  ----
380  \x1b[1mDifference between updated and original\x1b[0m:
381  ----
382  Constants: {}, Public: {}, Private: {}, Constraints: {}
383  ----
384  ",
385              count.file,
386              count.line,
387              count.column,
388              count,
389              new_update.count,
390              difference.0,
391              difference.1,
392              difference.2,
393              difference.3
394          );
395  
396          // Add the new update to the set of updates.
397          self.updates.replace(new_update);
398  
399          // Update the original file with the modified text.
400          fs::write(&self.absolute_path, &self.modified_text).unwrap()
401      }
402  }
403  
404  /// Helper struct to keep track of updates to the original file.
405  #[derive(Debug)]
406  struct Update {
407      /// Starting location in the original file.
408      start: usize,
409      /// Ending location in the original file.
410      end: usize,
411      /// A dummy count with the new `Measurement`s.
412      count: UpdatableCount,
413      /// A string representation of `count`.
414      argument_string: String,
415  }
416  
417  impl Update {
418      fn new(
419          range: &Range<usize>,
420          old_count: &UpdatableCount,
421          num_constants: u64,
422          num_public: u64,
423          num_private: u64,
424          num_constraints: u64,
425      ) -> Self {
426          // Helper function to determine the new `Measurement` based on the expected value.
427          let generate_new_measurement = |measurement: Measurement<u64>, expected: u64| match measurement {
428              Measurement::Exact(..) => Measurement::Exact(expected),
429              Measurement::Range(..) => panic!("UpdatableCount does not support ranges."),
430              Measurement::UpperBound(bound) => Measurement::UpperBound(std::cmp::max(expected, bound)),
431          };
432          let count = UpdatableCount::dummy(
433              generate_new_measurement(old_count.constant, num_constants),
434              generate_new_measurement(old_count.public, num_public),
435              generate_new_measurement(old_count.private, num_private),
436              generate_new_measurement(old_count.constraints, num_constraints),
437          );
438          Self { start: range.start, end: range.end, count, argument_string: count.as_argument_string() }
439      }
440  }
441  
442  impl PartialEq for Update {
443      fn eq(&self, other: &Self) -> bool {
444          self.start == other.start
445      }
446  }
447  impl Eq for Update {}
448  impl PartialOrd for Update {
449      fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
450          Some(self.cmp(other))
451      }
452  }
453  impl Ord for Update {
454      fn cmp(&self, other: &Self) -> Ordering {
455          self.start.cmp(&other.start)
456      }
457  }
458  
459  /// A struct that provides an iterator over the lines in a string, while preserving the original line endings.
460  /// This is necessary as `str::lines` does not preserve the original line endings.
461  struct LinesWithEnds<'a> {
462      text: &'a str,
463  }
464  
465  impl<'a> Iterator for LinesWithEnds<'a> {
466      type Item = &'a str;
467  
468      fn next(&mut self) -> Option<&'a str> {
469          match self.text.is_empty() {
470              true => None,
471              false => {
472                  let idx = self.text.find('\n').map_or(self.text.len(), |it| it + 1);
473                  let (res, next) = self.text.split_at(idx);
474                  self.text = next;
475                  Some(res)
476              }
477          }
478      }
479  }
480  
481  impl<'a> From<&'a str> for LinesWithEnds<'a> {
482      fn from(text: &'a str) -> Self {
483          LinesWithEnds { text }
484      }
485  }
486  
487  #[cfg(test)]
488  mod test {
489      use serial_test::serial;
490      use std::env;
491  
492      #[test]
493      fn check_position() {
494          let count = count_is!(0, 0, 0, 0);
495          assert_eq!(count.file, "circuit/environment/src/helpers/updatable_count.rs");
496          assert_eq!(count.line, 494);
497          assert_eq!(count.column, 21);
498      }
499  
500      // Note: The below tests must be run serially since the behavior `assert_matches` depends on whether or not
501      // the environment variable `UPDATE_COUNT` is set.
502  
503      #[test]
504      #[serial]
505      fn check_count_passes() {
506          let count = count_is!(1, 2, 3, 4);
507          let num_constants = 1;
508          let num_public = 2;
509          let num_private = 3;
510          let num_inputs = 4;
511          count.assert_matches(num_constants, num_public, num_private, num_inputs);
512      }
513  
514      #[test]
515      #[serial]
516      #[should_panic]
517      fn check_count_fails() {
518          let count = count_is!(1, 2, 3, 4);
519          let num_constants = 5;
520          let num_public = 6;
521          let num_private = 7;
522          let num_inputs = 8;
523  
524          count.assert_matches(num_constants, num_public, num_private, num_inputs);
525      }
526  
527      #[test]
528      #[serial]
529      #[should_panic]
530      fn check_count_does_not_update_if_env_var_is_not_set_correctly() {
531          let count = count_is!(1, 2, 3, 4);
532          let num_constants = 5;
533          let num_public = 6;
534          let num_private = 7;
535          let num_inputs = 8;
536  
537          // Set the environment variable to update the file.
538          env::set_var("UPDATE_COUNT", "1");
539  
540          count.assert_matches(num_constants, num_public, num_private, num_inputs);
541  
542          env::remove_var("UPDATE_COUNT");
543      }
544  
545      #[test]
546      #[serial]
547      fn check_count_updates_correctly() {
548          // `count` is originally `count_is!(1, 2, 3, 4)`. Replace `original_count` to demonstrate replacement.
549          let count = count_is!(11, 12, 13, 14);
550          let num_constants = 11;
551          let num_public = 12;
552          let num_private = 13;
553          let num_inputs = 14;
554  
555          // Set the environment variable to update the file.
556          env::set_var("UPDATE_COUNT", "updatable_count.rs");
557  
558          count.assert_matches(num_constants, num_public, num_private, num_inputs);
559  
560          env::remove_var("UPDATE_COUNT");
561      }
562  
563      #[test]
564      #[serial]
565      fn check_count_updates_correctly_multiple_times() {
566          // `count` is originally `count_is!(1, 2, 3, 4)`. Replace `original_count` to demonstrate replacement.
567          let count = count_is!(17, 18, 19, 20);
568  
569          env::set_var("UPDATE_COUNT", "updatable_count.rs");
570  
571          let (num_constants, num_public, num_private, num_inputs) = (5, 6, 7, 8);
572          count.assert_matches(num_constants, num_public, num_private, num_inputs);
573  
574          let (num_constants, num_public, num_private, num_inputs) = (9, 10, 11, 12);
575          count.assert_matches(num_constants, num_public, num_private, num_inputs);
576  
577          let (num_constants, num_public, num_private, num_inputs) = (13, 14, 15, 16);
578          count.assert_matches(num_constants, num_public, num_private, num_inputs);
579  
580          let (num_constants, num_public, num_private, num_inputs) = (17, 18, 19, 20);
581          count.assert_matches(num_constants, num_public, num_private, num_inputs);
582  
583          env::remove_var("UPDATE_COUNT");
584      }
585  
586      #[test]
587      #[serial]
588      fn check_count_less_than_selects_maximum() {
589          // `count` is initially `count_less_than!(1, 2, 3, 4)`.
590          // After counts are updated, `original_count` is `count_less_than!(17, 18, 19, 20)`.
591          // In other words, count is updated to be the maximum of the original and updated counts.
592          let count = count_less_than!(17, 18, 19, 20);
593  
594          // Set the environment variable to update the file.
595          env::set_var("UPDATE_COUNT", "updatable_count.rs");
596  
597          let (num_constants, num_public, num_private, num_inputs) = (5, 18, 7, 8);
598          count.assert_matches(num_constants, num_public, num_private, num_inputs);
599  
600          let (num_constants, num_public, num_private, num_inputs) = (17, 10, 11, 12);
601          count.assert_matches(num_constants, num_public, num_private, num_inputs);
602  
603          let (num_constants, num_public, num_private, num_inputs) = (13, 6, 19, 16);
604          count.assert_matches(num_constants, num_public, num_private, num_inputs);
605  
606          let (num_constants, num_public, num_private, num_inputs) = (9, 18, 15, 20);
607          count.assert_matches(num_constants, num_public, num_private, num_inputs);
608  
609          env::remove_var("UPDATE_COUNT");
610      }
611  }