/ compiler / ast / src / common / graph / mod.rs
mod.rs
  1  // Copyright (C) 2019-2025 ADnet Contributors
  2  // This file is part of the ADL library.
  3  
  4  // The ADL library is free software: you can redistribute it and/or modify
  5  // it under the terms of the GNU General Public License as published by
  6  // the Free Software Foundation, either version 3 of the License, or
  7  // (at your option) any later version.
  8  
  9  // The ADL library is distributed in the hope that it will be useful,
 10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
 11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 12  // GNU General Public License for more details.
 13  
 14  // You should have received a copy of the GNU General Public License
 15  // along with the ADL library. If not, see <https://www.gnu.org/licenses/>.
 16  
 17  use crate::Location;
 18  use adl_span::Symbol;
 19  
 20  use indexmap::{IndexMap, IndexSet};
 21  use std::{fmt::Debug, hash::Hash, rc::Rc};
 22  
 23  /// A composite dependency graph.
 24  /// The `Vec<Symbol>` is to the absolute path to each composite
 25  pub type CompositeGraph = DiGraph<Vec<Symbol>>;
 26  
 27  /// A call graph.
 28  pub type CallGraph = DiGraph<Location>;
 29  
 30  /// An import dependency graph.
 31  pub type ImportGraph = DiGraph<Symbol>;
 32  
 33  /// A node in a graph.
 34  pub trait GraphNode: Clone + 'static + Eq + PartialEq + Debug + Hash {}
 35  
 36  impl<T> GraphNode for T where T: 'static + Clone + Eq + PartialEq + Debug + Hash {}
 37  
 38  /// Errors in directed graph operations.
 39  #[derive(Debug)]
 40  pub enum DiGraphError<N: GraphNode> {
 41      /// An error that is emitted when a cycle is detected in the directed graph. Contains the path of the cycle.
 42      CycleDetected(Vec<N>),
 43  }
 44  
 45  /// A directed graph using reference-counted nodes.
 46  #[derive(Clone, Debug, PartialEq, Eq)]
 47  pub struct DiGraph<N: GraphNode> {
 48      /// The set of nodes in the graph.
 49      nodes: IndexSet<Rc<N>>,
 50  
 51      /// The directed edges in the graph.
 52      /// Each entry in the map is a node in the graph, and the set of nodes that it points to.
 53      edges: IndexMap<Rc<N>, IndexSet<Rc<N>>>,
 54  }
 55  
 56  impl<N: GraphNode> Default for DiGraph<N> {
 57      fn default() -> Self {
 58          Self { nodes: IndexSet::new(), edges: IndexMap::new() }
 59      }
 60  }
 61  
 62  impl<N: GraphNode> DiGraph<N> {
 63      /// Initializes a new `DiGraph` from a set of source nodes.
 64      pub fn new(nodes: IndexSet<N>) -> Self {
 65          let nodes: IndexSet<_> = nodes.into_iter().map(Rc::new).collect();
 66          Self { nodes, edges: IndexMap::new() }
 67      }
 68  
 69      /// Adds a node to the graph.
 70      pub fn add_node(&mut self, node: N) {
 71          self.nodes.insert(Rc::new(node));
 72      }
 73  
 74      /// Returns an iterator over the nodes in the graph.
 75      pub fn nodes(&self) -> impl Iterator<Item = &N> {
 76          self.nodes.iter().map(|rc| rc.as_ref())
 77      }
 78  
 79      /// Adds an edge to the graph.
 80      pub fn add_edge(&mut self, from: N, to: N) {
 81          // Add `from` and `to` to the set of nodes if they are not already in the set.
 82          let from_rc = self.get_or_insert(from);
 83          let to_rc = self.get_or_insert(to);
 84  
 85          // Add the edge to the adjacency list.
 86          self.edges.entry(from_rc).or_default().insert(to_rc);
 87      }
 88  
 89      /// Removes a node and all associated edges from the graph.
 90      pub fn remove_node(&mut self, node: &N) -> bool {
 91          if let Some(rc_node) = self.nodes.shift_take(&Rc::new(node.clone())) {
 92              // Remove all outgoing edges from the node
 93              self.edges.shift_remove(&rc_node);
 94  
 95              // Remove all incoming edges to the node
 96              for targets in self.edges.values_mut() {
 97                  targets.shift_remove(&rc_node);
 98              }
 99              true
100          } else {
101              false
102          }
103      }
104  
105      /// Returns an iterator to the immediate neighbors of a given node.
106      pub fn neighbors(&self, node: &N) -> impl Iterator<Item = &N> {
107          self.edges
108              .get(node) // ← no Rc::from() needed!
109              .into_iter()
110              .flat_map(|neighbors| neighbors.iter().map(|rc| rc.as_ref()))
111      }
112  
113      /// Returns `true` if the graph contains the given node.
114      pub fn contains_node(&self, node: N) -> bool {
115          self.nodes.contains(&Rc::new(node))
116      }
117  
118      /// Returns the post-order ordering of the graph.
119      /// Detects if there is a cycle in the graph.
120      pub fn post_order(&self) -> Result<IndexSet<N>, DiGraphError<N>> {
121          self.post_order_with_filter(|_| true)
122      }
123  
124      /// Returns the post-order ordering of the graph but only considering a subset of the nodes that
125      /// satisfy the given filter.
126      ///
127      /// Detects if there is a cycle in the graph.
128      pub fn post_order_with_filter<F>(&self, filter: F) -> Result<IndexSet<N>, DiGraphError<N>>
129      where
130          F: Fn(&N) -> bool,
131      {
132          // The set of nodes that do not need to be visited again.
133          let mut finished = IndexSet::with_capacity(self.nodes.len());
134  
135          // Perform a depth-first search of the graph, starting from `node`, for each node in the graph that satisfies
136          // `is_entry_point`.
137          for node_rc in self.nodes.iter().filter(|n| filter(n.as_ref())) {
138              // If the node has not been explored, explore it.
139              if !finished.contains(node_rc) {
140                  // The set of nodes that are on the path to the current node in the searc
141                  let mut discovered = IndexSet::new();
142                  // Check if there is a cycle in the graph starting from `node`.
143                  if let Some(cycle_node) = self.contains_cycle_from(node_rc, &mut discovered, &mut finished) {
144                      let mut path = vec![cycle_node.as_ref().clone()];
145                      // Backtrack through the discovered nodes to find the cycle.
146                      while let Some(next) = discovered.pop() {
147                          // Add the node to the path.
148                          path.push(next.as_ref().clone());
149                          // If the node is the same as the first node in the path, we have found the cycle.
150                          if Rc::ptr_eq(&next, &cycle_node) {
151                              break;
152                          }
153                      }
154                      // Reverse the path to get the cycle in the correct order.
155                      path.reverse();
156                      // A cycle was detected. Return the path of the cycle.
157                      return Err(DiGraphError::CycleDetected(path));
158                  }
159              }
160          }
161  
162          // No cycle was found. Return the set of nodes in topological order.
163          Ok(finished.iter().map(|rc| (**rc).clone()).collect())
164      }
165  
166      /// Retains a subset of the nodes, and removes all edges in which the source or destination is not in the subset.
167      pub fn retain_nodes(&mut self, keep: &IndexSet<N>) {
168          let keep: IndexSet<_> = keep.iter().map(|n| Rc::new(n.clone())).collect();
169          // Remove the nodes from the set of nodes.
170          self.nodes.retain(|n| keep.contains(n));
171          self.edges.retain(|n, _| keep.contains(n));
172          // Remove the edges that reference the nodes.
173          for targets in self.edges.values_mut() {
174              targets.retain(|t| keep.contains(t));
175          }
176      }
177  
178      // Detects if there is a cycle in the graph starting from the given node, via a recursive depth-first search.
179      // If there is no cycle, returns `None`.
180      // If there is a cycle, returns the node that was most recently discovered.
181      // Nodes are added to `finished` in post-order order.
182      fn contains_cycle_from(
183          &self,
184          node: &Rc<N>,
185          discovered: &mut IndexSet<Rc<N>>,
186          finished: &mut IndexSet<Rc<N>>,
187      ) -> Option<Rc<N>> {
188          // Add the node to the set of discovered nodes.
189          discovered.insert(node.clone());
190  
191          // Check each outgoing edge of the node.
192          if let Some(children) = self.edges.get(node) {
193              for child in children {
194                  // If the node already been discovered, there is a cycle.
195                  if discovered.contains(child) {
196                      // Insert the child node into the set of discovered nodes; this is used to reconstruct the cycle.
197                      // Note that this case is always hit when there is a cycle.
198                      return Some(child.clone());
199                  }
200                  // If the node has not been explored, explore it.
201                  if !finished.contains(child)
202                      && let Some(cycle_node) = self.contains_cycle_from(child, discovered, finished)
203                  {
204                      return Some(cycle_node);
205                  }
206              }
207          }
208  
209          // Remove the node from the set of discovered nodes.
210          discovered.pop();
211          // Add the node to the set of finished nodes.
212          finished.insert(node.clone());
213          None
214      }
215  
216      /// Helper: get or insert Rc<N> into the graph.
217      fn get_or_insert(&mut self, node: N) -> Rc<N> {
218          if let Some(existing) = self.nodes.get(&node) {
219              return existing.clone();
220          }
221          let rc = Rc::new(node);
222          self.nodes.insert(rc.clone());
223          rc
224      }
225  }
226  
227  #[cfg(test)]
228  mod test {
229      use super::*;
230  
231      fn check_post_order<N: GraphNode>(graph: &DiGraph<N>, expected: &[N]) {
232          let result = graph.post_order();
233          assert!(result.is_ok());
234  
235          let order: Vec<N> = result.unwrap().into_iter().collect();
236          assert_eq!(order, expected);
237      }
238  
239      #[test]
240      fn test_post_order() {
241          let mut graph = DiGraph::<u32>::new(IndexSet::new());
242  
243          graph.add_edge(1, 2);
244          graph.add_edge(1, 3);
245          graph.add_edge(2, 4);
246          graph.add_edge(3, 4);
247          graph.add_edge(4, 5);
248  
249          check_post_order(&graph, &[5, 4, 2, 3, 1]);
250  
251          let mut graph = DiGraph::<u32>::new(IndexSet::new());
252  
253          // F -> B
254          graph.add_edge(6, 2);
255          // B -> A
256          graph.add_edge(2, 1);
257          // B -> D
258          graph.add_edge(2, 4);
259          // D -> C
260          graph.add_edge(4, 3);
261          // D -> E
262          graph.add_edge(4, 5);
263          // F -> G
264          graph.add_edge(6, 7);
265          // G -> I
266          graph.add_edge(7, 9);
267          // I -> H
268          graph.add_edge(9, 8);
269  
270          // A, C, E, D, B, H, I, G, F.
271          check_post_order(&graph, &[1, 3, 5, 4, 2, 8, 9, 7, 6]);
272      }
273  
274      #[test]
275      fn test_cycle() {
276          let mut graph = DiGraph::<u32>::new(IndexSet::new());
277  
278          graph.add_edge(1, 2);
279          graph.add_edge(2, 3);
280          graph.add_edge(2, 4);
281          graph.add_edge(4, 1);
282  
283          let result = graph.post_order();
284          assert!(result.is_err());
285  
286          let DiGraphError::CycleDetected(cycle) = result.unwrap_err();
287          let expected = Vec::from([1u32, 2, 4, 1]);
288          assert_eq!(cycle, expected);
289      }
290  
291      #[test]
292      fn test_unconnected_graph() {
293          let graph = DiGraph::<u32>::new(IndexSet::from([1, 2, 3, 4, 5]));
294  
295          check_post_order(&graph, &[1, 2, 3, 4, 5]);
296      }
297  
298      #[test]
299      fn test_retain_nodes() {
300          let mut graph = DiGraph::<u32>::new(IndexSet::new());
301  
302          graph.add_edge(1, 2);
303          graph.add_edge(1, 3);
304          graph.add_edge(1, 5);
305          graph.add_edge(2, 3);
306          graph.add_edge(2, 4);
307          graph.add_edge(2, 5);
308          graph.add_edge(3, 4);
309          graph.add_edge(4, 5);
310  
311          let mut nodes = IndexSet::new();
312          nodes.insert(1);
313          nodes.insert(2);
314          nodes.insert(3);
315  
316          graph.retain_nodes(&nodes);
317  
318          let mut expected = DiGraph::<u32>::new(IndexSet::new());
319          expected.add_edge(1, 2);
320          expected.add_edge(1, 3);
321          expected.add_edge(2, 3);
322          expected.edges.insert(3.into(), IndexSet::new());
323  
324          assert_eq!(graph, expected);
325      }
326  }