diff --git a/crates/squawk_ide/src/classify.rs b/crates/squawk_ide/src/classify.rs index 272eb928..6c856b31 100644 --- a/crates/squawk_ide/src/classify.rs +++ b/crates/squawk_ide/src/classify.rs @@ -1,3 +1,4 @@ +use crate::symbols::Name; use squawk_syntax::{ SyntaxKind, ast::{self, AstNode}, @@ -106,6 +107,7 @@ pub(crate) enum NameRefClass { TriggerFunctionCall, TriggerProcedureCall, AlterEventTrigger, + OperatorFunctionRef, } fn is_special_fn(kind: SyntaxKind) -> bool { @@ -356,6 +358,24 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option return Some(NameRefClass::SchemaQualifier); } + // Check for function/procedure reference in CREATE OPERATOR before the type check + for ancestor in name_ref.syntax().ancestors() { + if let Some(attr_option) = ast::AttributeOption::cast(ancestor.clone()) + && let Some(name) = attr_option.name() + { + let attr_name = Name::from_node(&name); + if attr_name == Name::from_string("function") + || attr_name == Name::from_string("procedure") + { + for outer in attr_option.syntax().ancestors() { + if ast::CreateOperator::can_cast(outer.kind()) { + return Some(NameRefClass::OperatorFunctionRef); + } + } + } + } + } + let mut in_type = false; for ancestor in name_ref.syntax().ancestors() { if ast::PathType::can_cast(ancestor.kind()) || ast::ExprType::can_cast(ancestor.kind()) { diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 048c812d..71272f1c 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -7780,4 +7780,32 @@ select foo(wrong_param$0 := 5); ", ); } + + #[test] + fn goto_operator_function_ref() { + assert_snapshot!(goto(" +create function pg_catalog.tsvector_concat(tsvector, tsvector) returns tsvector language internal; +create operator pg_catalog.|| (leftarg = tsvector, rightarg = tsvector, function = pg_catalog.tsvector_concat$0); +"), @r" + ╭▸ + 2 │ create function pg_catalog.tsvector_concat(tsvector, tsvector) returns tsvector language internal; + │ ─────────────── 2. destination + 3 │ create operator pg_catalog.|| (leftarg = tsvector, rightarg = tsvector, function = pg_catalog.tsvector_concat); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_operator_procedure_ref() { + assert_snapshot!(goto(" +create function f(int, int) returns int language internal; +create operator ||| (leftarg = int, rightarg = int, procedure = f$0); +"), @r" + ╭▸ + 2 │ create function f(int, int) returns int language internal; + │ ─ 2. destination + 3 │ create operator ||| (leftarg = int, rightarg = int, procedure = f); + ╰╴ ─ 1. source + "); + } } diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index 828d0d23..b1771a61 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -154,7 +154,8 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { } NameRefClass::DropFunction | NameRefClass::DefaultConstraintFunctionCall - | NameRefClass::TriggerFunctionCall => { + | NameRefClass::TriggerFunctionCall + | NameRefClass::OperatorFunctionRef => { return hover_function(root, &name_ref, &binder); } NameRefClass::DropAggregate => return hover_aggregate(root, &name_ref, &binder), diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index b8d8c398..bdd4ec38 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -390,6 +390,18 @@ pub(crate) fn resolve_name_ref_ptrs( resolve_procedure(binder, &procedure_name, &schema, None, position) .map(|ptr| smallvec![ptr]) } + NameRefClass::OperatorFunctionRef => { + let path_type = name_ref + .syntax() + .ancestors() + .find_map(ast::PathType::cast)?; + let path = path_type.path()?; + let function_name = extract_table_name(&path)?; + let schema = extract_schema_name(&path); + let position = name_ref.syntax().text_range().start(); + resolve_function(binder, &function_name, &schema, None, position) + .map(|ptr| smallvec![ptr]) + } NameRefClass::SelectFunctionCall => { let schema = if let Some(parent_node) = name_ref.syntax().parent() && let Some(field_expr) = ast::FieldExpr::cast(parent_node)