diff --git a/crates/squawk_ide/src/binder.rs b/crates/squawk_ide/src/binder.rs index 543ffa4e..fa8d0af0 100644 --- a/crates/squawk_ide/src/binder.rs +++ b/crates/squawk_ide/src/binder.rs @@ -181,6 +181,23 @@ impl Binder { // default search path &self.search_path_changes[0].search_path } + + pub(crate) fn all_symbols_by_kind(&self, kind: SymbolKind) -> Vec<&Name> { + let root_scope = self.root_scope(); + let scope = &self.scopes[root_scope]; + + let mut names = vec![]; + for (name, symbol_ids) in &scope.entries { + for symbol_id in symbol_ids { + let symbol = &self.symbols[*symbol_id]; + if symbol.kind == kind { + names.push(name); + break; + } + } + } + names + } } pub(crate) fn bind(file: &ast::SourceFile) -> Binder { diff --git a/crates/squawk_ide/src/completion.rs b/crates/squawk_ide/src/completion.rs index 2580537c..038b2066 100644 --- a/crates/squawk_ide/src/completion.rs +++ b/crates/squawk_ide/src/completion.rs @@ -1,14 +1,17 @@ use rowan::TextSize; use squawk_syntax::ast::{self, AstNode}; +use squawk_syntax::{SyntaxKind, SyntaxToken}; +use crate::binder; +use crate::resolve; +use crate::symbols::SymbolKind; use crate::tokens::is_string_or_comment; pub fn completion(file: &ast::SourceFile, offset: TextSize) -> Vec { - let Some(token) = file.syntax().token_at_offset(offset).right_biased() else { + let Some(token) = token_at_offset(file, offset) else { // empty file - return top_level_completions(); + return default_completions(true); }; - // We don't support completions inside comments since we don't have doc // comments a la JSDoc. // And we don't have string literal types so we bail out early for strings too. @@ -16,13 +19,178 @@ pub fn completion(file: &ast::SourceFile, offset: TextSize) -> Vec table_completions(&binder), + CompletionContext::Default(is_nested) => default_completions(!is_nested), + CompletionContext::SelectClause(select_clause) => { + select_completions(binder, file, select_clause) + } + } } -fn top_level_completions() -> Vec { - ["select", "table"] - .map(|x| CompletionItem::keyword(x.to_owned())) - .to_vec() +fn select_completions( + binder: binder::Binder, + file: &ast::SourceFile, + select_clause: ast::SelectClause, +) -> Vec { + let mut completions = vec![]; + let functions = binder.all_symbols_by_kind(SymbolKind::Function); + completions.extend(functions.into_iter().map(|name| CompletionItem { + label: name.to_string(), + kind: CompletionItemKind::Function, + detail: None, + insert_text: None, + insert_text_format: None, + trigger_completion_after_insert: false, + })); + + let tables = binder.all_symbols_by_kind(SymbolKind::Table); + completions.extend(tables.into_iter().map(|name| CompletionItem { + label: name.to_string(), + kind: CompletionItemKind::Table, + detail: None, + insert_text: None, + insert_text_format: None, + trigger_completion_after_insert: false, + })); + + if let Some(parent) = select_clause.syntax().parent() + && let Some(select) = ast::Select::cast(parent) + && let Some(from_clause) = select.from_clause() + { + for table_ptr in resolve::table_ptrs_from_clause(&binder, &from_clause) { + if let Some(create_table) = table_ptr + .to_node(file.syntax()) + .ancestors() + .find_map(ast::CreateTableLike::cast) + { + let columns = resolve::collect_table_columns(&binder, file.syntax(), &create_table); + completions.extend(columns.into_iter().filter_map(|column| { + let name = column.name()?; + Some(CompletionItem { + label: crate::symbols::Name::from_node(&name).to_string(), + kind: CompletionItemKind::Column, + detail: None, + insert_text: None, + insert_text_format: None, + trigger_completion_after_insert: false, + }) + })); + } + } + } + + return completions; +} + +fn table_completions(binder: &binder::Binder) -> Vec { + // We're in a TRUNCATE or TABLE statement, return table names + let tables = binder.all_symbols_by_kind(SymbolKind::Table); + tables + .into_iter() + .map(|name| CompletionItem { + label: name.to_string(), + kind: CompletionItemKind::Table, + detail: None, + insert_text: None, + insert_text_format: None, + trigger_completion_after_insert: false, + }) + .collect() +} + +enum CompletionContext { + TableOnly, + Default(bool), + SelectClause(ast::SelectClause), +} + +fn completion_context(token: SyntaxToken) -> CompletionContext { + let mut node = token.parent(); + let mut is_nested = false; + let mut kind = None; + while let Some(current_node) = node { + if ast::Stmt::can_cast(current_node.kind()) + && current_node + .parent() + .is_some_and(|x| x.kind() == SyntaxKind::SOURCE_FILE) + { + is_nested = true; + } + if ast::Truncate::can_cast(current_node.kind()) || ast::Table::can_cast(current_node.kind()) + { + if kind.is_none() { + kind = Some(CompletionContext::TableOnly) + }; + } + if let Some(select_clause) = ast::SelectClause::cast(current_node.clone()) { + if kind.is_none() { + kind = Some(CompletionContext::SelectClause(select_clause)) + }; + } + node = current_node.parent(); + } + kind.unwrap_or_else(|| CompletionContext::Default(is_nested)) +} + +fn token_at_offset(file: &ast::SourceFile, offset: TextSize) -> Option { + let Some(mut token) = file.syntax().token_at_offset(offset).left_biased() else { + // empty file - definitely at top level + return None; + }; + while token.kind() == SyntaxKind::WHITESPACE { + if let Some(tk) = token.prev_token() { + token = tk; + } + } + Some(token) +} + +fn default_completions(at_top_level: bool) -> Vec { + let select_insert_text = if at_top_level { + "select $0;" + } else { + "select $0" + }; + + let table_insert_text = if at_top_level { + "table $0;" + } else { + "table $0" + }; + + let mut completions = vec![ + CompletionItem { + label: "select".to_owned(), + kind: CompletionItemKind::Keyword, + detail: None, + insert_text: Some(select_insert_text.to_owned()), + insert_text_format: Some(CompletionInsertTextFormat::Snippet), + trigger_completion_after_insert: false, + }, + CompletionItem { + label: "table".to_owned(), + kind: CompletionItemKind::Keyword, + detail: None, + insert_text: Some(table_insert_text.to_owned()), + insert_text_format: Some(CompletionInsertTextFormat::Snippet), + trigger_completion_after_insert: true, + }, + ]; + + if at_top_level { + completions.push(CompletionItem { + label: "truncate".to_owned(), + kind: CompletionItemKind::Keyword, + detail: None, + insert_text: Some("truncate $0;".to_owned()), + insert_text_format: Some(CompletionInsertTextFormat::Snippet), + trigger_completion_after_insert: true, + }); + } + + completions } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -49,18 +217,7 @@ pub struct CompletionItem { pub detail: Option, pub insert_text: Option, pub insert_text_format: Option, -} - -impl CompletionItem { - fn keyword(text: String) -> CompletionItem { - CompletionItem { - label: text, - kind: CompletionItemKind::Keyword, - detail: None, - insert_text: None, - insert_text_format: None, - } - } + pub trigger_completion_after_insert: bool, } #[cfg(test)] @@ -104,6 +261,7 @@ mod tests { item.label, format!("{:?}", item.kind), item.detail.unwrap_or_default(), + item.insert_text.unwrap_or_default(), ] }) .collect(); @@ -111,7 +269,7 @@ mod tests { rows.sort(); let mut builder = Builder::default(); - builder.push_record(["label", "kind", "detail"]); + builder.push_record(["label", "kind", "detail", "insert_text"]); for row in rows { builder.push_record(row); } @@ -124,10 +282,11 @@ mod tests { #[test] fn completion_at_start() { assert_snapshot!(completions("$0"), @r" - label | kind | detail - --------+---------+-------- - select | Keyword | - table | Keyword | + label | kind | detail | insert_text + ----------+---------+--------+-------------- + select | Keyword | | select $0; + table | Keyword | | table $0; + truncate | Keyword | | truncate $0; "); } @@ -140,4 +299,65 @@ mod tests { fn completion_in_comment() { completions_not_found("-- $0 "); } + + #[test] + fn completion_after_truncate() { + assert_snapshot!(completions(" +create table users (id int); +truncate $0; +"), @r" + label | kind | detail | insert_text + -------+-------+--------+------------- + users | Table | | + "); + } + + #[test] + fn completion_table_at_top_level() { + assert_snapshot!(completions("$0"), @r" + label | kind | detail | insert_text + ----------+---------+--------+-------------- + select | Keyword | | select $0; + table | Keyword | | table $0; + truncate | Keyword | | truncate $0; + "); + } + + #[test] + fn completion_table_nested() { + assert_snapshot!(completions("select * from ($0)"), @r" + label | kind | detail | insert_text + --------+---------+--------+------------- + select | Keyword | | select $0 + table | Keyword | | table $0 + "); + } + + #[test] + fn completion_after_table() { + assert_snapshot!(completions(" +create table users (id int); +table $0; +"), @r" + label | kind | detail | insert_text + -------+-------+--------+------------- + users | Table | | + "); + } + + #[test] + fn completion_after_select() { + assert_snapshot!(completions(" +create table t(a text, b int); +create function f() returns text as 'select 1::text' language sql; +select $0 from t; +"), @r" + label | kind | detail | insert_text + -------+----------+--------+------------- + a | Column | | + b | Column | | + f | Function | | + t | Table | | + "); + } } diff --git a/crates/squawk_ide/src/find_references.rs b/crates/squawk_ide/src/find_references.rs index 2d7bb074..1776d48f 100644 --- a/crates/squawk_ide/src/find_references.rs +++ b/crates/squawk_ide/src/find_references.rs @@ -22,7 +22,7 @@ pub fn find_references(file: &ast::SourceFile, offset: TextSize) -> Vec { - if let Some(found_refs) = resolve::resolve_name_ref(&binder, root, &name_ref) + if let Some(found_refs) = resolve::resolve_name_ref_ptrs(&binder, root, &name_ref) && found_refs.iter().any(|ptr| targets.contains(ptr)) { refs.push(name_ref.syntax().text_range()); @@ -57,7 +57,7 @@ fn find_targets( } if let Some(name_ref) = ast::NameRef::cast(parent.clone()) { - return resolve::resolve_name_ref(binder, root, &name_ref); + return resolve::resolve_name_ref_ptrs(binder, root, &name_ref); } None diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index e9a204a6..135c0bf8 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -59,7 +59,7 @@ pub fn goto_definition(file: ast::SourceFile, offset: TextSize) -> SmallVec<[Tex if let Some(name_ref) = ast::NameRef::cast(parent.clone()) { let binder_output = binder::bind(&file); let root = file.syntax(); - if let Some(ptrs) = resolve::resolve_name_ref(&binder_output, root, &name_ref) { + if let Some(ptrs) = resolve::resolve_name_ref_ptrs(&binder_output, root, &name_ref) { return ptrs .iter() .map(|ptr| ptr.to_node(file.syntax()).text_range()) diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index 083a52e8..3d592724 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -293,7 +293,7 @@ fn hover_column( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let column_ptrs = resolve::resolve_name_ref(binder, root, name_ref)?; + let column_ptrs = resolve::resolve_name_ref_ptrs(binder, root, name_ref)?; let results: Vec = column_ptrs .iter() @@ -383,7 +383,7 @@ fn hover_composite_type_field( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let field_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let field_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; let field_name_node = field_ptr.to_node(root); @@ -435,7 +435,7 @@ fn hover_table( return Some(result); } - let table_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let table_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; @@ -737,7 +737,7 @@ fn hover_index( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let index_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let index_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; @@ -755,7 +755,7 @@ fn hover_sequence( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let sequence_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let sequence_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; @@ -773,7 +773,7 @@ fn hover_trigger( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let trigger_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let trigger_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; @@ -791,7 +791,7 @@ fn hover_event_trigger( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let event_trigger_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let event_trigger_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; @@ -809,7 +809,7 @@ fn hover_tablespace( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let tablespace_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let tablespace_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; let tablespace_name_node = tablespace_ptr.to_node(root); @@ -821,7 +821,7 @@ fn hover_database( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let database_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let database_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; let database_name_node = database_ptr.to_node(root); @@ -833,7 +833,7 @@ fn hover_server( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let server_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let server_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; let server_name_node = server_ptr.to_node(root); @@ -845,7 +845,7 @@ fn hover_extension( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let extension_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let extension_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; let extension_name_node = extension_ptr.to_node(root); @@ -857,7 +857,7 @@ fn hover_role( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let role_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let role_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; let role_name_node = role_ptr.to_node(root); @@ -869,7 +869,7 @@ fn hover_cursor( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let cursor_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let cursor_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; let cursor_name_node = cursor_ptr.to_node(root); @@ -882,7 +882,7 @@ fn hover_prepared_statement( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let statement_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let statement_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; let statement_name_node = statement_ptr.to_node(root); @@ -897,7 +897,7 @@ fn hover_channel( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let channel_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let channel_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; let channel_name_node = channel_ptr.to_node(root); @@ -910,7 +910,7 @@ fn hover_type( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let type_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let type_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; @@ -1135,7 +1135,7 @@ fn hover_schema( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let schema_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let schema_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; @@ -1169,7 +1169,7 @@ fn hover_function( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let function_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let function_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; @@ -1187,7 +1187,7 @@ fn hover_named_arg_parameter( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let param_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let param_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; let param_name_node = param_ptr.to_node(root); @@ -1271,7 +1271,7 @@ fn hover_aggregate( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let aggregate_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let aggregate_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; @@ -1302,7 +1302,7 @@ fn hover_procedure( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let procedure_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let procedure_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; @@ -1333,7 +1333,7 @@ fn hover_routine( name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let routine_ptr = resolve::resolve_name_ref(binder, root, name_ref)? + let routine_ptr = resolve::resolve_name_ref_ptrs(binder, root, name_ref)? .into_iter() .next()?; let routine_name = routine_ptr.to_node(root); diff --git a/crates/squawk_ide/src/inlay_hints.rs b/crates/squawk_ide/src/inlay_hints.rs index 794f7385..50bc69b1 100644 --- a/crates/squawk_ide/src/inlay_hints.rs +++ b/crates/squawk_ide/src/inlay_hints.rs @@ -54,7 +54,7 @@ fn inlay_hint_call_expr( ast::FieldExpr::cast(expr.syntax().clone())?.field()? }; - let function_ptr = resolve::resolve_name_ref(binder, root, &name_ref)? + let function_ptr = resolve::resolve_name_ref_ptrs(binder, root, &name_ref)? .into_iter() .next()?; diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index a8328351..9cc5e6c0 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -11,7 +11,7 @@ use crate::column_name::ColumnName; pub(crate) use crate::symbols::Schema; use crate::symbols::{Name, SymbolKind}; -pub(crate) fn resolve_name_ref( +pub(crate) fn resolve_name_ref_ptrs( binder: &Binder, root: &SyntaxNode, name_ref: &ast::NameRef, @@ -2296,16 +2296,7 @@ pub(crate) fn resolve_unqualified_star_table_ptrs( for ancestor in target.syntax().ancestors() { if let Some(select) = ast::Select::cast(ancestor.clone()) { let from_clause = select.from_clause()?; - let mut results = vec![]; - - for from_item in from_clause.from_items() { - collect_tables_from_item(binder, position, &from_item, &mut results); - } - - for join_expr in from_clause.join_exprs() { - collect_table_ptrs_from_join_expr(binder, position, &join_expr, &mut results); - } - + let results = table_ptrs_from_clause(binder, &from_clause); if results.is_empty() { return None; } @@ -2339,49 +2330,54 @@ pub(crate) fn resolve_unqualified_star_in_arg_list_ptrs( ) -> Option> { let select = arg_list.syntax().ancestors().find_map(ast::Select::cast)?; let from_clause = select.from_clause()?; - let position = arg_list.syntax().text_range().start(); + let results = table_ptrs_from_clause(binder, &from_clause); + if results.is_empty() { + return None; + } + + Some(results) +} + +pub(crate) fn table_ptrs_from_clause( + binder: &Binder, + from_clause: &ast::FromClause, +) -> Vec { let mut results = vec![]; for from_item in from_clause.from_items() { - collect_tables_from_item(binder, position, &from_item, &mut results); + collect_tables_from_item(binder, &from_item, &mut results); } for join_expr in from_clause.join_exprs() { - collect_table_ptrs_from_join_expr(binder, position, &join_expr, &mut results); - } - - if results.is_empty() { - return None; + collect_table_ptrs_from_join_expr(binder, &join_expr, &mut results); } - Some(results) + results } fn collect_table_ptrs_from_join_expr( binder: &Binder, - position: TextSize, join_expr: &ast::JoinExpr, results: &mut Vec, ) { if let Some(nested) = join_expr.join_expr() { - collect_table_ptrs_from_join_expr(binder, position, &nested, results); + collect_table_ptrs_from_join_expr(binder, &nested, results); } if let Some(from_item) = join_expr.from_item() { - collect_tables_from_item(binder, position, &from_item, results); + collect_tables_from_item(binder, &from_item, results); } if let Some(join) = join_expr.join() && let Some(from_item) = join.from_item() { - collect_tables_from_item(binder, position, &from_item, results); + collect_tables_from_item(binder, &from_item, results); } } fn collect_tables_from_item( binder: &Binder, - position: TextSize, from_item: &ast::FromItem, results: &mut Vec, ) { @@ -2394,6 +2390,7 @@ fn collect_tables_from_item( return; }; + let position = from_item.syntax().text_range().start(); if let Some(table_name_ptr) = resolve_table_name_ptr(binder, &table_name, &schema, position) { results.push(table_name_ptr); return; diff --git a/crates/squawk_server/src/lsp_utils.rs b/crates/squawk_server/src/lsp_utils.rs index 28e0fdfe..fb7f9f4e 100644 --- a/crates/squawk_server/src/lsp_utils.rs +++ b/crates/squawk_server/src/lsp_utils.rs @@ -95,12 +95,23 @@ pub(crate) fn completion_item( CompletionInsertTextFormat::Snippet => lsp_types::InsertTextFormat::SNIPPET, }); + let command = if item.trigger_completion_after_insert { + Some(lsp_types::Command { + title: "Trigger Completion".to_owned(), + command: "editor.action.triggerSuggest".to_owned(), + arguments: None, + }) + } else { + None + }; + lsp_types::CompletionItem { label: item.label, kind: Some(kind), detail: item.detail, insert_text: item.insert_text, insert_text_format, + command, ..Default::default() } }