From 2b12b679f728cb898157230a352eadac465e8962 Mon Sep 17 00:00:00 2001 From: Steve Dignam Date: Sun, 11 Jan 2026 21:56:24 -0700 Subject: [PATCH] ide: goto def on named params & special functions --- crates/squawk_ide/src/classify.rs | 61 ++++++++++- crates/squawk_ide/src/goto_definition.rs | 128 +++++++++++++++++++++++ crates/squawk_ide/src/hover.rs | 113 +++++++++++++++++++- crates/squawk_ide/src/resolve.rs | 82 +++++++++++++++ 4 files changed, 381 insertions(+), 3 deletions(-) diff --git a/crates/squawk_ide/src/classify.rs b/crates/squawk_ide/src/classify.rs index 1ad6191f..e6b8e297 100644 --- a/crates/squawk_ide/src/classify.rs +++ b/crates/squawk_ide/src/classify.rs @@ -1,4 +1,7 @@ -use squawk_syntax::ast::{self, AstNode}; +use squawk_syntax::{ + SyntaxKind, + ast::{self, AstNode}, +}; #[derive(Debug)] pub(crate) enum NameRefClass { @@ -88,6 +91,39 @@ pub(crate) enum NameRefClass { ReindexDatabase, ReindexSystem, AttachPartition, + NamedArgParameter, +} + +fn is_special_fn(kind: SyntaxKind) -> bool { + matches!( + kind, + SyntaxKind::EXTRACT_FN + | SyntaxKind::JSON_EXISTS_FN + | SyntaxKind::JSON_ARRAY_FN + | SyntaxKind::JSON_OBJECT_FN + | SyntaxKind::JSON_OBJECT_AGG_FN + | SyntaxKind::JSON_ARRAY_AGG_FN + | SyntaxKind::JSON_QUERY_FN + | SyntaxKind::JSON_SCALAR_FN + | SyntaxKind::JSON_SERIALIZE_FN + | SyntaxKind::JSON_VALUE_FN + | SyntaxKind::JSON_FN + | SyntaxKind::SUBSTRING_FN + | SyntaxKind::POSITION_FN + | SyntaxKind::OVERLAY_FN + | SyntaxKind::TRIM_FN + | SyntaxKind::XML_ROOT_FN + | SyntaxKind::XML_SERIALIZE_FN + | SyntaxKind::XML_ELEMENT_FN + | SyntaxKind::XML_FOREST_FN + | SyntaxKind::XML_EXISTS_FN + | SyntaxKind::XML_PARSE_FN + | SyntaxKind::XML_PI_FN + | SyntaxKind::SOME_FN + | SyntaxKind::ANY_FN + | SyntaxKind::ALL_FN + | SyntaxKind::EXISTS_FN + ) } pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option { @@ -106,6 +142,7 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option let mut in_using_clause = false; let mut in_returning_clause = false; let mut in_when_clause = false; + let mut in_special_sql_fn = false; // TODO: can we combine this if and the one that follows? if let Some(parent) = name_ref.syntax().parent() @@ -530,6 +567,12 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option if ast::PartitionOf::can_cast(ancestor.kind()) { return Some(NameRefClass::PartitionOfTable); } + if is_special_fn(ancestor.kind()) { + in_special_sql_fn = true; + } + if ast::NamedArg::can_cast(ancestor.kind()) { + return Some(NameRefClass::NamedArgParameter); + } if ast::ArgList::can_cast(ancestor.kind()) { in_arg_list = true; } @@ -546,7 +589,7 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option in_from_clause = true; } if ast::Select::can_cast(ancestor.kind()) { - if in_call_expr && !in_arg_list { + if in_call_expr && !in_arg_list && !in_special_sql_fn { return Some(NameRefClass::SelectFunctionCall); } if in_from_clause && !in_on_clause { @@ -743,3 +786,17 @@ pub(crate) fn classify_name(name: &ast::Name) -> Option { None } + +#[test] +fn special_function() { + for kind in (0..SyntaxKind::__LAST as u16) + .map(SyntaxKind::from) + .filter(|kind| format!("{:?}", kind).ends_with("_FN")) + { + assert!( + is_special_fn(kind), + "unhandled special function kind: {:?}. Please update is_special_fn", + kind + ) + } +} diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 5914cc52..bf20c509 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -6693,4 +6693,132 @@ returning x, u.y$0; ╰╴ ─ 1. source "); } + + #[test] + fn goto_overlay_with_cte_column() { + assert_snapshot!(goto(" +with t as ( + select '1' a, '2' b, 3 start +) +select overlay(a placing b$0 from start) from t; + "), @r" + ╭▸ + 3 │ select '1' a, '2' b, 3 start + │ ─ 2. destination + 4 │ ) + 5 │ select overlay(a placing b from start) from t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_overlay_with_cte_column_first_arg() { + assert_snapshot!(goto(" +with t as ( + select '1' a, '2' b, 3 start +) +select overlay(a$0 placing b from start) from t; + "), @r" + ╭▸ + 3 │ select '1' a, '2' b, 3 start + │ ─ 2. destination + 4 │ ) + 5 │ select overlay(a placing b from start) from t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_overlay_with_cte_column_from_arg() { + assert_snapshot!(goto(" +with t as ( + select '1' a, '2' b, 3 start +) +select overlay(a placing b from start$0) from t; + "), @r" + ╭▸ + 3 │ select '1' a, '2' b, 3 start + │ ───── 2. destination + 4 │ ) + 5 │ select overlay(a placing b from start) from t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_named_arg_to_param() { + assert_snapshot!(goto(" +create function foo(bar_param int) returns int as 'select 1' language sql; +select foo(bar_param$0 := 5); +"), @r" + ╭▸ + 2 │ create function foo(bar_param int) returns int as 'select 1' language sql; + │ ───────── 2. destination + 3 │ select foo(bar_param := 5); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_named_arg_schema_qualified() { + assert_snapshot!(goto(" +create schema s; +create function s.foo(my_param int) returns int as 'select 1' language sql; +select s.foo(my_param$0 := 10); +"), @r" + ╭▸ + 3 │ create function s.foo(my_param int) returns int as 'select 1' language sql; + │ ──────── 2. destination + 4 │ select s.foo(my_param := 10); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_named_arg_multiple_params() { + assert_snapshot!(goto(" +create function foo(a int, b int, c int) returns int as 'select 1' language sql; +select foo(b$0 := 2, a := 1); +"), @r" + ╭▸ + 2 │ create function foo(a int, b int, c int) returns int as 'select 1' language sql; + │ ─ 2. destination + 3 │ select foo(b := 2, a := 1); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_named_arg_procedure() { + assert_snapshot!(goto(" +create procedure proc(param_x int) as 'select 1' language sql; +call proc(param_x$0 := 42); +"), @r" + ╭▸ + 2 │ create procedure proc(param_x int) as 'select 1' language sql; + │ ─────── 2. destination + 3 │ call proc(param_x := 42); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_named_arg_not_found_unnamed_param() { + goto_not_found( + " +create function foo(int) returns int as 'select 1' language sql; +select foo(bar$0 := 5); +", + ); + } + + #[test] + fn goto_named_arg_not_found_wrong_name() { + goto_not_found( + " +create function foo(correct_param int) returns int as 'select 1' language sql; +select foo(wrong_param$0 := 5); +", + ); + } } diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index 27ba878e..548236b6 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -2,7 +2,10 @@ use crate::classify::{NameClass, NameRefClass, classify_name, classify_name_ref} use crate::column_name::ColumnName; use crate::offsets::token_from_offset; use crate::resolve; -use crate::{binder, symbols::Name}; +use crate::{ + binder, + symbols::{Name, Schema}, +}; use rowan::TextSize; use squawk_syntax::SyntaxNode; use squawk_syntax::{ @@ -160,6 +163,9 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { | NameRefClass::ReindexSchema => { return hover_schema(root, &name_ref, &binder); } + NameRefClass::NamedArgParameter => { + return hover_named_arg_parameter(root, &name_ref, &binder); + } } } @@ -998,6 +1004,71 @@ fn hover_function( format_create_function(&create_function, binder) } +fn hover_named_arg_parameter( + root: &SyntaxNode, + name_ref: &ast::NameRef, + binder: &binder::Binder, +) -> Option { + let param_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + .into_iter() + .next()?; + let param_name_node = param_ptr.to_node(root); + let param = param_name_node.ancestors().find_map(ast::Param::cast)?; + let param_name = param.name().map(|name| Name::from_node(&name))?; + let param_type = param.ty().map(|ty| ty.syntax().text().to_string()); + + for ancestor in param_name_node.ancestors() { + if let Some(create_function) = ast::CreateFunction::cast(ancestor.clone()) { + let path = create_function.path()?; + let (schema, function_name) = resolve::resolve_function_info(binder, &path)?; + return Some(format_param_hover( + schema, + function_name, + param_name, + param_type, + )); + } + if let Some(create_procedure) = ast::CreateProcedure::cast(ancestor.clone()) { + let path = create_procedure.path()?; + let (schema, procedure_name) = resolve::resolve_procedure_info(binder, &path)?; + return Some(format_param_hover( + schema, + procedure_name, + param_name, + param_type, + )); + } + if let Some(create_aggregate) = ast::CreateAggregate::cast(ancestor) { + let path = create_aggregate.path()?; + let (schema, aggregate_name) = resolve::resolve_aggregate_info(binder, &path)?; + return Some(format_param_hover( + schema, + aggregate_name, + param_name, + param_type, + )); + } + } + + None +} + +fn format_param_hover( + schema: Schema, + routine_name: String, + param_name: Name, + param_type: Option, +) -> String { + if let Some(param_type) = param_type { + return format!( + "parameter {}.{}.{} {}", + schema, routine_name, param_name, param_type + ); + } + + format!("parameter {}.{}.{}", schema, routine_name, param_name) +} + fn format_create_function( create_function: &ast::CreateFunction, binder: &binder::Binder, @@ -1909,6 +1980,46 @@ select add$0(1, 2); "); } + #[test] + fn hover_on_named_arg_param() { + assert_snapshot!(check_hover(" +create function foo(bar_param int) returns int as $$ select 1 $$ language sql; +select foo(bar_param$0 := 5); +"), @r" + hover: parameter public.foo.bar_param int + ╭▸ + 3 │ select foo(bar_param := 5); + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_named_arg_param_schema_qualified() { + assert_snapshot!(check_hover(" +create schema s; +create function s.foo(my_param int) returns int as $$ select 1 $$ language sql; +select s.foo(my_param$0 := 10); +"), @r" + hover: parameter s.foo.my_param int + ╭▸ + 4 │ select s.foo(my_param := 10); + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_named_arg_param_procedure() { + assert_snapshot!(check_hover(" +create procedure proc(param_x int) as 'select 1' language sql; +call proc(param_x$0 := 42); +"), @r" + hover: parameter public.proc.param_x int + ╭▸ + 3 │ call proc(param_x := 42); + ╰╴ ─ hover + "); + } + #[test] fn hover_on_function_call_style_column_access() { assert_snapshot!(check_hover(" diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 6d5013cc..46225164 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -40,6 +40,24 @@ pub(crate) fn resolve_name_ref( let position = name_ref.syntax().text_range().start(); resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) } + NameRefClass::NamedArgParameter => { + let (function_name, schema) = find_func_call_from_named_arg(name_ref)?; + let param_name = Name::from_node(name_ref); + let position = name_ref.syntax().text_range().start(); + + // TODO: this should be one lookup + let function_ptr = binder + .lookup_with(&function_name, SymbolKind::Function, position, &schema) + .or_else(|| { + binder.lookup_with(&function_name, SymbolKind::Procedure, position, &schema) + }) + .or_else(|| { + binder.lookup_with(&function_name, SymbolKind::Aggregate, position, &schema) + })?; + + let param_ptr = find_param_in_func_def(root, function_ptr, ¶m_name)?; + Some(smallvec![param_ptr]) + } NameRefClass::SelectFromTable | NameRefClass::UpdateFromTable | NameRefClass::MergeUsingTable @@ -3069,3 +3087,67 @@ fn resolve_merge_table_name_ptr( merge.returning_clause(), ) } + +fn find_func_call_from_named_arg(name_ref: &ast::NameRef) -> Option<(Name, Option)> { + for a in name_ref.syntax().ancestors() { + if let Some(call_expr) = ast::CallExpr::cast(a.clone()) { + return match call_expr.expr()? { + ast::Expr::NameRef(func_name_ref) => { + let func_name = Name::from_node(&func_name_ref); + Some((func_name, None)) + } + ast::Expr::FieldExpr(field_expr) => { + let func_name_ref = field_expr.field()?; + let func_name = Name::from_node(&func_name_ref); + + let schema = if let Some(base) = field_expr.base() + && let ast::Expr::NameRef(schema_name_ref) = base + { + Some(Schema(Name::from_node(&schema_name_ref))) + } else { + None + }; + + Some((func_name, schema)) + } + _ => None, + }; + } else if let Some(call) = ast::Call::cast(a) { + let path = call.path()?; + let function_name = extract_table_name(&path)?; + let schema = extract_schema_name(&path); + return Some((function_name, schema)); + } + } + None +} + +fn find_param_in_func_def( + root: &SyntaxNode, + function_ptr: SyntaxNodePtr, + param_name: &Name, +) -> Option { + let function_node = function_ptr.to_node(root); + + let param_list = function_node.ancestors().find_map(|a| { + if let Some(create_func) = ast::CreateFunction::cast(a.clone()) { + create_func.param_list() + } else if let Some(create_proc) = ast::CreateProcedure::cast(a.clone()) { + create_proc.param_list() + } else if let Some(create_aggregate) = ast::CreateAggregate::cast(a) { + create_aggregate.param_list() + } else { + None + } + })?; + + for param in param_list.params() { + if let Some(name) = param.name() + && Name::from_node(&name) == *param_name + { + return Some(SyntaxNodePtr::new(name.syntax())); + } + } + + None +}