1+ use either:: Either ;
12use hir:: HirDisplay ;
23use syntax:: { AstNode , SyntaxKind , SyntaxToken , TextRange , TextSize , ast, match_ast} ;
34
@@ -133,8 +134,9 @@ fn peel_blocks(mut expr: ast::Expr) -> ast::Expr {
133134}
134135
135136fn extract_tail ( ctx : & AssistContext < ' _ > ) -> Option < ( FnType , ast:: Expr , InsertOrReplace ) > {
136- let ( fn_type, tail_expr, return_type_range, action) =
137- if let Some ( closure) = ctx. find_node_at_offset :: < ast:: ClosureExpr > ( ) {
137+ let node = ctx. find_node_at_offset :: < Either < ast:: ClosureExpr , ast:: Fn > > ( ) ?;
138+ let ( fn_type, tail_expr, return_type_range, action) = match node {
139+ Either :: Left ( closure) => {
138140 let rpipe = closure. param_list ( ) ?. syntax ( ) . last_token ( ) ?;
139141 let rpipe_pos = rpipe. text_range ( ) . end ( ) ;
140142
@@ -149,9 +151,8 @@ fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrR
149151
150152 let ret_range = TextRange :: new ( rpipe_pos, body_start) ;
151153 ( FnType :: Closure { wrap_expr } , tail_expr, ret_range, action)
152- } else {
153- let func = ctx. find_node_at_offset :: < ast:: Fn > ( ) ?;
154-
154+ }
155+ Either :: Right ( func) => {
155156 let rparen = func. param_list ( ) ?. r_paren_token ( ) ?;
156157 let rparen_pos = rparen. text_range ( ) . end ( ) ;
157158 let action = ret_ty_to_action ( func. ret_type ( ) , rparen) ?;
@@ -163,7 +164,8 @@ fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrR
163164 let ret_range_end = stmt_list. l_curly_token ( ) ?. text_range ( ) . start ( ) ;
164165 let ret_range = TextRange :: new ( rparen_pos, ret_range_end) ;
165166 ( FnType :: Function , tail_expr, ret_range, action)
166- } ;
167+ }
168+ } ;
167169 let range = ctx. selection_trimmed ( ) ;
168170 if return_type_range. contains_range ( range) {
169171 cov_mark:: hit!( cursor_in_ret_position) ;
@@ -239,6 +241,24 @@ mod tests {
239241 ) ;
240242 }
241243
244+ #[ test]
245+ fn infer_return_type_cursor_at_return_type_pos_fn_inside_closure ( ) {
246+ cov_mark:: check!( cursor_in_ret_position) ;
247+ check_assist (
248+ add_return_type,
249+ r#"const _: fn() = || {
250+ fn foo() $0{
251+ 45
252+ }
253+ };"# ,
254+ r#"const _: fn() = || {
255+ fn foo() -> i32 {
256+ 45
257+ }
258+ };"# ,
259+ ) ;
260+ }
261+
242262 #[ test]
243263 fn infer_return_type ( ) {
244264 cov_mark:: check!( cursor_on_tail) ;
0 commit comments