/ compiler / ast / src / passes / visitor.rs
visitor.rs
  1  // Copyright (C) 2019-2025 ADnet Contributors
  2  // This file is part of the ADL library.
  3  
  4  // The ADL library is free software: you can redistribute it and/or modify
  5  // it under the terms of the GNU General Public License as published by
  6  // the Free Software Foundation, either version 3 of the License, or
  7  // (at your option) any later version.
  8  
  9  // The ADL library is distributed in the hope that it will be useful,
 10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
 11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 12  // GNU General Public License for more details.
 13  
 14  // You should have received a copy of the GNU General Public License
 15  // along with the ADL library. If not, see <https://www.gnu.org/licenses/>.
 16  
 17  //! This module contains Visitor trait implementations for the AST.
 18  //! It implements default methods for each node to be made
 19  //! given the type of node its visiting.
 20  
 21  use crate::*;
 22  
 23  // TODO: The Visitor and Reconstructor patterns need a redesign so that the default implementation can easily be invoked though its implemented in an overriding trait.
 24  // Here is a pattern that seems to work
 25  // trait ProgramVisitor {
 26  //     // The trait method that can be overridden
 27  //     fn visit_program_scope(&mut self);
 28  //
 29  //     // Private helper function containing the default implementation
 30  //     fn default_visit_program_scope(&mut self) {
 31  //         println!("Do default stuff");
 32  //     }
 33  // }
 34  //
 35  // struct YourStruct;
 36  //
 37  // impl ProgramVisitor for YourStruct {
 38  //     fn visit_program_scope(&mut self) {
 39  //         println!("Do custom stuff.");
 40  //         // Call the default implementation
 41  //         self.default_visit_program_scope();
 42  //     }
 43  // }
 44  
 45  /// A Visitor trait for types in the AST.
 46  pub trait AstVisitor {
 47      /* Types */
 48      fn visit_type(&mut self, input: &Type) {
 49          match input {
 50              Type::Array(array_type) => self.visit_array_type(array_type),
 51              Type::Composite(composite_type) => self.visit_composite_type(composite_type),
 52              Type::Future(future_type) => self.visit_future_type(future_type),
 53              Type::Mapping(mapping_type) => self.visit_mapping_type(mapping_type),
 54              Type::Optional(optional_type) => self.visit_optional_type(optional_type),
 55              Type::Tuple(tuple_type) => self.visit_tuple_type(tuple_type),
 56              Type::Vector(array_type) => self.visit_vector_type(array_type),
 57              Type::Address
 58              | Type::Boolean
 59              | Type::Field
 60              | Type::Group
 61              | Type::Identifier(_)
 62              | Type::Integer(_)
 63              | Type::Scalar
 64              | Type::Signature
 65              | Type::String
 66              | Type::Numeric
 67              | Type::Unit
 68              | Type::Err => {}
 69          }
 70      }
 71  
 72      fn visit_array_type(&mut self, input: &ArrayType) {
 73          self.visit_type(&input.element_type);
 74          self.visit_expression(&input.length, &Default::default());
 75      }
 76  
 77      fn visit_composite_type(&mut self, input: &CompositeType) {
 78          input.const_arguments.iter().for_each(|expr| {
 79              self.visit_expression(expr, &Default::default());
 80          });
 81      }
 82  
 83      fn visit_future_type(&mut self, input: &FutureType) {
 84          input.inputs.iter().for_each(|input| self.visit_type(input));
 85      }
 86  
 87      fn visit_mapping_type(&mut self, input: &MappingType) {
 88          self.visit_type(&input.key);
 89          self.visit_type(&input.value);
 90      }
 91  
 92      fn visit_optional_type(&mut self, input: &OptionalType) {
 93          self.visit_type(&input.inner);
 94      }
 95  
 96      fn visit_tuple_type(&mut self, input: &TupleType) {
 97          input.elements().iter().for_each(|input| self.visit_type(input));
 98      }
 99  
100      fn visit_vector_type(&mut self, input: &VectorType) {
101          self.visit_type(&input.element_type);
102      }
103  
104      /* Expressions */
105      type AdditionalInput: Default;
106      type Output: Default;
107  
108      fn visit_expression(&mut self, input: &Expression, additional: &Self::AdditionalInput) -> Self::Output {
109          match input {
110              Expression::Array(array) => self.visit_array(array, additional),
111              Expression::ArrayAccess(access) => self.visit_array_access(access, additional),
112              Expression::Async(async_) => self.visit_async(async_, additional),
113              Expression::Binary(binary) => self.visit_binary(binary, additional),
114              Expression::Call(call) => self.visit_call(call, additional),
115              Expression::Cast(cast) => self.visit_cast(cast, additional),
116              Expression::Composite(composite_) => self.visit_composite_init(composite_, additional),
117              Expression::Err(err) => self.visit_err(err, additional),
118              Expression::Path(path) => self.visit_path(path, additional),
119              Expression::Literal(literal) => self.visit_literal(literal, additional),
120              Expression::Locator(locator) => self.visit_locator(locator, additional),
121              Expression::MemberAccess(access) => self.visit_member_access(access, additional),
122              Expression::Repeat(repeat) => self.visit_repeat(repeat, additional),
123              Expression::Ternary(ternary) => self.visit_ternary(ternary, additional),
124              Expression::Tuple(tuple) => self.visit_tuple(tuple, additional),
125              Expression::TupleAccess(access) => self.visit_tuple_access(access, additional),
126              Expression::Unary(unary) => self.visit_unary(unary, additional),
127              Expression::Unit(unit) => self.visit_unit(unit, additional),
128              Expression::Intrinsic(intr) => self.visit_intrinsic(intr, additional),
129          }
130      }
131  
132      fn visit_array_access(&mut self, input: &ArrayAccess, _additional: &Self::AdditionalInput) -> Self::Output {
133          self.visit_expression(&input.array, &Default::default());
134          self.visit_expression(&input.index, &Default::default());
135          Default::default()
136      }
137  
138      fn visit_member_access(&mut self, input: &MemberAccess, _additional: &Self::AdditionalInput) -> Self::Output {
139          self.visit_expression(&input.inner, &Default::default());
140          Default::default()
141      }
142  
143      fn visit_tuple_access(&mut self, input: &TupleAccess, _additional: &Self::AdditionalInput) -> Self::Output {
144          self.visit_expression(&input.tuple, &Default::default());
145          Default::default()
146      }
147  
148      fn visit_array(&mut self, input: &ArrayExpression, _additional: &Self::AdditionalInput) -> Self::Output {
149          input.elements.iter().for_each(|expr| {
150              self.visit_expression(expr, &Default::default());
151          });
152          Default::default()
153      }
154  
155      fn visit_async(&mut self, input: &AsyncExpression, _additional: &Self::AdditionalInput) -> Self::Output {
156          self.visit_block(&input.block);
157          Default::default()
158      }
159  
160      fn visit_binary(&mut self, input: &BinaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
161          self.visit_expression(&input.left, &Default::default());
162          self.visit_expression(&input.right, &Default::default());
163          Default::default()
164      }
165  
166      fn visit_call(&mut self, input: &CallExpression, _additional: &Self::AdditionalInput) -> Self::Output {
167          input.const_arguments.iter().for_each(|expr| {
168              self.visit_expression(expr, &Default::default());
169          });
170          input.arguments.iter().for_each(|expr| {
171              self.visit_expression(expr, &Default::default());
172          });
173          Default::default()
174      }
175  
176      fn visit_intrinsic(&mut self, input: &IntrinsicExpression, _additional: &Self::AdditionalInput) -> Self::Output {
177          input.arguments.iter().for_each(|arg| {
178              self.visit_expression(arg, &Default::default());
179          });
180          Default::default()
181      }
182  
183      fn visit_cast(&mut self, input: &CastExpression, _additional: &Self::AdditionalInput) -> Self::Output {
184          self.visit_expression(&input.expression, &Default::default());
185          Default::default()
186      }
187  
188      fn visit_composite_init(
189          &mut self,
190          input: &CompositeExpression,
191          _additional: &Self::AdditionalInput,
192      ) -> Self::Output {
193          input.const_arguments.iter().for_each(|expr| {
194              self.visit_expression(expr, &Default::default());
195          });
196          for CompositeFieldInitializer { expression, .. } in input.members.iter() {
197              if let Some(expression) = expression {
198                  self.visit_expression(expression, &Default::default());
199              }
200          }
201          Default::default()
202      }
203  
204      fn visit_err(&mut self, _input: &ErrExpression, _additional: &Self::AdditionalInput) -> Self::Output {
205          panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
206      }
207  
208      fn visit_path(&mut self, _input: &Path, _additional: &Self::AdditionalInput) -> Self::Output {
209          Default::default()
210      }
211  
212      fn visit_literal(&mut self, _input: &Literal, _additional: &Self::AdditionalInput) -> Self::Output {
213          Default::default()
214      }
215  
216      fn visit_locator(&mut self, _input: &LocatorExpression, _additional: &Self::AdditionalInput) -> Self::Output {
217          Default::default()
218      }
219  
220      fn visit_repeat(&mut self, input: &RepeatExpression, _additional: &Self::AdditionalInput) -> Self::Output {
221          self.visit_expression(&input.expr, &Default::default());
222          self.visit_expression(&input.count, &Default::default());
223          Default::default()
224      }
225  
226      fn visit_ternary(&mut self, input: &TernaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
227          self.visit_expression(&input.condition, &Default::default());
228          self.visit_expression(&input.if_true, &Default::default());
229          self.visit_expression(&input.if_false, &Default::default());
230          Default::default()
231      }
232  
233      fn visit_tuple(&mut self, input: &TupleExpression, _additional: &Self::AdditionalInput) -> Self::Output {
234          input.elements.iter().for_each(|expr| {
235              self.visit_expression(expr, &Default::default());
236          });
237          Default::default()
238      }
239  
240      fn visit_unary(&mut self, input: &UnaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
241          self.visit_expression(&input.receiver, &Default::default());
242          Default::default()
243      }
244  
245      fn visit_unit(&mut self, _input: &UnitExpression, _additional: &Self::AdditionalInput) -> Self::Output {
246          Default::default()
247      }
248  
249      /* Statements */
250      fn visit_statement(&mut self, input: &Statement) {
251          match input {
252              Statement::Assert(stmt) => self.visit_assert(stmt),
253              Statement::Assign(stmt) => self.visit_assign(stmt),
254              Statement::Block(stmt) => self.visit_block(stmt),
255              Statement::Conditional(stmt) => self.visit_conditional(stmt),
256              Statement::Const(stmt) => self.visit_const(stmt),
257              Statement::Definition(stmt) => self.visit_definition(stmt),
258              Statement::Expression(stmt) => self.visit_expression_statement(stmt),
259              Statement::Iteration(stmt) => self.visit_iteration(stmt),
260              Statement::Return(stmt) => self.visit_return(stmt),
261          }
262      }
263  
264      fn visit_assert(&mut self, input: &AssertStatement) {
265          match &input.variant {
266              AssertVariant::Assert(expr) => self.visit_expression(expr, &Default::default()),
267              AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
268                  self.visit_expression(left, &Default::default());
269                  self.visit_expression(right, &Default::default())
270              }
271          };
272      }
273  
274      fn visit_assign(&mut self, input: &AssignStatement) {
275          self.visit_expression(&input.place, &Default::default());
276          self.visit_expression(&input.value, &Default::default());
277      }
278  
279      fn visit_block(&mut self, input: &Block) {
280          input.statements.iter().for_each(|stmt| self.visit_statement(stmt));
281      }
282  
283      fn visit_conditional(&mut self, input: &ConditionalStatement) {
284          self.visit_expression(&input.condition, &Default::default());
285          self.visit_block(&input.then);
286          if let Some(stmt) = input.otherwise.as_ref() {
287              self.visit_statement(stmt);
288          }
289      }
290  
291      fn visit_const(&mut self, input: &ConstDeclaration) {
292          self.visit_type(&input.type_);
293          self.visit_expression(&input.value, &Default::default());
294      }
295  
296      fn visit_definition(&mut self, input: &DefinitionStatement) {
297          if let Some(ty) = input.type_.as_ref() {
298              self.visit_type(ty)
299          }
300          self.visit_expression(&input.value, &Default::default());
301      }
302  
303      fn visit_expression_statement(&mut self, input: &ExpressionStatement) {
304          self.visit_expression(&input.expression, &Default::default());
305      }
306  
307      fn visit_iteration(&mut self, input: &IterationStatement) {
308          if let Some(ty) = input.type_.as_ref() {
309              self.visit_type(ty)
310          }
311          self.visit_expression(&input.start, &Default::default());
312          self.visit_expression(&input.stop, &Default::default());
313          self.visit_block(&input.block);
314      }
315  
316      fn visit_return(&mut self, input: &ReturnStatement) {
317          self.visit_expression(&input.expression, &Default::default());
318      }
319  }
320  
321  /// A Visitor trait for the program represented by the AST.
322  pub trait ProgramVisitor: AstVisitor {
323      fn visit_program(&mut self, input: &Program) {
324          input.program_scopes.values().for_each(|scope| self.visit_program_scope(scope));
325          input.modules.values().for_each(|module| self.visit_module(module));
326          input.stubs.values().for_each(|stub| self.visit_stub(stub));
327      }
328  
329      fn visit_program_scope(&mut self, input: &ProgramScope) {
330          input.consts.iter().for_each(|(_, c)| self.visit_const(c));
331          input.composites.iter().for_each(|(_, c)| self.visit_composite(c));
332          input.mappings.iter().for_each(|(_, c)| self.visit_mapping(c));
333          input.storage_variables.iter().for_each(|(_, c)| self.visit_storage_variable(c));
334          input.functions.iter().for_each(|(_, c)| self.visit_function(c));
335          if let Some(c) = input.constructor.as_ref() {
336              self.visit_constructor(c);
337          }
338      }
339  
340      fn visit_module(&mut self, input: &Module) {
341          input.consts.iter().for_each(|(_, c)| self.visit_const(c));
342          input.composites.iter().for_each(|(_, c)| self.visit_composite(c));
343          input.functions.iter().for_each(|(_, c)| self.visit_function(c));
344      }
345  
346      fn visit_stub(&mut self, _input: &Stub) {}
347  
348      fn visit_import(&mut self, input: &Program) {
349          self.visit_program(input)
350      }
351  
352      fn visit_composite(&mut self, input: &Composite) {
353          input.const_parameters.iter().for_each(|input| self.visit_type(&input.type_));
354          input.members.iter().for_each(|member| self.visit_type(&member.type_));
355      }
356  
357      fn visit_mapping(&mut self, input: &Mapping) {
358          self.visit_type(&input.key_type);
359          self.visit_type(&input.value_type);
360      }
361  
362      fn visit_storage_variable(&mut self, input: &StorageVariable) {
363          self.visit_type(&input.type_);
364      }
365  
366      fn visit_function(&mut self, input: &Function) {
367          input.const_parameters.iter().for_each(|input| self.visit_type(&input.type_));
368          input.input.iter().for_each(|input| self.visit_type(&input.type_));
369          input.output.iter().for_each(|output| self.visit_type(&output.type_));
370          self.visit_type(&input.output_type);
371          self.visit_block(&input.block);
372      }
373  
374      fn visit_constructor(&mut self, input: &Constructor) {
375          self.visit_block(&input.block);
376      }
377  
378      fn visit_function_stub(&mut self, _input: &FunctionStub) {}
379  
380      fn visit_composite_stub(&mut self, _input: &Composite) {}
381  }