diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs index 055c70fbe..618b2d786 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_ref_tags.rs @@ -266,8 +266,16 @@ pub fn analyze_return(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturn) -> Optio pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast) -> Option<()> { if let Some(LuaSemanticDeclId::Signature(signature_id)) = get_owner_id(analyzer, None, false) { - let name_token = tag.get_name_token()?; - let name = name_token.get_name_text(); + // Extract name from either name_token or key_expr + let name = if let Some(key_expr) = tag.get_key_expr() { + // Handle multi-level expressions like self.xxx + extract_name_from_expr(&key_expr) + } else if let Some(name_token) = tag.get_name_token() { + // Fallback to simple name token + name_token.get_name_text().to_string() + } else { + return None; + }; let op_types: Vec<_> = tag.get_op_types().collect(); let cast_op_type = op_types.first()?; @@ -297,7 +305,7 @@ pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast) analyzer.db.get_flow_index_mut().add_signature_cast( analyzer.file_id, signature_id, - name.to_string(), + name, cast_op_type.to_ptr(), fallback_cast, ); @@ -308,6 +316,45 @@ pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast) Some(()) } +// Helper function to extract name string from expression +fn extract_name_from_expr(expr: &LuaExpr) -> String { + match expr { + LuaExpr::NameExpr(name_expr) => { + if let Some(token) = name_expr.get_name_token() { + token.get_name_text().to_string() + } else { + String::new() + } + } + LuaExpr::IndexExpr(index_expr) => { + // Recursively build the path like "self.xxx" or "self.a.b" + let prefix = if let Some(prefix_expr) = index_expr.get_prefix_expr() { + extract_name_from_expr(&prefix_expr) + } else { + String::new() + }; + + let suffix = if let Some(key) = index_expr.get_index_key() { + match key { + emmylua_parser::LuaIndexKey::Name(name_token) => name_token.get_name_text().to_string(), + _ => String::new(), + } + } else { + String::new() + }; + + if prefix.is_empty() { + suffix + } else if suffix.is_empty() { + prefix + } else { + format!("{}.{}", prefix, suffix) + } + } + _ => String::new(), + } +} + pub fn analyze_overload(analyzer: &mut DocAnalyzer, tag: LuaDocTagOverload) -> Option<()> { if let Some(decl_id) = analyzer.current_type_id.clone() { let type_ref = infer_type(analyzer, tag.get_type()?); diff --git a/crates/emmylua_code_analysis/src/compilation/test/flow.rs b/crates/emmylua_code_analysis/src/compilation/test/flow.rs index 2abeda8b1..195c7d1d7 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/flow.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/flow.rs @@ -1476,4 +1476,135 @@ _2 = a[1] "#, ); } + + #[test] + fn test_return_cast_self_field() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class MyClass + ---@field value string|number + local MyClass = {} + + ---Check if value field is string + ---@param self MyClass + ---@return_cast self.value string + function MyClass:check_string() + return type(self.value) == "string" + end + + ---@param obj MyClass + function test(obj) + if obj:check_string() then + a = obj.value + else + b = obj.value + end + end + "#, + ); + + let a = ws.expr_ty("a"); + let a_expected = ws.ty("string"); + assert_eq!(a, a_expected); + + let b = ws.expr_ty("b"); + let b_expected = ws.ty("number"); + assert_eq!(b, b_expected); + } + + #[test] + fn test_return_cast_self_field_with_fallback() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class MyClass + ---@field data table|nil + local MyClass = {} + + ---Check if data exists + ---@param self MyClass + ---@return_cast self.data table else nil + function MyClass:has_data() + return self.data ~= nil + end + + ---@param obj MyClass + function test(obj) + if obj:has_data() then + c = obj.data + else + d = obj.data + end + end + "#, + ); + + let c = ws.expr_ty("c"); + let c_str = ws.humanize_type(c); + assert_eq!(c_str, "table"); + + let d = ws.expr_ty("d"); + let d_expected = ws.ty("nil"); + assert_eq!(d, d_expected); + } + + #[test] + fn test_return_cast_self_field_complex() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Vehicle + ---@field type "car"|"bike"|"truck" + ---@field engine string|nil + local Vehicle = {} + + ---@param self Vehicle + ---@return_cast self.type "car" + function Vehicle:is_car() + return self.type == "car" + end + + ---@param self Vehicle + ---@return_cast self.engine string else nil + function Vehicle:has_engine() + return self.engine ~= nil + end + + ---@param v Vehicle + function test(v) + if v:is_car() then + e = v.type + else + f = v.type + end + + if v:has_engine() then + g = v.engine + else + h = v.engine + end + end + "#, + ); + + let e = ws.expr_ty("e"); + let e_expected = ws.ty("\"car\""); + assert_eq!(e, e_expected); + + let f = ws.expr_ty("f"); + let f_expected = ws.ty("\"bike\"|\"truck\""); + assert_eq!(f, f_expected); + + let g = ws.expr_ty("g"); + let g_expected = ws.ty("string"); + assert_eq!(g, g_expected); + + let h = ws.expr_ty("h"); + let h_expected = ws.ty("nil"); + assert_eq!(h, h_expected); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/undefined_global.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/undefined_global.rs index ba4159828..13ab27176 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/undefined_global.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/undefined_global.rs @@ -1,6 +1,8 @@ use std::collections::HashSet; -use emmylua_parser::{LuaAstNode, LuaClosureExpr, LuaNameExpr}; +use emmylua_parser::{ + LuaAst, LuaAstNode, LuaClosureExpr, LuaComment, LuaDocTagReturnCast, LuaNameExpr, +}; use rowan::TextRange; use crate::{DiagnosticCode, LuaSignatureId, SemanticModel}; @@ -94,6 +96,7 @@ fn check_name_expr( } fn check_self_name(semantic_model: &SemanticModel, name_expr: LuaNameExpr) -> Option<()> { + // Check if self is in a method context (regular Lua code) let closure_expr = name_expr.ancestors::(); for closure_expr in closure_expr { let signature_id = @@ -105,6 +108,47 @@ fn check_self_name(semantic_model: &SemanticModel, name_expr: LuaNameExpr) -> Op if signature.is_method(semantic_model, None) { return Some(()); } + + // Check if self is a parameter of this function (from @param self) + if signature.find_param_idx("self").is_some() { + return Some(()); + } } + + // Check if self is in @return_cast tag + // The name_expr might be inside a doc comment, not inside actual Lua code + for ancestor in name_expr.syntax().ancestors() { + if let Some(return_cast_tag) = LuaDocTagReturnCast::cast(ancestor.clone()) { + // Find the LuaComment that contains this tag + for comment_ancestor in return_cast_tag.syntax().ancestors() { + if let Some(comment) = LuaComment::cast(comment_ancestor) { + // Get the owner (function) of this comment + if let Some(owner) = comment.get_owner() { + if let LuaAst::LuaClosureExpr(closure) = owner { + let sig_id = LuaSignatureId::from_closure( + semantic_model.get_file_id(), + &closure, + ); + if let Some(sig) = + semantic_model.get_db().get_signature_index().get(&sig_id) + { + // Check if the owner function is a method + if sig.is_method(semantic_model, None) { + return Some(()); + } + // Check if self is a parameter + if sig.find_param_idx("self").is_some() { + return Some(()); + } + } + } + } + break; + } + } + break; + } + } + None } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_global_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_global_test.rs index 123b1e7a9..c29ce3b16 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_global_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_global_test.rs @@ -20,4 +20,41 @@ mod test { "# )); } + + #[test] + fn test_return_cast_self_no_undefined_global() { + let mut ws = VirtualWorkspace::new(); + // @return_cast self should not produce undefined-global error in methods + assert!(!ws.check_code_for( + DiagnosticCode::UndefinedGlobal, + r#" + ---@class MyClass + local MyClass = {} + + ---@return_cast self MyClass + function MyClass:check1() + return true + end + "# + )); + } + + #[test] + fn test_return_cast_self_field_no_undefined_global() { + let mut ws = VirtualWorkspace::new(); + // @return_cast self.field should not produce undefined-global error in methods + assert!(!ws.check_code_for( + DiagnosticCode::UndefinedGlobal, + r#" + ---@class MyClass + ---@field value string|number + local MyClass = {} + + ---@return_cast self.value string + function MyClass:check_string() + return type(self.value) == "string" + end + "# + )); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index 8ad1b8b0e..3ff8908ab 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -93,8 +93,9 @@ pub fn get_type_at_call_expr( return Ok(ResultTypeOrContinue::Continue); }; - match signature_cast.name.as_str() { - "self" => get_type_at_call_expr_by_signature_self( + // Check if name starts with "self." for multi-level field access + if signature_cast.name.starts_with("self.") { + get_type_at_call_expr_by_signature_self_field( db, tree, cache, @@ -105,20 +106,35 @@ pub fn get_type_at_call_expr( signature_cast, signature_id, condition_flow, - ), - name => get_type_at_call_expr_by_signature_param_name( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - call_expr, - signature_cast, - signature_id, - name, - condition_flow, - ), + ) + } else { + match signature_cast.name.as_str() { + "self" => get_type_at_call_expr_by_signature_self( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + prefix_expr, + signature_cast, + signature_id, + condition_flow, + ), + name => get_type_at_call_expr_by_signature_param_name( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + call_expr, + signature_cast, + signature_id, + name, + condition_flow, + ), + } } } _ => { @@ -275,6 +291,128 @@ fn get_type_at_call_expr_by_signature_self( Ok(ResultTypeOrContinue::Result(result_type)) } +#[allow(clippy::too_many_arguments)] +fn get_type_at_call_expr_by_signature_self_field( + db: &DbIndex, + tree: &FlowTree, + cache: &mut LuaInferCache, + root: &LuaChunk, + var_ref_id: &VarRefId, + flow_node: &FlowNode, + call_prefix: LuaExpr, + signature_cast: &LuaSignatureCast, + signature_id: LuaSignatureId, + condition_flow: InferConditionFlow, +) -> Result { + // Extract the field path after "self." (e.g., "self.xxx" -> "xxx") + let field_path = signature_cast.name.strip_prefix("self.").unwrap_or(""); + if field_path.is_empty() { + return Ok(ResultTypeOrContinue::Continue); + } + + let LuaExpr::IndexExpr(call_prefix_index) = call_prefix else { + return Ok(ResultTypeOrContinue::Continue); + }; + + let Some(self_expr) = call_prefix_index.get_prefix_expr() else { + return Ok(ResultTypeOrContinue::Continue); + }; + + // Get the var_ref_id of the self expression (e.g., "obj") + let Some(self_var_ref_id) = get_var_expr_var_ref_id(db, cache, self_expr) else { + return Ok(ResultTypeOrContinue::Continue); + }; + + // Check if the tracked variable matches the pattern "self.field" + // For example, if we're tracking "obj.value", we check: + // 1. Does it start with "obj" (self_var_ref_id)? + // 2. Does the path end with ".value" (field_path)? + let matches = match var_ref_id { + VarRefId::IndexRef(decl_or_member, path) => { + // Check if the base matches + let base_matches = match &self_var_ref_id { + VarRefId::VarRef(self_decl_id) => { + decl_or_member.as_decl_id() == Some(*self_decl_id) + } + VarRefId::SelfRef(self_decl_or_member) => { + decl_or_member == self_decl_or_member + } + VarRefId::IndexRef(self_decl_or_member, _) => { + decl_or_member == self_decl_or_member + } + }; + + if !base_matches { + return Ok(ResultTypeOrContinue::Continue); + } + + // Check if the path ends with the field_path + // path might be "obj.value" or "something.obj.value" + // field_path is "value" + path.ends_with(field_path) + } + _ => false, + }; + + if !matches { + return Ok(ResultTypeOrContinue::Continue); + } + + let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; + + let Some(syntax_tree) = db.get_vfs().get_syntax_tree(&signature_id.get_file_id()) else { + return Ok(ResultTypeOrContinue::Continue); + }; + + let signature_root = syntax_tree.get_chunk_node(); + + // Choose the appropriate cast based on condition_flow and whether fallback exists + let result_type = match condition_flow { + InferConditionFlow::TrueCondition => { + let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { + return Ok(ResultTypeOrContinue::Continue); + }; + cast_type( + db, + signature_id.get_file_id(), + cast_op_type, + antecedent_type, + condition_flow, + )? + } + InferConditionFlow::FalseCondition => { + // Use fallback_cast if available, otherwise use the default behavior + if let Some(fallback_cast_ptr) = &signature_cast.fallback_cast { + let Some(fallback_op_type) = fallback_cast_ptr.to_node(&signature_root) else { + return Ok(ResultTypeOrContinue::Continue); + }; + cast_type( + db, + signature_id.get_file_id(), + fallback_op_type, + antecedent_type.clone(), + InferConditionFlow::TrueCondition, // Apply fallback as force cast + )? + } else { + // Original behavior: remove the true type from antecedent + let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { + return Ok(ResultTypeOrContinue::Continue); + }; + cast_type( + db, + signature_id.get_file_id(), + cast_op_type, + antecedent_type, + condition_flow, + )? + } + } + }; + + Ok(ResultTypeOrContinue::Result(result_type)) +} + #[allow(clippy::too_many_arguments)] fn get_type_at_call_expr_by_signature_param_name( db: &DbIndex, diff --git a/crates/emmylua_parser/src/grammar/doc/tag.rs b/crates/emmylua_parser/src/grammar/doc/tag.rs index 10f18363a..2cafedae9 100644 --- a/crates/emmylua_parser/src/grammar/doc/tag.rs +++ b/crates/emmylua_parser/src/grammar/doc/tag.rs @@ -365,12 +365,23 @@ fn parse_tag_return(p: &mut LuaDocParser) -> DocParseResult { // ---@return_cast // ---@return_cast else +// ---@return_cast self.xxx fn parse_tag_return_cast(p: &mut LuaDocParser) -> DocParseResult { - p.set_lexer_state(LuaDocLexerState::Normal); + p.set_lexer_state(LuaDocLexerState::CastExpr); let m = p.mark(LuaSyntaxKind::DocTagReturnCast); p.bump(); - expect_token(p, LuaTokenKind::TkName)?; + + // Parse param name or expression (like self.xxx) + if p.current_token() == LuaTokenKind::TkName { + match parse_cast_expr(p) { + Ok(_) => {} + Err(e) => { + return Err(e); + } + } + } + p.set_lexer_state(LuaDocLexerState::Normal); parse_op_type(p)?; // Allow optional second type after 'else' for false condition diff --git a/crates/emmylua_parser/src/syntax/node/doc/tag.rs b/crates/emmylua_parser/src/syntax/node/doc/tag.rs index c6d67c129..c91a25f60 100644 --- a/crates/emmylua_parser/src/syntax/node/doc/tag.rs +++ b/crates/emmylua_parser/src/syntax/node/doc/tag.rs @@ -1500,6 +1500,10 @@ impl LuaDocTagReturnCast { pub fn get_name_token(&self) -> Option { self.token() } + + pub fn get_key_expr(&self) -> Option { + self.child() + } } #[derive(Debug, Clone, PartialEq, Eq, Hash)]