From 7ce5956cb12b26b06e14c87876fe671dc1475ca7 Mon Sep 17 00:00:00 2001 From: Steve Dignam Date: Wed, 21 Jan 2026 20:05:51 -0500 Subject: [PATCH] ide: include types in completions --- crates/squawk_ide/src/completion.rs | 124 +++++++++++++++---- crates/squawk_ide/src/infer.rs | 185 ++++++++++++++++++++++++++++ crates/squawk_ide/src/lib.rs | 1 + crates/squawk_ide/src/resolve.rs | 82 ++++++++++++ 4 files changed, 369 insertions(+), 23 deletions(-) create mode 100644 crates/squawk_ide/src/infer.rs diff --git a/crates/squawk_ide/src/completion.rs b/crates/squawk_ide/src/completion.rs index 5440bc62..a49c2870 100644 --- a/crates/squawk_ide/src/completion.rs +++ b/crates/squawk_ide/src/completion.rs @@ -14,7 +14,8 @@ pub fn completion(file: &ast::SourceFile, offset: TextSize) -> Vec { + let columns = + resolve::collect_table_columns(&binder, file.syntax(), &create_table); + completions.extend(columns.into_iter().filter_map(|column| { + let name = column.name()?; + let detail = column.ty().map(|t| t.syntax().text().to_string()); + Some(CompletionItem { + label: Name::from_node(&name).to_string(), + kind: CompletionItemKind::Column, + detail, + insert_text: None, + insert_text_format: None, + trigger_completion_after_insert: false, + sort_text: None, + }) + })); + } + Some(resolve::TableSource::WithTable(with_table)) => { + let columns = resolve::collect_with_table_columns_with_types(&with_table); + completions.extend(columns.into_iter().map(|(name, ty)| CompletionItem { + label: name.to_string(), + kind: CompletionItemKind::Column, + detail: ty.map(|t| t.to_string()), + insert_text: None, + insert_text_format: None, + trigger_completion_after_insert: false, + sort_text: None, + })); + } + Some(resolve::TableSource::CreateView(create_view)) => { + let columns = resolve::collect_view_columns_with_types(&create_view); + completions.extend(columns.into_iter().map(|(name, ty)| CompletionItem { + label: name.to_string(), kind: CompletionItemKind::Column, - detail: None, + detail: ty.map(|t| t.to_string()), insert_text: None, insert_text_format: None, trigger_completion_after_insert: false, sort_text: None, - }) - })); + })); + } + Some(resolve::TableSource::CreateMaterializedView(create_materialized_view)) => { + let columns = resolve::collect_materialized_view_columns_with_types( + &create_materialized_view, + ); + completions.extend(columns.into_iter().map(|(name, ty)| CompletionItem { + label: name.to_string(), + kind: CompletionItemKind::Column, + detail: ty.map(|t| t.to_string()), + insert_text: None, + insert_text_format: None, + trigger_completion_after_insert: false, + sort_text: None, + })); + } + None => {} } } } @@ -262,10 +303,11 @@ fn delete_expr_completions( let columns = resolve::collect_table_columns(&binder, file.syntax(), &create_table); completions.extend(columns.into_iter().filter_map(|column| { let name = column.name()?; + let detail = column.ty().map(|t| t.syntax().text().to_string()); Some(CompletionItem { label: Name::from_node(&name).to_string(), kind: CompletionItemKind::Column, - detail: None, + detail, insert_text: None, insert_text_format: None, trigger_completion_after_insert: false, @@ -604,8 +646,8 @@ select $0 from t; "), @r" label | kind | detail | insert_text --------------------+----------+-------------------------+------------- - a | Column | | - b | Column | | + a | Column | text | + b | Column | int | t | Table | | f() | Function | public.f() returns text | public | Schema | | @@ -616,6 +658,42 @@ select $0 from t; "); } + #[test] + fn completion_after_select_with_cte() { + assert_snapshot!(completions(" +with t as (select 1 a) +select $0 from t; +"), @r" + label | kind | detail | insert_text + --------------------+--------+---------+------------- + a | Column | integer | + public | Schema | | + pg_catalog | Schema | | + pg_temp | Schema | | + pg_toast | Schema | | + information_schema | Schema | | + "); + } + + #[test] + fn completion_values_cte() { + assert_snapshot!(completions(" +with t as (values (1, 'foo', false)) +select $0 from t; +"), @r" + label | kind | detail | insert_text + --------------------+--------+---------+------------- + column1 | Column | integer | + column2 | Column | text | + column3 | Column | boolean | + public | Schema | | + pg_catalog | Schema | | + pg_temp | Schema | | + pg_toast | Schema | | + information_schema | Schema | | + "); + } + #[test] fn completion_with_schema_qualifier() { assert_snapshot!(completions(" @@ -681,8 +759,8 @@ delete from t where $0; "), @r" label | kind | detail | insert_text -------------+----------+---------------------------------+------------- - id | Column | | - name | Column | | + id | Column | int | + name | Column | text | t | Table | | is_active() | Function | public.is_active() returns bool | ") @@ -696,8 +774,8 @@ delete from t returning $0; "), @r" label | kind | detail | insert_text -------+--------+--------+------------- - id | Column | | - name | Column | | + id | Column | int | + name | Column | text | t | Table | | "); } @@ -717,8 +795,8 @@ delete from t where t.$0; "), @r" label | kind | detail | insert_text -------+----------+--------+------------- - a | Column | | - b | Column | | + a | Column | int | + b | Column | text | f | Function | | "); } diff --git a/crates/squawk_ide/src/infer.rs b/crates/squawk_ide/src/infer.rs new file mode 100644 index 00000000..3d19e204 --- /dev/null +++ b/crates/squawk_ide/src/infer.rs @@ -0,0 +1,185 @@ +use std::fmt; + +use squawk_syntax::{ + SyntaxKind, + ast::{self, AstNode}, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum Type { + Integer, + Numeric, + Text, + Bit, + Boolean, + Unknown, + Record, + Array(Box), + Other(String), +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Integer => write!(f, "integer"), + Type::Numeric => write!(f, "numeric"), + Type::Text => write!(f, "text"), + Type::Bit => write!(f, "bit"), + Type::Boolean => write!(f, "boolean"), + Type::Unknown => write!(f, "unknown"), + Type::Record => write!(f, "record"), + Type::Array(inner) => write!(f, "{inner}[]"), + Type::Other(s) => write!(f, "{s}"), + } + } +} + +pub(crate) fn infer_type_from_expr(expr: &ast::Expr) -> Option { + match expr { + ast::Expr::CastExpr(cast_expr) => infer_type_from_ty(&cast_expr.ty()?), + ast::Expr::ArrayExpr(array_expr) => { + let first_elem = array_expr.exprs().next()?; + let elem_ty = infer_type_from_expr(&first_elem)?; + Some(Type::Array(Box::new(elem_ty))) + } + ast::Expr::BinExpr(_bin_expr) => todo!(), + ast::Expr::Literal(literal) => infer_type_from_literal(literal), + ast::Expr::ParenExpr(paren) => paren.expr().and_then(|e| infer_type_from_expr(&e)), + ast::Expr::TupleExpr(_) => Some(Type::Record), + _ => None, + } +} + +fn infer_type_from_ty(ty: &ast::Type) -> Option { + match ty { + ast::Type::CharType(_) => Some(Type::Text), + ast::Type::BitType(_) => Some(Type::Bit), + ast::Type::PathType(path_type) => { + let name = path_type.path()?.segment()?.name_ref()?; + Some(Type::Other(name.syntax().text().to_string())) + } + _ => None, + } +} + +fn infer_type_from_literal(literal: &ast::Literal) -> Option { + let token = literal.syntax().first_token()?; + match token.kind() { + SyntaxKind::INT_NUMBER => Some(Type::Integer), + SyntaxKind::FLOAT_NUMBER => Some(Type::Numeric), + SyntaxKind::STRING + | SyntaxKind::DOLLAR_QUOTED_STRING + | SyntaxKind::ESC_STRING + | SyntaxKind::UNICODE_ESC_STRING => Some(Type::Text), + SyntaxKind::BIT_STRING | SyntaxKind::BYTE_STRING => Some(Type::Bit), + SyntaxKind::TRUE_KW | SyntaxKind::FALSE_KW => Some(Type::Boolean), + SyntaxKind::NULL_KW => Some(Type::Unknown), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use insta::assert_snapshot; + + fn infer(sql: &str) -> String { + let parse = ast::SourceFile::parse(sql); + for stmt in parse.tree().stmts() { + match stmt { + ast::Stmt::Select(select) => { + let select_clause = select.select_clause().expect("expected select clause"); + let target_list = select_clause.target_list().expect("expected target list"); + + if let Some(target) = target_list.targets().next() { + let expr = target.expr().expect("expected expr"); + let ty = infer_type_from_expr(&expr).expect("expected type"); + return format!("{ty}"); + } + } + _ => unreachable!("unexpected stmt type"), + } + } + unreachable!("should always have at least one target") + } + + #[test] + fn integer_literal() { + assert_snapshot!(infer("select 1"), @"integer"); + } + + #[test] + fn float_literal() { + assert_snapshot!(infer("select 1.5"), @"numeric"); + } + + #[test] + fn string_literal() { + assert_snapshot!(infer("select 'hello'"), @"text"); + } + + #[test] + fn dollar_quoted_string() { + assert_snapshot!(infer("select $$hello$$"), @"text"); + } + + #[test] + fn escape_string() { + assert_snapshot!(infer("select E'hello'"), @"text"); + } + + #[test] + fn boolean_true() { + assert_snapshot!(infer("select true"), @"boolean"); + } + + #[test] + fn boolean_false() { + assert_snapshot!(infer("select false"), @"boolean"); + } + + #[test] + fn null_literal() { + assert_snapshot!(infer("select null"), @"unknown"); + } + + #[test] + fn cast_expr() { + assert_snapshot!(infer("select 1::text"), @"text"); + } + + #[test] + fn cast_expr_varchar() { + assert_snapshot!(infer("select 1::varchar(255)"), @"text"); + } + + #[test] + fn bit_string() { + assert_snapshot!(infer("select b'100'"), @"bit"); + } + + #[test] + fn bit_varying() { + assert_snapshot!(infer("select b'100'::bit varying"), @"bit"); + } + + #[test] + fn array() { + assert_snapshot!(infer("select array['foo', 'bar']"), @"text[]"); + } + + #[test] + fn record() { + assert_snapshot!(infer("select (1, 2)"), @"record"); + } + + #[test] + fn paren_expr() { + assert_snapshot!(infer("select (1)"), @"integer"); + } + + #[test] + fn nested_paren_expr() { + assert_snapshot!(infer("select ((1.5))"), @"numeric"); + } +} diff --git a/crates/squawk_ide/src/lib.rs b/crates/squawk_ide/src/lib.rs index 19ed9650..22103535 100644 --- a/crates/squawk_ide/src/lib.rs +++ b/crates/squawk_ide/src/lib.rs @@ -9,6 +9,7 @@ pub mod find_references; mod generated; pub mod goto_definition; pub mod hover; +mod infer; pub mod inlay_hints; mod offsets; mod quote; diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index bdd4ec38..f2320f11 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -8,6 +8,7 @@ use squawk_syntax::{ use crate::binder::Binder; use crate::classify::{NameRefClass, classify_name_ref}; use crate::column_name::ColumnName; +use crate::infer::{Type, infer_type_from_expr}; pub(crate) use crate::symbols::Schema; use crate::symbols::{Name, SymbolKind}; @@ -2946,6 +2947,22 @@ pub(crate) fn collect_view_column_names(create_view: &ast::CreateView) -> Vec Vec<(Name, Option)> { + let Some(select) = select_from_view_query(create_view) else { + return vec![]; + }; + let Some(select_clause) = select.select_clause() else { + return vec![]; + }; + let Some(target_list) = select_clause.target_list() else { + return vec![]; + }; + + collect_target_list_columns_with_types(&target_list) +} + pub(crate) fn collect_materialized_view_column_names( create_materialized_view: &ast::CreateMaterializedView, ) -> Vec { @@ -2969,6 +2986,22 @@ pub(crate) fn collect_materialized_view_column_names( collect_target_list_column_names(&target_list) } +pub(crate) fn collect_materialized_view_columns_with_types( + create_materialized_view: &ast::CreateMaterializedView, +) -> Vec<(Name, Option)> { + let Some(select) = select_from_materialized_view_query(create_materialized_view) else { + return vec![]; + }; + let Some(select_clause) = select.select_clause() else { + return vec![]; + }; + let Some(target_list) = select_clause.target_list() else { + return vec![]; + }; + + collect_target_list_columns_with_types(&target_list) +} + fn select_from_materialized_view_query( create_materialized_view: &ast::CreateMaterializedView, ) -> Option { @@ -3074,6 +3107,55 @@ fn collect_target_list_column_names(target_list: &ast::TargetList) -> Vec columns } +pub(crate) fn collect_with_table_columns_with_types( + with_table: &ast::WithTable, +) -> Vec<(Name, Option)> { + let Some(query) = with_table.query() else { + return vec![]; + }; + + if let ast::WithQuery::Values(values) = query { + let mut results = vec![]; + if let Some(row_list) = values.row_list() + && let Some(first_row) = row_list.rows().next() + { + for (idx, expr) in first_row.exprs().enumerate() { + let name = Name::from_string(format!("column{}", idx + 1)); + let ty = infer_type_from_expr(&expr); + results.push((name, ty)); + } + } + return results; + } + + let Some(cte_select) = select_from_with_query(query) else { + return vec![]; + }; + let Some(select_clause) = cte_select.select_clause() else { + return vec![]; + }; + let Some(target_list) = select_clause.target_list() else { + return vec![]; + }; + + collect_target_list_columns_with_types(&target_list) +} + +fn collect_target_list_columns_with_types( + target_list: &ast::TargetList, +) -> Vec<(Name, Option)> { + let mut columns = vec![]; + for target in target_list.targets() { + if let Some((col_name, _node)) = ColumnName::from_target(target.clone()) + && let Some(col_name_str) = col_name.to_string() + { + let ty = target.expr().and_then(|e| infer_type_from_expr(&e)); + columns.push((Name::from_string(col_name_str), ty)); + } + } + columns +} + fn select_from_view_query(create_view: &ast::CreateView) -> Option { let query = create_view.query()?; match query {