diff --git a/crates/squawk_ide/src/classify.rs b/crates/squawk_ide/src/classify.rs index 6924bf53..b5d61d99 100644 --- a/crates/squawk_ide/src/classify.rs +++ b/crates/squawk_ide/src/classify.rs @@ -356,6 +356,15 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option return Some(NameRefClass::SelectQualifiedColumnTable); } } + if ast::CreatePolicy::can_cast(ancestor.kind()) + || ast::AlterPolicy::can_cast(ancestor.kind()) + { + if is_base_of_outer_field_expr { + return Some(NameRefClass::PolicyQualifiedColumnTable); + } else { + return Some(NameRefClass::PolicyColumn); + } + } } } diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 7fb537b6..dad33262 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -615,6 +615,25 @@ create policy p on t "); } + #[test] + fn goto_create_policy_field_style_function_call() { + assert_snapshot!(goto(" +create table t(c int); +create function x(t) returns int8 + as 'select 1' + language sql; +create policy p on t + with check (t.c > 1 and t.x$0 > 0); +"), @r" + ╭▸ + 3 │ create function x(t) returns int8 + │ ─ 2. destination + ‡ + 7 │ with check (t.c > 1 and t.x > 0); + ╰╴ ─ 1. source + "); + } + #[test] fn goto_alter_policy_qualified_column_table() { assert_snapshot!(goto(" @@ -631,6 +650,56 @@ alter policy p on t "); } + #[test] + fn goto_alter_policy_qualified_column() { + assert_snapshot!(goto(" +create table t(c int, d int); +alter policy p on t + with check (t.c$0 > d); +"), @r" + ╭▸ + 2 │ create table t(c int, d int); + │ ─ 2. destination + 3 │ alter policy p on t + 4 │ with check (t.c > d); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_create_policy_schema_qualified_table() { + assert_snapshot!(goto(" +create schema foo; +create table foo.t(c int); +create policy p on foo.t + with check (foo.t$0.c > 1); +"), @r" + ╭▸ + 3 │ create table foo.t(c int); + │ ─ 2. destination + 4 │ create policy p on foo.t + 5 │ with check (foo.t.c > 1); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_create_policy_unqualified_table_with_schema_on_table() { + assert_snapshot!(goto(" +create schema foo; +create table foo.t(c int); +create policy p on foo.t + with check (t$0.c > 1); +"), @r" + ╭▸ + 3 │ create table foo.t(c int); + │ ─ 2. destination + 4 │ create policy p on foo.t + 5 │ with check (t.c > 1); + ╰╴ ─ 1. source + "); + } + #[test] fn goto_drop_event_trigger() { assert_snapshot!(goto(" diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index ce67925f..3bd2cb95 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -57,7 +57,6 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { | NameRefClass::MergeWhenColumn | NameRefClass::MergeOnColumn | NameRefClass::CheckConstraintColumn - | NameRefClass::PolicyColumn | NameRefClass::GeneratedColumn | NameRefClass::UniqueConstraintColumn | NameRefClass::PrimaryKeyConstraintColumn @@ -78,7 +77,9 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { NameRefClass::CompositeTypeField => { return hover_composite_type_field(root, &name_ref, &binder); } - NameRefClass::SelectColumn | NameRefClass::SelectQualifiedColumn => { + NameRefClass::SelectColumn + | NameRefClass::SelectQualifiedColumn + | NameRefClass::PolicyColumn => { // Try hover as column first if let Some(result) = hover_column(root, &name_ref, &binder) { return Some(result); diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index f7379939..df147341 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -285,8 +285,7 @@ pub(crate) fn resolve_name_ref_ptrs( None } })?; - let column_name = Name::from_node(name_ref); - resolve_column_for_path(binder, root, &on_table_path, column_name) + resolve_policy_column_ptr(binder, root, &on_table_path, name_ref) .map(|ptr| smallvec![ptr]) } NameRefClass::PolicyQualifiedColumnTable => { @@ -844,17 +843,43 @@ fn resolve_column_for_path( let schema = extract_schema_name(path); let position = path.syntax().text_range().start(); - if let Some(resolved) = resolve_table_name(binder, root, &table_name, &schema, position) { - match resolved { - ResolvedTableName::View(create_view) => { - find_column_in_create_view(&create_view, &column_name) + let resolved = resolve_table_name(binder, root, &table_name, &schema, position)?; + match resolved { + ResolvedTableName::View(create_view) => { + find_column_in_create_view(&create_view, &column_name) + } + ResolvedTableName::Table(create_table_like) => { + find_column_in_create_table(binder, root, &create_table_like, &column_name) + } + } +} + +fn resolve_policy_column_ptr( + binder: &Binder, + root: &SyntaxNode, + on_table_path: &ast::Path, + column_name_ref: &ast::NameRef, +) -> Option { + let column_name = Name::from_node(column_name_ref); + let (table_name, schema) = extract_table_schema_from_path(on_table_path)?; + let position = column_name_ref.syntax().text_range().start(); + + let resolved = resolve_table_name(binder, root, &table_name, &schema, position)?; + match resolved { + ResolvedTableName::View(create_view) => { + if let Some(ptr) = find_column_in_create_view(&create_view, &column_name) { + return Some(ptr); } - ResolvedTableName::Table(create_table_like) => { + resolve_function(binder, &column_name, &schema, None, position) + } + ResolvedTableName::Table(create_table_like) => { + if let Some(ptr) = find_column_in_create_table(binder, root, &create_table_like, &column_name) + { + return Some(ptr); } + resolve_function(binder, &column_name, &schema, None, position) } - } else { - None } } @@ -1199,30 +1224,26 @@ fn resolve_select_qualified_column_ptr( } } - if let Some(resolved) = resolve_table_name(binder, root, &table_name, &schema, position) { - match resolved { - ResolvedTableName::View(create_view) => { - if let Some(ptr) = find_column_in_create_view(&create_view, &column_name) { - return Some(ptr); - } - - return resolve_function(binder, &column_name, &schema, None, position); + let resolved = resolve_table_name(binder, root, &table_name, &schema, position)?; + match resolved { + ResolvedTableName::View(create_view) => { + if let Some(ptr) = find_column_in_create_view(&create_view, &column_name) { + return Some(ptr); } - ResolvedTableName::Table(create_table_like) => { - // 1. Try to find a matching column (columns take precedence) - if let Some(ptr) = - find_column_in_create_table(binder, root, &create_table_like, &column_name) - { - return Some(ptr); - } - // 2. No column found, check for field-style function call - // e.g., select t.b from t where b is a function that takes t as an argument - return resolve_function(binder, &column_name, &schema, None, position); + return resolve_function(binder, &column_name, &schema, None, position); + } + ResolvedTableName::Table(create_table_like) => { + // 1. Try to find a matching column (columns take precedence) + if let Some(ptr) = + find_column_in_create_table(binder, root, &create_table_like, &column_name) + { + return Some(ptr); } + // 2. No column found, check for field-style function call + // e.g., select t.b from t where b is a function that takes t as an argument + return resolve_function(binder, &column_name, &schema, None, position); } } - - None } enum ResolvedTableName {