ast.rs
1 // Copyright (C) 2019-2025 ADnet Contributors 2 // This file is part of the ADL library. 3 4 // The ADL library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 9 // The ADL library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU General Public License for more details. 13 14 // You should have received a copy of the GNU General Public License 15 // along with the ADL library. If not, see <https://www.gnu.org/licenses/>. 16 17 use super::ProcessingAsyncVisitor; 18 use crate::{CompilerState, Replacer}; 19 use adl_ast::{ 20 AstReconstructor, 21 AstVisitor, 22 AsyncExpression, 23 Block, 24 CallExpression, 25 Expression, 26 Function, 27 Identifier, 28 Input, 29 IterationStatement, 30 Location, 31 Node, 32 Path, 33 ProgramVisitor, 34 Statement, 35 TupleAccess, 36 TupleExpression, 37 TupleType, 38 Type, 39 Variant, 40 }; 41 use adl_span::{Span, Symbol}; 42 use indexmap::{IndexMap, IndexSet}; 43 44 /// Collects all symbol accesses within an async block, 45 /// including both direct variable identifiers (`x`) and tuple field accesses (`x.0`, `x.1`, etc.). 46 /// Each access is recorded as a pair: (Symbol, Option<usize>). 47 /// - `None` means a direct variable access. 48 /// - `Some(index)` means a tuple field access. 49 struct SymbolAccessCollector<'a> { 50 state: &'a CompilerState, 51 symbol_accesses: IndexSet<(Vec<Symbol>, Option<usize>)>, 52 } 53 54 impl AstVisitor for SymbolAccessCollector<'_> { 55 type AdditionalInput = (); 56 type Output = (); 57 58 fn visit_path(&mut self, input: &Path, _: &Self::AdditionalInput) -> Self::Output { 59 self.symbol_accesses.insert((input.absolute_path(), None)); 60 } 61 62 fn visit_tuple_access(&mut self, input: &TupleAccess, _: &Self::AdditionalInput) -> Self::Output { 63 // Here we assume that we can't have nested tuples which is currently guaranteed by type 64 // checking. This may change in the future. 65 if let Expression::Path(path) = &input.tuple { 66 // Futures aren't accessed by field; treat the whole thing as a direct variable 67 if let Some(Type::Future(_)) = self.state.type_table.get(&input.tuple.id()) { 68 self.symbol_accesses.insert((path.absolute_path(), None)); 69 } else { 70 self.symbol_accesses.insert((path.absolute_path(), Some(input.index.value()))); 71 } 72 } else { 73 self.visit_expression(&input.tuple, &()); 74 } 75 } 76 } 77 78 impl ProgramVisitor for SymbolAccessCollector<'_> {} 79 80 impl AstReconstructor for ProcessingAsyncVisitor<'_> { 81 type AdditionalInput = (); 82 type AdditionalOutput = (); 83 84 /// Transforms an `AsyncExpression` into a standalone async `Function` and returns 85 /// a call to this function. This process: 86 /// - Collects all referenced symbol accesses in the async block. 87 /// - Filters out mappings and constructs typed input parameters. 88 /// - Reconstructs an async function with those inputs and the original block. 89 /// - Builds and returns a `CallExpression` that invokes the new function. 90 fn reconstruct_async(&mut self, input: AsyncExpression, _additional: &()) -> (Expression, Self::AdditionalOutput) { 91 // Step 1: Generate a unique name for the async function 92 let finalize_fn_name = self.state.assigner.unique_symbol(self.current_function, "$"); 93 94 // Step 2: Collect all symbol accesses in the async block 95 let mut access_collector = SymbolAccessCollector { state: self.state, symbol_accesses: IndexSet::new() }; 96 access_collector.visit_async(&input, &()); 97 98 // Stores mapping from accessed symbol (and optional index) to the expression used in replacement 99 let mut replacements: IndexMap<(Symbol, Option<usize>), Expression> = IndexMap::new(); 100 101 // Helper to create a fresh `Identifier` 102 let make_identifier = |slf: &mut Self, symbol: Symbol| Identifier { 103 name: symbol, 104 span: Span::default(), 105 id: slf.state.node_builder.next_id(), 106 }; 107 108 // Generates a set of `Input`s and corresponding call-site `Expression`s for a given symbol access. 109 // 110 // This function handles both: 111 // - Direct variable accesses (e.g., `foo`) 112 // - Tuple element accesses (e.g., `foo.0`) 113 // 114 // For tuple accesses: 115 // - If a single element (e.g. `foo.0`) is accessed, it generates a synthetic input like `"foo.0"`. 116 // - If the whole tuple (e.g. `foo`) is accessed, it ensures all elements are covered by: 117 // - Reusing existing inputs from `replacements` if already generated via prior field access. 118 // - Creating new inputs and arguments for any missing elements. 119 // - The entire tuple is reconstructed in `replacements` using the individual elements as a `TupleExpression`. 120 // 121 // This function also ensures deduplication by consulting the `replacements` map: 122 // - If a given `(symbol, index)` has already been processed, no duplicate input or argument is generated. 123 // - This prevents repeated parameters for accesses like both `foo` and `foo.0`. 124 // 125 // # Parameters 126 // - `symbol`: The symbol being accessed. 127 // - `var_type`: The type of the symbol (may be a tuple or base type). 128 // - `index_opt`: `Some(index)` for a tuple field (e.g., `.0`), or `None` for full-variable access. 129 // 130 // # Returns 131 // A `Vec<(Input, Expression)>`, where: 132 // - `Input` is a parameter for the generated async function. 133 // - `Expression` is the call-site argument expression used to invoke that parameter. 134 let mut make_inputs_and_arguments = 135 |slf: &mut Self, symbol: Symbol, var_type: &Type, index_opt: Option<usize>| -> Vec<(Input, Expression)> { 136 if replacements.contains_key(&(symbol, index_opt)) { 137 return vec![]; // No new input needed; argument already exists 138 } 139 140 match index_opt { 141 Some(index) => { 142 let Type::Tuple(TupleType { elements }) = var_type else { 143 panic!("Expected tuple type when accessing tuple field: {symbol}.{index}"); 144 }; 145 146 let synthetic_name = format!("\"{symbol}.{index}\""); 147 let synthetic_symbol = Symbol::intern(&synthetic_name); 148 let identifier = make_identifier(slf, synthetic_symbol); 149 150 let input = Input { 151 identifier, 152 mode: adl_ast::Mode::None, 153 type_: elements[index].clone(), 154 span: Span::default(), 155 id: slf.state.node_builder.next_id(), 156 }; 157 158 replacements.insert((symbol, Some(index)), Path::from(identifier).into_absolute().into()); 159 160 vec![( 161 input, 162 TupleAccess { 163 tuple: Path::from(make_identifier(slf, symbol)).into_absolute().into(), 164 index: index.into(), 165 span: Span::default(), 166 id: slf.state.node_builder.next_id(), 167 } 168 .into(), 169 )] 170 } 171 172 None => match var_type { 173 Type::Tuple(TupleType { elements }) => { 174 let mut inputs_and_arguments = Vec::with_capacity(elements.len()); 175 let mut tuple_elements = Vec::with_capacity(elements.len()); 176 177 for (i, element_type) in elements.iter().enumerate() { 178 let key = (symbol, Some(i)); 179 180 // Skip if this field is already handled 181 if let Some(existing_expr) = replacements.get(&key) { 182 tuple_elements.push(existing_expr.clone()); 183 continue; 184 } 185 186 // Otherwise, synthesize identifier and input 187 let synthetic_name = format!("\"{symbol}.{i}\""); 188 let synthetic_symbol = Symbol::intern(&synthetic_name); 189 let identifier = make_identifier(slf, synthetic_symbol); 190 191 let input = Input { 192 identifier, 193 mode: adl_ast::Mode::None, 194 type_: element_type.clone(), 195 span: Span::default(), 196 id: slf.state.node_builder.next_id(), 197 }; 198 199 let expr: Expression = Path::from(identifier).into_absolute().into(); 200 201 replacements.insert(key, expr.clone()); 202 tuple_elements.push(expr.clone()); 203 inputs_and_arguments.push(( 204 input, 205 TupleAccess { 206 tuple: Path::from(make_identifier(slf, symbol)).into_absolute().into(), 207 index: i.into(), 208 span: Span::default(), 209 id: slf.state.node_builder.next_id(), 210 } 211 .into(), 212 )); 213 } 214 215 // Now insert the full tuple (even if all fields were already there) 216 replacements.insert( 217 (symbol, None), 218 Expression::Tuple(TupleExpression { 219 elements: tuple_elements, 220 span: Span::default(), 221 id: slf.state.node_builder.next_id(), 222 }), 223 ); 224 225 inputs_and_arguments 226 } 227 228 _ => { 229 let identifier = make_identifier(slf, symbol); 230 let input = Input { 231 identifier, 232 mode: adl_ast::Mode::None, 233 type_: var_type.clone(), 234 span: Span::default(), 235 id: slf.state.node_builder.next_id(), 236 }; 237 238 replacements.insert((symbol, None), Path::from(identifier).into_absolute().into()); 239 240 let argument = Path::from(make_identifier(slf, symbol)).into_absolute().into(); 241 vec![(input, argument)] 242 } 243 }, 244 } 245 }; 246 247 // Step 3: Resolve symbol accesses into inputs and call arguments 248 let (inputs, arguments): (Vec<_>, Vec<_>) = access_collector 249 .symbol_accesses 250 .iter() 251 .filter_map(|(path, index)| { 252 // Skip globals and variables that are local to this block or to one of its children. 253 254 // Skip globals. 255 if self.state.symbol_table.lookup_global(&Location::new(self.current_program, path.to_vec())).is_some() 256 { 257 return None; 258 } 259 260 // Skip variables that are local to this block or to one of its children. 261 let local_var_name = *path.last().expect("all paths must have at least one segment."); 262 if self.state.symbol_table.is_local_to_or_in_child_scope(input.block.id(), local_var_name) { 263 return None; 264 } 265 266 // All other variables become parameters to the async function being built. 267 let var = self.state.symbol_table.lookup_local(local_var_name)?; 268 Some(make_inputs_and_arguments(self, local_var_name, &var.type_, *index)) 269 }) 270 .flatten() 271 .unzip(); 272 273 // Step 4: Replacement logic used to patch the async block 274 let replace_expr = |expr: &Expression| -> Expression { 275 match expr { 276 Expression::Path(path) => { 277 replacements.get(&(path.identifier().name, None)).cloned().unwrap_or_else(|| expr.clone()) 278 } 279 280 Expression::TupleAccess(ta) => { 281 if let Expression::Path(path) = &ta.tuple { 282 replacements 283 .get(&(path.identifier().name, Some(ta.index.value()))) 284 .cloned() 285 .unwrap_or_else(|| expr.clone()) 286 } else { 287 expr.clone() 288 } 289 } 290 291 _ => expr.clone(), 292 } 293 }; 294 295 // Step 5: Reconstruct the block with replaced references 296 let mut replacer = Replacer::new(replace_expr, true /* refresh IDs */, self.state); 297 let new_block = replacer.reconstruct_block(input.block.clone()).0; 298 299 // Ensure we're not trying to capture too many variables 300 if inputs.len() > self.max_inputs { 301 self.state.handler.emit_err(adl_errors::StaticAnalyzerError::async_block_capturing_too_many_vars( 302 inputs.len(), 303 self.max_inputs, 304 input.span, 305 )); 306 } 307 308 // Step 6: Define the new async function 309 let function = Function { 310 annotations: vec![], 311 variant: Variant::AsyncFunction, 312 identifier: make_identifier(self, finalize_fn_name), 313 const_parameters: vec![], 314 input: inputs, 315 output: vec![], // `async function`s can't have returns 316 output_type: Type::Unit, // Always the case for `async function`s 317 block: new_block, 318 span: input.span, 319 id: self.state.node_builder.next_id(), 320 }; 321 322 // Register the generated function 323 self.new_async_functions.push((finalize_fn_name, function)); 324 325 // Step 7: Create the call expression to invoke the async function 326 let call_to_finalize = CallExpression { 327 function: Path::new( 328 vec![], 329 make_identifier(self, finalize_fn_name), 330 true, 331 Some(vec![finalize_fn_name]), // the finalize function lives in the top level program scope 332 Span::default(), 333 self.state.node_builder.next_id(), 334 ), 335 const_arguments: vec![], 336 arguments, 337 program: Some(self.current_program), 338 span: input.span, 339 id: self.state.node_builder.next_id(), 340 }; 341 342 self.modified = true; 343 344 (call_to_finalize.into(), ()) 345 } 346 347 fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) { 348 self.in_scope(input.id(), |slf| { 349 ( 350 Block { 351 statements: input.statements.into_iter().map(|s| slf.reconstruct_statement(s).0).collect(), 352 span: input.span, 353 id: input.id, 354 }, 355 Default::default(), 356 ) 357 }) 358 } 359 360 fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) { 361 self.in_scope(input.id(), |slf| { 362 ( 363 IterationStatement { 364 type_: input.type_.map(|ty| slf.reconstruct_type(ty).0), 365 start: slf.reconstruct_expression(input.start, &()).0, 366 stop: slf.reconstruct_expression(input.stop, &()).0, 367 block: slf.reconstruct_block(input.block).0, 368 ..input 369 } 370 .into(), 371 Default::default(), 372 ) 373 }) 374 } 375 }