lua.rs
1 use crate::languages::LanguageSupport; 2 use crate::types::EnvSourceKind; 3 use std::sync::OnceLock; 4 use tracing::error; 5 use tree_sitter::{Language, Node, Query}; 6 7 pub struct Lua; 8 9 static REFERENCE_QUERY: OnceLock<Query> = OnceLock::new(); 10 static BINDING_QUERY: OnceLock<Query> = OnceLock::new(); 11 static IMPORT_QUERY: OnceLock<Query> = OnceLock::new(); 12 static COMPLETION_QUERY: OnceLock<Query> = OnceLock::new(); 13 static REASSIGNMENT_QUERY: OnceLock<Query> = OnceLock::new(); 14 static IDENTIFIER_QUERY: OnceLock<Query> = OnceLock::new(); 15 static EXPORT_QUERY: OnceLock<Query> = OnceLock::new(); 16 17 static ASSIGNMENT_QUERY: OnceLock<Query> = OnceLock::new(); 18 static DESTRUCTURE_QUERY: OnceLock<Query> = OnceLock::new(); 19 static SCOPE_QUERY: OnceLock<Query> = OnceLock::new(); 20 21 /// Compiles a tree-sitter query and fails fast on errors to surface invalid language query definitions early. 22 fn compile_query(grammar: &Language, source: &str, query_name: &str) -> Query { 23 match Query::new(grammar, source) { 24 Ok(query) => query, 25 Err(e) => { 26 error!( 27 language = "lua", 28 query = query_name, 29 error = %e, 30 "Failed to compile query, failing fast" 31 ); 32 panic!("Failed to compile query '{}': {}", query_name, e) 33 } 34 } 35 } 36 37 impl LanguageSupport for Lua { 38 fn id(&self) -> &'static str { 39 "lua" 40 } 41 42 fn is_standard_env_object(&self, name: &str) -> bool { 43 name == "os" || name == "os.getenv" 44 } 45 46 fn default_env_object_name(&self) -> Option<&'static str> { 47 Some("os.getenv") 48 } 49 50 fn extensions(&self) -> &'static [&'static str] { 51 &["lua"] 52 } 53 54 fn language_ids(&self) -> &'static [&'static str] { 55 &["lua"] 56 } 57 58 fn grammar(&self) -> Language { 59 tree_sitter_lua::LANGUAGE.into() 60 } 61 62 fn reference_query(&self) -> &Query { 63 REFERENCE_QUERY.get_or_init(|| { 64 compile_query( 65 &self.grammar(), 66 include_str!("../../queries/lua/references.scm"), 67 "references", 68 ) 69 }) 70 } 71 72 fn binding_query(&self) -> Option<&Query> { 73 Some(BINDING_QUERY.get_or_init(|| { 74 compile_query( 75 &self.grammar(), 76 include_str!("../../queries/lua/bindings.scm"), 77 "bindings", 78 ) 79 })) 80 } 81 82 fn import_query(&self) -> Option<&Query> { 83 Some(IMPORT_QUERY.get_or_init(|| { 84 compile_query( 85 &self.grammar(), 86 include_str!("../../queries/lua/imports.scm"), 87 "imports", 88 ) 89 })) 90 } 91 92 fn completion_query(&self) -> Option<&Query> { 93 Some(COMPLETION_QUERY.get_or_init(|| { 94 compile_query( 95 &self.grammar(), 96 include_str!("../../queries/lua/completion.scm"), 97 "completion", 98 ) 99 })) 100 } 101 102 fn reassignment_query(&self) -> Option<&Query> { 103 Some(REASSIGNMENT_QUERY.get_or_init(|| { 104 compile_query( 105 &self.grammar(), 106 include_str!("../../queries/lua/reassignments.scm"), 107 "reassignments", 108 ) 109 })) 110 } 111 112 fn identifier_query(&self) -> Option<&Query> { 113 Some(IDENTIFIER_QUERY.get_or_init(|| { 114 compile_query( 115 &self.grammar(), 116 include_str!("../../queries/lua/identifiers.scm"), 117 "identifiers", 118 ) 119 })) 120 } 121 122 fn export_query(&self) -> Option<&Query> { 123 Some(EXPORT_QUERY.get_or_init(|| { 124 compile_query( 125 &self.grammar(), 126 include_str!("../../queries/lua/exports.scm"), 127 "exports", 128 ) 129 })) 130 } 131 132 fn assignment_query(&self) -> Option<&Query> { 133 Some(ASSIGNMENT_QUERY.get_or_init(|| { 134 compile_query( 135 &self.grammar(), 136 include_str!("../../queries/lua/assignments.scm"), 137 "assignments", 138 ) 139 })) 140 } 141 142 fn destructure_query(&self) -> Option<&Query> { 143 Some(DESTRUCTURE_QUERY.get_or_init(|| { 144 compile_query( 145 &self.grammar(), 146 include_str!("../../queries/lua/destructures.scm"), 147 "destructures", 148 ) 149 })) 150 } 151 152 fn scope_query(&self) -> Option<&Query> { 153 Some(SCOPE_QUERY.get_or_init(|| { 154 compile_query( 155 &self.grammar(), 156 include_str!("../../queries/lua/scopes.scm"), 157 "scopes", 158 ) 159 })) 160 } 161 162 fn is_env_source_node(&self, node: Node, source: &[u8]) -> Option<EnvSourceKind> { 163 // Detect os.getenv pattern 164 // In Lua, this is typically a function call: os.getenv("VAR") 165 // We want to detect when we're looking at the "os" identifier that's part of os.getenv 166 if node.kind() == "identifier" { 167 let text = node.utf8_text(source).ok()?; 168 if text == "os" { 169 return Some(EnvSourceKind::Object { 170 canonical_name: "os".into(), 171 }); 172 } 173 } 174 175 None 176 } 177 178 fn known_env_modules(&self) -> &'static [&'static str] { 179 &["os"] 180 } 181 182 fn completion_trigger_characters(&self) -> &'static [&'static str] { 183 // Support both parenthesized and parenthesis-less function calls: 184 // os.getenv(" os.getenv(' os.getenv " os.getenv ' 185 &["(\"", "('", " \"", " '"] 186 } 187 188 fn is_scope_node(&self, node: Node) -> bool { 189 matches!( 190 node.kind(), 191 "function_declaration" 192 | "function_definition" 193 | "do_statement" 194 | "while_statement" 195 | "repeat_statement" 196 | "for_statement" 197 | "if_statement" 198 ) 199 } 200 201 fn strip_quotes<'a>(&self, text: &'a str) -> &'a str { 202 text.trim_matches(|c| c == '"' || c == '\'') 203 } 204 205 fn extract_var_name(&self, node: Node, source: &[u8]) -> Option<compact_str::CompactString> { 206 node.utf8_text(source) 207 .ok() 208 .map(|s| compact_str::CompactString::from(self.strip_quotes(s))) 209 } 210 211 fn extract_property_access( 212 &self, 213 tree: &tree_sitter::Tree, 214 content: &str, 215 byte_offset: usize, 216 ) -> Option<(compact_str::CompactString, compact_str::CompactString)> { 217 let node = tree 218 .root_node() 219 .descendant_for_byte_range(byte_offset, byte_offset)?; 220 221 // In Lua, property access is through dot_index_expression 222 let dot_index = if node.kind() == "dot_index_expression" { 223 node 224 } else if let Some(parent) = node.parent() { 225 if parent.kind() == "dot_index_expression" { 226 parent 227 } else { 228 return None; 229 } 230 } else { 231 return None; 232 }; 233 234 let table_node = dot_index.child_by_field_name("table")?; 235 let field_node = dot_index.child_by_field_name("field")?; 236 237 if table_node.kind() != "identifier" { 238 return None; 239 } 240 241 let table_name = table_node.utf8_text(content.as_bytes()).ok()?; 242 let field_name = field_node.utf8_text(content.as_bytes()).ok()?; 243 244 Some((table_name.into(), field_name.into())) 245 } 246 } 247 248 #[cfg(test)] 249 mod tests { 250 use super::*; 251 252 fn get_lua() -> Lua { 253 Lua 254 } 255 256 #[test] 257 fn test_id() { 258 assert_eq!(get_lua().id(), "lua"); 259 } 260 261 #[test] 262 fn test_extensions() { 263 let exts = get_lua().extensions(); 264 assert!(exts.contains(&"lua")); 265 } 266 267 #[test] 268 fn test_language_ids() { 269 let ids = get_lua().language_ids(); 270 assert!(ids.contains(&"lua")); 271 } 272 273 #[test] 274 fn test_is_standard_env_object() { 275 let lua = get_lua(); 276 assert!(lua.is_standard_env_object("os")); 277 assert!(lua.is_standard_env_object("os.getenv")); 278 assert!(!lua.is_standard_env_object("process")); 279 } 280 281 #[test] 282 fn test_default_env_object_name() { 283 assert_eq!(get_lua().default_env_object_name(), Some("os.getenv")); 284 } 285 286 #[test] 287 fn test_known_env_modules() { 288 let modules = get_lua().known_env_modules(); 289 assert!(modules.contains(&"os")); 290 } 291 292 #[test] 293 fn test_grammar_compiles() { 294 let lua = get_lua(); 295 let _grammar = lua.grammar(); 296 } 297 298 #[test] 299 fn test_reference_query_compiles() { 300 let lua = get_lua(); 301 let _query = lua.reference_query(); 302 } 303 304 #[test] 305 fn test_binding_query_compiles() { 306 let lua = get_lua(); 307 assert!(lua.binding_query().is_some()); 308 } 309 310 #[test] 311 fn test_import_query_compiles() { 312 let lua = get_lua(); 313 assert!(lua.import_query().is_some()); 314 } 315 316 #[test] 317 fn test_completion_query_compiles() { 318 let lua = get_lua(); 319 assert!(lua.completion_query().is_some()); 320 } 321 322 #[test] 323 fn test_reassignment_query_compiles() { 324 let lua = get_lua(); 325 assert!(lua.reassignment_query().is_some()); 326 } 327 328 #[test] 329 fn test_identifier_query_compiles() { 330 let lua = get_lua(); 331 assert!(lua.identifier_query().is_some()); 332 } 333 334 #[test] 335 fn test_export_query_compiles() { 336 let lua = get_lua(); 337 assert!(lua.export_query().is_some()); 338 } 339 340 #[test] 341 fn test_assignment_query_compiles() { 342 let lua = get_lua(); 343 assert!(lua.assignment_query().is_some()); 344 } 345 346 #[test] 347 fn test_scope_query_compiles() { 348 let lua = get_lua(); 349 assert!(lua.scope_query().is_some()); 350 } 351 352 #[test] 353 fn test_destructure_query_compiles() { 354 let lua = get_lua(); 355 assert!(lua.destructure_query().is_some()); 356 } 357 358 #[test] 359 fn test_strip_quotes() { 360 let lua = get_lua(); 361 assert_eq!(lua.strip_quotes("\"hello\""), "hello"); 362 assert_eq!(lua.strip_quotes("'world'"), "world"); 363 assert_eq!(lua.strip_quotes("noquotes"), "noquotes"); 364 } 365 366 #[test] 367 fn test_is_env_source_node_os() { 368 let lua = get_lua(); 369 let mut parser = tree_sitter::Parser::new(); 370 parser.set_language(&lua.grammar()).unwrap(); 371 372 let code = "local x = os.getenv(\"VAR\")"; 373 let tree = parser.parse(code, None).unwrap(); 374 let root = tree.root_node(); 375 376 fn walk_tree(cursor: &mut tree_sitter::TreeCursor, lua: &Lua, code: &str) -> bool { 377 loop { 378 let node = cursor.node(); 379 if node.kind() == "identifier" { 380 if let Some(kind) = lua.is_env_source_node(node, code.as_bytes()) { 381 if let EnvSourceKind::Object { canonical_name } = kind { 382 if canonical_name == "os" { 383 return true; 384 } 385 } 386 } 387 } 388 389 if cursor.goto_first_child() { 390 if walk_tree(cursor, lua, code) { 391 return true; 392 } 393 cursor.goto_parent(); 394 } 395 396 if !cursor.goto_next_sibling() { 397 break; 398 } 399 } 400 false 401 } 402 403 let mut cursor = root.walk(); 404 let found = walk_tree(&mut cursor, &lua, code); 405 assert!(found, "Should detect os as env source"); 406 } 407 408 #[test] 409 fn test_extract_property_access() { 410 let lua = get_lua(); 411 let mut parser = tree_sitter::Parser::new(); 412 parser.set_language(&lua.grammar()).unwrap(); 413 414 let code = "local x = env.DATABASE_URL"; 415 let tree = parser.parse(code, None).unwrap(); 416 417 let offset = code.find("DATABASE_URL").unwrap(); 418 let result = lua.extract_property_access(&tree, code, offset); 419 assert!(result.is_some()); 420 let (table, field) = result.unwrap(); 421 assert_eq!(table.as_str(), "env"); 422 assert_eq!(field.as_str(), "DATABASE_URL"); 423 } 424 425 #[test] 426 fn test_is_scope_node() { 427 let lua = get_lua(); 428 let mut parser = tree_sitter::Parser::new(); 429 parser.set_language(&lua.grammar()).unwrap(); 430 431 let code = "function test()\nend"; 432 let tree = parser.parse(code, None).unwrap(); 433 let root = tree.root_node(); 434 435 fn find_node_of_kind<'a>( 436 node: tree_sitter::Node<'a>, 437 kind: &str, 438 ) -> Option<tree_sitter::Node<'a>> { 439 if node.kind() == kind { 440 return Some(node); 441 } 442 for i in 0..node.child_count() { 443 if let Some(child) = node.child(i) { 444 if let Some(found) = find_node_of_kind(child, kind) { 445 return Some(found); 446 } 447 } 448 } 449 None 450 } 451 452 if let Some(func) = find_node_of_kind(root, "function_declaration") { 453 assert!(lua.is_scope_node(func)); 454 } 455 } 456 457 #[test] 458 fn test_extract_var_name() { 459 let lua = get_lua(); 460 let mut parser = tree_sitter::Parser::new(); 461 parser.set_language(&lua.grammar()).unwrap(); 462 463 let code = "local VAR = \"value\""; 464 let tree = parser.parse(code, None).unwrap(); 465 let root = tree.root_node(); 466 467 fn find_node_of_kind<'a>( 468 node: tree_sitter::Node<'a>, 469 kind: &str, 470 ) -> Option<tree_sitter::Node<'a>> { 471 if node.kind() == kind { 472 return Some(node); 473 } 474 for i in 0..node.child_count() { 475 if let Some(child) = node.child(i) { 476 if let Some(found) = find_node_of_kind(child, kind) { 477 return Some(found); 478 } 479 } 480 } 481 None 482 } 483 484 if let Some(str_lit) = find_node_of_kind(root, "string") { 485 let name = lua.extract_var_name(str_lit, code.as_bytes()); 486 assert!(name.is_some()); 487 assert_eq!(name.unwrap().as_str(), "value"); 488 } 489 } 490 }