Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions crates/squawk_ide/src/goto_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3686,6 +3686,178 @@ select foo$0();
");
}

#[test]
fn goto_select_column_from_function_return_table() {
assert_snapshot!(goto(r#"
create function dup(int) returns table(f1 int, f2 text)
as ''
language sql;

select f1$0 from dup(42);
"#), @r"
╭▸
2 │ create function dup(int) returns table(f1 int, f2 text)
│ ── 2. destination
6 │ select f1 from dup(42);
╰╴ ─ 1. source
");
}

#[test]
fn goto_select_column_from_function_return_table_with_schema() {
assert_snapshot!(goto(r#"
create function myschema.dup(int) returns table(f1 int, f2 text)
as ''
language sql;
create function otherschema.dup(int) returns table(f1 int, f2 text)
as ''
language sql;

select f1$0 from myschema.dup(42);
"#), @r"
╭▸
2 │ create function myschema.dup(int) returns table(f1 int, f2 text)
│ ── 2. destination
9 │ select f1 from myschema.dup(42);
╰╴ ─ 1. source
");
}

#[test]
fn goto_select_column_from_function_return_table_paren() {
assert_snapshot!(goto(r#"
create function dup(int) returns table(f1 int, f2 text)
as ''
language sql;

select (dup(42)).f2$0;
"#), @r"
╭▸
2 │ create function dup(int) returns table(f1 int, f2 text)
│ ── 2. destination
6 │ select (dup(42)).f2;
╰╴ ─ 1. source
");
}

#[test]
fn goto_select_column_from_function_return_table_qualified() {
assert_snapshot!(goto(r#"
create function dup(int) returns table(f1 int, f2 text)
as ''
language sql;

select dup.f1$0 from dup(42);
"#), @r"
╭▸
2 │ create function dup(int) returns table(f1 int, f2 text)
│ ── 2. destination
6 │ select dup.f1 from dup(42);
╰╴ ─ 1. source
");
}

#[test]
fn goto_select_column_from_function_return_table_qualified_function_name() {
assert_snapshot!(goto(r#"
create function dup(int) returns table(f1 int, f2 text)
as ''
language sql;

select dup$0.f1 from dup(42);
"#), @r"
╭▸
2 │ create function dup(int) returns table(f1 int, f2 text)
│ ─── 2. destination
6 │ select dup.f1 from dup(42);
╰╴ ─ 1. source
");
}

#[test]
fn goto_select_column_from_function_return_table_qualified_function_name_with_alias() {
assert_snapshot!(goto(r#"
create function dup(int) returns table(f1 int, f2 text)
as ''
language sql;

select dup$0.f2 from dup(42) as dup;
"#), @r"
╭▸
6 │ select dup.f2 from dup(42) as dup;
╰╴ ─ 1. source ─── 2. destination
");
}

#[test]
fn goto_select_column_from_function_return_table_alias_list() {
assert_snapshot!(goto(r#"
create function dup(int) returns table(f1 int, f2 text)
as ''
language sql;

select a$0 from dup(42) t(a, b);
"#), @r"
╭▸
6 │ select a from dup(42) t(a, b);
╰╴ ─ 1. source ─ 2. destination
");
}

#[test]
fn goto_select_column_from_function_return_table_alias_list_qualified_partial() {
assert_snapshot!(goto(r#"
create function dup(int) returns table(f1 int, f2 text)
as ''
language sql;

select u.f2$0 from dup(42) as u(x);
"#), @r"
╭▸
2 │ create function dup(int) returns table(f1 int, f2 text)
│ ── 2. destination
6 │ select u.f2 from dup(42) as u(x);
╰╴ ─ 1. source
");
}

#[test]
fn goto_select_column_from_function_return_table_alias_list_unqualified_partial() {
assert_snapshot!(goto(r#"
create function dup(int) returns table(f1 int, f2 text)
as ''
language sql;

select f2$0 from dup(42) as u(x);
"#), @r"
╭▸
2 │ create function dup(int) returns table(f1 int, f2 text)
│ ── 2. destination
6 │ select f2 from dup(42) as u(x);
╰╴ ─ 1. source
");
}

#[test]
fn goto_select_column_from_function_return_table_alias_list_unqualified_not_found() {
goto_not_found(
r#"
create function dup(int) returns table(f1 int, f2 text)
as ''
language sql;

select f2$0 from dup(42) as u(x, y);
"#,
);
}

#[test]
fn goto_select_aggregate_call() {
assert_snapshot!(goto("
Expand Down
147 changes: 139 additions & 8 deletions crates/squawk_ide/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -835,11 +835,6 @@ fn resolve_select_qualified_column_table_name_ptr(
None
};

if let Some(schema) = explicit_schema {
let position = table_name_ref.syntax().text_range().start();
return resolve_table_name_ptr(binder, &table_name, &Some(schema), position);
}

let select = table_name_ref
.syntax()
.ancestors()
Expand All @@ -853,6 +848,21 @@ fn resolve_select_qualified_column_table_name_ptr(
return Some(SyntaxNodePtr::new(alias_name.syntax()));
}

if let Some(call_expr) = from_item.call_expr()
&& let Some((function_name, function_schema)) =
function_name_and_schema_from_call_expr(&call_expr)
&& function_name == table_name
&& function_schema == explicit_schema
{
let position = table_name_ref.syntax().text_range().start();
return resolve_function(binder, &function_name, &function_schema, None, position);
}

if let Some(schema) = explicit_schema {
let position = table_name_ref.syntax().text_range().start();
return resolve_table_name_ptr(binder, &table_name, &Some(schema), position);
}

let (table_name, schema) = if let Some(name_ref_node) = from_item.name_ref() {
if let Some(cte_ptr) = resolve_cte_table(table_name_ref, &table_name) {
return Some(cte_ptr);
Expand Down Expand Up @@ -1031,6 +1041,19 @@ fn resolve_select_qualified_column_ptr(
let from_clause = select.from_clause()?;
let from_item = find_from_item_in_from_clause(&from_clause, &column_table_name)?;

if let Some(call_expr) = from_item.call_expr()
&& let Some(ptr) = resolve_column_from_call_expr_return_table(
binder,
root,
&call_expr,
column_name_ref,
&column_name,
0,
)
{
return Some(ptr);
}

// `from t as u`
// `from t as u(a, b, c)`
if let Some(alias) = from_item.alias()
Expand Down Expand Up @@ -1255,7 +1278,7 @@ fn resolve_from_item_column_ptr(
);
}

if let Some(alias) = from_item.alias()
let alias_len = if let Some(alias) = from_item.alias()
&& let Some(column_list) = alias.column_list()
{
for col in column_list.columns() {
Expand All @@ -1265,6 +1288,22 @@ fn resolve_from_item_column_ptr(
return Some(SyntaxNodePtr::new(col_name.syntax()));
}
}
column_list.columns().count()
} else {
0
};

if let Some(call_expr) = from_item.call_expr()
&& let Some(ptr) = resolve_column_from_call_expr_return_table(
binder,
root,
&call_expr,
column_name_ref,
&column_name,
alias_len,
)
{
return Some(ptr);
}

let (table_name, schema) = table_and_schema_from_from_item(from_item)?;
Expand Down Expand Up @@ -1654,6 +1693,13 @@ fn is_from_item_match(from_item: &ast::FromItem, qualifier: &Name) -> bool {
return true;
}

if let Some(call_expr) = from_item.call_expr()
&& let Some((function_name, _schema)) = function_name_and_schema_from_call_expr(&call_expr)
&& function_name == *qualifier
{
return true;
}

if let Some(name_ref) = from_item.name_ref()
&& Name::from_node(&name_ref) == *qualifier
{
Expand Down Expand Up @@ -2943,6 +2989,19 @@ fn resolve_column_from_paren_expr(
return None;
}

if let Some(ast::Expr::CallExpr(call_expr)) = paren_expr.expr()
&& let Some(ptr) = resolve_column_from_call_expr_return_table(
binder,
root,
&call_expr,
name_ref,
column_name,
0,
)
{
return Some(ptr);
}

if let Some(ast::Expr::ParenExpr(paren_expr)) = paren_expr.expr() {
return resolve_column_from_paren_expr(binder, root, &paren_expr, name_ref, column_name);
}
Expand All @@ -2964,6 +3023,45 @@ fn resolve_column_from_paren_expr(
None
}

fn resolve_column_from_call_expr_return_table(
binder: &Binder,
root: &SyntaxNode,
call_expr: &ast::CallExpr,
name_ref: &ast::NameRef,
column_name: &Name,
min_index: usize,
) -> Option<SyntaxNodePtr> {
let position = name_ref.syntax().text_range().start();
let function_ptr = resolve_function_ptr_from_call_expr(binder, call_expr, position)?;
find_column_in_function_return_table_min_index(root, function_ptr, column_name, min_index)
}

fn resolve_function_ptr_from_call_expr(
binder: &Binder,
call_expr: &ast::CallExpr,
position: TextSize,
) -> Option<SyntaxNodePtr> {
let (function_name, schema) = function_name_and_schema_from_call_expr(call_expr)?;
resolve_function(binder, &function_name, &schema, None, position)
}

fn function_name_and_schema_from_call_expr(
call_expr: &ast::CallExpr,
) -> Option<(Name, Option<Schema>)> {
match call_expr.expr()? {
ast::Expr::NameRef(name_ref) => Some((Name::from_node(&name_ref), None)),
ast::Expr::FieldExpr(field_expr) => {
let function_name = Name::from_node(&field_expr.field()?);
let ast::Expr::NameRef(schema_name_ref) = field_expr.base()? else {
return None;
};
let schema = Schema(Name::from_node(&schema_name_ref));
Some((function_name, Some(schema)))
}
_ => None,
}
}

pub(crate) fn resolve_table_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> {
resolve_symbol_info(binder, path, SymbolKind::Table)
}
Expand Down Expand Up @@ -3433,7 +3531,7 @@ fn extract_param_signature(node: &impl ast::HasParamList) -> Option<Vec<Name>> {
(!params.is_empty()).then_some(params)
}

fn unwrap_paren_expr(expr: ast::Expr) -> Option<ast::NameRef> {
fn unwrap_paren_expr_to_name_ref(expr: ast::Expr) -> Option<ast::NameRef> {
let mut current = expr;
for _ in 0..10_000 {
match current {
Expand All @@ -3459,7 +3557,14 @@ fn resolve_composite_type_field_ptr(
.and_then(ast::FieldExpr::cast)?;
let base = field_expr.base()?;

let base_name_ref = unwrap_paren_expr(base)?;
if let ast::Expr::ParenExpr(ref paren_expr) = base
&& let Some(ptr) =
resolve_column_from_paren_expr(binder, root, paren_expr, field_name_ref, &field_name)
{
return Some(ptr);
}

let base_name_ref = unwrap_paren_expr_to_name_ref(base.clone())?;

let column_ptr = resolve_select_column_ptr(binder, root, &base_name_ref)?;
let column_node = column_ptr.to_node(root);
Expand Down Expand Up @@ -3746,3 +3851,29 @@ fn find_param_in_func_def(

None
}

fn find_column_in_function_return_table_min_index(
root: &SyntaxNode,
function_ptr: SyntaxNodePtr,
column_name: &Name,
min_index: usize,
) -> Option<SyntaxNodePtr> {
let function_node = function_ptr.to_node(root);
let create_function = function_node
.ancestors()
.find_map(ast::CreateFunction::cast)?;
let mut index = 0usize;
for arg in create_function.ret_type()?.table_arg_list()?.args() {
if let ast::TableArg::Column(column) = arg {
if let Some(name) = column.name()
&& Name::from_node(&name) == *column_name
&& index >= min_index
{
return Some(SyntaxNodePtr::new(name.syntax()));
}
index += 1;
}
}

None
}
Loading
Loading