diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index cea61cdb..18a550cf 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -10,6 +10,7 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { let binder = binder::bind(file); + // TODO: can we use the classify_name_ref_context function from goto def here? if let Some(name_ref) = ast::NameRef::cast(parent.clone()) { if is_column_ref(&name_ref) { return hover_column(file, &name_ref, &binder); @@ -36,6 +37,10 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { return hover_table(file, &name_ref, &binder); } + if is_update_from_table(&name_ref) { + return hover_table(file, &name_ref, &binder); + } + if is_index_ref(&name_ref) { return hover_index(file, &name_ref, &binder); } @@ -306,6 +311,7 @@ fn is_column_ref(name_ref: &ast::NameRef) -> bool { let mut in_partition_item = false; let mut in_column_list = false; let mut in_where_clause = false; + let mut in_set_clause = false; for ancestor in name_ref.syntax().ancestors() { if ast::PartitionItem::can_cast(ancestor.kind()) { @@ -323,9 +329,15 @@ fn is_column_ref(name_ref: &ast::NameRef) -> bool { if ast::WhereClause::can_cast(ancestor.kind()) { in_where_clause = true; } + if ast::SetClause::can_cast(ancestor.kind()) { + in_set_clause = true; + } if ast::Delete::can_cast(ancestor.kind()) { return in_where_clause; } + if ast::Update::can_cast(ancestor.kind()) { + return in_where_clause || in_set_clause; + } } false } @@ -334,6 +346,8 @@ fn is_table_ref(name_ref: &ast::NameRef) -> bool { let mut in_partition_item = false; let mut in_column_list = false; let mut in_where_clause = false; + let mut in_set_clause = false; + let mut in_from_clause = false; for ancestor in name_ref.syntax().ancestors() { if ast::DropTable::can_cast(ancestor.kind()) { @@ -351,9 +365,18 @@ fn is_table_ref(name_ref: &ast::NameRef) -> bool { if ast::WhereClause::can_cast(ancestor.kind()) { in_where_clause = true; } + if ast::SetClause::can_cast(ancestor.kind()) { + in_set_clause = true; + } + if ast::FromClause::can_cast(ancestor.kind()) { + in_from_clause = true; + } if ast::Delete::can_cast(ancestor.kind()) { return !in_where_clause; } + if ast::Update::can_cast(ancestor.kind()) { + return !in_where_clause && !in_set_clause && !in_from_clause; + } if ast::DropIndex::can_cast(ancestor.kind()) { return false; } @@ -453,6 +476,20 @@ fn is_select_from_table(name_ref: &ast::NameRef) -> bool { false } +fn is_update_from_table(name_ref: &ast::NameRef) -> bool { + let mut in_from_clause = false; + + for ancestor in name_ref.syntax().ancestors() { + if ast::FromClause::can_cast(ancestor.kind()) { + in_from_clause = true; + } + if ast::Update::can_cast(ancestor.kind()) && in_from_clause { + return true; + } + } + false +} + fn is_select_column(name_ref: &ast::NameRef) -> bool { let mut in_call_expr = false; let mut in_arg_list = false; @@ -2237,4 +2274,110 @@ drop routine foo$0(int); ╰╴ ─ hover "); } + + #[test] + fn hover_on_update_table() { + assert_snapshot!(check_hover(" +create table users(id int, email text); +update users$0 set email = 'new@example.com'; +"), @r" + hover: table public.users(id int, email text) + ╭▸ + 3 │ update users set email = 'new@example.com'; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_update_table_with_schema() { + assert_snapshot!(check_hover(" +create table public.users(id int, email text); +update public.users$0 set email = 'new@example.com'; +"), @r" + hover: table public.users(id int, email text) + ╭▸ + 3 │ update public.users set email = 'new@example.com'; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_update_set_column() { + assert_snapshot!(check_hover(" +create table users(id int, email text); +update users set email$0 = 'new@example.com' where id = 1; +"), @r" + hover: column public.users.email text + ╭▸ + 3 │ update users set email = 'new@example.com' where id = 1; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_update_set_column_with_schema() { + assert_snapshot!(check_hover(" +create table public.users(id int, email text); +update public.users set email$0 = 'new@example.com' where id = 1; +"), @r" + hover: column public.users.email text + ╭▸ + 3 │ update public.users set email = 'new@example.com' where id = 1; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_update_where_column() { + assert_snapshot!(check_hover(" +create table users(id int, email text); +update users set email = 'new@example.com' where id$0 = 1; +"), @r" + hover: column public.users.id int + ╭▸ + 3 │ update users set email = 'new@example.com' where id = 1; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_update_where_column_with_schema() { + assert_snapshot!(check_hover(" +create table public.users(id int, email text); +update public.users set email = 'new@example.com' where id$0 = 1; +"), @r" + hover: column public.users.id int + ╭▸ + 3 │ update public.users set email = 'new@example.com' where id = 1; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_update_from_table() { + assert_snapshot!(check_hover(" +create table users(id int, email text); +create table messages(id int, user_id int, email text); +update users set email = messages.email from messages$0 where users.id = messages.user_id; +"), @r" + hover: table public.messages(id int, user_id int, email text) + ╭▸ + 4 │ update users set email = messages.email from messages where users.id = messages.user_id; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_update_from_table_with_schema() { + assert_snapshot!(check_hover(" +create table users(id int, email text); +create table public.messages(id int, user_id int, email text); +update users set email = messages.email from public.messages$0 where users.id = messages.user_id; +"), @r" + hover: table public.messages(id int, user_id int, email text) + ╭▸ + 4 │ update users set email = messages.email from public.messages where users.id = messages.user_id; + ╰╴ ─ hover + "); + } } diff --git a/crates/squawk_ide/src/inlay_hints.rs b/crates/squawk_ide/src/inlay_hints.rs index 3e1e9f9a..78e826c5 100644 --- a/crates/squawk_ide/src/inlay_hints.rs +++ b/crates/squawk_ide/src/inlay_hints.rs @@ -87,22 +87,22 @@ fn inlay_hint_insert( let row_list = values.row_list()?; let columns: Vec<(Name, Option)> = if let Some(column_list) = insert.column_list() { - let table_arg_list = resolve::resolve_insert_table_columns(file, binder, &insert); - + let create_table = resolve::resolve_insert_create_table(file, binder, &insert); column_list .columns() .filter_map(|col| { let col_name = resolve::extract_column_name(&col)?; - let target = table_arg_list + let target = create_table .as_ref() - .and_then(|list| resolve::find_column_in_table(list, &col_name)); + .and_then(|x| resolve::find_column_in_create_table(x, &col_name)) + .map(|x| x.text_range()); Some((col_name, target)) }) .collect() } else { - let table_arg_list = resolve::resolve_insert_table_columns(file, binder, &insert)?; - - table_arg_list + let create_table = resolve::resolve_insert_create_table(file, binder, &insert)?; + create_table + .table_arg_list()? .args() .filter_map(|arg| { if let ast::TableArg::Column(column) = arg diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index f076bbb3..18417574 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -1,4 +1,4 @@ -use rowan::{TextRange, TextSize}; +use rowan::TextSize; use squawk_syntax::{ SyntaxNode, SyntaxNodePtr, ast::{self, AstNode}, @@ -206,8 +206,9 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti NameRefContext::SelectQualifiedColumn => resolve_select_qualified_column(binder, name_ref), NameRefContext::InsertColumn => resolve_insert_column(binder, name_ref), NameRefContext::DeleteWhereColumn => resolve_delete_where_column(binder, name_ref), - NameRefContext::UpdateWhereColumn => resolve_update_where_column(binder, name_ref), - NameRefContext::UpdateSetColumn => resolve_update_set_column(binder, name_ref), + NameRefContext::UpdateWhereColumn | NameRefContext::UpdateSetColumn => { + resolve_update_where_column(binder, name_ref) + } NameRefContext::UpdateFromTable => { let table_name = Name::from_node(name_ref); let schema = if let Some(parent) = name_ref.syntax().parent() @@ -582,29 +583,28 @@ fn resolve_create_index_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti let relation_name = create_index.relation_name()?; let path = relation_name.path()?; - let table_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); - let position = name_ref.syntax().text_range().start(); + resolve_column_for_path(binder, &path, column_name) +} + +fn resolve_column_for_path( + binder: &Binder, + path: &ast::Path, + column_name: Name, +) -> Option { + let table_name = extract_table_name(path)?; + let schema = extract_schema_name(path); + let position = path.syntax().text_range().start(); let table_ptr = resolve_table(binder, &table_name, &schema, position)?; - let root = &name_ref.syntax().ancestors().last()?; + let root = &path.syntax().ancestors().last()?; let table_name_node = table_ptr.to_node(root); let create_table = table_name_node .ancestors() .find_map(ast::CreateTable::cast)?; - for arg in create_table.table_arg_list()?.args() { - if let ast::TableArg::Column(column) = arg - && let Some(col_name) = column.name() - && Name::from_node(&col_name) == column_name - { - return Some(SyntaxNodePtr::new(col_name.syntax())); - } - } - - None + find_column_in_create_table(&create_table, &column_name) } fn resolve_insert_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { @@ -613,29 +613,7 @@ fn resolve_insert_column(binder: &Binder, name_ref: &ast::NameRef) -> Option Opti let relation_name = delete.relation_name()?; let path = relation_name.path()?; - let table_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); - let position = name_ref.syntax().text_range().start(); - - let table_ptr = resolve_table(binder, &table_name, &schema, position)?; - - let root = &name_ref.syntax().ancestors().last()?; - let table_name_node = table_ptr.to_node(root); - - let create_table = table_name_node - .ancestors() - .find_map(ast::CreateTable::cast)?; - - for arg in create_table.table_arg_list()?.args() { - if let ast::TableArg::Column(column) = arg - && let Some(col_name) = column.name() - && Name::from_node(&col_name) == column_name - { - return Some(SyntaxNodePtr::new(col_name.syntax())); - } - } - - None + resolve_column_for_path(binder, &path, column_name) } fn resolve_update_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { @@ -981,61 +928,7 @@ fn resolve_update_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti let relation_name = update.relation_name()?; let path = relation_name.path()?; - let table_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); - let position = name_ref.syntax().text_range().start(); - - let table_ptr = resolve_table(binder, &table_name, &schema, position)?; - - let root = &name_ref.syntax().ancestors().last()?; - let table_name_node = table_ptr.to_node(root); - - let create_table = table_name_node - .ancestors() - .find_map(ast::CreateTable::cast)?; - - for arg in create_table.table_arg_list()?.args() { - if let ast::TableArg::Column(column) = arg - && let Some(col_name) = column.name() - && Name::from_node(&col_name) == column_name - { - return Some(SyntaxNodePtr::new(col_name.syntax())); - } - } - - None -} - -fn resolve_update_set_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { - let column_name = Name::from_node(name_ref); - - let update = name_ref.syntax().ancestors().find_map(ast::Update::cast)?; - let relation_name = update.relation_name()?; - let path = relation_name.path()?; - - let table_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); - let position = name_ref.syntax().text_range().start(); - - let table_ptr = resolve_table(binder, &table_name, &schema, position)?; - - let root = &name_ref.syntax().ancestors().last()?; - let table_name_node = table_ptr.to_node(root); - - let create_table = table_name_node - .ancestors() - .find_map(ast::CreateTable::cast)?; - - for arg in create_table.table_arg_list()?.args() { - if let ast::TableArg::Column(column) = arg - && let Some(col_name) = column.name() - && Name::from_node(&col_name) == column_name - { - return Some(SyntaxNodePtr::new(col_name.syntax())); - } - } - - None + resolve_column_for_path(binder, &path, column_name) } fn resolve_fn_call_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { @@ -1099,16 +992,8 @@ fn resolve_from_item_for_fn_call_column( let create_table = table_name_node .ancestors() .find_map(ast::CreateTable::cast)?; - for arg in create_table.table_arg_list()?.args() { - if let ast::TableArg::Column(column) = arg - && let Some(col_name) = column.name() - && Name::from_node(&col_name) == *column_name - { - return Some(SyntaxNodePtr::new(col_name.syntax())); - } - } - None + find_column_in_create_table(&create_table, column_name) } fn is_from_item_match(from_item: &ast::FromItem, qualifier: &Name) -> bool { @@ -1211,16 +1096,16 @@ pub(crate) fn extract_column_name(col: &ast::Column) -> Option { Some(name) } -pub(crate) fn find_column_in_table( - table_arg_list: &ast::TableArgList, - col_name: &Name, -) -> Option { - table_arg_list.args().find_map(|arg| { +pub(crate) fn find_column_in_create_table( + create_table: &ast::CreateTable, + column_name: &Name, +) -> Option { + create_table.table_arg_list()?.args().find_map(|arg| { if let ast::TableArg::Column(column) = arg && let Some(name) = column.name() - && Name::from_node(&name) == *col_name + && Name::from_node(&name) == *column_name { - Some(name.syntax().text_range()) + return Some(SyntaxNodePtr::new(name.syntax())); } else { None } @@ -1413,11 +1298,11 @@ fn resolve_column_from_paren_expr( None } -pub(crate) fn resolve_insert_table_columns( +pub(crate) fn resolve_insert_create_table( file: &ast::SourceFile, binder: &Binder, insert: &ast::Insert, -) -> Option { +) -> Option { let path = insert.path()?; let table_name = extract_table_name(&path)?; let schema = extract_schema_name(&path); @@ -1427,11 +1312,7 @@ pub(crate) fn resolve_insert_table_columns( let root = file.syntax(); let table_name_node = table_ptr.to_node(root); - let create_table = table_name_node - .ancestors() - .find_map(ast::CreateTable::cast)?; - - create_table.table_arg_list() + table_name_node.ancestors().find_map(ast::CreateTable::cast) } pub(crate) fn resolve_table_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> {