diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index a10de07e..30335efa 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -6037,4 +6037,256 @@ merge into x ╰╴ ─ 1. source "); } + + #[test] + fn goto_insert_returning_old_table() { + assert_snapshot!(goto(" +create table t(a int, b int); +insert into t values (1, 2), (3, 4) +returning old$0.a, new.b; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ insert into t values (1, 2), (3, 4) + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_insert_returning_old_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +insert into t values (1, 2), (3, 4) +returning old.a$0, new.b; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ insert into t values (1, 2), (3, 4) + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_insert_returning_new_table() { + assert_snapshot!(goto(" +create table t(a int, b int); +insert into t values (1, 2), (3, 4) +returning old.a, new$0.b; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ insert into t values (1, 2), (3, 4) + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_insert_returning_new_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +insert into t values (1, 2), (3, 4) +returning old.a, new.b$0; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ insert into t values (1, 2), (3, 4) + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_returning_old_table() { + assert_snapshot!(goto(" +create table t(a int, b int); +update t set a = 42 +returning old$0.a, new.b; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ update t set a = 42 + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_returning_old_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +update t set a = 42 +returning old.a$0, new.b; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ update t set a = 42 + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_returning_new_table() { + assert_snapshot!(goto(" +create table t(a int, b int); +update t set a = 42 +returning old.a, new$0.b; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ update t set a = 42 + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_returning_new_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +update t set a = 42 +returning old.a, new.b$0; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ update t set a = 42 + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_delete_returning_old_table() { + assert_snapshot!(goto(" +create table t(a int, b int); +delete from t +returning old$0.a, new.b; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ delete from t + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_delete_returning_old_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +delete from t +returning old.a$0, new.b; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ delete from t + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_delete_returning_new_table() { + assert_snapshot!(goto(" +create table t(a int, b int); +delete from t +returning old.a, new$0.b; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ delete from t + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_delete_returning_new_column() { + assert_snapshot!(goto(" +create table t(a int, b int); +delete from t +returning old.a, new.b$0; +" + ), @r" + ╭▸ + 2 │ create table t(a int, b int); + │ ─ 2. destination + 3 │ delete from t + 4 │ returning old.a, new.b; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_insert_as_old_alias() { + assert_snapshot!(goto(" +create table t(a int, b int); +insert into t as old values (1, 2) +returning old$0.a, new.a; +" + ), @r" + ╭▸ + 3 │ insert into t as old values (1, 2) + │ ─── 2. destination + 4 │ returning old.a, new.a; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_delete_as_old_alias() { + assert_snapshot!(goto(" +create table t(a int, b int); +delete from t as old +returning old$0.a, new.a; +" + ), @r" + ╭▸ + 3 │ delete from t as old + │ ─── 2. destination + 4 │ returning old.a, new.a; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_update_as_old_alias() { + assert_snapshot!(goto(" +create table t(a int, b int); +update t as old set a = 42 +returning old$0.a, new.a; +" + ), @r" + ╭▸ + 3 │ update t as old set a = 42 + │ ─── 2. destination + 4 │ returning old.a, new.a; + ╰╴ ─ 1. source + "); + } } diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 8b58f765..a538d783 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -383,12 +383,11 @@ pub(crate) fn resolve_name_ref( NameRefClass::InsertReturningColumn => { resolve_insert_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) } - NameRefClass::MergeReturningColumn | NameRefClass::MergeOnColumn => { + NameRefClass::MergeWhenColumn + | NameRefClass::MergeReturningColumn + | NameRefClass::MergeOnColumn => { resolve_merge_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) } - NameRefClass::MergeWhenColumn => { - resolve_merge_when_column_ptr(binder, root, name_ref).map(|ptr| smallvec![ptr]) - } NameRefClass::MergeQualifiedColumnTable | NameRefClass::MergeReturningQualifiedColumnTable => { resolve_merge_table_name_ptr(binder, name_ref).map(|ptr| smallvec![ptr]) @@ -403,8 +402,7 @@ pub(crate) fn resolve_name_ref( .syntax() .ancestors() .find_map(ast::AlterTable::cast)?; - let relation_name = alter_table.relation_name()?; - let path = relation_name.path()?; + let path = alter_table.relation_name()?.path()?; resolve_column_for_path(binder, root, &path, column_name).map(|ptr| smallvec![ptr]) } NameRefClass::ReindexSchema => { @@ -635,8 +633,7 @@ fn resolve_create_index_column_ptr( .syntax() .ancestors() .find_map(ast::CreateIndex::cast)?; - let relation_name = create_index.relation_name()?; - let path = relation_name.path()?; + let path = create_index.relation_name()?.path()?; resolve_column_for_path(binder, root, &path, column_name) } @@ -762,6 +759,56 @@ fn resolve_select_qualified_column_table_name_ptr( resolve_table_name_ptr(binder, &table_name, &schema, position) } +enum ReturningClauseMatch { + ReturningAlias(ast::Name), + TableAlias(ast::Name), + PseudoTable, + Table, +} + +fn match_table_in_returning_clause( + table_name: &Name, + stmt_table_name: &Name, + alias: Option, + returning_clause: Option, +) -> Option { + // Check `returning with (old as alias, new as alias)` + if let Some(returning_clause) = returning_clause + && let Some(option_list) = returning_clause.returning_option_list() + { + for option in option_list.returning_options() { + if let Some(name) = option.name() + && Name::from_node(&name) == *table_name + { + return Some(ReturningClauseMatch::ReturningAlias(name)); + } + } + } + + if let Some(alias) = alias + && let Some(alias_name) = alias.name() + && Name::from_node(&alias_name) == *table_name + { + return Some(ReturningClauseMatch::TableAlias(alias_name)); + } + + let old_name = Name::from_string("old"); + let new_name = Name::from_string("new"); + if *table_name == old_name || *table_name == new_name { + return Some(ReturningClauseMatch::PseudoTable); + } + + if *stmt_table_name == *table_name { + return Some(ReturningClauseMatch::Table); + } + + None +} + +fn extract_table_schema_from_path(path: &ast::Path) -> Option<(Name, Option)> { + Some((extract_table_name(path)?, extract_schema_name(path))) +} + // TODO: this is similar to resolve_from_item_for_column, maybe we can simplify fn resolve_select_qualified_column_ptr( binder: &Binder, @@ -805,62 +852,54 @@ fn resolve_select_qualified_column_ptr( .ancestors() .find_map(ast::Merge::cast) { - // Handle MERGE statements - let relation_name = merge.relation_name()?; - let path = relation_name.path()?; - let merge_table_name = extract_table_name(&path)?; - - // Check `returning with (old as alias, new as alias)` - if let Some(returning_clause) = merge.returning_clause() - && let Some(option_list) = returning_clause.returning_option_list() - && is_table_name_in_option_list(option_list, &column_table_name) + let found_in_using = if let Some(using_on) = merge.using_on_clause() + && let Some(from_item) = using_on.from_item() { - (merge_table_name, None) - // Check if this is the alias or the table name - } else if let Some(alias) = merge.alias() - && let Some(alias_name) = alias.name() - && Name::from_node(&alias_name) == column_table_name - { - (merge_table_name, extract_schema_name(&path)) - } else if merge_table_name == column_table_name { - (merge_table_name, extract_schema_name(&path)) - } else { - // Try to find in USING clause first - let found_in_using = if let Some(using_on) = merge.using_on_clause() - && let Some(from_item) = using_on.from_item() + if let Some(item_name_ref) = from_item.name_ref() + && let item_name = Name::from_node(&item_name_ref) + && item_name == column_table_name { - if let Some(item_name_ref) = from_item.name_ref() - && let item_name = Name::from_node(&item_name_ref) - && item_name == column_table_name - { - Some((item_name, None)) - } else if let Some(alias) = from_item.alias() - && let Some(alias_name) = alias.name() - && let alias_name = Name::from_node(&alias_name) - && alias_name == column_table_name - { - Some((alias_name, None)) - } else { - None - } + Some((item_name, None)) + } else if let Some(alias) = from_item.alias() + && let Some(alias_name) = alias.name() + && let alias_name = Name::from_node(&alias_name) + && alias_name == column_table_name + { + Some((alias_name, None)) } else { None - }; - - if let Some(result) = found_in_using { - result - } else { - // Fallback: check for OLD and NEW pseudo-tables in MERGE RETURNING clause - let old_name = Name::from_string("old"); - let new_name = Name::from_string("new"); - if column_table_name == old_name || column_table_name == new_name { - // Both OLD and NEW refer to the target table - (merge_table_name, extract_schema_name(&path)) - } else { - return None; - } } + } else { + None + }; + + if let Some(result) = found_in_using { + result + } else { + let path = merge.relation_name()?.path()?; + extract_table_schema_from_path(&path)? } + } else if let Some(insert) = column_name_ref + .syntax() + .ancestors() + .find_map(ast::Insert::cast) + { + let path = insert.path()?; + extract_table_schema_from_path(&path)? + } else if let Some(update) = column_name_ref + .syntax() + .ancestors() + .find_map(ast::Update::cast) + { + let path = update.relation_name()?.path()?; + extract_table_schema_from_path(&path)? + } else if let Some(delete) = column_name_ref + .syntax() + .ancestors() + .find_map(ast::Delete::cast) + { + let path = delete.relation_name()?.path()?; + extract_table_schema_from_path(&path)? } else { let select = column_name_ref .syntax() @@ -984,18 +1023,6 @@ fn resolve_select_qualified_column_ptr( None } -fn is_table_name_in_option_list(option_list: ast::ReturningOptionList, table_name: &Name) -> bool { - for option in option_list.returning_options() { - if let Some(name) = option.name() - && let option_name = Name::from_node(&name) - && option_name == *table_name - { - return true; - } - } - false -} - enum ResolvedTableName { View(ast::CreateView), Table(ast::CreateTableLike), @@ -1232,8 +1259,7 @@ fn resolve_delete_column_ptr( .syntax() .ancestors() .find_map(ast::Delete::cast)?; - let relation_name = delete.relation_name()?; - let path = relation_name.path()?; + let path = delete.relation_name()?.path()?; resolve_column_for_path(binder, root, &path, column_name) } @@ -1315,8 +1341,7 @@ fn resolve_update_column_ptr( } // `update t set a = b` - let relation_name = update.relation_name()?; - let path = relation_name.path()?; + let path = update.relation_name()?.path()?; resolve_column_for_path(binder, root, &path, column_name) } @@ -2718,13 +2743,13 @@ fn extract_type_name_and_schema(ty: &ast::Type) -> Option<(Name, Option) } } -fn resolve_merge_when_column_ptr( +fn resolve_merge_column_ptr( binder: &Binder, root: &SyntaxNode, - table_name_ref: &ast::NameRef, + column_name_ref: &ast::NameRef, ) -> Option { - let column_name = Name::from_node(table_name_ref); - let merge = table_name_ref + let column_name = Name::from_node(column_name_ref); + let merge = column_name_ref .syntax() .ancestors() .find_map(ast::Merge::cast)?; @@ -2732,109 +2757,99 @@ fn resolve_merge_when_column_ptr( // Try resolving in source table first if let Some(using_on) = merge.using_on_clause() && let Some(from_item) = using_on.from_item() - && let Some(ptr) = resolve_from_item_column_ptr(binder, root, &from_item, table_name_ref) + && let Some(ptr) = resolve_from_item_column_ptr(binder, root, &from_item, column_name_ref) { return Some(ptr); } - let relation_name = merge.relation_name()?; - let path = relation_name.path()?; + let path = merge.relation_name()?.path()?; resolve_column_for_path(binder, root, &path, column_name) } -fn resolve_merge_column_ptr( +// TODO: I think we could use trait(s) here to simplify this and have the +// callers pass in the stmt instead of the fields. +fn resolve_table_in_returning_clause( binder: &Binder, - root: &SyntaxNode, - column_name_ref: &ast::NameRef, + table_name_ref: &ast::NameRef, + alias: Option, + path: &ast::Path, + returning_clause: Option, ) -> Option { - let column_name = Name::from_node(column_name_ref); - let merge = column_name_ref - .syntax() - .ancestors() - .find_map(ast::Merge::cast)?; + let table_name = Name::from_node(table_name_ref); + let stmt_table_name = extract_table_name(path)?; - // Try resolving in source table first - if let Some(using_on) = merge.using_on_clause() - && let Some(from_item) = using_on.from_item() - && let Some(ptr) = resolve_from_item_column_ptr(binder, root, &from_item, column_name_ref) - { - return Some(ptr); - } + let matched = + match_table_in_returning_clause(&table_name, &stmt_table_name, alias, returning_clause)?; - let relation_name = merge.relation_name()?; - let path = relation_name.path()?; - resolve_column_for_path(binder, root, &path, column_name) + let schema = extract_schema_name(path); + let position = table_name_ref.syntax().text_range().start(); + + match matched { + ReturningClauseMatch::ReturningAlias(name) => Some(SyntaxNodePtr::new(name.syntax())), + ReturningClauseMatch::TableAlias(alias_name) => { + Some(SyntaxNodePtr::new(alias_name.syntax())) + } + ReturningClauseMatch::PseudoTable => { + resolve_table_name_ptr(binder, &stmt_table_name, &schema, position) + } + ReturningClauseMatch::Table => { + resolve_table_name_ptr(binder, &table_name, &schema, position) + } + } } fn resolve_insert_table_name_ptr( binder: &Binder, table_name_ref: &ast::NameRef, ) -> Option { - let table_name = Name::from_node(table_name_ref); let insert = table_name_ref .syntax() .ancestors() .find_map(ast::Insert::cast)?; - - if let Some(alias) = insert.alias() - && let Some(alias_name) = alias.name() - && Name::from_node(&alias_name) == table_name - { - return Some(SyntaxNodePtr::new(alias_name.syntax())); - } - let path = insert.path()?; - let schema = extract_schema_name(&path); - let position = table_name_ref.syntax().text_range().start(); - resolve_table_name_ptr(binder, &table_name, &schema, position) + resolve_table_in_returning_clause( + binder, + table_name_ref, + insert.alias(), + &path, + insert.returning_clause(), + ) } fn resolve_delete_table_name_ptr( binder: &Binder, table_name_ref: &ast::NameRef, ) -> Option { - let table_name = Name::from_node(table_name_ref); let delete = table_name_ref .syntax() .ancestors() .find_map(ast::Delete::cast)?; - - if let Some(alias) = delete.alias() - && let Some(alias_name) = alias.name() - && Name::from_node(&alias_name) == table_name - { - return Some(SyntaxNodePtr::new(alias_name.syntax())); - } - - let relation_name = delete.relation_name()?; - let path = relation_name.path()?; - let schema = extract_schema_name(&path); - let position = table_name_ref.syntax().text_range().start(); - resolve_table_name_ptr(binder, &table_name, &schema, position) + let path = delete.relation_name()?.path()?; + resolve_table_in_returning_clause( + binder, + table_name_ref, + delete.alias(), + &path, + delete.returning_clause(), + ) } fn resolve_update_table_name_ptr( binder: &Binder, table_name_ref: &ast::NameRef, ) -> Option { - let table_name = Name::from_node(table_name_ref); let update = table_name_ref .syntax() .ancestors() .find_map(ast::Update::cast)?; - - if let Some(alias) = update.alias() - && let Some(alias_name) = alias.name() - && Name::from_node(&alias_name) == table_name - { - return Some(SyntaxNodePtr::new(alias_name.syntax())); - } - - let relation_name = update.relation_name()?; - let path = relation_name.path()?; - let schema = extract_schema_name(&path); - let position = table_name_ref.syntax().text_range().start(); - resolve_table_name_ptr(binder, &table_name, &schema, position) + let path = update.relation_name()?.path()?; + resolve_table_in_returning_clause( + binder, + table_name_ref, + update.alias(), + &path, + update.returning_clause(), + ) } fn resolve_merge_table_name_ptr( @@ -2847,32 +2862,9 @@ fn resolve_merge_table_name_ptr( .ancestors() .find_map(ast::Merge::cast)?; - let relation_name = merge.relation_name()?; - let path = relation_name.path()?; - let merge_table_name = extract_table_name(&path)?; + let path = merge.relation_name()?.path()?; - // Check `returning with (old as alias, new as alias)` - if let Some(returning_clause) = merge.returning_clause() - && let Some(option_list) = returning_clause.returning_option_list() - { - for option in option_list.returning_options() { - if let Some(name) = option.name() - && Name::from_node(&name) == table_name - { - return Some(SyntaxNodePtr::new(name.syntax())); - } - } - } - - // Check target table alias - if let Some(alias) = merge.alias() - && let Some(alias_name) = alias.name() - && Name::from_node(&alias_name) == table_name - { - return Some(SyntaxNodePtr::new(alias_name.syntax())); - } - - // Check USING clause for the source table + // Check USING clause for the source table - MERGE-specific if let Some(using_on) = merge.using_on_clause() && let Some(from_item) = using_on.from_item() { @@ -2892,23 +2884,11 @@ fn resolve_merge_table_name_ptr( } } - if merge_table_name == table_name { - let schema = extract_schema_name(&path); - let position = table_name_ref.syntax().text_range().start(); - return resolve_table_name_ptr(binder, &table_name, &schema, position); - } - - // Check for OLD and NEW pseudo-tables in MERGE RETURNING clause (fallback) - let old_name = Name::from_string("old"); - let new_name = Name::from_string("new"); - if table_name == old_name || table_name == new_name { - // Both OLD and NEW refer to the target table - let schema = extract_schema_name(&path); - let position = table_name_ref.syntax().text_range().start(); - return resolve_table_name_ptr(binder, &merge_table_name, &schema, position); - } - - let schema = extract_schema_name(&path); - let position = table_name_ref.syntax().text_range().start(); - resolve_table_name_ptr(binder, &table_name, &schema, position) + resolve_table_in_returning_clause( + binder, + table_name_ref, + merge.alias(), + &path, + merge.returning_clause(), + ) }