Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,16 @@ pub fn analyze_return(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturn) -> Optio

pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast) -> Option<()> {
if let Some(LuaSemanticDeclId::Signature(signature_id)) = get_owner_id(analyzer, None, false) {
let name_token = tag.get_name_token()?;
let name = name_token.get_name_text();
// Extract name from either name_token or key_expr
let name = if let Some(key_expr) = tag.get_key_expr() {
// Handle multi-level expressions like self.xxx
extract_name_from_expr(&key_expr)
} else if let Some(name_token) = tag.get_name_token() {
// Fallback to simple name token
name_token.get_name_text().to_string()
} else {
return None;
};

let op_types: Vec<_> = tag.get_op_types().collect();
let cast_op_type = op_types.first()?;
Expand Down Expand Up @@ -297,7 +305,7 @@ pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast)
analyzer.db.get_flow_index_mut().add_signature_cast(
analyzer.file_id,
signature_id,
name.to_string(),
name,
cast_op_type.to_ptr(),
fallback_cast,
);
Expand All @@ -308,6 +316,45 @@ pub fn analyze_return_cast(analyzer: &mut DocAnalyzer, tag: LuaDocTagReturnCast)
Some(())
}

// Helper function to extract name string from expression
fn extract_name_from_expr(expr: &LuaExpr) -> String {
match expr {
LuaExpr::NameExpr(name_expr) => {
if let Some(token) = name_expr.get_name_token() {
token.get_name_text().to_string()
} else {
String::new()
}
}
LuaExpr::IndexExpr(index_expr) => {
// Recursively build the path like "self.xxx" or "self.a.b"
let prefix = if let Some(prefix_expr) = index_expr.get_prefix_expr() {
extract_name_from_expr(&prefix_expr)
} else {
String::new()
};

let suffix = if let Some(key) = index_expr.get_index_key() {
match key {
emmylua_parser::LuaIndexKey::Name(name_token) => name_token.get_name_text().to_string(),
_ => String::new(),
}
} else {
String::new()
};

if prefix.is_empty() {
suffix
} else if suffix.is_empty() {
prefix
} else {
format!("{}.{}", prefix, suffix)
}
}
_ => String::new(),
}
}

pub fn analyze_overload(analyzer: &mut DocAnalyzer, tag: LuaDocTagOverload) -> Option<()> {
if let Some(decl_id) = analyzer.current_type_id.clone() {
let type_ref = infer_type(analyzer, tag.get_type()?);
Expand Down
131 changes: 131 additions & 0 deletions crates/emmylua_code_analysis/src/compilation/test/flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1476,4 +1476,135 @@ _2 = a[1]
"#,
);
}

#[test]
fn test_return_cast_self_field() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@class MyClass
---@field value string|number
local MyClass = {}

---Check if value field is string
---@param self MyClass
---@return_cast self.value string
function MyClass:check_string()
return type(self.value) == "string"
end

---@param obj MyClass
function test(obj)
if obj:check_string() then
a = obj.value
else
b = obj.value
end
end
"#,
);

let a = ws.expr_ty("a");
let a_expected = ws.ty("string");
assert_eq!(a, a_expected);

let b = ws.expr_ty("b");
let b_expected = ws.ty("number");
assert_eq!(b, b_expected);
}

#[test]
fn test_return_cast_self_field_with_fallback() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@class MyClass
---@field data table|nil
local MyClass = {}

---Check if data exists
---@param self MyClass
---@return_cast self.data table else nil
function MyClass:has_data()
return self.data ~= nil
end

---@param obj MyClass
function test(obj)
if obj:has_data() then
c = obj.data
else
d = obj.data
end
end
"#,
);

let c = ws.expr_ty("c");
let c_str = ws.humanize_type(c);
assert_eq!(c_str, "table");

let d = ws.expr_ty("d");
let d_expected = ws.ty("nil");
assert_eq!(d, d_expected);
}

