updatable_count.rs
1 // Copyright (c) 2019-2025 Alpha-Delta Network Inc. 2 // This file is part of the deltavm library. 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at: 7 8 // http://www.apache.org/licenses/LICENSE-2.0 9 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 use crate::{Constant, Constraints, Measurement, Private, Public}; 17 18 use core::fmt::Debug; 19 use std::{ 20 cmp::Ordering, 21 collections::{BTreeSet, HashMap}, 22 env, 23 fmt::Display, 24 fs, 25 ops::Range, 26 path::{Path, PathBuf}, 27 sync::{LazyLock, Mutex, OnceLock}, 28 }; 29 30 static FILES: LazyLock<Mutex<HashMap<&'static str, FileUpdates>>> = LazyLock::new(Default::default); 31 static WORKSPACE_ROOT: OnceLock<PathBuf> = OnceLock::new(); 32 33 /// To update the arguments to `count_is!`, run cargo test with the `UPDATE_COUNT` flag set to the name of the file containing the macro invocation. 34 /// e.g. `UPDATE_COUNT=boolean cargo test 35 /// See <https://github.com/ProvableHQ/deltavm/pull/1688> for more details. 36 #[macro_export] 37 macro_rules! count_is { 38 ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => { 39 $crate::UpdatableCount { 40 constant: $crate::Measurement::Exact($num_constants), 41 public: $crate::Measurement::Exact($num_public), 42 private: $crate::Measurement::Exact($num_private), 43 constraints: $crate::Measurement::Exact($num_constraints), 44 file: file!(), 45 line: line!(), 46 column: column!(), 47 } 48 }; 49 } 50 51 /// To update the arguments to `count_less_than!`, run cargo test with the `UPDATE_COUNT` flag set to the name of the file containing the macro invocation. 52 /// e.g. `UPDATE_COUNT=boolean cargo test 53 /// See <https://github.com/ProvableHQ/deltavm/pull/1688> for more details. 54 #[macro_export] 55 macro_rules! count_less_than { 56 ($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => { 57 $crate::UpdatableCount { 58 constant: $crate::Measurement::UpperBound($num_constants), 59 public: $crate::Measurement::UpperBound($num_public), 60 private: $crate::Measurement::UpperBound($num_private), 61 constraints: $crate::Measurement::UpperBound($num_constraints), 62 file: file!(), 63 line: line!(), 64 column: column!(), 65 } 66 }; 67 } 68 69 /// A helper struct for tracking the number of constants, public inputs, private inputs, and constraints. 70 /// Warning: Do not construct this struct directly. Instead, use the `count_is!` and `count_less_than!` macros. 71 #[derive(Copy, Clone, Debug)] 72 pub struct UpdatableCount { 73 pub constant: Constant, 74 pub public: Public, 75 pub private: Private, 76 pub constraints: Constraints, 77 #[doc(hidden)] 78 pub file: &'static str, 79 #[doc(hidden)] 80 pub line: u32, 81 #[doc(hidden)] 82 pub column: u32, 83 } 84 85 impl Display for UpdatableCount { 86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 87 write!( 88 f, 89 "Constants: {}, Public: {}, Private: {}, Constraints: {}", 90 self.constant, self.public, self.private, self.constraints 91 ) 92 } 93 } 94 95 impl UpdatableCount { 96 /// Returns `true` if the values matches the `Measurement`s in `UpdatableCount`. 97 /// 98 /// For an `Exact` metric, `value` must be equal to the exact value defined by the metric. 99 /// For a `Range` metric, `value` must be satisfy lower bound and the upper bound. 100 /// For an `UpperBound` metric, `value` must be satisfy the upper bound. 101 pub fn matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) -> bool { 102 self.constant.matches(num_constants) 103 && self.public.matches(num_public) 104 && self.private.matches(num_private) 105 && self.constraints.matches(num_constraints) 106 } 107 108 /// If all values match, do nothing. 109 /// If all values metrics do not match: 110 /// - If the update condition is satisfied, then update the macro invocation that constructed this `UpdatableCount`. 111 /// - Otherwise, panic. 112 pub fn assert_matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) { 113 if !self.matches(num_constants, num_public, num_private, num_constraints) { 114 let mut files = FILES.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); 115 match env::var("UPDATE_COUNT") { 116 // If `UPDATE_COUNT` is set and the `query_string` matches the file containing the macro invocation 117 // that constructed this `UpdatableCount`, then update the macro invocation. 118 Ok(query_string) if self.file.contains(&query_string) => { 119 files.entry(self.file).or_insert_with(|| FileUpdates::new(self)).update_count( 120 self, 121 num_constants, 122 num_public, 123 num_private, 124 num_constraints, 125 ); 126 } 127 // Otherwise, error. 128 _ => { 129 println!( 130 "\n 131 \x1b[1m\x1b[91merror\x1b[97m: Count does not match\x1b[0m 132 \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{} 133 \x1b[1mExpected\x1b[0m: 134 ---- 135 {} 136 ---- 137 \x1b[1mActual\x1b[0m: 138 ---- 139 Constants: {}, Public: {}, Private: {}, Constraints: {} 140 ---- 141 ", 142 self.file, 143 self.line, 144 self.column, 145 self, 146 num_constants, 147 num_public, 148 num_private, 149 num_constraints, 150 ); 151 // Use resume_unwind instead of panic!() to prevent a backtrace, which is unnecessary noise. 152 std::panic::resume_unwind(Box::new(())); 153 } 154 } 155 } 156 } 157 158 /// Given a string containing the contents of a file, `locate` returns a range delimiting the arguments 159 /// to the macro invocation that constructed this `UpdatableCount`. 160 /// The beginning of the range corresponds to the opening parenthesis of the macro invocation. 161 /// The end of the range corresponds to the closing parenthesis of the macro invocation. 162 /// ```ignore 163 /// count_is!(0, 1, 2, 3) 164 /// ``` ^ ^ 165 /// starting_index ending_index 166 /// 167 /// Note: This function must always invoked with the file contents of the same file as the macro invocation. 168 fn locate(&self, file: &str) -> Range<usize> { 169 // `line_start` is the absolute byte offset from the beginning of the file to the beginning of the current line. 170 let mut line_start = 0; 171 let mut starting_index = None; 172 let mut ending_index = None; 173 for (i, line) in LinesWithEnds::from(file).enumerate() { 174 if i == self.line as usize - 1 { 175 // Seek past the exclamation point, then skip any whitespace and the macro delimiter to get to the opening parentheses. 176 let mut argument_character_indices = line.char_indices().skip((self.column - 1).try_into().unwrap()) 177 .skip_while(|&(_, c)| c != '!') // Skip up to the exclamation point. 178 .skip(1) // Skip `!`. 179 .skip_while(|(_, c)| c.is_whitespace()); // Skip any whitespace. 180 181 // Set `starting_index` to the absolute position of the opening parenthesis in `file`. 182 starting_index = Some( 183 line_start 184 + argument_character_indices 185 .next() 186 .expect("Could not find the beginning of the macro invocation.") 187 .0, 188 ); 189 } 190 191 if starting_index.is_some() { 192 // At this point, we have found the opening parentheses, so we continue to skip all characters until the closing parentheses. 193 match line.char_indices().find(|&(_, c)| c == ')') { 194 None => (), // Do nothing. This means that the closing parentheses was not found on the same line as the opening parentheses. 195 Some((offset, _)) => { 196 // Note that the `+ 1` is to account for the fact that `std::ops::Range` is exclusive on the upper bound. 197 ending_index = Some(line_start + offset + 1); 198 break; 199 } 200 } 201 } 202 line_start += line.len(); 203 } 204 205 Range { 206 start: starting_index.expect("Could not find the beginning of the macro invocation."), 207 end: ending_index.expect("Could not find the ending of the macro invocation."), 208 } 209 } 210 211 /// Computes the difference between the number of constants, public, private, and constraints of `self` and those of `other`. 212 pub fn difference_between(&self, other: &Self) -> (i64, i64, i64, i64) { 213 let difference = |self_measurement, other_measurement| match (self_measurement, other_measurement) { 214 (Measurement::Exact(self_value), Measurement::Exact(other_value)) 215 | (Measurement::UpperBound(self_value), Measurement::UpperBound(other_value)) => { 216 // Note: This assumes that the number of constants, public, private, and constraints do not exceed `i64::MAX`. 217 (self_value as i64) - (other_value as i64) 218 } 219 _ => panic!( 220 "Cannot compute difference for `Measurement::Range` or if both measurements are of different types." 221 ), 222 }; 223 ( 224 difference(self.constant, other.constant), 225 difference(self.public, other.public), 226 difference(self.private, other.private), 227 difference(self.constraints, other.constraints), 228 ) 229 } 230 231 /// Initializes an `UpdatableCount` without a specified location. 232 /// This is only used to store intermediate counts as the source file is updated. 233 fn dummy(constant: Constant, public: Public, private: Private, constraints: Constraints) -> Self { 234 Self { 235 constant, 236 public, 237 private, 238 constraints, 239 file: Default::default(), 240 line: Default::default(), 241 column: Default::default(), 242 } 243 } 244 245 /// Returns a string that is intended to replace the arguments to `count_is` or `count_less_than` in the source file. 246 fn as_argument_string(&self) -> String { 247 let generate_arg = |measurement| match measurement { 248 Measurement::Exact(value) => value, 249 Measurement::UpperBound(bound) => bound, 250 Measurement::Range(..) => panic!( 251 "Cannot create an argument string from an `UpdatableCount` that contains a `Measurement::Range`." 252 ), 253 }; 254 format!( 255 "({}, {}, {}, {})", 256 generate_arg(self.constant), 257 generate_arg(self.public), 258 generate_arg(self.private), 259 generate_arg(self.constraints) 260 ) 261 } 262 } 263 264 /// This struct is used to track updates to the `UpdatableCount`s in a file. 265 /// It is used to ensure that the updates are written to the appropriate location in the file as the file is modified. 266 /// This design avoids having to re-read the source file in the event that it has been modified. 267 struct FileUpdates { 268 absolute_path: PathBuf, 269 original_text: String, 270 modified_text: String, 271 /// An ordered set of `Update`s. 272 /// `Update`s are ordered by their starting location. 273 /// We assume that all `Updates` are made to disjoint ranges in the original file. 274 /// This assumption is valid since invocations of `count_is` and `count_less_than` cannot be nested. 275 updates: BTreeSet<Update>, 276 } 277 278 impl FileUpdates { 279 /// Initializes a new instance of `FileUpdates`. 280 /// This function will read the contents of the file at the specified path and store it in the `original_text` field. 281 /// This function will also initialize the `updates` field to an empty vector. 282 fn new(count: &UpdatableCount) -> Self { 283 let path = Path::new(count.file); 284 let absolute_path = match path.is_absolute() { 285 true => path.to_owned(), 286 false => { 287 // Append `path` to the workspace root. 288 WORKSPACE_ROOT 289 .get_or_init(|| { 290 // Heuristic, see https://github.com/rust-lang/cargo/issues/3946 291 Path::new(&env!("CARGO_MANIFEST_DIR")) 292 .ancestors() 293 .filter(|it| it.join("Cargo.toml").exists()) 294 .last() 295 .unwrap() 296 .to_path_buf() 297 }) 298 .join(path) 299 } 300 }; 301 let original_text = fs::read_to_string(&absolute_path).unwrap(); 302 let modified_text = original_text.clone(); 303 let updates = Default::default(); 304 Self { absolute_path, original_text, modified_text, updates } 305 } 306 307 /// This function will update the `modified_text` field with the new text that is being inserted. 308 /// The resulting `modified_text` is written to the file at the specified path. 309 /// This implementation allows us to avoid re-reading the source file in the case where multiple updates 310 /// are being made to the same location in the source code. 311 fn update_count( 312 &mut self, 313 count: &UpdatableCount, 314 num_constants: u64, 315 num_public: u64, 316 num_private: u64, 317 num_constraints: u64, 318 ) { 319 // Get the location of arguments in the macro invocation. 320 let range = count.locate(&self.original_text); 321 322 let mut new_range = range.clone(); 323 let mut update_with_same_start = None; 324 325 // Shift the range to account for changes made to the original file. 326 // Note that the `Update`s in self.updates are ordered by their starting location. 327 for previous_update in &self.updates { 328 let amount_deleted = previous_update.end - previous_update.start; 329 let amount_inserted = previous_update.argument_string.len(); 330 331 match previous_update.start.cmp(&range.start) { 332 // If an update was made in a location preceding the range in the original file, we need to shift the range by the length of the text that was changed. 333 Ordering::Less => { 334 new_range.start = new_range.start - amount_deleted + amount_inserted; 335 new_range.end = new_range.end - amount_deleted + amount_inserted; 336 } 337 // If an update was made at the same location as the range in the original file, we need to shift the end of the range by the amount of text that was changed. 338 Ordering::Equal => { 339 new_range.end = new_range.end - amount_deleted + amount_inserted; 340 update_with_same_start = Some(previous_update); 341 } 342 // We do not need to shift the range if an update was made in a location following the range in the original file. 343 Ordering::Greater => { 344 break; 345 } 346 } 347 } 348 349 // If the original `UpdatableCount` has been modified, then check if the counts satisfy the most recent `UpdatableCount`. 350 // If so, then there is no need to write to update the file and we can return early. 351 if let Some(update) = update_with_same_start { 352 if update.count.matches(num_constants, num_public, num_private, num_constraints) { 353 return; 354 } 355 } 356 357 // Construct the new update. 358 let new_update = match update_with_same_start { 359 None => Update::new(&range, count, num_constants, num_public, num_private, num_constraints), 360 Some(update) => Update::new(&range, &update.count, num_constants, num_public, num_private, num_constraints), 361 }; 362 363 // Apply the update at the adjusted location. 364 self.modified_text.replace_range(new_range, &new_update.argument_string); 365 366 // Print the difference between the original and updated counts. 367 let difference = new_update.count.difference_between(count); 368 println!( 369 "\n 370 \x1b[1m\x1b[33mwarning\x1b[97m: Updated count\x1b[0m 371 \x1b[1m\x1b[34m-->\x1b[0m {}:{}:{} 372 \x1b[1mOriginal count\x1b[0m: 373 ---- 374 {} 375 ---- 376 \x1b[1mUpdated count\x1b[0m: 377 ---- 378 {} 379 ---- 380 \x1b[1mDifference between updated and original\x1b[0m: 381 ---- 382 Constants: {}, Public: {}, Private: {}, Constraints: {} 383 ---- 384 ", 385 count.file, 386 count.line, 387 count.column, 388 count, 389 new_update.count, 390 difference.0, 391 difference.1, 392 difference.2, 393 difference.3 394 ); 395 396 // Add the new update to the set of updates. 397 self.updates.replace(new_update); 398 399 // Update the original file with the modified text. 400 fs::write(&self.absolute_path, &self.modified_text).unwrap() 401 } 402 } 403 404 /// Helper struct to keep track of updates to the original file. 405 #[derive(Debug)] 406 struct Update { 407 /// Starting location in the original file. 408 start: usize, 409 /// Ending location in the original file. 410 end: usize, 411 /// A dummy count with the new `Measurement`s. 412 count: UpdatableCount, 413 /// A string representation of `count`. 414 argument_string: String, 415 } 416 417 impl Update { 418 fn new( 419 range: &Range<usize>, 420 old_count: &UpdatableCount, 421 num_constants: u64, 422 num_public: u64, 423 num_private: u64, 424 num_constraints: u64, 425 ) -> Self { 426 // Helper function to determine the new `Measurement` based on the expected value. 427 let generate_new_measurement = |measurement: Measurement<u64>, expected: u64| match measurement { 428 Measurement::Exact(..) => Measurement::Exact(expected), 429 Measurement::Range(..) => panic!("UpdatableCount does not support ranges."), 430 Measurement::UpperBound(bound) => Measurement::UpperBound(std::cmp::max(expected, bound)), 431 }; 432 let count = UpdatableCount::dummy( 433 generate_new_measurement(old_count.constant, num_constants), 434 generate_new_measurement(old_count.public, num_public), 435 generate_new_measurement(old_count.private, num_private), 436 generate_new_measurement(old_count.constraints, num_constraints), 437 ); 438 Self { start: range.start, end: range.end, count, argument_string: count.as_argument_string() } 439 } 440 } 441 442 impl PartialEq for Update { 443 fn eq(&self, other: &Self) -> bool { 444 self.start == other.start 445 } 446 } 447 impl Eq for Update {} 448 impl PartialOrd for Update { 449 fn partial_cmp(&self, other: &Self) -> Option<Ordering> { 450 Some(self.cmp(other)) 451 } 452 } 453 impl Ord for Update { 454 fn cmp(&self, other: &Self) -> Ordering { 455 self.start.cmp(&other.start) 456 } 457 } 458 459 /// A struct that provides an iterator over the lines in a string, while preserving the original line endings. 460 /// This is necessary as `str::lines` does not preserve the original line endings. 461 struct LinesWithEnds<'a> { 462 text: &'a str, 463 } 464 465 impl<'a> Iterator for LinesWithEnds<'a> { 466 type Item = &'a str; 467 468 fn next(&mut self) -> Option<&'a str> { 469 match self.text.is_empty() { 470 true => None, 471 false => { 472 let idx = self.text.find('\n').map_or(self.text.len(), |it| it + 1); 473 let (res, next) = self.text.split_at(idx); 474 self.text = next; 475 Some(res) 476 } 477 } 478 } 479 } 480 481 impl<'a> From<&'a str> for LinesWithEnds<'a> { 482 fn from(text: &'a str) -> Self { 483 LinesWithEnds { text } 484 } 485 } 486 487 #[cfg(test)] 488 mod test { 489 use serial_test::serial; 490 use std::env; 491 492 #[test] 493 fn check_position() { 494 let count = count_is!(0, 0, 0, 0); 495 assert_eq!(count.file, "circuit/environment/src/helpers/updatable_count.rs"); 496 assert_eq!(count.line, 494); 497 assert_eq!(count.column, 21); 498 } 499 500 // Note: The below tests must be run serially since the behavior `assert_matches` depends on whether or not 501 // the environment variable `UPDATE_COUNT` is set. 502 503 #[test] 504 #[serial] 505 fn check_count_passes() { 506 let count = count_is!(1, 2, 3, 4); 507 let num_constants = 1; 508 let num_public = 2; 509 let num_private = 3; 510 let num_inputs = 4; 511 count.assert_matches(num_constants, num_public, num_private, num_inputs); 512 } 513 514 #[test] 515 #[serial] 516 #[should_panic] 517 fn check_count_fails() { 518 let count = count_is!(1, 2, 3, 4); 519 let num_constants = 5; 520 let num_public = 6; 521 let num_private = 7; 522 let num_inputs = 8; 523 524 count.assert_matches(num_constants, num_public, num_private, num_inputs); 525 } 526 527 #[test] 528 #[serial] 529 #[should_panic] 530 fn check_count_does_not_update_if_env_var_is_not_set_correctly() { 531 let count = count_is!(1, 2, 3, 4); 532 let num_constants = 5; 533 let num_public = 6; 534 let num_private = 7; 535 let num_inputs = 8; 536 537 // Set the environment variable to update the file. 538 env::set_var("UPDATE_COUNT", "1"); 539 540 count.assert_matches(num_constants, num_public, num_private, num_inputs); 541 542 env::remove_var("UPDATE_COUNT"); 543 } 544 545 #[test] 546 #[serial] 547 fn check_count_updates_correctly() { 548 // `count` is originally `count_is!(1, 2, 3, 4)`. Replace `original_count` to demonstrate replacement. 549 let count = count_is!(11, 12, 13, 14); 550 let num_constants = 11; 551 let num_public = 12; 552 let num_private = 13; 553 let num_inputs = 14; 554 555 // Set the environment variable to update the file. 556 env::set_var("UPDATE_COUNT", "updatable_count.rs"); 557 558 count.assert_matches(num_constants, num_public, num_private, num_inputs); 559 560 env::remove_var("UPDATE_COUNT"); 561 } 562 563 #[test] 564 #[serial] 565 fn check_count_updates_correctly_multiple_times() { 566 // `count` is originally `count_is!(1, 2, 3, 4)`. Replace `original_count` to demonstrate replacement. 567 let count = count_is!(17, 18, 19, 20); 568 569 env::set_var("UPDATE_COUNT", "updatable_count.rs"); 570 571 let (num_constants, num_public, num_private, num_inputs) = (5, 6, 7, 8); 572 count.assert_matches(num_constants, num_public, num_private, num_inputs); 573 574 let (num_constants, num_public, num_private, num_inputs) = (9, 10, 11, 12); 575 count.assert_matches(num_constants, num_public, num_private, num_inputs); 576 577 let (num_constants, num_public, num_private, num_inputs) = (13, 14, 15, 16); 578 count.assert_matches(num_constants, num_public, num_private, num_inputs); 579 580 let (num_constants, num_public, num_private, num_inputs) = (17, 18, 19, 20); 581 count.assert_matches(num_constants, num_public, num_private, num_inputs); 582 583 env::remove_var("UPDATE_COUNT"); 584 } 585 586 #[test] 587 #[serial] 588 fn check_count_less_than_selects_maximum() { 589 // `count` is initially `count_less_than!(1, 2, 3, 4)`. 590 // After counts are updated, `original_count` is `count_less_than!(17, 18, 19, 20)`. 591 // In other words, count is updated to be the maximum of the original and updated counts. 592 let count = count_less_than!(17, 18, 19, 20); 593 594 // Set the environment variable to update the file. 595 env::set_var("UPDATE_COUNT", "updatable_count.rs"); 596 597 let (num_constants, num_public, num_private, num_inputs) = (5, 18, 7, 8); 598 count.assert_matches(num_constants, num_public, num_private, num_inputs); 599 600 let (num_constants, num_public, num_private, num_inputs) = (17, 10, 11, 12); 601 count.assert_matches(num_constants, num_public, num_private, num_inputs); 602 603 let (num_constants, num_public, num_private, num_inputs) = (13, 6, 19, 16); 604 count.assert_matches(num_constants, num_public, num_private, num_inputs); 605 606 let (num_constants, num_public, num_private, num_inputs) = (9, 18, 15, 20); 607 count.assert_matches(num_constants, num_public, num_private, num_inputs); 608 609 env::remove_var("UPDATE_COUNT"); 610 } 611 }