/ src / languages / python.rs
python.rs
  1  use crate::languages::LanguageSupport;
  2  use crate::types::EnvSourceKind;
  3  use std::sync::OnceLock;
  4  use tracing::error;
  5  use tree_sitter::{Language, Node, Query};
  6  
  7  pub struct Python;
  8  
  9  static REFERENCE_QUERY: OnceLock<Query> = OnceLock::new();
 10  static BINDING_QUERY: OnceLock<Query> = OnceLock::new();
 11  static IMPORT_QUERY: OnceLock<Query> = OnceLock::new();
 12  static COMPLETION_QUERY: OnceLock<Query> = OnceLock::new();
 13  static REASSIGNMENT_QUERY: OnceLock<Query> = OnceLock::new();
 14  static IDENTIFIER_QUERY: OnceLock<Query> = OnceLock::new();
 15  static EXPORT_QUERY: OnceLock<Query> = OnceLock::new();
 16  
 17  static ASSIGNMENT_QUERY: OnceLock<Query> = OnceLock::new();
 18  static DESTRUCTURE_QUERY: OnceLock<Query> = OnceLock::new();
 19  static SCOPE_QUERY: OnceLock<Query> = OnceLock::new();
 20  
 21  /// Compiles a tree-sitter query and fails fast on errors to surface invalid language query definitions early.
 22  fn compile_query(grammar: &Language, source: &str, query_name: &str) -> Query {
 23      match Query::new(grammar, source) {
 24          Ok(query) => query,
 25          Err(e) => {
 26              error!(
 27                  language = "python",
 28                  query = query_name,
 29                  error = %e,
 30                  "Failed to compile query, failing fast"
 31              );
 32              panic!("Failed to compile query '{}': {}", query_name, e)
 33          }
 34      }
 35  }
 36  
 37  impl LanguageSupport for Python {
 38      fn id(&self) -> &'static str {
 39          "python"
 40      }
 41  
 42      fn is_standard_env_object(&self, name: &str) -> bool {
 43          name == "os.environ" || name == "os"
 44      }
 45  
 46      fn default_env_object_name(&self) -> Option<&'static str> {
 47          Some("os.environ")
 48      }
 49  
 50      fn is_scope_node(&self, node: Node) -> bool {
 51          matches!(
 52              node.kind(),
 53              "module"
 54                  | "function_definition"
 55                  | "class_definition"
 56                  | "for_statement"
 57                  | "if_statement"
 58                  | "try_statement"
 59                  | "with_statement"
 60                  | "while_statement"
 61          )
 62      }
 63  
 64      fn extensions(&self) -> &'static [&'static str] {
 65          &["py"]
 66      }
 67  
 68      fn language_ids(&self) -> &'static [&'static str] {
 69          &["python"]
 70      }
 71  
 72      fn grammar(&self) -> Language {
 73          tree_sitter_python::LANGUAGE.into()
 74      }
 75  
 76      fn reference_query(&self) -> &Query {
 77          REFERENCE_QUERY.get_or_init(|| {
 78              compile_query(
 79                  &self.grammar(),
 80                  include_str!("../../queries/python/references.scm"),
 81                  "references",
 82              )
 83          })
 84      }
 85  
 86      fn binding_query(&self) -> Option<&Query> {
 87          Some(BINDING_QUERY.get_or_init(|| {
 88              compile_query(
 89                  &self.grammar(),
 90                  include_str!("../../queries/python/bindings.scm"),
 91                  "bindings",
 92              )
 93          }))
 94      }
 95  
 96      fn import_query(&self) -> Option<&Query> {
 97          Some(IMPORT_QUERY.get_or_init(|| {
 98              compile_query(
 99                  &self.grammar(),
100                  include_str!("../../queries/python/imports.scm"),
101                  "imports",
102              )
103          }))
104      }
105  
106      fn completion_query(&self) -> Option<&Query> {
107          Some(COMPLETION_QUERY.get_or_init(|| {
108              compile_query(
109                  &self.grammar(),
110                  include_str!("../../queries/python/completion.scm"),
111                  "completion",
112              )
113          }))
114      }
115  
116      fn reassignment_query(&self) -> Option<&Query> {
117          Some(REASSIGNMENT_QUERY.get_or_init(|| {
118              compile_query(
119                  &self.grammar(),
120                  include_str!("../../queries/python/reassignments.scm"),
121                  "reassignments",
122              )
123          }))
124      }
125  
126      fn identifier_query(&self) -> Option<&Query> {
127          Some(IDENTIFIER_QUERY.get_or_init(|| {
128              compile_query(
129                  &self.grammar(),
130                  include_str!("../../queries/python/identifiers.scm"),
131                  "identifiers",
132              )
133          }))
134      }
135  
136      fn export_query(&self) -> Option<&Query> {
137          Some(EXPORT_QUERY.get_or_init(|| {
138              compile_query(
139                  &self.grammar(),
140                  include_str!("../../queries/python/exports.scm"),
141                  "exports",
142              )
143          }))
144      }
145  
146      fn assignment_query(&self) -> Option<&Query> {
147          Some(ASSIGNMENT_QUERY.get_or_init(|| {
148              compile_query(
149                  &self.grammar(),
150                  include_str!("../../queries/python/assignments.scm"),
151                  "assignments",
152              )
153          }))
154      }
155  
156      fn destructure_query(&self) -> Option<&Query> {
157          Some(DESTRUCTURE_QUERY.get_or_init(|| {
158              compile_query(
159                  &self.grammar(),
160                  include_str!("../../queries/python/destructures.scm"),
161                  "destructures",
162              )
163          }))
164      }
165  
166      fn scope_query(&self) -> Option<&Query> {
167          Some(SCOPE_QUERY.get_or_init(|| {
168              compile_query(
169                  &self.grammar(),
170                  include_str!("../../queries/python/scopes.scm"),
171                  "scopes",
172              )
173          }))
174      }
175  
176      fn is_env_source_node(&self, node: Node, source: &[u8]) -> Option<EnvSourceKind> {
177          if node.kind() == "attribute" {
178              let object = node.child_by_field_name("object")?;
179              let attribute = node.child_by_field_name("attribute")?;
180  
181              let object_text = object.utf8_text(source).ok()?;
182              let attribute_text = attribute.utf8_text(source).ok()?;
183  
184              if object_text == "os" && attribute_text == "environ" {
185                  return Some(EnvSourceKind::Object {
186                      canonical_name: "os.environ".into(),
187                  });
188              }
189          }
190  
191          if node.kind() == "identifier" {
192              let text = node.utf8_text(source).ok()?;
193              if text == "environ" {
194                  return Some(EnvSourceKind::Object {
195                      canonical_name: "os.environ".into(),
196                  });
197              }
198          }
199  
200          None
201      }
202  
203      fn known_env_modules(&self) -> &'static [&'static str] {
204          &["os", "dotenv", "decouple"]
205      }
206  
207      fn completion_trigger_characters(&self) -> &'static [&'static str] {
208          &[".", "[\"", "['", "(\"", "('"]
209      }
210  
211      fn strip_quotes<'a>(&self, text: &'a str) -> &'a str {
212          text.trim_matches(|c| c == '"' || c == '\'')
213      }
214  
215      fn extract_property_access(
216          &self,
217          tree: &tree_sitter::Tree,
218          content: &str,
219          byte_offset: usize,
220      ) -> Option<(compact_str::CompactString, compact_str::CompactString)> {
221          let node = tree
222              .root_node()
223              .descendant_for_byte_range(byte_offset, byte_offset)?;
224  
225          let attr_node = if node.kind() == "attribute" {
226              node
227          } else if let Some(parent) = node.parent() {
228              if parent.kind() == "attribute" {
229                  parent
230              } else {
231                  return None;
232              }
233          } else {
234              return None;
235          };
236  
237          let object_node = attr_node.child_by_field_name("object")?;
238          let attribute_node = attr_node.child_by_field_name("attribute")?;
239  
240          if object_node.kind() != "identifier" {
241              return None;
242          }
243  
244          let object_name = object_node.utf8_text(content.as_bytes()).ok()?;
245          let property_name = attribute_node.utf8_text(content.as_bytes()).ok()?;
246  
247          Some((object_name.into(), property_name.into()))
248      }
249  }
250  
251  #[cfg(test)]
252  mod tests {
253      use super::*;
254  
255      fn get_python() -> Python {
256          Python
257      }
258  
259      #[test]
260      fn test_id() {
261          assert_eq!(get_python().id(), "python");
262      }
263  
264      #[test]
265      fn test_extensions() {
266          let exts = get_python().extensions();
267          assert!(exts.contains(&"py"));
268      }
269  
270      #[test]
271      fn test_language_ids() {
272          let ids = get_python().language_ids();
273          assert!(ids.contains(&"python"));
274      }
275  
276      #[test]
277      fn test_is_standard_env_object() {
278          let py = get_python();
279          assert!(py.is_standard_env_object("os.environ"));
280          assert!(py.is_standard_env_object("os")); // "os" is valid for function-call patterns like os.getenv()
281          assert!(!py.is_standard_env_object("process"));
282      }
283  
284      #[test]
285      fn test_default_env_object_name() {
286          assert_eq!(get_python().default_env_object_name(), Some("os.environ"));
287      }
288  
289      #[test]
290      fn test_known_env_modules() {
291          let modules = get_python().known_env_modules();
292          assert!(modules.contains(&"os"));
293      }
294  
295      #[test]
296      fn test_grammar_compiles() {
297          let py = get_python();
298          let _grammar = py.grammar();
299      }
300  
301      #[test]
302      fn test_reference_query_compiles() {
303          let py = get_python();
304          let _query = py.reference_query();
305      }
306  
307      #[test]
308      fn test_binding_query_compiles() {
309          let py = get_python();
310          assert!(py.binding_query().is_some());
311      }
312  
313      #[test]
314      fn test_import_query_compiles() {
315          let py = get_python();
316          assert!(py.import_query().is_some());
317      }
318  
319      #[test]
320      fn test_completion_query_compiles() {
321          let py = get_python();
322          assert!(py.completion_query().is_some());
323      }
324  
325      #[test]
326      fn test_reassignment_query_compiles() {
327          let py = get_python();
328          assert!(py.reassignment_query().is_some());
329      }
330  
331      #[test]
332      fn test_identifier_query_compiles() {
333          let py = get_python();
334          assert!(py.identifier_query().is_some());
335      }
336  
337      #[test]
338      fn test_export_query_compiles() {
339          let py = get_python();
340          assert!(py.export_query().is_some());
341      }
342  
343      #[test]
344      fn test_assignment_query_compiles() {
345          let py = get_python();
346          assert!(py.assignment_query().is_some());
347      }
348  
349      #[test]
350      fn test_scope_query_compiles() {
351          let py = get_python();
352          assert!(py.scope_query().is_some());
353      }
354  
355      #[test]
356      fn test_destructure_query_compiles() {
357          let py = get_python();
358          assert!(py.destructure_query().is_some());
359      }
360  
361      #[test]
362      fn test_strip_quotes() {
363          let py = get_python();
364          assert_eq!(py.strip_quotes("\"hello\""), "hello");
365          assert_eq!(py.strip_quotes("'world'"), "world");
366          assert_eq!(py.strip_quotes("noquotes"), "noquotes");
367      }
368  
369      #[test]
370      fn test_is_env_source_node_os_environ() {
371          let py = get_python();
372          let mut parser = tree_sitter::Parser::new();
373          parser.set_language(&py.grammar()).unwrap();
374  
375          let code = "import os\nx = os.environ";
376          let tree = parser.parse(code, None).unwrap();
377          let root = tree.root_node();
378  
379          fn walk_tree(cursor: &mut tree_sitter::TreeCursor, py: &Python, code: &str) -> bool {
380              loop {
381                  let node = cursor.node();
382                  if node.kind() == "attribute" {
383                      if let Some(kind) = py.is_env_source_node(node, code.as_bytes()) {
384                          if let EnvSourceKind::Object { canonical_name } = kind {
385                              if canonical_name == "os.environ" {
386                                  return true;
387                              }
388                          }
389                      }
390                  }
391  
392                  if cursor.goto_first_child() {
393                      if walk_tree(cursor, py, code) {
394                          return true;
395                      }
396                      cursor.goto_parent();
397                  }
398  
399                  if !cursor.goto_next_sibling() {
400                      break;
401                  }
402              }
403              false
404          }
405  
406          let mut cursor = root.walk();
407          let found = walk_tree(&mut cursor, &py, code);
408          assert!(found, "Should detect os.environ as env source");
409      }
410  
411      #[test]
412      fn test_extract_property_access() {
413          let py = get_python();
414          let mut parser = tree_sitter::Parser::new();
415          parser.set_language(&py.grammar()).unwrap();
416  
417          let code = "x = env.DATABASE_URL";
418          let tree = parser.parse(code, None).unwrap();
419  
420          let offset = code.find("DATABASE_URL").unwrap();
421          let result = py.extract_property_access(&tree, code, offset);
422          assert!(result.is_some());
423          let (obj, prop) = result.unwrap();
424          assert_eq!(obj.as_str(), "env");
425          assert_eq!(prop.as_str(), "DATABASE_URL");
426      }
427  
428      #[test]
429      fn test_is_scope_node() {
430          let py = get_python();
431          let mut parser = tree_sitter::Parser::new();
432          parser.set_language(&py.grammar()).unwrap();
433  
434          let code = "def test():\n    pass";
435          let tree = parser.parse(code, None).unwrap();
436          let root = tree.root_node();
437  
438          fn find_node_of_kind<'a>(
439              node: tree_sitter::Node<'a>,
440              kind: &str,
441          ) -> Option<tree_sitter::Node<'a>> {
442              if node.kind() == kind {
443                  return Some(node);
444              }
445              for i in 0..node.child_count() {
446                  if let Some(child) = node.child(i) {
447                      if let Some(found) = find_node_of_kind(child, kind) {
448                          return Some(found);
449                      }
450                  }
451              }
452              None
453          }
454  
455          if let Some(func) = find_node_of_kind(root, "function_definition") {
456              assert!(py.is_scope_node(func));
457          }
458      }
459  }