Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions crates/squawk_ide/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,22 @@ fn select_completions(
sort_text: None,
}));
}
Some(resolve::TableSource::ParenSelect(paren_select)) => {
let columns = resolve::collect_paren_select_columns_with_types(
&binder,
file.syntax(),
&paren_select,
);
completions.extend(columns.into_iter().map(|(name, ty)| CompletionItem {
label: name.to_string(),
kind: CompletionItemKind::Column,
detail: ty.map(|t| t.to_string()),
insert_text: None,
insert_text_format: None,
trigger_completion_after_insert: false,
sort_text: None,
}));
}
None => {}
}
}
Expand Down Expand Up @@ -694,6 +710,25 @@ select $0 from t;
");
}

#[test]
fn completion_values_subquery() {
assert_snapshot!(completions("
select $0 from (values (1, 'foo', 1.5, false));
"), @r"
label | kind | detail | insert_text
--------------------+--------+---------+-------------
column1 | Column | integer |
column2 | Column | text |
column3 | Column | numeric |
column4 | Column | boolean |
public | Schema | |
pg_catalog | Schema | |
pg_temp | Schema | |
pg_toast | Schema | |
information_schema | Schema | |
");
}

#[test]
fn completion_with_schema_qualifier() {
assert_snapshot!(completions("
Expand Down
77 changes: 67 additions & 10 deletions crates/squawk_ide/src/hover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ impl ColumnHover {
fn anon_column(col_name: &str) -> String {
format!("column {}", col_name)
}
fn anon_column_type(col_name: &str, ty: &str) -> String {
format!("column {} {}", col_name, ty)
}
}

fn hover_column(
Expand All @@ -300,7 +303,7 @@ fn hover_column(
.iter()
.filter_map(|column_ptr| {
let column_name_node = column_ptr.to_node(root);
format_hover_for_column_node(binder, &column_name_node, name_ref)
format_hover_for_column_node(binder, root, &column_name_node, name_ref)
})
.collect();

Expand All @@ -313,6 +316,7 @@ fn hover_column(

fn format_hover_for_column_node(
binder: &binder::Binder,
root: &SyntaxNode,
column_name_node: &squawk_syntax::SyntaxNode,
name_ref: &ast::NameRef,
) -> Option<String> {
Expand All @@ -333,16 +337,36 @@ fn format_hover_for_column_node(
&column_name.to_string(),
));
}
if ast::ParenSelect::can_cast(a.kind())
&& let Some(field_expr) = name_ref.syntax().parent().and_then(ast::FieldExpr::cast)
&& let Some(base) = field_expr.base()
&& let ast::Expr::NameRef(table_name_ref) = base
{
let table_name = Name::from_node(&table_name_ref);
let column_name = Name::from_string(column_name_node.text().to_string());
return Some(ColumnHover::table_column(
&table_name.to_string(),
if let Some(paren_select) = ast::ParenSelect::cast(a.clone()) {
// Qualified access like `t.a`
if let Some(field_expr) = name_ref.syntax().parent().and_then(ast::FieldExpr::cast)
&& let Some(base) = field_expr.base()
&& let ast::Expr::NameRef(table_name_ref) = base
{
let table_name = Name::from_node(&table_name_ref);
let column_name = Name::from_string(column_name_node.text().to_string());
return Some(ColumnHover::table_column(
&table_name.to_string(),
&column_name.to_string(),
));
}
// Unqualified access like `a` from `select a from (select 1 a)`
// For VALUES, use name_ref since column_name_node is the expression
let column_name = if column_name_node
.ancestors()
.any(|a| ast::Values::can_cast(a.kind()))
{
Name::from_node(name_ref)
} else {
Name::from_string(column_name_node.text().to_string())
};
let ty = resolve::collect_paren_select_columns_with_types(binder, root, &paren_select)
.into_iter()
.find(|(name, _)| *name == column_name)
.and_then(|(_, ty)| ty)?;
return Some(ColumnHover::anon_column_type(
&column_name.to_string(),
&ty.to_string(),
));
}

Expand Down Expand Up @@ -459,6 +483,7 @@ fn hover_table_from_ptr(
resolve::TableSource::CreateTable(create_table) => {
format_create_table(&create_table, binder)
}
resolve::TableSource::ParenSelect(paren_select) => format_paren_select(&paren_select),
}
}

Expand Down Expand Up @@ -553,6 +578,9 @@ fn hover_qualified_star_columns(
resolve::TableSource::CreateMaterializedView(create_materialized_view) => {
hover_qualified_star_columns_from_materialized_view(&create_materialized_view, binder)
}
resolve::TableSource::ParenSelect(paren_select) => {
hover_qualified_star_columns_from_subquery(root, &paren_select, binder)
}
}
}

Expand Down Expand Up @@ -1028,6 +1056,11 @@ fn format_with_table(with_table: &ast::WithTable) -> Option<String> {
Some(format!("with {} as ({})", name, query))
}

fn format_paren_select(paren_select: &ast::ParenSelect) -> Option<String> {
let query = paren_select.select()?.syntax().text().to_string();
Some(format!("({})", query))
}

fn format_create_index(create_index: &ast::CreateIndex, binder: &binder::Binder) -> Option<String> {
let index_name = create_index.name()?.syntax().text().to_string();

Expand Down Expand Up @@ -2914,6 +2947,30 @@ select COLUMN1$0, COLUMN2 from t;
");
}

#[test]
fn hover_on_subquery_column() {
assert_snapshot!(check_hover("
select a$0 from (select 1 a);
"), @r"
hover: column a integer
╭▸
2 │ select a from (select 1 a);
╰╴ ─ hover
");
}

#[test]
fn hover_on_subquery_values_column() {
assert_snapshot!(check_hover("
select column1$0 from (values (1, 'foo'));
"), @r"
hover: column column1 integer
╭▸
2 │ select column1 from (values (1, 'foo'));
╰╴ ─ hover
");
}

#[test]
fn hover_on_cte_qualified_star() {
assert_snapshot!(check_hover("
Expand Down
2 changes: 1 addition & 1 deletion crates/squawk_ide/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub(crate) fn infer_type_from_expr(expr: &ast::Expr) -> Option<Type> {
}
}

fn infer_type_from_ty(ty: &ast::Type) -> Option<Type> {
pub(crate) fn infer_type_from_ty(ty: &ast::Type) -> Option<Type> {
match ty {
ast::Type::CharType(_) => Some(Type::Text),
ast::Type::BitType(_) => Some(Type::Bit),
Expand Down
112 changes: 111 additions & 1 deletion crates/squawk_ide/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use squawk_syntax::{
use crate::binder::Binder;
use crate::classify::{NameRefClass, classify_name_ref};
use crate::column_name::ColumnName;
use crate::infer::{Type, infer_type_from_expr};
use crate::infer::{Type, infer_type_from_expr, infer_type_from_ty};
pub(crate) use crate::symbols::Schema;
use crate::symbols::{Name, SymbolKind};

Expand Down Expand Up @@ -2178,6 +2178,7 @@ fn resolve_subquery_column_ptr(
}
}

// TODO: this should just be a match stmt
if let ast::SelectVariant::Table(table) = select_variant {
let path = table.relation_name()?.path()?;
let (table_name, schema) = extract_table_schema_from_path(&path)?;
Expand All @@ -2199,6 +2200,19 @@ fn resolve_subquery_column_ptr(
);
}

if let ast::SelectVariant::Values(values) = select_variant {
if let Some(num_str) = column_name.0.strip_prefix("column")
&& let Ok(col_num) = num_str.parse::<usize>()
&& col_num > 0
&& let Some(row_list) = values.row_list()
&& let Some(first_row) = row_list.rows().next()
&& let Some(expr) = first_row.exprs().nth(col_num - 1)
{
return Some(SyntaxNodePtr::new(expr.syntax()));
}
return None;
}

let ast::SelectVariant::Select(subquery_select) = select_variant else {
return None;
};
Expand Down Expand Up @@ -2527,9 +2541,14 @@ pub(crate) enum TableSource {
CreateView(ast::CreateView),
CreateMaterializedView(ast::CreateMaterializedView),
CreateTable(ast::CreateTableLike),
ParenSelect(ast::ParenSelect),
}

pub(crate) fn find_table_source(node: &SyntaxNode) -> Option<TableSource> {
if let Some(paren_select) = ast::ParenSelect::cast(node.clone()) {
return Some(TableSource::ParenSelect(paren_select));
}

for ancestor in node.ancestors() {
if let Some(with_table) = ast::WithTable::cast(ancestor.clone()) {
return Some(TableSource::WithTable(with_table));
Expand Down Expand Up @@ -3156,6 +3175,97 @@ fn collect_target_list_columns_with_types(
columns
}

pub(crate) fn collect_paren_select_columns_with_types(
binder: &Binder,
root: &SyntaxNode,
paren_select: &ast::ParenSelect,
) -> Vec<(Name, Option<Type>)> {
let Some(select_variant) = paren_select.select() else {
return vec![];
};
collect_select_variant_columns_with_types(binder, root, &select_variant)
}

fn collect_select_variant_columns_with_types(
binder: &Binder,
root: &SyntaxNode,
select_variant: &ast::SelectVariant,
) -> Vec<(Name, Option<Type>)> {
match select_variant {
ast::SelectVariant::Values(values) => {
let mut results = vec![];
if let Some(row_list) = values.row_list()
&& let Some(first_row) = row_list.rows().next()
{
for (idx, expr) in first_row.exprs().enumerate() {
let name = Name::from_string(format!("column{}", idx + 1));
let ty = infer_type_from_expr(&expr);
results.push((name, ty));
}
}
results
}
ast::SelectVariant::Select(select) => {
let Some(select_clause) = select.select_clause() else {
return vec![];
};
let Some(target_list) = select_clause.target_list() else {
return vec![];
};
collect_target_list_columns_with_types(&target_list)
}
ast::SelectVariant::SelectInto(select_into) => {
let Some(select_clause) = select_into.select_clause() else {
return vec![];
};
let Some(target_list) = select_clause.target_list() else {
return vec![];
};
collect_target_list_columns_with_types(&target_list)
}
ast::SelectVariant::ParenSelect(nested) => {
collect_paren_select_columns_with_types(binder, root, nested)
}
ast::SelectVariant::CompoundSelect(compound) => {
let Some(lhs) = compound.lhs() else {
return vec![];
};
collect_select_variant_columns_with_types(binder, root, &lhs)
}
ast::SelectVariant::Table(table) => {
let Some(path) = table.relation_name().and_then(|r| r.path()) else {
return vec![];
};
let Some(table_name) = extract_table_name(&path) else {
return vec![];
};
let schema = extract_schema_name(&path);
let position = table.syntax().text_range().start();
let Some(table_ptr) =
binder.lookup_with(&table_name, SymbolKind::Table, position, &schema)
else {
return vec![];
};
let Some(create_table) = table_ptr
.to_node(root)
.ancestors()
.find_map(ast::CreateTableLike::cast)
else {
return vec![];
};
let columns = collect_table_columns(binder, root, &create_table);
columns
.into_iter()
.filter_map(|col| {
let name = Name::from_node(&col.name()?);
let ty = col.ty().and_then(|t| infer_type_from_ty(&t));
Some((name, ty))
})
.collect()
}
}
}

fn select_from_view_query(create_view: &ast::CreateView) -> Option<ast::Select> {
let query = create_view.query()?;
match query {
Expand Down
15 changes: 14 additions & 1 deletion crates/squawk_server/src/lsp_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,23 @@ pub(crate) fn completion_item(
None
};

let label_details = item
.detail
.map(|detail| lsp_types::CompletionItemLabelDetails {
detail: None,
// Use description instead of detail so VSCode puts it to the right
// of the item's name instead of smushing them together.
description: Some(detail),
});

lsp_types::CompletionItem {
label: item.label,
kind: Some(kind),
detail: item.detail,
// We use label_details instead of detail so that VSCode shows the type
// info / function signature when the completion list is open, instead
// of waiting until you select the given field.
detail: None,
label_details,
insert_text: item.insert_text,
insert_text_format,
sort_text,
Expand Down
Loading