From 0fed51d57ab1071347f39bf1e4d892cc2d3cba1f Mon Sep 17 00:00:00 2001 From: Steve Dignam Date: Mon, 5 Jan 2026 00:21:24 -0500 Subject: [PATCH 1/2] ide: basic returning support --- crates/squawk_ide/src/classify.rs | 12 + crates/squawk_ide/src/goto_definition.rs | 227 ++++++++++++++++++ crates/squawk_ide/src/hover.rs | 31 ++- crates/squawk_ide/src/resolve.rs | 113 ++++++--- .../squawk_syntax/src/ast/generated/nodes.rs | 6 +- crates/squawk_syntax/src/postgresql.ungram | 3 +- 6 files changed, 361 insertions(+), 31 deletions(-) diff --git a/crates/squawk_ide/src/classify.rs b/crates/squawk_ide/src/classify.rs index 6e179078..e98ebd5d 100644 --- a/crates/squawk_ide/src/classify.rs +++ b/crates/squawk_ide/src/classify.rs @@ -49,6 +49,7 @@ pub(crate) enum NameRefClass { InsertColumn, DeleteTable, DeleteWhereColumn, + DeleteUsingTable, UpdateTable, UpdateWhereColumn, UpdateSetColumn, @@ -61,6 +62,7 @@ pub(crate) enum NameRefClass { VacuumTable, AlterTable, AlterTableColumn, + AlterTableDropColumn, RefreshMaterializedView, ReindexTable, ReindexIndex, @@ -82,6 +84,7 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option let mut in_constraint_where_clause = false; let mut in_partition_item = false; let mut in_set_null_columns = false; + let mut in_using_clause = false; // TODO: can we combine this if and the one that follows? if let Some(parent) = name_ref.syntax().parent() @@ -215,6 +218,9 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option if ast::AlterColumn::can_cast(ancestor.kind()) { return Some(NameRefClass::AlterTableColumn); } + if ast::DropColumn::can_cast(ancestor.kind()) { + return Some(NameRefClass::AlterTableDropColumn); + } if ast::AlterTable::can_cast(ancestor.kind()) { return Some(NameRefClass::AlterTable); } @@ -457,10 +463,16 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option if ast::SetClause::can_cast(ancestor.kind()) { in_set_clause = true; } + if ast::UsingClause::can_cast(ancestor.kind()) { + in_using_clause = true; + } if ast::Delete::can_cast(ancestor.kind()) { if in_where_clause { return Some(NameRefClass::DeleteWhereColumn); } + if in_using_clause { + return Some(NameRefClass::DeleteUsingTable); + } return Some(NameRefClass::DeleteTable); } if ast::Update::can_cast(ancestor.kind()) { diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 29ff9f79..e44285a0 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -2809,6 +2809,77 @@ select x$0 from t; "); } + #[test] + fn goto_cte_insert_returning_star_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +with inserted as ( + insert into t values (1, 2), (3, 4) + returning * +) +select a$0 from inserted; +"), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + ‡ + 7 │ select a from inserted; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_cte_delete_returning_star_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +with deleted as ( + delete from t + returning * +) +select a$0 from deleted; +"), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + ‡ + 7 │ select a from deleted; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_cte_update_returning_star_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +with updated as ( + update t set a = 42 + returning * +) +select a$0 from updated; +"), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + ‡ + 7 │ select a from updated; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_cte_update_returning_column_list_overwrites_column() { + goto_not_found( + " +create table t(a int, b int); +with updated(c) as ( + update t set a = 10 + returning a +) +select a$0 from updated; +", + ); + } + #[test] fn goto_cte_column_list_overwrites_column() { goto_not_found( @@ -3146,6 +3217,52 @@ delete from users where id$0 = 1 and active = true; "); } + #[test] + fn goto_delete_using_table() { + assert_snapshot!(goto(" +create table t(id int, f_id int); +create table f(id int, name text); +delete from t using f$0 where f_id = f.id and f.name = 'foo'; +"), @r" + ╭▸ + 3 │ create table f(id int, name text); + │ ─ 2. destination + 4 │ delete from t using f where f_id = f.id and f.name = 'foo'; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_delete_using_table_with_schema() { + assert_snapshot!(goto(" +create table t(id int, f_id int); +create table public.f(id int, name text); +delete from t using public.f$0 where f_id = f.id; +"), @r" + ╭▸ + 3 │ create table public.f(id int, name text); + │ ─ 2. destination + 4 │ delete from t using public.f where f_id = f.id; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_delete_using_column_in_where() { + assert_snapshot!(goto(" +create table t(id int, f_id int); +create table f(id int, name text); +delete from t using f where f_id = f.id$0 and f.name = 'foo'; +"), @r" + ╭▸ + 2 │ create table t(id int, f_id int); + │ ── 2. destination + 3 │ create table f(id int, name text); + 4 │ delete from t using f where f_id = f.id and f.name = 'foo'; + ╰╴ ─ 1. source + "); + } + #[test] fn goto_select_from_table() { assert_snapshot!(goto(" @@ -3741,6 +3858,44 @@ select a$0 from t; "); } + #[test] + fn goto_cte_insert_returning_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +with inserted as ( + insert into t values (1, 2), (3, 4) + returning a, b +) +select a$0 from inserted; +"), @r" + ╭▸ + 5 │ returning a, b + │ ─ 2. destination + 6 │ ) + 7 │ select a from inserted; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_cte_insert_returning_aliased_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +with inserted as ( + insert into t values (1, 2), (3, 4) + returning a as x +) +select x$0 from inserted; +"), @r" + ╭▸ + 5 │ returning a as x + │ ─ 2. destination + 6 │ ) + 7 │ select x from inserted; + ╰╴ ─ 1. source + "); + } + #[test] fn goto_drop_aggregate() { assert_snapshot!(goto(" @@ -4938,6 +5093,48 @@ alter table users alter column email$0 set not null; "); } + #[test] + fn goto_alter_table_add_column() { + assert_snapshot!(goto(" +create table users(id int); +alter table users$0 add column email text; +"), @r" + ╭▸ + 2 │ create table users(id int); + │ ───── 2. destination + 3 │ alter table users add column email text; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_alter_table_drop_column() { + assert_snapshot!(goto(" +create table users(id int, email text); +alter table users drop column email$0; +"), @r" + ╭▸ + 2 │ create table users(id int, email text); + │ ───── 2. destination + 3 │ alter table users drop column email; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_alter_table_drop_column_table_name() { + assert_snapshot!(goto(" +create table users(id int, email text); +alter table users$0 drop column email; +"), @r" + ╭▸ + 2 │ create table users(id int, email text); + │ ───── 2. destination + 3 │ alter table users drop column email; + ╰╴ ─ 1. source + "); + } + #[test] fn goto_refresh_materialized_view() { assert_snapshot!(goto(" @@ -5051,4 +5248,34 @@ reindex system systemdb$0; ╰╴ ─ 1. source "); } + + #[test] + fn goto_merge_returning_aliased_column() { + assert_snapshot!(goto( + r#" +create table t(a int, b int); +with u(x, y) as ( + select 1, 2 +), +merged as ( + merge into t + using u + on t.a = u.x + when matched then + do nothing + when not matched then + do nothing + returning a as x, b as y +) +select x$0 from merged; +"#, + ), @r" + ╭▸ + 14 │ returning a as x, b as y + │ ─ 2. destination + 15 │ ) + 16 │ select x from merged; + ╰╴ ─ 1. source + "); + } } diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index e53e3dfe..0f4c80cb 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -58,7 +58,8 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { | NameRefClass::ForeignKeyColumn | NameRefClass::ForeignKeyLocalColumn | NameRefClass::SequenceOwnedByColumn - | NameRefClass::AlterTableColumn => { + | NameRefClass::AlterTableColumn + | NameRefClass::AlterTableDropColumn => { return hover_column(root, &name_ref, &binder); } NameRefClass::TypeReference | NameRefClass::DropType => { @@ -86,6 +87,7 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { | NameRefClass::CreateIndex | NameRefClass::InsertTable | NameRefClass::DeleteTable + | NameRefClass::DeleteUsingTable | NameRefClass::UpdateTable | NameRefClass::SelectFromTable | NameRefClass::UpdateFromTable @@ -3378,4 +3380,31 @@ reindex index idx$0; ╰╴ ─ hover "); } + + #[test] + fn hover_merge_returning_star_from_cte() { + assert_snapshot!(check_hover(" +create table t(a int, b int); +with u(x, y) as ( + select 1, 2 +), +merged as ( + merge into t + using u + on t.a = u.x + when matched then + do nothing + when not matched then + do nothing + returning a as x, b as y +) +select *$0 from merged; +"), @r" + hover: column merged.x + column merged.y + ╭▸ + 16 │ select * from merged; + ╰╴ ─ hover + "); + } } diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index f909ddcd..98102358 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -38,7 +38,9 @@ pub(crate) fn resolve_name_ref( let position = name_ref.syntax().text_range().start(); resolve_table(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) } - NameRefClass::SelectFromTable => { + NameRefClass::SelectFromTable + | NameRefClass::UpdateFromTable + | NameRefClass::DeleteUsingTable => { let table_name = Name::from_node(name_ref); let schema = if let Some(parent) = name_ref.syntax().parent() && let Some(field_expr) = ast::FieldExpr::cast(parent) @@ -359,33 +361,7 @@ pub(crate) fn resolve_name_ref( resolve_update_where_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::JoinUsingColumn => resolve_join_using_columns(binder, root, name_ref), - NameRefClass::UpdateFromTable => { - let table_name = Name::from_node(name_ref); - let schema = if let Some(parent) = name_ref.syntax().parent() - && let Some(field_expr) = ast::FieldExpr::cast(parent) - && let Some(base) = field_expr.base() - && let Some(schema_name_ref) = ast::NameRef::cast(base.syntax().clone()) - { - Some(Schema(Name::from_node(&schema_name_ref))) - } else { - None - }; - - if schema.is_none() - && let Some(cte_ptr) = resolve_cte_table(name_ref, &table_name) - { - return Some(smallvec![cte_ptr]); - } - - let position = name_ref.syntax().text_range().start(); - - if let Some(ptr) = resolve_table(binder, &table_name, &schema, position) { - return Some(smallvec![ptr]); - } - - resolve_view(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) - } - NameRefClass::AlterTableColumn => { + NameRefClass::AlterTableColumn | NameRefClass::AlterTableDropColumn => { let column_name = Name::from_node(name_ref); let alter_table = name_ref .syntax() @@ -1436,6 +1412,9 @@ fn resolve_cte_column( let query = with_table.query()?; if let ast::WithQuery::Values(values) = query { + if column_list_len > 0 { + continue; + } if let Some(num_str) = column_name.0.strip_prefix("column") && let Ok(col_num) = num_str.parse::() && col_num > 0 @@ -1448,6 +1427,12 @@ fn resolve_cte_column( continue; } + if column_list_len == 0 + && let Some(column) = column_in_with_query(&query, binder, root, column_name) + { + return Some(column); + } + let Some(cte_select) = select_from_with_query(query) else { continue; }; @@ -1517,6 +1502,53 @@ fn resolve_cte_column( None } +fn column_in_with_query( + query: &ast::WithQuery, + binder: &Binder, + root: &SyntaxNode, + column_name: &Name, +) -> Option { + let (returning_clause, path) = match query { + ast::WithQuery::Delete(delete) => ( + delete.returning_clause(), + delete.relation_name().and_then(|r| r.path()), + ), + ast::WithQuery::Insert(insert) => (insert.returning_clause(), insert.path()), + ast::WithQuery::Merge(merge) => ( + merge.returning_clause(), + merge.relation_name().and_then(|r| r.path()), + ), + ast::WithQuery::Update(update) => ( + update.returning_clause(), + update.relation_name().and_then(|r| r.path()), + ), + ast::WithQuery::Select(_) + | ast::WithQuery::CompoundSelect(_) + | ast::WithQuery::Table(_) + | ast::WithQuery::Values(_) + | ast::WithQuery::ParenSelect(_) => return None, + }; + + let target_list = returning_clause?.target_list()?; + let path = path?; + for target in target_list.targets() { + if let Some((col_name, node)) = ColumnName::from_target(target) { + if let Some(col_name_str) = col_name.to_string() + && Name::from_string(col_name_str) == *column_name + { + return Some(SyntaxNodePtr::new(&node)); + } + if matches!(col_name, ColumnName::Star) + && let Some(ptr) = resolve_column_for_path(binder, root, &path, column_name.clone()) + { + return Some(ptr); + } + } + } + + None +} + fn resolve_subquery_column( binder: &Binder, root: &SyntaxNode, @@ -2190,6 +2222,10 @@ pub(crate) fn collect_with_table_column_names(with_table: &ast::WithTable) -> Ve return results; } + if let Some(columns) = columns_from_returning_clause(&query) { + return columns; + } + let Some(cte_select) = select_from_with_query(query) else { return vec![]; }; @@ -2203,6 +2239,27 @@ pub(crate) fn collect_with_table_column_names(with_table: &ast::WithTable) -> Ve collect_target_list_column_names(&target_list) } +fn columns_from_returning_clause(query: &ast::WithQuery) -> Option> { + let returning_clause = match query { + ast::WithQuery::Delete(delete) => delete.returning_clause(), + ast::WithQuery::Insert(insert) => insert.returning_clause(), + ast::WithQuery::Merge(merge) => merge.returning_clause(), + ast::WithQuery::Update(update) => update.returning_clause(), + ast::WithQuery::Select(_) + | ast::WithQuery::CompoundSelect(_) + | ast::WithQuery::Table(_) + | ast::WithQuery::Values(_) + | ast::WithQuery::ParenSelect(_) => None, + }; + if let Some(returning_clause) = returning_clause { + if let Some(target_list) = returning_clause.target_list() { + return Some(collect_target_list_column_names(&target_list)); + } + return Some(vec![]); + } + None +} + fn resolve_symbol_info( binder: &Binder, path: &ast::Path, diff --git a/crates/squawk_syntax/src/ast/generated/nodes.rs b/crates/squawk_syntax/src/ast/generated/nodes.rs index 93c8f78b..9f6e5a3c 100644 --- a/crates/squawk_syntax/src/ast/generated/nodes.rs +++ b/crates/squawk_syntax/src/ast/generated/nodes.rs @@ -9395,7 +9395,7 @@ impl JsonObjectAggFn { support::child(&self.syntax) } #[inline] - pub fn returning_clause(&self) -> Option { + pub fn json_returning_clause(&self) -> Option { support::child(&self.syntax) } #[inline] @@ -10211,6 +10211,10 @@ impl Merge { support::child(&self.syntax) } #[inline] + pub fn returning_clause(&self) -> Option { + support::child(&self.syntax) + } + #[inline] pub fn using_on_clause(&self) -> Option { support::child(&self.syntax) } diff --git a/crates/squawk_syntax/src/postgresql.ungram b/crates/squawk_syntax/src/postgresql.ungram index 090de625..b2e81977 100644 --- a/crates/squawk_syntax/src/postgresql.ungram +++ b/crates/squawk_syntax/src/postgresql.ungram @@ -157,7 +157,7 @@ JsonObjectAggFn = JsonKeyValue? JsonNullClause? JsonKeysUniqueClause? - ReturningClause? + JsonReturningClause? ')' JsonArrayAggFn = @@ -1406,6 +1406,7 @@ Merge = 'merge' 'into' RelationName Alias? UsingOnClause MergeWhenClause* + ReturningClause? Declare = 'declare' Name From 8ec07a8f5421ca07ca4eafd0f3cf439487ab9327 Mon Sep 17 00:00:00 2001 From: Steve Dignam Date: Mon, 5 Jan 2026 23:51:46 -0500 Subject: [PATCH 2/2] ide: add merge & returning clause support --- crates/squawk_ide/src/classify.rs | 144 ++++ crates/squawk_ide/src/goto_definition.rs | 677 ++++++++++++++++++- crates/squawk_ide/src/hover.rs | 16 + crates/squawk_ide/src/resolve.rs | 796 +++++++++++++++++------ 4 files changed, 1442 insertions(+), 191 deletions(-) diff --git a/crates/squawk_ide/src/classify.rs b/crates/squawk_ide/src/classify.rs index e98ebd5d..4e9c7202 100644 --- a/crates/squawk_ide/src/classify.rs +++ b/crates/squawk_ide/src/classify.rs @@ -47,13 +47,29 @@ pub(crate) enum NameRefClass { CompositeTypeField, InsertTable, InsertColumn, + InsertQualifiedColumnTable, DeleteTable, DeleteWhereColumn, + DeleteQualifiedColumnTable, DeleteUsingTable, UpdateTable, UpdateWhereColumn, UpdateSetColumn, UpdateFromTable, + UpdateSetQualifiedColumnTable, + UpdateReturningColumn, + InsertReturningColumn, + DeleteReturningColumn, + MergeReturningColumn, + MergeWhenColumn, + MergeOnColumn, + MergeQualifiedColumnTable, + MergeUsingTable, + MergeTable, + UpdateReturningQualifiedColumnTable, + InsertReturningQualifiedColumnTable, + DeleteReturningQualifiedColumnTable, + MergeReturningQualifiedColumnTable, JoinUsingColumn, SchemaQualifier, TypeReference, @@ -85,6 +101,8 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option let mut in_partition_item = false; let mut in_set_null_columns = false; let mut in_using_clause = false; + let mut in_returning_clause = false; + let mut in_when_clause = false; // TODO: can we combine this if and the one that follows? if let Some(parent) = name_ref.syntax().parent() @@ -108,6 +126,10 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option let mut in_from_clause = false; let mut in_on_clause = false; + let mut in_returning_clause = false; + let mut in_set_clause = false; + let mut in_where_clause = false; + let mut in_when_clause = false; for ancestor in parent.ancestors() { if ast::OnClause::can_cast(ancestor.kind()) { in_on_clause = true; @@ -115,6 +137,78 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option if ast::FromClause::can_cast(ancestor.kind()) { in_from_clause = true; } + if ast::ReturningClause::can_cast(ancestor.kind()) { + in_returning_clause = true; + } + if ast::SetClause::can_cast(ancestor.kind()) { + in_set_clause = true; + } + if ast::WhereClause::can_cast(ancestor.kind()) { + in_where_clause = true; + } + if ast::MergeWhenMatched::can_cast(ancestor.kind()) { + in_when_clause = true; + } + if ast::Update::can_cast(ancestor.kind()) { + if in_returning_clause { + if is_function_call || is_schema_table_col { + return Some(NameRefClass::SchemaQualifier); + } else { + return Some(NameRefClass::UpdateReturningQualifiedColumnTable); + } + } else if in_set_clause || in_where_clause || in_from_clause { + if is_function_call || is_schema_table_col { + return Some(NameRefClass::SchemaQualifier); + } else { + return Some(NameRefClass::UpdateSetQualifiedColumnTable); + } + } + } + if ast::Insert::can_cast(ancestor.kind()) { + if in_returning_clause { + if is_function_call || is_schema_table_col { + return Some(NameRefClass::SchemaQualifier); + } else { + return Some(NameRefClass::InsertReturningQualifiedColumnTable); + } + } else if !in_from_clause && !in_on_clause { + if is_function_call || is_schema_table_col { + return Some(NameRefClass::SchemaQualifier); + } else { + return Some(NameRefClass::InsertQualifiedColumnTable); + } + } + } + if ast::Delete::can_cast(ancestor.kind()) { + if in_returning_clause { + if is_function_call || is_schema_table_col { + return Some(NameRefClass::SchemaQualifier); + } else { + return Some(NameRefClass::DeleteReturningQualifiedColumnTable); + } + } else if in_where_clause || in_using_clause { + if is_function_call || is_schema_table_col { + return Some(NameRefClass::SchemaQualifier); + } else { + return Some(NameRefClass::DeleteQualifiedColumnTable); + } + } + } + if ast::Merge::can_cast(ancestor.kind()) { + if in_returning_clause { + if is_function_call || is_schema_table_col { + return Some(NameRefClass::SchemaQualifier); + } else { + return Some(NameRefClass::MergeReturningQualifiedColumnTable); + } + } else if in_on_clause || in_when_clause { + if is_function_call || is_schema_table_col { + return Some(NameRefClass::SchemaQualifier); + } else { + return Some(NameRefClass::MergeQualifiedColumnTable); + } + } + } if ast::Select::can_cast(ancestor.kind()) && (!in_from_clause || in_on_clause) { if is_function_call || is_schema_table_col { return Some(NameRefClass::SchemaQualifier); @@ -150,6 +244,8 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option let mut in_from_clause = false; let mut in_on_clause = false; let mut in_cast_expr = false; + let mut in_when_clause = false; + let mut in_returning_clause = false; for ancestor in parent.ancestors() { if ast::OnClause::can_cast(ancestor.kind()) { in_on_clause = true; @@ -160,6 +256,21 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option if ast::FromClause::can_cast(ancestor.kind()) { in_from_clause = true; } + if ast::MergeWhenMatched::can_cast(ancestor.kind()) { + in_when_clause = true; + } + if ast::ReturningClause::can_cast(ancestor.kind()) { + in_returning_clause = true; + } + if ast::Merge::can_cast(ancestor.kind()) + && (in_on_clause || in_when_clause || in_returning_clause) + { + if let Some(base) = field_expr.base() + && matches!(base, ast::Expr::NameRef(_) | ast::Expr::FieldExpr(_)) + { + return Some(NameRefClass::SelectQualifiedColumn); + } + } if ast::Select::can_cast(ancestor.kind()) && (!in_from_clause || in_on_clause) { if in_cast_expr { return Some(NameRefClass::TypeReference); @@ -449,6 +560,9 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option in_partition_item = true; } if ast::Insert::can_cast(ancestor.kind()) { + if in_returning_clause { + return Some(NameRefClass::InsertReturningColumn); + } if in_column_list { return Some(NameRefClass::InsertColumn); } @@ -466,7 +580,16 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option if ast::UsingClause::can_cast(ancestor.kind()) { in_using_clause = true; } + if ast::UsingOnClause::can_cast(ancestor.kind()) { + in_using_clause = true; + } + if ast::ReturningClause::can_cast(ancestor.kind()) { + in_returning_clause = true; + } if ast::Delete::can_cast(ancestor.kind()) { + if in_returning_clause { + return Some(NameRefClass::DeleteReturningColumn); + } if in_where_clause { return Some(NameRefClass::DeleteWhereColumn); } @@ -476,6 +599,9 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option return Some(NameRefClass::DeleteTable); } if ast::Update::can_cast(ancestor.kind()) { + if in_returning_clause { + return Some(NameRefClass::UpdateReturningColumn); + } if in_where_clause { return Some(NameRefClass::UpdateWhereColumn); } @@ -487,6 +613,24 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option } return Some(NameRefClass::UpdateTable); } + if ast::MergeWhenMatched::can_cast(ancestor.kind()) { + in_when_clause = true; + } + if ast::Merge::can_cast(ancestor.kind()) { + if in_returning_clause { + return Some(NameRefClass::MergeReturningColumn); + } + if in_when_clause { + return Some(NameRefClass::MergeWhenColumn); + } + if in_on_clause { + return Some(NameRefClass::MergeOnColumn); + } + if in_using_clause { + return Some(NameRefClass::MergeUsingTable); + } + return Some(NameRefClass::MergeTable); + } } None diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index e44285a0..9e70589b 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -5252,7 +5252,7 @@ reindex system systemdb$0; #[test] fn goto_merge_returning_aliased_column() { assert_snapshot!(goto( - r#" + " create table t(a int, b int); with u(x, y) as ( select 1, 2 @@ -5268,7 +5268,7 @@ merged as ( returning a as x, b as y ) select x$0 from merged; -"#, +", ), @r" ╭▸ 14 │ returning a as x, b as y @@ -5278,4 +5278,677 @@ select x$0 from merged; ╰╴ ─ 1. source "); } + + #[test] + fn goto_cte_update_returning_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +with updated(c) as ( + update t set a = 10 + returning a, b +) +select c, b$0 from updated;" + ), @r" + ╭▸ + 5 │ returning a, b + │ ─ 2. destination + 6 │ ) + 7 │ select c, b from updated; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_returning_column_to_table_def() { + assert_snapshot!(goto(" +create table t(a int, b int); +with updated(c) as ( + update t set a = 10 + returning a, b$0 +) +select c, b from updated;" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + ‡ + 5 │ returning a, b + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_insert_returning_column_to_table_def() { + assert_snapshot!(goto(" +create table t(a int, b int); +with inserted as ( + insert into t values (1, 2) + returning a$0, b +) +select a, b from inserted;" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + ‡ + 5 │ returning a, b + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_delete_returning_column_to_table_def() { + assert_snapshot!(goto(" +create table t(a int, b int); +with deleted as ( + delete from t + returning a, b$0 +) +select a, b from deleted;" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + ‡ + 5 │ returning a, b + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_returning_qualified_star_table() { + assert_snapshot!(goto(" +create table t(a int, b int); +update t set a = 10 +returning t$0.*;" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ update t set a = 10 + 4 │ returning t.*; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_insert_returning_qualified_star_table() { + assert_snapshot!(goto(" +create table t(a int, b int); +insert into t values (1, 2) +returning t$0.*;" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ insert into t values (1, 2) + 4 │ returning t.*; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_delete_returning_qualified_star_table() { + assert_snapshot!(goto(" +create table t(a int, b int); +delete from t +returning t$0.*;" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ delete from t + 4 │ returning t.*; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_alias_in_set_clause() { + assert_snapshot!(goto(" +create table t(a int, b int); +update t as f set f$0.a = 10;" + ), @r" + ╭▸ + 3 │ update t as f set f.a = 10; + │ ┬ ─ 1. source + │ │ + ╰╴ 2. destination + "); + } + + #[test] + fn goto_update_alias_in_where_clause() { + assert_snapshot!(goto(" +create table t(a int, b int); +update t as f set a = 10 where f$0.b = 5;" + ), @r" + ╭▸ + 3 │ update t as f set a = 10 where f.b = 5; + ╰╴ ─ 2. destination ─ 1. source + "); + } + + #[test] + fn goto_update_alias_in_from_clause() { + assert_snapshot!(goto(" +create table t(a int, b int); +create table u(c int); +update t as f set a = 10 from u where f$0.b = u.c;" + ), @r" + ╭▸ + 4 │ update t as f set a = 10 from u where f.b = u.c; + ╰╴ ─ 2. destination ─ 1. source + "); + } + + #[test] + fn goto_insert_alias_in_on_conflict() { + assert_snapshot!(goto(" +create table t(a int primary key, b int); +insert into t as f values (1, 2) on conflict (f$0.a) do nothing;" + ), @r" + ╭▸ + 3 │ insert into t as f values (1, 2) on conflict (f.a) do nothing; + ╰╴ ─ 2. destination ─ 1. source + "); + } + + #[test] + fn goto_insert_alias_in_returning() { + assert_snapshot!(goto(" +create table t(a int, b int); +insert into t as f values (1, 2) returning f$0.a;" + ), @r" + ╭▸ + 3 │ insert into t as f values (1, 2) returning f.a; + ╰╴ ─ 2. destination ─ 1. source + "); + } + + #[test] + fn goto_insert_alias_returning_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +insert into t as f values (1, 2) returning f.a$0;" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ insert into t as f values (1, 2) returning f.a; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_delete_from_alias() { + assert_snapshot!(goto(" +create table t(a int, b int); +delete from t as f where f$0.a = 10;" + ), @r" + ╭▸ + 3 │ delete from t as f where f.a = 10; + │ ┬ ─ 1. source + │ │ + ╰╴ 2. destination + "); + } + + #[test] + fn goto_delete_from_alias_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +delete from t as f where f.a$0 = 10;" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ delete from t as f where f.a = 10; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_delete_from_alias_returning() { + assert_snapshot!(goto(" +create table t(a int, b int); +delete from t as f returning f$0.a" + ), @r" + ╭▸ + 3 │ delete from t as f returning f.a + │ ┬ ─ 1. source + │ │ + ╰╴ 2. destination + "); + + assert_snapshot!(goto(" +create table t(a int, b int); +delete from t as f returning f.a$0" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ delete from t as f returning f.a + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_alias_on_table() { + assert_snapshot!(goto(" +create table t(a int, b int); +create table u(a int, b int); +merge into t as f + using u on u.a = f$0.a + when matched then do nothing; +" + + ), @r" + ╭▸ + 4 │ merge into t as f + │ ─ 2. destination + 5 │ using u on u.a = f.a + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_alias_on_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +create table u(a int, b int); +merge into t as f + using u on u.a = f.a$0 + when matched then do nothing; +" + + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + ‡ + 5 │ using u on u.a = f.a + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_alias_returning() { + assert_snapshot!(goto(" +create table t(a int, b int); +create table u(a int, b int); +merge into t as f + using u on u.a = f.a + when matched then do nothing + returning f$0.a; +" + + ), @r" + ╭▸ + 4 │ merge into t as f + │ ─ 2. destination + ‡ + 7 │ returning f.a; + ╰╴ ─ 1. source + "); + + assert_snapshot!(goto(" +create table t(a int, b int); +create table u(a int, b int); +merge into t as f + using u on u.a = f.a + when matched then do nothing + returning f.a$0; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + ‡ + 7 │ returning f.a; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_using_table_in_when_clause() { + assert_snapshot!(goto(" +create table t(a int, b int); +create table u(a int, b int); +merge into t + using u on true + when matched and u$0.a = t.a + then do nothing; +" + ), @r" + ╭▸ + 3 │ create table u(a int, b int); + │ ─ 2. destination + ‡ + 6 │ when matched and u.a = t.a + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_using_table_column_in_when_clause() { + assert_snapshot!(goto(" +create table t(a int, b int); +create table u(a int, b int); +merge into t + using u on true + when matched and u.a$0 = t.a + then do nothing; +" + ), @r" + ╭▸ + 3 │ create table u(a int, b int); + │ ─ 2. destination + ‡ + 6 │ when matched and u.a = t.a + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_unqualified_column_target_table() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x + using y + on true + when matched and a$0 = c + then do nothing; +" + ), @r" + ╭▸ + 2 │ create table x(a int, b int); + │ ─ 2. destination + ‡ + 7 │ when matched and a = c + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_unqualified_column_source_table() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x + using y + on true + when matched and a = c$0 + then do nothing; +" + ), @r" + ╭▸ + 3 │ create table y(c int, d int); + │ ─ 2. destination + ‡ + 7 │ when matched and a = c + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_into_table() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x$0 + using y + on true + when matched and a = c + then do nothing; +" + ), @r" + ╭▸ + 2 │ create table x(a int, b int); + │ ─ 2. destination + 3 │ create table y(c int, d int); + 4 │ merge into x + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_using_clause_table() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x + using y$0 + on true + when matched and a = c + then do nothing; +" + ), @r" + ╭▸ + 3 │ create table y(c int, d int); + │ ─ 2. destination + 4 │ merge into x + 5 │ using y + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_using_clause_alias() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x as g + using y as k + on true + when matched and a = k.c$0 + then do nothing; +" + ), @r" + ╭▸ + 3 │ create table y(c int, d int); + │ ─ 2. destination + ‡ + 7 │ when matched and a = k.c + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_on_clause_unqualified_source_column() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x as g + using y as k + on g.a = c$0 and a = c + when matched and g.a = k.c + then do nothing; +" + ), @r" + ╭▸ + 3 │ create table y(c int, d int); + │ ─ 2. destination + ‡ + 6 │ on g.a = c and a = c + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_returning_old_table() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x as g + using y as k + on g.a = c and a = k.c + when matched and g.a = k.c + then do nothing + returning old$0.a, new.a; +" + ), @r" + ╭▸ + 2 │ create table x(a int, b int); + │ ─ 2. destination + ‡ + 9 │ returning old.a, new.a; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_returning_old_column() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x as g + using y as k + on g.a = c and a = k.c + when matched and g.a = k.c + then do nothing + returning old.a$0, new.a; +" + ), @r" + ╭▸ + 2 │ create table x(a int, b int); + │ ─ 2. destination + ‡ + 9 │ returning old.a, new.a; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_returning_new_table() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x as g + using y as k + on g.a = c and a = k.c + when matched and g.a = k.c + then do nothing + returning old.a, new$0.a; +" + ), @r" + ╭▸ + 2 │ create table x(a int, b int); + │ ─ 2. destination + ‡ + 9 │ returning old.a, new.a; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_returning_new_column() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x as g + using y as k + on g.a = c and a = k.c + when matched and g.a = k.c + then do nothing + returning old.a, new.a$0; +" + ), @r" + ╭▸ + 2 │ create table x(a int, b int); + │ ─ 2. destination + ‡ + 9 │ returning old.a, new.a; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_with_tables_named_old_new_old_table() { + assert_snapshot!(goto(" +create table old(a int, b int); +create table new(c int, d int); +merge into old + using new + on true + when matched + then do nothing + returning old$0.a, new.d; +" + ), @r" + ╭▸ + 2 │ create table old(a int, b int); + │ ─── 2. destination + ‡ + 9 │ returning old.a, new.d; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_with_tables_named_old_new_old_column() { + assert_snapshot!(goto(" +create table old(a int, b int); +create table new(c int, d int); +merge into old + using new + on true + when matched + then do nothing + returning old.a$0, new.d; +" + ), @r" + ╭▸ + 2 │ create table old(a int, b int); + │ ─ 2. destination + ‡ + 9 │ returning old.a, new.d; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_with_tables_named_old_new_new_table() { + assert_snapshot!(goto(" +create table old(a int, b int); +create table new(c int, d int); +merge into old + using new + on true + when matched + then do nothing + returning old.a, new$0.d; +" + ), @r" + ╭▸ + 3 │ create table new(c int, d int); + │ ─── 2. destination + ‡ + 9 │ returning old.a, new.d; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_with_tables_named_old_new_new_column() { + assert_snapshot!(goto(" +create table old(a int, b int); +create table new(c int, d int); +merge into old + using new + on true + when matched + then do nothing + returning old.a, new.d$0; +" + ), @r" + ╭▸ + 3 │ create table new(c int, d int); + │ ─ 2. destination + ‡ + 9 │ returning old.a, new.d; + ╰╴ ─ 1. source + "); + } } diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index 0f4c80cb..f90a554b 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -47,6 +47,12 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { | NameRefClass::DeleteWhereColumn | NameRefClass::UpdateWhereColumn | NameRefClass::UpdateSetColumn + | NameRefClass::UpdateReturningColumn + | NameRefClass::InsertReturningColumn + | NameRefClass::DeleteReturningColumn + | NameRefClass::MergeReturningColumn + | NameRefClass::MergeWhenColumn + | NameRefClass::MergeOnColumn | NameRefClass::CheckConstraintColumn | NameRefClass::GeneratedColumn | NameRefClass::UniqueConstraintColumn @@ -86,12 +92,22 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { | NameRefClass::DropMaterializedView | NameRefClass::CreateIndex | NameRefClass::InsertTable + | NameRefClass::InsertQualifiedColumnTable | NameRefClass::DeleteTable + | NameRefClass::DeleteQualifiedColumnTable | NameRefClass::DeleteUsingTable + | NameRefClass::MergeUsingTable | NameRefClass::UpdateTable | NameRefClass::SelectFromTable | NameRefClass::UpdateFromTable | NameRefClass::SelectQualifiedColumnTable + | NameRefClass::UpdateSetQualifiedColumnTable + | NameRefClass::MergeQualifiedColumnTable + | NameRefClass::UpdateReturningQualifiedColumnTable + | NameRefClass::InsertReturningQualifiedColumnTable + | NameRefClass::DeleteReturningQualifiedColumnTable + | NameRefClass::MergeReturningQualifiedColumnTable + | NameRefClass::MergeTable | NameRefClass::ForeignKeyTable | NameRefClass::LikeTable | NameRefClass::InheritsTable diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 98102358..071dc467 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -31,15 +31,17 @@ pub(crate) fn resolve_name_ref( | NameRefClass::LockTable | NameRefClass::VacuumTable | NameRefClass::AlterTable - | NameRefClass::ReindexTable => { + | NameRefClass::ReindexTable + | NameRefClass::MergeTable => { let path = find_containing_path(name_ref)?; let table_name = extract_table_name(&path)?; let schema = extract_schema_name(&path); let position = name_ref.syntax().text_range().start(); - resolve_table(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) } NameRefClass::SelectFromTable | NameRefClass::UpdateFromTable + | NameRefClass::MergeUsingTable | NameRefClass::DeleteUsingTable => { let table_name = Name::from_node(name_ref); let schema = if let Some(parent) = name_ref.syntax().parent() @@ -60,18 +62,20 @@ pub(crate) fn resolve_name_ref( let position = name_ref.syntax().text_range().start(); - if let Some(ptr) = resolve_table(binder, &table_name, &schema, position) { - return Some(smallvec![ptr]); + if let Some(table_name_ptr) = + resolve_table_name_ptr(binder, &table_name, &schema, position) + { + return Some(smallvec![table_name_ptr]); } - resolve_view(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_view_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) } NameRefClass::DropIndex | NameRefClass::ReindexIndex => { let path = find_containing_path(name_ref)?; let index_name = extract_table_name(&path)?; let schema = extract_schema_name(&path); let position = name_ref.syntax().text_range().start(); - resolve_index(binder, &index_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_index_name_ptr(binder, &index_name, &schema, position).map(|ptr| smallvec![ptr]) } NameRefClass::DropType | NameRefClass::TypeReference => { let (type_name, schema) = if let Some(parent) = name_ref.syntax().parent() @@ -96,7 +100,7 @@ pub(crate) fn resolve_name_ref( (type_name, schema) }; let position = name_ref.syntax().text_range().start(); - resolve_type(binder, &type_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_type_name_ptr(binder, &type_name, &schema, position).map(|ptr| smallvec![ptr]) } NameRefClass::DropView | NameRefClass::DropMaterializedView @@ -105,27 +109,28 @@ pub(crate) fn resolve_name_ref( let view_name = extract_table_name(&path)?; let schema = extract_schema_name(&path); let position = name_ref.syntax().text_range().start(); - resolve_view(binder, &view_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_view_name_ptr(binder, &view_name, &schema, position).map(|ptr| smallvec![ptr]) } NameRefClass::DropSequence => { let path = find_containing_path(name_ref)?; let sequence_name = extract_table_name(&path)?; let schema = extract_schema_name(&path); let position = name_ref.syntax().text_range().start(); - resolve_sequence(binder, &sequence_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_sequence_name_ptr(binder, &sequence_name, &schema, position) + .map(|ptr| smallvec![ptr]) } NameRefClass::ReindexDatabase | NameRefClass::ReindexSystem | NameRefClass::DropDatabase => { let database_name = Name::from_node(name_ref); - resolve_database(binder, &database_name).map(|ptr| smallvec![ptr]) + resolve_database_name_ptr(binder, &database_name).map(|ptr| smallvec![ptr]) } NameRefClass::DropServer | NameRefClass::AlterServer | NameRefClass::CreateServer | NameRefClass::ForeignTableServerName => { let server_name = Name::from_node(name_ref); - resolve_server(binder, &server_name).map(|ptr| smallvec![ptr]) + resolve_server_name_ptr(binder, &server_name).map(|ptr| smallvec![ptr]) } NameRefClass::SequenceOwnedByColumn => { let sequence_option = name_ref @@ -140,14 +145,14 @@ pub(crate) fn resolve_name_ref( } NameRefClass::Tablespace => { let tablespace_name = Name::from_node(name_ref); - resolve_tablespace(binder, &tablespace_name).map(|ptr| smallvec![ptr]) + resolve_tablespace_name_ptr(binder, &tablespace_name).map(|ptr| smallvec![ptr]) } NameRefClass::ForeignKeyTable => { let path = name_ref.syntax().ancestors().find_map(ast::Path::cast)?; let table_name = extract_table_name(&path)?; let schema = extract_schema_name(&path); let position = name_ref.syntax().text_range().start(); - resolve_table(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) } NameRefClass::ForeignKeyColumn => { // TODO: the ast is too flat here @@ -193,7 +198,7 @@ pub(crate) fn resolve_name_ref( let table_name = extract_table_name(&path)?; let schema = extract_schema_name(&path); let position = name_ref.syntax().text_range().start(); - resolve_table(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) + resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) } NameRefClass::DropFunction => { let function_sig = name_ref @@ -337,28 +342,59 @@ pub(crate) fn resolve_name_ref( None } NameRefClass::CreateIndexColumn => { - resolve_create_index_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) + resolve_create_index_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::SelectColumn => { - resolve_select_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) + resolve_select_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::SelectQualifiedColumnTable => { - resolve_select_qualified_column_table(binder, name_ref).map(|ptr| smallvec![ptr]) + resolve_select_qualified_column_table_name_ptr(binder, name_ref) + .map(|ptr| smallvec![ptr]) } NameRefClass::SelectQualifiedColumn => { - resolve_select_qualified_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) + resolve_select_qualified_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::CompositeTypeField => { - resolve_composite_type_field(binder, root, name_ref).map(|ptr| smallvec![ptr]) + resolve_composite_type_field_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::InsertColumn => { - resolve_insert_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) + resolve_insert_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) + } + NameRefClass::InsertQualifiedColumnTable => { + resolve_insert_table_name_ptr(binder, name_ref).map(|ptr| smallvec![ptr]) + } + NameRefClass::DeleteWhereColumn | NameRefClass::DeleteReturningColumn => { + resolve_delete_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) } - NameRefClass::DeleteWhereColumn => { - resolve_delete_where_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) + NameRefClass::DeleteQualifiedColumnTable + | NameRefClass::DeleteReturningQualifiedColumnTable => { + resolve_delete_table_name_ptr(binder, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::UpdateWhereColumn | NameRefClass::UpdateSetColumn => { - resolve_update_where_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) + resolve_update_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) + } + NameRefClass::UpdateReturningQualifiedColumnTable + | NameRefClass::UpdateSetQualifiedColumnTable => { + resolve_update_table_name_ptr(binder, name_ref).map(|ptr| smallvec![ptr]) + } + NameRefClass::UpdateReturningColumn => { + resolve_update_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) + } + NameRefClass::InsertReturningColumn => { + resolve_insert_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) + } + NameRefClass::MergeReturningColumn | NameRefClass::MergeOnColumn => { + resolve_merge_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) + } + NameRefClass::MergeWhenColumn => { + resolve_merge_when_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) + } + NameRefClass::MergeQualifiedColumnTable + | NameRefClass::MergeReturningQualifiedColumnTable => { + resolve_merge_table_name_ptr(binder, name_ref).map(|ptr| smallvec![ptr]) + } + NameRefClass::InsertReturningQualifiedColumnTable => { + resolve_insert_table_name_ptr(binder, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::JoinUsingColumn => resolve_join_using_columns(binder, root, name_ref), NameRefClass::AlterTableColumn | NameRefClass::AlterTableDropColumn => { @@ -379,7 +415,7 @@ pub(crate) fn resolve_name_ref( } } -fn resolve_table( +fn resolve_table_name_ptr( binder: &Binder, table_name: &Name, schema: &Option, @@ -388,7 +424,7 @@ fn resolve_table( resolve_for_kind(binder, table_name, schema, position, SymbolKind::Table) } -fn resolve_index( +fn resolve_index_name_ptr( binder: &Binder, index_name: &Name, schema: &Option, @@ -397,7 +433,7 @@ fn resolve_index( resolve_for_kind(binder, index_name, schema, position, SymbolKind::Index) } -fn resolve_type( +fn resolve_type_name_ptr( binder: &Binder, type_name: &Name, schema: &Option, @@ -406,7 +442,7 @@ fn resolve_type( resolve_for_kind(binder, type_name, schema, position, SymbolKind::Type) } -fn resolve_view( +fn resolve_view_name_ptr( binder: &Binder, view_name: &Name, schema: &Option, @@ -415,7 +451,7 @@ fn resolve_view( resolve_for_kind(binder, view_name, schema, position, SymbolKind::View) } -fn resolve_sequence( +fn resolve_sequence_name_ptr( binder: &Binder, sequence_name: &Name, schema: &Option, @@ -430,7 +466,7 @@ fn resolve_sequence( ) } -fn resolve_tablespace(binder: &Binder, tablespace_name: &Name) -> Option { +fn resolve_tablespace_name_ptr(binder: &Binder, tablespace_name: &Name) -> Option { let symbols = binder.scopes[binder.root_scope()].get(tablespace_name)?; let symbol_id = symbols.iter().copied().find(|id| { let symbol = &binder.symbols[*id]; @@ -439,7 +475,7 @@ fn resolve_tablespace(binder: &Binder, tablespace_name: &Name) -> Option Option { +fn resolve_database_name_ptr(binder: &Binder, database_name: &Name) -> Option { let symbols = binder.scopes[binder.root_scope()].get(database_name)?; let symbol_id = symbols.iter().copied().find(|id| { let symbol = &binder.symbols[*id]; @@ -448,7 +484,7 @@ fn resolve_database(binder: &Binder, database_name: &Name) -> Option Option { +fn resolve_server_name_ptr(binder: &Binder, server_name: &Name) -> Option { let symbols = binder.scopes[binder.root_scope()].get(server_name)?; let symbol_id = symbols.iter().copied().find(|id| { let symbol = &binder.symbols[*id]; @@ -588,14 +624,14 @@ fn resolve_schema(binder: &Binder, schema_name: &Name) -> Option Some(binder.symbols[symbol_id].ptr) } -fn resolve_create_index_column( +fn resolve_create_index_column_ptr( binder: &Binder, root: &SyntaxNode, - name_ref: &ast::NameRef, + column_name_ref: &ast::NameRef, ) -> Option { - let column_name = Name::from_node(name_ref); + let column_name = Name::from_node(column_name_ref); - let create_index = name_ref + let create_index = column_name_ref .syntax() .ancestors() .find_map(ast::CreateIndex::cast)?; @@ -615,41 +651,50 @@ fn resolve_column_for_path( let schema = extract_schema_name(path); let position = path.syntax().text_range().start(); - let table_ptr = resolve_table(binder, &table_name, &schema, position)?; - - let table_name_node = table_ptr.to_node(root); - - let create_table = table_name_node - .ancestors() - .find_map(ast::CreateTableLike::cast)?; - - find_column_in_create_table(&create_table, &column_name) + if let Some(resolved) = resolve_table_name(binder, root, &table_name, &schema, position) { + match resolved { + ResolvedTableName::View(create_view) => { + find_column_in_create_view(&create_view, &column_name) + } + ResolvedTableName::Table(create_table_like) => { + find_column_in_create_table(&create_table_like, &column_name) + } + } + } else { + None + } } -fn resolve_insert_column( +fn resolve_insert_column_ptr( binder: &Binder, root: &SyntaxNode, - name_ref: &ast::NameRef, + column_name_ref: &ast::NameRef, ) -> Option { - let column_name = Name::from_node(name_ref); + let column_name = Name::from_node(column_name_ref); - let insert = name_ref.syntax().ancestors().find_map(ast::Insert::cast)?; + let insert = column_name_ref + .syntax() + .ancestors() + .find_map(ast::Insert::cast)?; let path = insert.path()?; resolve_column_for_path(binder, root, &path, column_name) } -fn resolve_select_qualified_column_table( +fn resolve_select_qualified_column_table_name_ptr( binder: &Binder, - name_ref: &ast::NameRef, + table_name_ref: &ast::NameRef, ) -> Option { - let table_name = Name::from_node(name_ref); + let table_name = Name::from_node(table_name_ref); - let field_expr = name_ref.syntax().parent().and_then(ast::FieldExpr::cast)?; + let field_expr = table_name_ref + .syntax() + .parent() + .and_then(ast::FieldExpr::cast)?; let explicit_schema = if field_expr .field() - .is_some_and(|f| f.syntax() == name_ref.syntax()) + .is_some_and(|f| f.syntax() == table_name_ref.syntax()) && field_expr.star_token().is_none() { // if we're at the field `bar` in `foo.bar` @@ -670,11 +715,14 @@ fn resolve_select_qualified_column_table( }; if let Some(schema) = explicit_schema { - let position = name_ref.syntax().text_range().start(); - return resolve_table(binder, &table_name, &Some(schema), position); + let position = table_name_ref.syntax().text_range().start(); + return resolve_table_name_ptr(binder, &table_name, &Some(schema), position); } - let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?; + let select = table_name_ref + .syntax() + .ancestors() + .find_map(ast::Select::cast)?; let from_clause = select.from_clause()?; let from_item = find_from_item_in_from_clause(&from_clause, &table_name)?; @@ -685,7 +733,7 @@ fn resolve_select_qualified_column_table( } let (table_name, schema) = if let Some(name_ref_node) = from_item.name_ref() { - if let Some(cte_ptr) = resolve_cte_table(name_ref, &table_name) { + if let Some(cte_ptr) = resolve_cte_table(table_name_ref, &table_name) { return Some(cte_ptr); } @@ -710,19 +758,22 @@ fn resolve_select_qualified_column_table( (from_table_name, Some(schema)) }; - let position = name_ref.syntax().text_range().start(); - resolve_table(binder, &table_name, &schema, position) + let position = table_name_ref.syntax().text_range().start(); + resolve_table_name_ptr(binder, &table_name, &schema, position) } // TODO: this is similar to resolve_from_item_for_column, maybe we can simplify -fn resolve_select_qualified_column( +fn resolve_select_qualified_column_ptr( binder: &Binder, root: &SyntaxNode, - name_ref: &ast::NameRef, + column_name_ref: &ast::NameRef, ) -> Option { - let column_name = Name::from_node(name_ref); + let column_name = Name::from_node(column_name_ref); - let field_expr = name_ref.syntax().parent().and_then(ast::FieldExpr::cast)?; + let field_expr = column_name_ref + .syntax() + .parent() + .and_then(ast::FieldExpr::cast)?; let (column_table_name, explicit_schema) = // if we're at `base` in `base.field` @@ -745,12 +796,70 @@ fn resolve_select_qualified_column( return None; }; - let position = name_ref.syntax().text_range().start(); + let position = column_name_ref.syntax().text_range().start(); - let (table_name, schema) = if let Some(schema) = explicit_schema { + let (mut table_name, schema) = if let Some(schema) = explicit_schema { (column_table_name, Some(schema)) + } else if let Some(merge) = column_name_ref + .syntax() + .ancestors() + .find_map(ast::Merge::cast) + { + // Handle MERGE statements + let relation_name = merge.relation_name()?; + let path = relation_name.path()?; + let merge_table_name = extract_table_name(&path)?; + + // Check if this is the alias or the table name + if let Some(alias) = merge.alias() + && let Some(alias_name) = alias.name() + && Name::from_node(&alias_name) == column_table_name + { + (merge_table_name, extract_schema_name(&path)) + } else if merge_table_name == column_table_name { + (merge_table_name, extract_schema_name(&path)) + } else { + // Try to find in USING clause first + let found_in_using = if let Some(using_on) = merge.using_on_clause() + && let Some(from_item) = using_on.from_item() + { + if let Some(item_name_ref) = from_item.name_ref() + && let item_name = Name::from_node(&item_name_ref) + && item_name == column_table_name + { + Some((item_name, None)) + } else if let Some(alias) = from_item.alias() + && let Some(alias_name) = alias.name() + && let alias_name = Name::from_node(&alias_name) + && alias_name == column_table_name + { + Some((alias_name, None)) + } else { + None + } + } else { + None + }; + + if let Some(result) = found_in_using { + result + } else { + // Fallback: check for OLD and NEW pseudo-tables in MERGE RETURNING clause + let old_name = Name::from_string("old"); + let new_name = Name::from_string("new"); + if column_table_name == old_name || column_table_name == new_name { + // Both OLD and NEW refer to the target table + (merge_table_name, extract_schema_name(&path)) + } else { + return None; + } + } + } } else { - let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?; + let select = column_name_ref + .syntax() + .ancestors() + .find_map(ast::Select::cast)?; let from_clause = select.from_clause()?; let from_item = find_from_item_in_from_clause(&from_clause, &column_table_name)?; @@ -765,7 +874,7 @@ fn resolve_select_qualified_column( binder, root, &paren_select, - name_ref, + column_name_ref, &column_name, ); } @@ -787,7 +896,13 @@ fn resolve_select_qualified_column( // ``` if let Some(name_ref_node) = from_item.name_ref() { let cte_name = Name::from_node(&name_ref_node); - return resolve_cte_column(binder, root, name_ref, &cte_name, &column_name); + return resolve_cte_column( + binder, + root, + column_name_ref, + &cte_name, + &column_name, + ); } } @@ -828,70 +943,133 @@ fn resolve_select_qualified_column( } }; - if schema.is_none() - && let Some(cte_column_ptr) = - resolve_cte_column(binder, root, name_ref, &table_name, &column_name) - { - return Some(cte_column_ptr); + if schema.is_none() { + if let Some(cte_column_ptr) = + resolve_cte_column(binder, root, column_name_ref, &table_name, &column_name) + { + return Some(cte_column_ptr); + } + if let Some(alias_table_name) = resolve_alias(column_name_ref, &table_name) { + table_name = alias_table_name; + } + } + + if let Some(resolved) = resolve_table_name(binder, root, &table_name, &schema, position) { + match resolved { + ResolvedTableName::View(create_view) => { + if let Some(ptr) = find_column_in_create_view(&create_view, &column_name) { + return Some(ptr); + } + + return resolve_function(binder, &column_name, &schema, None, position); + } + ResolvedTableName::Table(create_table_like) => { + // 1. Try to find a matching column (columns take precedence) + if let Some(ptr) = find_column_in_create_table(&create_table_like, &column_name) { + return Some(ptr); + } + // 2. No column found, check for field-style function call + // e.g., select t.b from t where b is a function that takes t as an argument + return resolve_function(binder, &column_name, &schema, None, position); + } + } } - if let Some(table_ptr) = resolve_table(binder, &table_name, &schema, position) { - let table_name_node = table_ptr.to_node(root); + None +} + +enum ResolvedTableName { + View(ast::CreateView), + Table(ast::CreateTableLike), +} +fn resolve_table_name( + binder: &Binder, + root: &SyntaxNode, + table_name: &Name, + schema: &Option, + position: TextSize, +) -> Option { + use ResolvedTableName::*; + if let Some(table_name_ptr) = resolve_table_name_ptr(binder, table_name, schema, position) { + let table_name_node = table_name_ptr.to_node(root); if let Some(create_table) = table_name_node .ancestors() .find_map(ast::CreateTableLike::cast) { - // 1. Try to find a matching column (columns take precedence) - if let Some(ptr) = find_column_in_create_table(&create_table, &column_name) { - return Some(ptr); - } - // 2. No column found, check for field-style function call - // e.g., select t.b from t where b is a function that takes t as an argument - return resolve_function(binder, &column_name, &schema, None, position); + return Some(Table(create_table)); } } - // ditto as above but with views - if let Some(view_ptr) = resolve_view(binder, &table_name, &schema, position) { - let view_name_node = view_ptr.to_node(root); - + if let Some(view_name_ptr) = resolve_view_name_ptr(binder, table_name, schema, position) { + let view_name_node = view_name_ptr.to_node(root); if let Some(create_view) = view_name_node.ancestors().find_map(ast::CreateView::cast) { - if let Some(ptr) = find_column_in_create_view(&create_view, &column_name) { - return Some(ptr); - } - - return resolve_function(binder, &column_name, &schema, None, position); + return Some(View(create_view)); } } + None +} +fn resolve_alias(name_ref: &ast::NameRef, table_name: &Name) -> Option { + let from_item = find_parent_alias_from_item(name_ref.syntax())?; + if let Some(alias) = from_item.alias() + && let Some(alias_name) = alias.name() + && Name::from_node(&alias_name) == *table_name + { + let table_name = Name::from_node(&from_item.name_ref()?); + return Some(table_name); + } None } -fn resolve_from_item_for_column( +fn find_parent_alias_from_item(syntax: &SyntaxNode) -> Option { + for a in syntax.ancestors() { + if let Some(merge) = ast::Merge::cast(a) + && let Some(from_item) = merge.using_on_clause().and_then(|c| c.from_item()) + { + return Some(from_item); + } + } + None +} + +fn resolve_from_item_column_ptr( binder: &Binder, root: &SyntaxNode, from_item: &ast::FromItem, - name_ref: &ast::NameRef, + column_name_ref: &ast::NameRef, ) -> Option { - let column_name = Name::from_node(name_ref); + let column_name = Name::from_node(column_name_ref); if let Some(paren_select) = from_item.paren_select() { - return resolve_subquery_column(binder, root, &paren_select, name_ref, &column_name); + return resolve_subquery_column(binder, root, &paren_select, column_name_ref, &column_name); } if let Some(paren_expr) = from_item.paren_expr() { - return resolve_column_from_paren_expr(binder, root, &paren_expr, name_ref, &column_name); + return resolve_column_from_paren_expr( + binder, + root, + &paren_expr, + column_name_ref, + &column_name, + ); } let (table_name, schema) = table_and_schema_from_from_item(from_item)?; if schema.is_none() && let Some(cte_column_ptr) = - resolve_cte_column(binder, root, name_ref, &table_name, &column_name) + resolve_cte_column(binder, root, column_name_ref, &table_name, &column_name) { return Some(cte_column_ptr); } - resolve_column_from_table_or_view(binder, root, name_ref, &table_name, &schema, &column_name) + resolve_column_from_table_or_view( + binder, + root, + column_name_ref, + &table_name, + &schema, + &column_name, + ) } fn resolve_column_from_table_or_view( @@ -904,8 +1082,8 @@ fn resolve_column_from_table_or_view( ) -> Option { let position = name_ref.syntax().text_range().start(); - if let Some(table_ptr) = resolve_table(binder, table_name, schema, position) { - let table_name_node = table_ptr.to_node(root); + if let Some(table_name_ptr) = resolve_table_name_ptr(binder, table_name, schema, position) { + let table_name_node = table_name_ptr.to_node(root); if let Some(create_table) = table_name_node .ancestors() @@ -923,14 +1101,14 @@ fn resolve_column_from_table_or_view( // select t from t; // ``` if column_name == table_name { - return Some(table_ptr); + return Some(table_name_ptr); } } } // ditto as above but with view - if let Some(view_ptr) = resolve_view(binder, table_name, schema, position) { - let view_name_node = view_ptr.to_node(root); + if let Some(view_name_ptr) = resolve_view_name_ptr(binder, table_name, schema, position) { + let view_name_node = view_name_ptr.to_node(root); if let Some(create_view) = view_name_node.ancestors().find_map(ast::CreateView::cast) { if let Some(ptr) = find_column_in_create_view(&create_view, column_name) { @@ -938,7 +1116,7 @@ fn resolve_column_from_table_or_view( } if column_name == table_name { - return Some(view_ptr); + return Some(view_name_ptr); } } } @@ -967,7 +1145,7 @@ fn resolve_from_item_for_cte_star( ); } - resolve_from_item_for_column(binder, root, from_item, name_ref) + resolve_from_item_column_ptr(binder, root, from_item, name_ref) } fn resolve_from_join_expr(join_expr: &ast::JoinExpr, try_resolve: &F) -> Option @@ -993,39 +1171,49 @@ where None } -fn resolve_select_column( +fn resolve_select_column_ptr( binder: &Binder, root: &SyntaxNode, - name_ref: &ast::NameRef, + column_name_ref: &ast::NameRef, ) -> Option { - let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?; + let select = column_name_ref + .syntax() + .ancestors() + .find_map(ast::Select::cast)?; let from_clause = select.from_clause()?; for from_item in from_clause.from_items() { - if let Some(result) = resolve_from_item_for_column(binder, root, &from_item, name_ref) { - return Some(result); + if let Some(column_ptr) = + resolve_from_item_column_ptr(binder, root, &from_item, column_name_ref) + { + return Some(column_ptr); } } for join_expr in from_clause.join_exprs() { - if let Some(result) = resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| { - resolve_from_item_for_column(binder, root, from_item, name_ref) - }) { - return Some(result); + if let Some(column_ptr) = + resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| { + resolve_from_item_column_ptr(binder, root, from_item, column_name_ref) + }) + { + return Some(column_ptr); } } None } -fn resolve_delete_where_column( +fn resolve_delete_column_ptr( binder: &Binder, root: &SyntaxNode, - name_ref: &ast::NameRef, + column_name_ref: &ast::NameRef, ) -> Option { - let column_name = Name::from_node(name_ref); + let column_name = Name::from_node(column_name_ref); - let delete = name_ref.syntax().ancestors().find_map(ast::Delete::cast)?; + let delete = column_name_ref + .syntax() + .ancestors() + .find_map(ast::Delete::cast)?; let relation_name = delete.relation_name()?; let path = relation_name.path()?; @@ -1045,7 +1233,7 @@ fn resolve_join_using_columns( let mut results: SmallVec<[SyntaxNodePtr; 1]> = SmallVec::new(); collect_from_join_expr(&join_expr, &mut results, &|from_item: &ast::FromItem| { - resolve_from_item_for_column(binder, root, from_item, name_ref) + resolve_from_item_column_ptr(binder, root, from_item, name_ref) }); (!results.is_empty()).then_some(results) @@ -1075,19 +1263,24 @@ fn collect_from_join_expr( } } -fn resolve_update_where_column( +fn resolve_update_column_ptr( binder: &Binder, root: &SyntaxNode, - name_ref: &ast::NameRef, + column_name_ref: &ast::NameRef, ) -> Option { - let column_name = Name::from_node(name_ref); + let column_name = Name::from_node(column_name_ref); - let update = name_ref.syntax().ancestors().find_map(ast::Update::cast)?; + let update = column_name_ref + .syntax() + .ancestors() + .find_map(ast::Update::cast)?; // `update t set a = b from u` if let Some(from_clause) = update.from_clause() { for from_item in from_clause.from_items() { - if let Some(result) = resolve_from_item_for_column(binder, root, &from_item, name_ref) { + if let Some(result) = + resolve_from_item_column_ptr(binder, root, &from_item, column_name_ref) + { return Some(result); } } @@ -1095,7 +1288,7 @@ fn resolve_update_where_column( for join_expr in from_clause.join_exprs() { if let Some(result) = resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| { - resolve_from_item_for_column(binder, root, from_item, name_ref) + resolve_from_item_column_ptr(binder, root, from_item, column_name_ref) }) { return Some(result); @@ -1159,7 +1352,7 @@ fn resolve_from_item_for_fn_call_column( let (table_name, schema) = table_and_schema_from_from_item(from_item)?; let position = name_ref.syntax().text_range().start(); - let table_ptr = resolve_table(binder, &table_name, &schema, position)?; + let table_ptr = resolve_table_name_ptr(binder, &table_name, &schema, position)?; let table_name_node = table_ptr.to_node(root); let create_table = table_name_node @@ -1374,6 +1567,68 @@ fn find_parent_with_clause(node: &SyntaxNode) -> Option { }) } +fn count_columns_for_path(binder: &Binder, root: &SyntaxNode, path: &ast::Path) -> Option { + let table_name = extract_table_name(path)?; + let schema = extract_schema_name(path); + let position = path.syntax().text_range().start(); + + count_columns_for_table_name(binder, root, &table_name, &schema, position) +} + +fn count_columns_for_table_name( + binder: &Binder, + root: &SyntaxNode, + table_name: &Name, + schema: &Option, + position: TextSize, +) -> Option { + if let Some(table_name_ptr) = resolve_table_name_ptr(binder, table_name, schema, position) { + let table_name_node = table_name_ptr.to_node(root); + + if let Some(create_table) = table_name_node + .ancestors() + .find_map(ast::CreateTableLike::cast) + { + let mut count: usize = 0; + if let Some(args) = create_table.table_arg_list() { + for arg in args.args() { + if matches!(arg, ast::TableArg::Column(_)) { + count = count.saturating_add(1); + } + } + } + return Some(count); + } + } + + if let Some(view_name_ptr) = resolve_view_name_ptr(binder, table_name, schema, position) { + let view_name_node = view_name_ptr.to_node(root); + + if let Some(create_view) = view_name_node.ancestors().find_map(ast::CreateView::cast) { + if let Some(column_list) = create_view.column_list() { + return Some(column_list.columns().count()); + } + + let select = match create_view.query()? { + ast::SelectVariant::Select(s) => s, + ast::SelectVariant::ParenSelect(ps) => match ps.select()? { + ast::SelectVariant::Select(s) => s, + _ => return None, + }, + _ => return None, + }; + + if let Some(target_list) = select.select_clause().and_then(|c| c.target_list()) { + // This is not quite right if there's a `*` in the view definition. + // It becomes recursive. + // For now, let's assume simple views. + return Some(target_list.targets().count()); + } + } + } + None +} + fn resolve_cte_column( binder: &Binder, root: &SyntaxNode, @@ -1427,8 +1682,8 @@ fn resolve_cte_column( continue; } - if column_list_len == 0 - && let Some(column) = column_in_with_query(&query, binder, root, column_name) + if let Some(column) = + column_in_with_query(&query, binder, root, column_name, column_list_len) { return Some(column); } @@ -1507,6 +1762,7 @@ fn column_in_with_query( binder: &Binder, root: &SyntaxNode, column_name: &Name, + column_list_len: usize, ) -> Option { let (returning_clause, path) = match query { ast::WithQuery::Delete(delete) => ( @@ -1531,7 +1787,20 @@ fn column_in_with_query( let target_list = returning_clause?.target_list()?; let path = path?; + let mut column_index: usize = 0; for target in target_list.targets() { + let target_column_count = if target.star_token().is_some() { + count_columns_for_path(binder, root, &path).unwrap_or(1) + } else { + 1 + }; + let column_list_end = column_index.saturating_add(target_column_count); + + if column_list_end <= column_list_len { + column_index = column_list_end; + continue; + } + if let Some((col_name, node)) = ColumnName::from_target(target) { if let Some(col_name_str) = col_name.to_string() && Name::from_string(col_name_str) == *column_name @@ -1544,6 +1813,7 @@ fn column_in_with_query( return Some(ptr); } } + column_index = column_list_end; } None @@ -1575,7 +1845,7 @@ fn resolve_subquery_column( if let Some(from_clause) = subquery_select.from_clause() { for from_item in from_clause.from_items() { if let Some(result) = - resolve_from_item_for_column(binder, root, &from_item, name_ref) + resolve_from_item_column_ptr(binder, root, &from_item, name_ref) { return Some(result); } @@ -1584,7 +1854,7 @@ fn resolve_subquery_column( for join_expr in from_clause.join_exprs() { if let Some(result) = resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| { - resolve_from_item_for_column(binder, root, from_item, name_ref) + resolve_from_item_column_ptr(binder, root, from_item, name_ref) }) { return Some(result); @@ -1599,7 +1869,7 @@ fn resolve_subquery_column( && let Some(table_name) = qualified_star_table_name(&field_expr) && let Some(from_clause) = subquery_select.from_clause() && let Some(from_item) = find_from_item_in_from_clause(&from_clause, &table_name) - && let Some(result) = resolve_from_item_for_column(binder, root, &from_item, name_ref) + && let Some(result) = resolve_from_item_column_ptr(binder, root, &from_item, name_ref) { return Some(result); } @@ -1635,12 +1905,12 @@ pub(crate) fn resolve_qualified_star_table( let (table_name, schema) = table_and_schema_from_from_item(&from_item)?; let position = field_expr.syntax().text_range().start(); - if let Some(ptr) = resolve_table(binder, &table_name, &schema, position) { - return Some(ptr); + if let Some(table_name_ptr) = resolve_table_name_ptr(binder, &table_name, &schema, position) { + return Some(table_name_ptr); } - if let Some(ptr) = resolve_view(binder, &table_name, &schema, position) { - return Some(ptr); + if let Some(view_name_ptr) = resolve_view_name_ptr(binder, &table_name, &schema, position) { + return Some(view_name_ptr); } if schema.is_none() @@ -1740,13 +2010,13 @@ fn collect_tables_from_item( return; }; - if let Some(ptr) = resolve_table(binder, &table_name, &schema, position) { - results.push(ptr); + if let Some(table_name_ptr) = resolve_table_name_ptr(binder, &table_name, &schema, position) { + results.push(table_name_ptr); return; } - if let Some(ptr) = resolve_view(binder, &table_name, &schema, position) { - results.push(ptr); + if let Some(view_name_ptr) = resolve_view_name_ptr(binder, &table_name, &schema, position) { + results.push(view_name_ptr); return; } @@ -1900,46 +2170,9 @@ fn count_columns_for_from_item( let (table_name, schema) = table_and_schema_from_from_item(from_item)?; let position = name_ref.syntax().text_range().start(); - if let Some(table_ptr) = resolve_table(binder, &table_name, &schema, position) { - let table_name_node = table_ptr.to_node(root); - - if let Some(create_table) = table_name_node - .ancestors() - .find_map(ast::CreateTableLike::cast) - { - let mut count: usize = 0; - if let Some(args) = create_table.table_arg_list() { - for arg in args.args() { - if matches!(arg, ast::TableArg::Column(_)) { - count = count.saturating_add(1); - } - } - } - return Some(count); - } - } - - if let Some(view_ptr) = resolve_view(binder, &table_name, &schema, position) { - let view_name_node = view_ptr.to_node(root); - - if let Some(create_view) = view_name_node.ancestors().find_map(ast::CreateView::cast) { - if let Some(column_list) = create_view.column_list() { - return Some(column_list.columns().count()); - } - - let select = match create_view.query()? { - ast::SelectVariant::Select(s) => s, - ast::SelectVariant::ParenSelect(ps) => match ps.select()? { - ast::SelectVariant::Select(s) => s, - _ => return None, - }, - _ => return None, - }; - - if let Some(target_list) = select.select_clause().and_then(|c| c.target_list()) { - return Some(target_list.targets().count()); - } - } + if let Some(count) = count_columns_for_table_name(binder, root, &table_name, &schema, position) + { + return Some(count); } if schema.is_none() @@ -2083,8 +2316,8 @@ pub(crate) fn resolve_insert_create_table( let schema = extract_schema_name(&path); let position = insert.syntax().text_range().start(); - let table_ptr = resolve_table(binder, &table_name, &schema, position)?; - let table_name_node = table_ptr.to_node(root); + let table_name_ptr = resolve_table_name_ptr(binder, &table_name, &schema, position)?; + let table_name_node = table_name_ptr.to_node(root); table_name_node .ancestors() @@ -2381,18 +2614,21 @@ fn unwrap_paren_expr(expr: ast::Expr) -> Option { None } -fn resolve_composite_type_field( +fn resolve_composite_type_field_ptr( binder: &Binder, root: &SyntaxNode, - name_ref: &ast::NameRef, + field_name_ref: &ast::NameRef, ) -> Option { - let field_name = Name::from_node(name_ref); - let field_expr = name_ref.syntax().parent().and_then(ast::FieldExpr::cast)?; + let field_name = Name::from_node(field_name_ref); + let field_expr = field_name_ref + .syntax() + .parent() + .and_then(ast::FieldExpr::cast)?; let base = field_expr.base()?; let base_name_ref = unwrap_paren_expr(base)?; - let column_ptr = resolve_select_column(binder, root, &base_name_ref)?; + let column_ptr = resolve_select_column_ptr(binder, root, &base_name_ref)?; let column_node = column_ptr.to_node(root); let (type_name, schema) = @@ -2402,9 +2638,9 @@ fn resolve_composite_type_field( resolve_composite_type_from_cast_node(&column_node)? }; - let position = name_ref.syntax().text_range().start(); - let type_ptr = resolve_type(binder, &type_name, &schema, position)?; - let type_node = type_ptr.to_node(root); + let position = field_name_ref.syntax().text_range().start(); + let type_name_ptr = resolve_type_name_ptr(binder, &type_name, &schema, position)?; + let type_node = type_name_ptr.to_node(root); let create_type = type_node.ancestors().find_map(ast::CreateType::cast)?; let column_list = create_type.column_list()?; @@ -2463,3 +2699,185 @@ fn extract_type_name_and_schema(ty: &ast::Type) -> Option<(Name, Option) _ => None, } } + +fn resolve_merge_when_column_ptr( + binder: &Binder, + root: &SyntaxNode, + table_name_ref: &ast::NameRef, +) -> Option { + let column_name = Name::from_node(table_name_ref); + let merge = table_name_ref + .syntax() + .ancestors() + .find_map(ast::Merge::cast)?; + + // Try resolving in source table first + if let Some(using_on) = merge.using_on_clause() + && let Some(from_item) = using_on.from_item() + && let Some(ptr) = resolve_from_item_column_ptr(binder, root, &from_item, table_name_ref) + { + return Some(ptr); + } + + let relation_name = merge.relation_name()?; + let path = relation_name.path()?; + resolve_column_for_path(binder, root, &path, column_name) +} + +fn resolve_merge_column_ptr( + binder: &Binder, + root: &SyntaxNode, + column_name_ref: &ast::NameRef, +) -> Option { + let column_name = Name::from_node(column_name_ref); + let merge = column_name_ref + .syntax() + .ancestors() + .find_map(ast::Merge::cast)?; + + // Try resolving in source table first + if let Some(using_on) = merge.using_on_clause() + && let Some(from_item) = using_on.from_item() + && let Some(ptr) = resolve_from_item_column_ptr(binder, root, &from_item, column_name_ref) + { + return Some(ptr); + } + + let relation_name = merge.relation_name()?; + let path = relation_name.path()?; + resolve_column_for_path(binder, root, &path, column_name) +} + +fn resolve_insert_table_name_ptr( + binder: &Binder, + table_name_ref: &ast::NameRef, +) -> Option { + let table_name = Name::from_node(table_name_ref); + let insert = table_name_ref + .syntax() + .ancestors() + .find_map(ast::Insert::cast)?; + + if let Some(alias) = insert.alias() + && let Some(alias_name) = alias.name() + && Name::from_node(&alias_name) == table_name + { + return Some(SyntaxNodePtr::new(alias_name.syntax())); + } + + let path = insert.path()?; + let schema = extract_schema_name(&path); + let position = table_name_ref.syntax().text_range().start(); + resolve_table_name_ptr(binder, &table_name, &schema, position) +} + +fn resolve_delete_table_name_ptr( + binder: &Binder, + table_name_ref: &ast::NameRef, +) -> Option { + let table_name = Name::from_node(table_name_ref); + let delete = table_name_ref + .syntax() + .ancestors() + .find_map(ast::Delete::cast)?; + + if let Some(alias) = delete.alias() + && let Some(alias_name) = alias.name() + && Name::from_node(&alias_name) == table_name + { + return Some(SyntaxNodePtr::new(alias_name.syntax())); + } + + let relation_name = delete.relation_name()?; + let path = relation_name.path()?; + let schema = extract_schema_name(&path); + let position = table_name_ref.syntax().text_range().start(); + resolve_table_name_ptr(binder, &table_name, &schema, position) +} + +fn resolve_update_table_name_ptr( + binder: &Binder, + table_name_ref: &ast::NameRef, +) -> Option { + let table_name = Name::from_node(table_name_ref); + let update = table_name_ref + .syntax() + .ancestors() + .find_map(ast::Update::cast)?; + + if let Some(alias) = update.alias() + && let Some(alias_name) = alias.name() + && Name::from_node(&alias_name) == table_name + { + return Some(SyntaxNodePtr::new(alias_name.syntax())); + } + + let relation_name = update.relation_name()?; + let path = relation_name.path()?; + let schema = extract_schema_name(&path); + let position = table_name_ref.syntax().text_range().start(); + resolve_table_name_ptr(binder, &table_name, &schema, position) +} + +fn resolve_merge_table_name_ptr( + binder: &Binder, + table_name_ref: &ast::NameRef, +) -> Option { + let table_name = Name::from_node(table_name_ref); + let merge = table_name_ref + .syntax() + .ancestors() + .find_map(ast::Merge::cast)?; + + let relation_name = merge.relation_name()?; + let path = relation_name.path()?; + let merge_table_name = extract_table_name(&path)?; + + // Check target table alias + if let Some(alias) = merge.alias() + && let Some(alias_name) = alias.name() + && Name::from_node(&alias_name) == table_name + { + return Some(SyntaxNodePtr::new(alias_name.syntax())); + } + + // Check USING clause for the source table + if let Some(using_on) = merge.using_on_clause() + && let Some(from_item) = using_on.from_item() + { + if let Some(item_name_ref) = from_item.name_ref() { + let item_name = Name::from_node(&item_name_ref); + if item_name == table_name { + let position = table_name_ref.syntax().text_range().start(); + return resolve_table_name_ptr(binder, &item_name, &None, position); + } + } + // Check for aliased source tables + if let Some(alias) = from_item.alias() + && let Some(alias_name) = alias.name() + && Name::from_node(&alias_name) == table_name + { + return Some(SyntaxNodePtr::new(alias_name.syntax())); + } + } + + if merge_table_name == table_name { + let schema = extract_schema_name(&path); + let position = table_name_ref.syntax().text_range().start(); + return resolve_table_name_ptr(binder, &table_name, &schema, position); + } + + // Check for OLD and NEW pseudo-tables in MERGE RETURNING clause (fallback) + let old_name = Name::from_string("old"); + let new_name = Name::from_string("new"); + if table_name == old_name || table_name == new_name { + // Both OLD and NEW refer to the target table + let schema = extract_schema_name(&path); + let position = table_name_ref.syntax().text_range().start(); + return resolve_table_name_ptr(binder, &merge_table_name, &schema, position); + } + + let schema = extract_schema_name(&path); + let position = table_name_ref.syntax().text_range().start(); + resolve_table_name_ptr(binder, &table_name, &schema, position) +}