ast.rs
  1  // Copyright (C) 2019-2025 ADnet Contributors
  2  // This file is part of the ADL library.
  3  
  4  // The ADL library is free software: you can redistribute it and/or modify
  5  // it under the terms of the GNU General Public License as published by
  6  // the Free Software Foundation, either version 3 of the License, or
  7  // (at your option) any later version.
  8  
  9  // The ADL library is distributed in the hope that it will be useful,
 10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
 11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 12  // GNU General Public License for more details.
 13  
 14  // You should have received a copy of the GNU General Public License
 15  // along with the ADL library. If not, see <https://www.gnu.org/licenses/>.
 16  
 17  use super::ProcessingAsyncVisitor;
 18  use crate::{CompilerState, Replacer};
 19  use adl_ast::{
 20      AstReconstructor,
 21      AstVisitor,
 22      AsyncExpression,
 23      Block,
 24      CallExpression,
 25      Expression,
 26      Function,
 27      Identifier,
 28      Input,
 29      IterationStatement,
 30      Location,
 31      Node,
 32      Path,
 33      ProgramVisitor,
 34      Statement,
 35      TupleAccess,
 36      TupleExpression,
 37      TupleType,
 38      Type,
 39      Variant,
 40  };
 41  use adl_span::{Span, Symbol};
 42  use indexmap::{IndexMap, IndexSet};
 43  
 44  /// Collects all symbol accesses within an async block,
 45  /// including both direct variable identifiers (`x`) and tuple field accesses (`x.0`, `x.1`, etc.).
 46  /// Each access is recorded as a pair: (Symbol, Option<usize>).
 47  /// - `None` means a direct variable access.
 48  /// - `Some(index)` means a tuple field access.
 49  struct SymbolAccessCollector<'a> {
 50      state: &'a CompilerState,
 51      symbol_accesses: IndexSet<(Vec<Symbol>, Option<usize>)>,
 52  }
 53  
 54  impl AstVisitor for SymbolAccessCollector<'_> {
 55      type AdditionalInput = ();
 56      type Output = ();
 57  
 58      fn visit_path(&mut self, input: &Path, _: &Self::AdditionalInput) -> Self::Output {
 59          self.symbol_accesses.insert((input.absolute_path(), None));
 60      }
 61  
 62      fn visit_tuple_access(&mut self, input: &TupleAccess, _: &Self::AdditionalInput) -> Self::Output {
 63          // Here we assume that we can't have nested tuples which is currently guaranteed by type
 64          // checking. This may change in the future.
 65          if let Expression::Path(path) = &input.tuple {
 66              // Futures aren't accessed by field; treat the whole thing as a direct variable
 67              if let Some(Type::Future(_)) = self.state.type_table.get(&input.tuple.id()) {
 68                  self.symbol_accesses.insert((path.absolute_path(), None));
 69              } else {
 70                  self.symbol_accesses.insert((path.absolute_path(), Some(input.index.value())));
 71              }
 72          } else {
 73              self.visit_expression(&input.tuple, &());
 74          }
 75      }
 76  }
 77  
 78  impl ProgramVisitor for SymbolAccessCollector<'_> {}
 79  
 80  impl AstReconstructor for ProcessingAsyncVisitor<'_> {
 81      type AdditionalInput = ();
 82      type AdditionalOutput = ();
 83  
 84      /// Transforms an `AsyncExpression` into a standalone async `Function` and returns
 85      /// a call to this function. This process:
 86      /// - Collects all referenced symbol accesses in the async block.
 87      /// - Filters out mappings and constructs typed input parameters.
 88      /// - Reconstructs an async function with those inputs and the original block.
 89      /// - Builds and returns a `CallExpression` that invokes the new function.
 90      fn reconstruct_async(&mut self, input: AsyncExpression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
 91          // Step 1: Generate a unique name for the async function
 92          let finalize_fn_name = self.state.assigner.unique_symbol(self.current_function, "$");
 93  
 94          // Step 2: Collect all symbol accesses in the async block
 95          let mut access_collector = SymbolAccessCollector { state: self.state, symbol_accesses: IndexSet::new() };
 96          access_collector.visit_async(&input, &());
 97  
 98          // Stores mapping from accessed symbol (and optional index) to the expression used in replacement
 99          let mut replacements: IndexMap<(Symbol, Option<usize>), Expression> = IndexMap::new();
100  
101          // Helper to create a fresh `Identifier`
102          let make_identifier = |slf: &mut Self, symbol: Symbol| Identifier {
103              name: symbol,
104              span: Span::default(),
105              id: slf.state.node_builder.next_id(),
106          };
107  
108          // Generates a set of `Input`s and corresponding call-site `Expression`s for a given symbol access.
109          //
110          // This function handles both:
111          // - Direct variable accesses (e.g., `foo`)
112          // - Tuple element accesses (e.g., `foo.0`)
113          //
114          // For tuple accesses:
115          // - If a single element (e.g. `foo.0`) is accessed, it generates a synthetic input like `"foo.0"`.
116          // - If the whole tuple (e.g. `foo`) is accessed, it ensures all elements are covered by:
117          //     - Reusing existing inputs from `replacements` if already generated via prior field access.
118          //     - Creating new inputs and arguments for any missing elements.
119          // - The entire tuple is reconstructed in `replacements` using the individual elements as a `TupleExpression`.
120          //
121          // This function also ensures deduplication by consulting the `replacements` map:
122          // - If a given `(symbol, index)` has already been processed, no duplicate input or argument is generated.
123          // - This prevents repeated parameters for accesses like both `foo` and `foo.0`.
124          //
125          // # Parameters
126          // - `symbol`: The symbol being accessed.
127          // - `var_type`: The type of the symbol (may be a tuple or base type).
128          // - `index_opt`: `Some(index)` for a tuple field (e.g., `.0`), or `None` for full-variable access.
129          //
130          // # Returns
131          // A `Vec<(Input, Expression)>`, where:
132          // - `Input` is a parameter for the generated async function.
133          // - `Expression` is the call-site argument expression used to invoke that parameter.
134          let mut make_inputs_and_arguments =
135              |slf: &mut Self, symbol: Symbol, var_type: &Type, index_opt: Option<usize>| -> Vec<(Input, Expression)> {
136                  if replacements.contains_key(&(symbol, index_opt)) {
137                      return vec![]; // No new input needed; argument already exists
138                  }
139  
140                  match index_opt {
141                      Some(index) => {
142                          let Type::Tuple(TupleType { elements }) = var_type else {
143                              panic!("Expected tuple type when accessing tuple field: {symbol}.{index}");
144                          };
145  
146                          let synthetic_name = format!("\"{symbol}.{index}\"");
147                          let synthetic_symbol = Symbol::intern(&synthetic_name);
148                          let identifier = make_identifier(slf, synthetic_symbol);
149  
150                          let input = Input {
151                              identifier,
152                              mode: adl_ast::Mode::None,
153                              type_: elements[index].clone(),
154                              span: Span::default(),
155                              id: slf.state.node_builder.next_id(),
156                          };
157  
158                          replacements.insert((symbol, Some(index)), Path::from(identifier).into_absolute().into());
159  
160                          vec![(
161                              input,
162                              TupleAccess {
163                                  tuple: Path::from(make_identifier(slf, symbol)).into_absolute().into(),
164                                  index: index.into(),
165                                  span: Span::default(),
166                                  id: slf.state.node_builder.next_id(),
167                              }
168                              .into(),
169                          )]
170                      }
171  
172                      None => match var_type {
173                          Type::Tuple(TupleType { elements }) => {
174                              let mut inputs_and_arguments = Vec::with_capacity(elements.len());
175                              let mut tuple_elements = Vec::with_capacity(elements.len());
176  
177                              for (i, element_type) in elements.iter().enumerate() {
178                                  let key = (symbol, Some(i));
179  
180                                  // Skip if this field is already handled
181                                  if let Some(existing_expr) = replacements.get(&key) {
182                                      tuple_elements.push(existing_expr.clone());
183                                      continue;
184                                  }
185  
186                                  // Otherwise, synthesize identifier and input
187                                  let synthetic_name = format!("\"{symbol}.{i}\"");
188                                  let synthetic_symbol = Symbol::intern(&synthetic_name);
189                                  let identifier = make_identifier(slf, synthetic_symbol);
190  
191                                  let input = Input {
192                                      identifier,
193                                      mode: adl_ast::Mode::None,
194                                      type_: element_type.clone(),
195                                      span: Span::default(),
196                                      id: slf.state.node_builder.next_id(),
197                                  };
198  
199                                  let expr: Expression = Path::from(identifier).into_absolute().into();
200  
201                                  replacements.insert(key, expr.clone());
202                                  tuple_elements.push(expr.clone());
203                                  inputs_and_arguments.push((
204                                      input,
205                                      TupleAccess {
206                                          tuple: Path::from(make_identifier(slf, symbol)).into_absolute().into(),
207                                          index: i.into(),
208                                          span: Span::default(),
209                                          id: slf.state.node_builder.next_id(),
210                                      }
211                                      .into(),
212                                  ));
213                              }
214  
215                              // Now insert the full tuple (even if all fields were already there)
216                              replacements.insert(
217                                  (symbol, None),
218                                  Expression::Tuple(TupleExpression {
219                                      elements: tuple_elements,
220                                      span: Span::default(),
221                                      id: slf.state.node_builder.next_id(),
222                                  }),
223                              );
224  
225                              inputs_and_arguments
226                          }
227  
228                          _ => {
229                              let identifier = make_identifier(slf, symbol);
230                              let input = Input {
231                                  identifier,
232                                  mode: adl_ast::Mode::None,
233                                  type_: var_type.clone(),
234                                  span: Span::default(),
235                                  id: slf.state.node_builder.next_id(),
236                              };
237  
238                              replacements.insert((symbol, None), Path::from(identifier).into_absolute().into());
239  
240                              let argument = Path::from(make_identifier(slf, symbol)).into_absolute().into();
241                              vec![(input, argument)]
242                          }
243                      },
244                  }
245              };
246  
247          // Step 3: Resolve symbol accesses into inputs and call arguments
248          let (inputs, arguments): (Vec<_>, Vec<_>) = access_collector
249              .symbol_accesses
250              .iter()
251              .filter_map(|(path, index)| {
252                  // Skip globals and variables that are local to this block or to one of its children.
253  
254                  // Skip globals.
255                  if self.state.symbol_table.lookup_global(&Location::new(self.current_program, path.to_vec())).is_some()
256                  {
257                      return None;
258                  }
259  
260                  // Skip variables that are local to this block or to one of its children.
261                  let local_var_name = *path.last().expect("all paths must have at least one segment.");
262                  if self.state.symbol_table.is_local_to_or_in_child_scope(input.block.id(), local_var_name) {
263                      return None;
264                  }
265  
266                  // All other variables become parameters to the async function being built.
267                  let var = self.state.symbol_table.lookup_local(local_var_name)?;
268                  Some(make_inputs_and_arguments(self, local_var_name, &var.type_, *index))
269              })
270              .flatten()
271              .unzip();
272  
273          // Step 4: Replacement logic used to patch the async block
274          let replace_expr = |expr: &Expression| -> Expression {
275              match expr {
276                  Expression::Path(path) => {
277                      replacements.get(&(path.identifier().name, None)).cloned().unwrap_or_else(|| expr.clone())
278                  }
279  
280                  Expression::TupleAccess(ta) => {
281                      if let Expression::Path(path) = &ta.tuple {
282                          replacements
283                              .get(&(path.identifier().name, Some(ta.index.value())))
284                              .cloned()
285                              .unwrap_or_else(|| expr.clone())
286                      } else {
287                          expr.clone()
288                      }
289                  }
290  
291                  _ => expr.clone(),
292              }
293          };
294  
295          // Step 5: Reconstruct the block with replaced references
296          let mut replacer = Replacer::new(replace_expr, true /* refresh IDs */, self.state);
297          let new_block = replacer.reconstruct_block(input.block.clone()).0;
298  
299          // Ensure we're not trying to capture too many variables
300          if inputs.len() > self.max_inputs {
301              self.state.handler.emit_err(adl_errors::StaticAnalyzerError::async_block_capturing_too_many_vars(
302                  inputs.len(),
303                  self.max_inputs,
304                  input.span,
305              ));
306          }
307  
308          // Step 6: Define the new async function
309          let function = Function {
310              annotations: vec![],
311              variant: Variant::AsyncFunction,
312              identifier: make_identifier(self, finalize_fn_name),
313              const_parameters: vec![],
314              input: inputs,
315              output: vec![],          // `async function`s can't have returns
316              output_type: Type::Unit, // Always the case for `async function`s
317              block: new_block,
318              span: input.span,
319              id: self.state.node_builder.next_id(),
320          };
321  
322          // Register the generated function
323          self.new_async_functions.push((finalize_fn_name, function));
324  
325          // Step 7: Create the call expression to invoke the async function
326          let call_to_finalize = CallExpression {
327              function: Path::new(
328                  vec![],
329                  make_identifier(self, finalize_fn_name),
330                  true,
331                  Some(vec![finalize_fn_name]), // the finalize function lives in the top level program scope
332                  Span::default(),
333                  self.state.node_builder.next_id(),
334              ),
335              const_arguments: vec![],
336              arguments,
337              program: Some(self.current_program),
338              span: input.span,
339              id: self.state.node_builder.next_id(),
340          };
341  
342          self.modified = true;
343  
344          (call_to_finalize.into(), ())
345      }
346  
347      fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
348          self.in_scope(input.id(), |slf| {
349              (
350                  Block {
351                      statements: input.statements.into_iter().map(|s| slf.reconstruct_statement(s).0).collect(),
352                      span: input.span,
353                      id: input.id,
354                  },
355                  Default::default(),
356              )
357          })
358      }
359  
360      fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
361          self.in_scope(input.id(), |slf| {
362              (
363                  IterationStatement {
364                      type_: input.type_.map(|ty| slf.reconstruct_type(ty).0),
365                      start: slf.reconstruct_expression(input.start, &()).0,
366                      stop: slf.reconstruct_expression(input.stop, &()).0,
367                      block: slf.reconstruct_block(input.block).0,
368                      ..input
369                  }
370                  .into(),
371                  Default::default(),
372              )
373          })
374      }
375  }