diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 4d168fdc..f8998992 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -4110,6 +4110,25 @@ select a from y; "); } + #[test] + fn goto_recursive_cte_reference_inside_cte() { + assert_snapshot!(goto(" +with recursive nums as ( + select 1 as n + union all + select n + 1 from nums$0 where n < 5 +) +select * from nums; +"), @r" + ╭▸ + 2 │ with recursive nums as ( + │ ──── 2. destination + ‡ + 5 │ select n + 1 from nums where n < 5 + ╰╴ ─ 1. source + "); + } + #[test] fn goto_cte_with_column_list() { assert_snapshot!(goto(" diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index c971f690..e10a14e6 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -1886,14 +1886,16 @@ fn find_column_in_create_table_as( fn resolve_cte_table(name_ref: &ast::NameRef, cte_name: &Name) -> Option { let with_clause = find_parent_with_clause(name_ref.syntax())?; + let is_recursive = with_clause.recursive_token().is_some(); for with_table in with_clause.with_tables() { if let Some(name) = with_table.name() && Name::from_node(&name) == *cte_name { - if with_table - .syntax() - .text_range() - .contains_range(name_ref.syntax().text_range()) + if !is_recursive + && with_table + .syntax() + .text_range() + .contains_range(name_ref.syntax().text_range()) { continue; } @@ -1998,16 +2000,18 @@ fn resolve_cte_column( column_name: &Name, ) -> Option { let with_clause = find_parent_with_clause(name_ref.syntax())?; + let is_recursive = with_clause.recursive_token().is_some(); for with_table in with_clause.with_tables() { if let Some(name) = with_table.name() && Name::from_node(&name) == *cte_name { // Skip if we're inside this CTE's definition (CTE doesn't shadow itself) - if with_table - .syntax() - .text_range() - .contains_range(name_ref.syntax().text_range()) + if !is_recursive + && with_table + .syntax() + .text_range() + .contains_range(name_ref.syntax().text_range()) { continue; } @@ -2746,15 +2750,17 @@ fn count_columns_for_from_item( fn count_columns_for_cte(name_ref: &ast::NameRef, cte_name: &Name) -> Option { let with_clause = find_parent_with_clause(name_ref.syntax())?; + let is_recursive = with_clause.recursive_token().is_some(); for with_table in with_clause.with_tables() { if let Some(name) = with_table.name() && Name::from_node(&name) == *cte_name { - if with_table - .syntax() - .text_range() - .contains_range(name_ref.syntax().text_range()) + if !is_recursive + && with_table + .syntax() + .text_range() + .contains_range(name_ref.syntax().text_range()) { return None; }