diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 50fe5ecc..10d95904 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -3686,6 +3686,178 @@ select foo$0(); "); } + #[test] + fn goto_select_column_from_function_return_table() { + assert_snapshot!(goto(r#" +create function dup(int) returns table(f1 int, f2 text) + as '' + language sql; + +select f1$0 from dup(42); +"#), @r" + ╭▸ + 2 │ create function dup(int) returns table(f1 int, f2 text) + │ ── 2. destination + ‡ + 6 │ select f1 from dup(42); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_select_column_from_function_return_table_with_schema() { + assert_snapshot!(goto(r#" +create function myschema.dup(int) returns table(f1 int, f2 text) + as '' + language sql; +create function otherschema.dup(int) returns table(f1 int, f2 text) + as '' + language sql; + +select f1$0 from myschema.dup(42); +"#), @r" + ╭▸ + 2 │ create function myschema.dup(int) returns table(f1 int, f2 text) + │ ── 2. destination + ‡ + 9 │ select f1 from myschema.dup(42); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_select_column_from_function_return_table_paren() { + assert_snapshot!(goto(r#" +create function dup(int) returns table(f1 int, f2 text) + as '' + language sql; + +select (dup(42)).f2$0; +"#), @r" + ╭▸ + 2 │ create function dup(int) returns table(f1 int, f2 text) + │ ── 2. destination + ‡ + 6 │ select (dup(42)).f2; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_select_column_from_function_return_table_qualified() { + assert_snapshot!(goto(r#" +create function dup(int) returns table(f1 int, f2 text) + as '' + language sql; + +select dup.f1$0 from dup(42); +"#), @r" + ╭▸ + 2 │ create function dup(int) returns table(f1 int, f2 text) + │ ── 2. destination + ‡ + 6 │ select dup.f1 from dup(42); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_select_column_from_function_return_table_qualified_function_name() { + assert_snapshot!(goto(r#" +create function dup(int) returns table(f1 int, f2 text) + as '' + language sql; + +select dup$0.f1 from dup(42); +"#), @r" + ╭▸ + 2 │ create function dup(int) returns table(f1 int, f2 text) + │ ─── 2. destination + ‡ + 6 │ select dup.f1 from dup(42); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_select_column_from_function_return_table_qualified_function_name_with_alias() { + assert_snapshot!(goto(r#" +create function dup(int) returns table(f1 int, f2 text) + as '' + language sql; + +select dup$0.f2 from dup(42) as dup; +"#), @r" + ╭▸ + 6 │ select dup.f2 from dup(42) as dup; + ╰╴ ─ 1. source ─── 2. destination + "); + } + + #[test] + fn goto_select_column_from_function_return_table_alias_list() { + assert_snapshot!(goto(r#" +create function dup(int) returns table(f1 int, f2 text) + as '' + language sql; + +select a$0 from dup(42) t(a, b); +"#), @r" + ╭▸ + 6 │ select a from dup(42) t(a, b); + ╰╴ ─ 1. source ─ 2. destination + "); + } + + #[test] + fn goto_select_column_from_function_return_table_alias_list_qualified_partial() { + assert_snapshot!(goto(r#" +create function dup(int) returns table(f1 int, f2 text) + as '' + language sql; + +select u.f2$0 from dup(42) as u(x); +"#), @r" + ╭▸ + 2 │ create function dup(int) returns table(f1 int, f2 text) + │ ── 2. destination + ‡ + 6 │ select u.f2 from dup(42) as u(x); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_select_column_from_function_return_table_alias_list_unqualified_partial() { + assert_snapshot!(goto(r#" +create function dup(int) returns table(f1 int, f2 text) + as '' + language sql; + +select f2$0 from dup(42) as u(x); +"#), @r" + ╭▸ + 2 │ create function dup(int) returns table(f1 int, f2 text) + │ ── 2. destination + ‡ + 6 │ select f2 from dup(42) as u(x); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_select_column_from_function_return_table_alias_list_unqualified_not_found() { + goto_not_found( + r#" +create function dup(int) returns table(f1 int, f2 text) + as '' + language sql; + +select f2$0 from dup(42) as u(x, y); +"#, + ); + } + #[test] fn goto_select_aggregate_call() { assert_snapshot!(goto(" diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 69a5e906..d1932ce9 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -835,11 +835,6 @@ fn resolve_select_qualified_column_table_name_ptr( None }; - if let Some(schema) = explicit_schema { - let position = table_name_ref.syntax().text_range().start(); - return resolve_table_name_ptr(binder, &table_name, &Some(schema), position); - } - let select = table_name_ref .syntax() .ancestors() @@ -853,6 +848,21 @@ fn resolve_select_qualified_column_table_name_ptr( return Some(SyntaxNodePtr::new(alias_name.syntax())); } + if let Some(call_expr) = from_item.call_expr() + && let Some((function_name, function_schema)) = + function_name_and_schema_from_call_expr(&call_expr) + && function_name == table_name + && function_schema == explicit_schema + { + let position = table_name_ref.syntax().text_range().start(); + return resolve_function(binder, &function_name, &function_schema, None, position); + } + + if let Some(schema) = explicit_schema { + let position = table_name_ref.syntax().text_range().start(); + return resolve_table_name_ptr(binder, &table_name, &Some(schema), position); + } + let (table_name, schema) = if let Some(name_ref_node) = from_item.name_ref() { if let Some(cte_ptr) = resolve_cte_table(table_name_ref, &table_name) { return Some(cte_ptr); @@ -1031,6 +1041,19 @@ fn resolve_select_qualified_column_ptr( let from_clause = select.from_clause()?; let from_item = find_from_item_in_from_clause(&from_clause, &column_table_name)?; + if let Some(call_expr) = from_item.call_expr() + && let Some(ptr) = resolve_column_from_call_expr_return_table( + binder, + root, + &call_expr, + column_name_ref, + &column_name, + 0, + ) + { + return Some(ptr); + } + // `from t as u` // `from t as u(a, b, c)` if let Some(alias) = from_item.alias() @@ -1255,7 +1278,7 @@ fn resolve_from_item_column_ptr( ); } - if let Some(alias) = from_item.alias() + let alias_len = if let Some(alias) = from_item.alias() && let Some(column_list) = alias.column_list() { for col in column_list.columns() { @@ -1265,6 +1288,22 @@ fn resolve_from_item_column_ptr( return Some(SyntaxNodePtr::new(col_name.syntax())); } } + column_list.columns().count() + } else { + 0 + }; + + if let Some(call_expr) = from_item.call_expr() + && let Some(ptr) = resolve_column_from_call_expr_return_table( + binder, + root, + &call_expr, + column_name_ref, + &column_name, + alias_len, + ) + { + return Some(ptr); } let (table_name, schema) = table_and_schema_from_from_item(from_item)?; @@ -1654,6 +1693,13 @@ fn is_from_item_match(from_item: &ast::FromItem, qualifier: &Name) -> bool { return true; } + if let Some(call_expr) = from_item.call_expr() + && let Some((function_name, _schema)) = function_name_and_schema_from_call_expr(&call_expr) + && function_name == *qualifier + { + return true; + } + if let Some(name_ref) = from_item.name_ref() && Name::from_node(&name_ref) == *qualifier { @@ -2943,6 +2989,19 @@ fn resolve_column_from_paren_expr( return None; } + if let Some(ast::Expr::CallExpr(call_expr)) = paren_expr.expr() + && let Some(ptr) = resolve_column_from_call_expr_return_table( + binder, + root, + &call_expr, + name_ref, + column_name, + 0, + ) + { + return Some(ptr); + } + if let Some(ast::Expr::ParenExpr(paren_expr)) = paren_expr.expr() { return resolve_column_from_paren_expr(binder, root, &paren_expr, name_ref, column_name); } @@ -2964,6 +3023,45 @@ fn resolve_column_from_paren_expr( None } +fn resolve_column_from_call_expr_return_table( + binder: &Binder, + root: &SyntaxNode, + call_expr: &ast::CallExpr, + name_ref: &ast::NameRef, + column_name: &Name, + min_index: usize, +) -> Option { + let position = name_ref.syntax().text_range().start(); + let function_ptr = resolve_function_ptr_from_call_expr(binder, call_expr, position)?; + find_column_in_function_return_table_min_index(root, function_ptr, column_name, min_index) +} + +fn resolve_function_ptr_from_call_expr( + binder: &Binder, + call_expr: &ast::CallExpr, + position: TextSize, +) -> Option { + let (function_name, schema) = function_name_and_schema_from_call_expr(call_expr)?; + resolve_function(binder, &function_name, &schema, None, position) +} + +fn function_name_and_schema_from_call_expr( + call_expr: &ast::CallExpr, +) -> Option<(Name, Option)> { + match call_expr.expr()? { + ast::Expr::NameRef(name_ref) => Some((Name::from_node(&name_ref), None)), + ast::Expr::FieldExpr(field_expr) => { + let function_name = Name::from_node(&field_expr.field()?); + let ast::Expr::NameRef(schema_name_ref) = field_expr.base()? else { + return None; + }; + let schema = Schema(Name::from_node(&schema_name_ref)); + Some((function_name, Some(schema))) + } + _ => None, + } +} + pub(crate) fn resolve_table_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> { resolve_symbol_info(binder, path, SymbolKind::Table) } @@ -3433,7 +3531,7 @@ fn extract_param_signature(node: &impl ast::HasParamList) -> Option> { (!params.is_empty()).then_some(params) } -fn unwrap_paren_expr(expr: ast::Expr) -> Option { +fn unwrap_paren_expr_to_name_ref(expr: ast::Expr) -> Option { let mut current = expr; for _ in 0..10_000 { match current { @@ -3459,7 +3557,14 @@ fn resolve_composite_type_field_ptr( .and_then(ast::FieldExpr::cast)?; let base = field_expr.base()?; - let base_name_ref = unwrap_paren_expr(base)?; + if let ast::Expr::ParenExpr(ref paren_expr) = base + && let Some(ptr) = + resolve_column_from_paren_expr(binder, root, paren_expr, field_name_ref, &field_name) + { + return Some(ptr); + } + + let base_name_ref = unwrap_paren_expr_to_name_ref(base.clone())?; let column_ptr = resolve_select_column_ptr(binder, root, &base_name_ref)?; let column_node = column_ptr.to_node(root); @@ -3746,3 +3851,29 @@ fn find_param_in_func_def( None } + +fn find_column_in_function_return_table_min_index( + root: &SyntaxNode, + function_ptr: SyntaxNodePtr, + column_name: &Name, + min_index: usize, +) -> Option { + let function_node = function_ptr.to_node(root); + let create_function = function_node + .ancestors() + .find_map(ast::CreateFunction::cast)?; + let mut index = 0usize; + for arg in create_function.ret_type()?.table_arg_list()?.args() { + if let ast::TableArg::Column(column) = arg { + if let Some(name) = column.name() + && Name::from_node(&name) == *column_name + && index >= min_index + { + return Some(SyntaxNodePtr::new(name.syntax())); + } + index += 1; + } + } + + None +} diff --git a/crates/squawk_parser/src/grammar.rs b/crates/squawk_parser/src/grammar.rs index 84a3fa46..c65b41ed 100644 --- a/crates/squawk_parser/src/grammar.rs +++ b/crates/squawk_parser/src/grammar.rs @@ -388,7 +388,7 @@ fn substring_fn(p: &mut Parser<'_>) -> CompletedMarker { _ if p.eat(COMMA) => { opt_expr_list(p); } - _ => {} + _ => (), } p.expect(R_PAREN); let m = m.complete(p, SUBSTRING_FN).precede(p);