#[test]
fn test_return_cast_self_field_complex() {
let mut ws = VirtualWorkspace::new();

ws.def(
r#"
---@class Vehicle
---@field type "car"|"bike"|"truck"
---@field engine string|nil
local Vehicle = {}

---@param self Vehicle
---@return_cast self.type "car"
function Vehicle:is_car()
return self.type == "car"
end

---@param self Vehicle
---@return_cast self.engine string else nil
function Vehicle:has_engine()
return self.engine ~= nil
end

---@param v Vehicle
function test(v)
if v:is_car() then
e = v.type
else
f = v.type
end

if v:has_engine() then
g = v.engine
else
h = v.engine
end
end
"#,
);

let e = ws.expr_ty("e");
let e_expected = ws.ty("\"car\"");
assert_eq!(e, e_expected);

let f = ws.expr_ty("f");
let f_expected = ws.ty("\"bike\"|\"truck\"");
assert_eq!(f, f_expected);

let g = ws.expr_ty("g");
let g_expected = ws.ty("string");
assert_eq!(g, g_expected);

let h = ws.expr_ty("h");
let h_expected = ws.ty("nil");
assert_eq!(h, h_expected);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::collections::HashSet;

use emmylua_parser::{LuaAstNode, LuaClosureExpr, LuaNameExpr};
use emmylua_parser::{
LuaAst, LuaAstNode, LuaClosureExpr, LuaComment, LuaDocTagReturnCast, LuaNameExpr,
};
use rowan::TextRange;

use crate::{DiagnosticCode, LuaSignatureId, SemanticModel};
Expand Down Expand Up @@ -94,6 +96,7 @@ fn check_name_expr(
}

fn check_self_name(semantic_model: &SemanticModel, name_expr: LuaNameExpr) -> Option<()> {
// Check if self is in a method context (regular Lua code)
let closure_expr = name_expr.ancestors::<LuaClosureExpr>();
for closure_expr in closure_expr {
let signature_id =
Expand All @@ -105,6 +108,47 @@ fn check_self_name(semantic_model: &SemanticModel, name_expr: LuaNameExpr) -> Op
if signature.is_method(semantic_model, None) {
return Some(());
}

// Check if self is a parameter of this function (from @param self)
if signature.find_param_idx("self").is_some() {
return Some(());
}
}

// Check if self is in @return_cast tag
// The name_expr might be inside a doc comment, not inside actual Lua code
for ancestor in name_expr.syntax().ancestors() {
if let Some(return_cast_tag) = LuaDocTagReturnCast::cast(ancestor.clone()) {
// Find the LuaComment that contains this tag
for comment_ancestor in return_cast_tag.syntax().ancestors() {
if let Some(comment) = LuaComment::cast(comment_ancestor) {
// Get the owner (function) of this comment
if let Some(owner) = comment.get_owner() {
if let LuaAst::LuaClosureExpr(closure) = owner {
let sig_id = LuaSignatureId::from_closure(
semantic_model.get_file_id(),
&closure,
);
if let Some(sig) =
semantic_model.get_db().get_signature_index().get(&sig_id)
{
// Check if the owner function is a method
if sig.is_method(semantic_model, None) {
return Some(());
}
// Check if self is a parameter
if sig.find_param_idx("self").is_some() {
return Some(());
}
}
}
}
break;
}
}
break;
}
}
Comment on lines +120 to +151

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nested for loops for traversing ancestors can be made more concise and readable by using ancestors::<T>().next() and if let chains. This would flatten the code and make the intent clearer.

    if let Some(return_cast_tag) = name_expr.ancestors::<LuaDocTagReturnCast>().next() {
        if let Some(comment) = return_cast_tag.ancestors::<LuaComment>().next() {
            if let Some(owner) = comment.get_owner() {
                if let LuaAst::LuaClosureExpr(closure) = owner {
                    let sig_id = LuaSignatureId::from_closure(
                        semantic_model.get_file_id(),
                        &closure,
                    );
                    if let Some(sig) =
                        semantic_model.get_db().get_signature_index().get(&sig_id)
                    {
                        // Check if the owner function is a method or has a 'self' param
                        if sig.is_method(semantic_model, None) || sig.find_param_idx("self").is_some() {
                            return Some(());
                        }
                    }
                }
            }
        }
    }


None
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,41 @@ mod test {
"#
));
}

#[test]
fn test_return_cast_self_no_undefined_global() {
let mut ws = VirtualWorkspace::new();
// @return_cast self should not produce undefined-global error in methods
assert!(!ws.check_code_for(
DiagnosticCode::UndefinedGlobal,
r#"
---@class MyClass
local MyClass = {}

---@return_cast self MyClass
function MyClass:check1()
return true
end
"#
));
}

#[test]
fn test_return_cast_self_field_no_undefined_global() {
let mut ws = VirtualWorkspace::new();
// @return_cast self.field should not produce undefined-global error in methods
assert!(!ws.check_code_for(
DiagnosticCode::UndefinedGlobal,
r#"
---@class MyClass
---@field value string|number
local MyClass = {}

---@return_cast self.value string
function MyClass:check_string()
return type(self.value) == "string"
end
"#
));
}
}
Loading
Loading