diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 9e70589b..a10de07e 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -5951,4 +5951,90 @@ merge into old ╰╴ ─ 1. source "); } + + #[test] + fn goto_merge_returning_with_aliases_before_table() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x + using y on true + when matched then do nothing + returning + with (old as before, new as after) + before$0.a, after.a; +" + ), @r" + ╭▸ + 8 │ with (old as before, new as after) + │ ────── 2. destination + 9 │ before.a, after.a; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_returning_with_aliases_before_column() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x + using y on true + when matched then do nothing + returning + with (old as before, new as after) + before.a$0, after.a; +" + ), @r" + ╭▸ + 2 │ create table x(a int, b int); + │ ─ 2. destination + ‡ + 9 │ before.a, after.a; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_returning_with_aliases_after_table() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x + using y on true + when matched then do nothing + returning + with (old as before, new as after) + before.a, after$0.a; +" + ), @r" + ╭▸ + 8 │ with (old as before, new as after) + │ ───── 2. destination + 9 │ before.a, after.a; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_merge_returning_with_aliases_after_column() { + assert_snapshot!(goto(" +create table x(a int, b int); +create table y(c int, d int); +merge into x + using y on true + when matched then do nothing + returning + with (old as before, new as after) + before.a, after.a$0; +" + ), @r" + ╭▸ + 2 │ create table x(a int, b int); + │ ─ 2. destination + ‡ + 9 │ before.a, after.a; + ╰╴ ─ 1. source + "); + } } diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 071dc467..8b58f765 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -810,8 +810,14 @@ fn resolve_select_qualified_column_ptr( let path = relation_name.path()?; let merge_table_name = extract_table_name(&path)?; - // Check if this is the alias or the table name - if let Some(alias) = merge.alias() + // Check `returning with (old as alias, new as alias)` + if let Some(returning_clause) = merge.returning_clause() + && let Some(option_list) = returning_clause.returning_option_list() + && is_table_name_in_option_list(option_list, &column_table_name) + { + (merge_table_name, None) + // Check if this is the alias or the table name + } else if let Some(alias) = merge.alias() && let Some(alias_name) = alias.name() && Name::from_node(&alias_name) == column_table_name { @@ -978,6 +984,18 @@ fn resolve_select_qualified_column_ptr( None } +fn is_table_name_in_option_list(option_list: ast::ReturningOptionList, table_name: &Name) -> bool { + for option in option_list.returning_options() { + if let Some(name) = option.name() + && let option_name = Name::from_node(&name) + && option_name == *table_name + { + return true; + } + } + false +} + enum ResolvedTableName { View(ast::CreateView), Table(ast::CreateTableLike), @@ -2833,6 +2851,19 @@ fn resolve_merge_table_name_ptr( let path = relation_name.path()?; let merge_table_name = extract_table_name(&path)?; + // Check `returning with (old as alias, new as alias)` + if let Some(returning_clause) = merge.returning_clause() + && let Some(option_list) = returning_clause.returning_option_list() + { + for option in option_list.returning_options() { + if let Some(name) = option.name() + && Name::from_node(&name) == table_name + { + return Some(SyntaxNodePtr::new(name.syntax())); + } + } + } + // Check target table alias if let Some(alias) = merge.alias() && let Some(alias_name) = alias.name()