diff --git a/crates/squawk_ide/src/classify.rs b/crates/squawk_ide/src/classify.rs index 7f92b2b3..5454f258 100644 --- a/crates/squawk_ide/src/classify.rs +++ b/crates/squawk_ide/src/classify.rs @@ -308,19 +308,31 @@ pub(crate) fn classify_name_ref(node: &SyntaxNode) -> Option { return Some(NameRefClass::Schema); } - // Check for function/procedure reference in CREATE OPERATOR before the type check + // Check for function/procedure reference in CREATE OPERATOR / CREATE AGGREGATE + // before the type check for ancestor in node.ancestors() { if let Some(attr_option) = ast::AttributeOption::cast(ancestor.clone()) && let Some(name) = attr_option.name() { let attr_name = Name::from_node(&name); - if attr_name == Name::from_string("function") - || attr_name == Name::from_string("procedure") - { - for outer in attr_option.syntax().ancestors() { - if ast::CreateOperator::can_cast(outer.kind()) { - return Some(NameRefClass::FunctionName); - } + for outer in attr_option.syntax().ancestors() { + if ast::CreateOperator::can_cast(outer.kind()) + && (attr_name == Name::from_string("function") + || attr_name == Name::from_string("procedure")) + { + return Some(NameRefClass::FunctionName); + } + if ast::CreateAggregate::can_cast(outer.kind()) + && (attr_name == Name::from_string("sfunc") + || attr_name == Name::from_string("finalfunc") + || attr_name == Name::from_string("combinefunc") + || attr_name == Name::from_string("serialfunc") + || attr_name == Name::from_string("deserialfunc") + || attr_name == Name::from_string("msfunc") + || attr_name == Name::from_string("minvfunc") + || attr_name == Name::from_string("mfinalfunc")) + { + return Some(NameRefClass::FunctionName); } } } diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index c9e27c96..e62c53ca 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -4053,6 +4053,52 @@ select foo$0(1); "); } + #[test] + fn goto_create_aggregate_sfunc() { + assert_snapshot!(goto(" +create function pg_catalog.int8inc(bigint) returns bigint + language internal; + +create aggregate pg_catalog.count(*) ( + sfunc = int8inc$0, + stype = bigint, + combinefunc = int8pl, + initcond = '0' +); +"), @r" + ╭▸ + 2 │ create function pg_catalog.int8inc(bigint) returns bigint + │ ─────── 2. destination + ‡ + 6 │ sfunc = int8inc, + ╰╴ ─ 1. source + " + ); + } + + #[test] + fn goto_create_aggregate_combinefunc() { + assert_snapshot!(goto(" +create function pg_catalog.int8pl(bigint, bigint) returns bigint + language internal; + +create aggregate pg_catalog.count(*) ( + sfunc = int8inc, + stype = bigint, + combinefunc = int8pl$0, + initcond = '0' +); +"), @r" + ╭▸ + 2 │ create function pg_catalog.int8pl(bigint, bigint) returns bigint + │ ────── 2. destination + ‡ + 8 │ combinefunc = int8pl, + ╰╴ ─ 1. source + " + ); + } + #[test] fn goto_default_constraint_function_call() { assert_snapshot!(goto("