diff --git a/crates/ide-assists/src/handlers/add_return_type.rs b/crates/ide-assists/src/handlers/add_return_type.rs index c9022f66d1e2..7934a80bfabb 100644 --- a/crates/ide-assists/src/handlers/add_return_type.rs +++ b/crates/ide-assists/src/handlers/add_return_type.rs @@ -1,3 +1,4 @@ +use either::Either; use hir::HirDisplay; use syntax::{AstNode, SyntaxKind, SyntaxToken, TextRange, TextSize, ast, match_ast}; @@ -133,8 +134,9 @@ fn peel_blocks(mut expr: ast::Expr) -> ast::Expr { } fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrReplace)> { - let (fn_type, tail_expr, return_type_range, action) = - if let Some(closure) = ctx.find_node_at_offset::() { + let node = ctx.find_node_at_offset::>()?; + let (fn_type, tail_expr, return_type_range, action) = match node { + Either::Left(closure) => { let rpipe = closure.param_list()?.syntax().last_token()?; let rpipe_pos = rpipe.text_range().end(); @@ -149,9 +151,8 @@ fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrR let ret_range = TextRange::new(rpipe_pos, body_start); (FnType::Closure { wrap_expr }, tail_expr, ret_range, action) - } else { - let func = ctx.find_node_at_offset::()?; - + } + Either::Right(func) => { let rparen = func.param_list()?.r_paren_token()?; let rparen_pos = rparen.text_range().end(); let action = ret_ty_to_action(func.ret_type(), rparen)?; @@ -163,7 +164,8 @@ fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrR let ret_range_end = stmt_list.l_curly_token()?.text_range().start(); let ret_range = TextRange::new(rparen_pos, ret_range_end); (FnType::Function, tail_expr, ret_range, action) - }; + } + }; let range = ctx.selection_trimmed(); if return_type_range.contains_range(range) { cov_mark::hit!(cursor_in_ret_position); @@ -239,6 +241,24 @@ mod tests { ); } + #[test] + fn infer_return_type_cursor_at_return_type_pos_fn_inside_closure() { + cov_mark::check!(cursor_in_ret_position); + check_assist( + add_return_type, + r#"const _: fn() = || { + fn foo() $0{ + 45 + } +};"#, + r#"const _: fn() = || { + fn foo() -> i32 { + 45 + } +};"#, + ); + } + #[test] fn infer_return_type() { cov_mark::check!(cursor_on_tail);