Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 101 additions & 23 deletions crates/squawk_ide/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ pub fn completion(file: &ast::SourceFile, offset: TextSize) -> Vec<CompletionIte
};
// We don't support completions inside comments since we don't have doc
// comments a la JSDoc.
// And we don't have string literal types so we bail out early for strings too.
// And we don't support enums aka string literal types yet so we bail out
// early for strings as well
if is_string_or_comment(token.kind()) {
return vec![];
}
Expand Down Expand Up @@ -70,24 +71,64 @@ fn select_completions(
&& let Some(from_clause) = select.from_clause()
{
for table_ptr in resolve::table_ptrs_from_clause(&binder, &from_clause) {
if let Some(create_table) = table_ptr
.to_node(file.syntax())
.ancestors()
.find_map(ast::CreateTableLike::cast)
{
let columns = resolve::collect_table_columns(&binder, file.syntax(), &create_table);
completions.extend(columns.into_iter().filter_map(|column| {
let name = column.name()?;
Some(CompletionItem {
label: crate::symbols::Name::from_node(&name).to_string(),
let table_node = table_ptr.to_node(file.syntax());
match resolve::find_table_source(&table_node) {
Some(resolve::TableSource::CreateTable(create_table)) => {
let columns =
resolve::collect_table_columns(&binder, file.syntax(), &create_table);
completions.extend(columns.into_iter().filter_map(|column| {
let name = column.name()?;
let detail = column.ty().map(|t| t.syntax().text().to_string());
Some(CompletionItem {
label: Name::from_node(&name).to_string(),
kind: CompletionItemKind::Column,
detail,
insert_text: None,
insert_text_format: None,
trigger_completion_after_insert: false,
sort_text: None,
})
}));
}
Some(resolve::TableSource::WithTable(with_table)) => {
let columns = resolve::collect_with_table_columns_with_types(&with_table);
completions.extend(columns.into_iter().map(|(name, ty)| CompletionItem {
label: name.to_string(),
kind: CompletionItemKind::Column,
detail: ty.map(|t| t.to_string()),
insert_text: None,
insert_text_format: None,
trigger_completion_after_insert: false,
sort_text: None,
}));
}
Some(resolve::TableSource::CreateView(create_view)) => {
let columns = resolve::collect_view_columns_with_types(&create_view);
completions.extend(columns.into_iter().map(|(name, ty)| CompletionItem {
label: name.to_string(),
kind: CompletionItemKind::Column,
detail: None,
detail: ty.map(|t| t.to_string()),
insert_text: None,
insert_text_format: None,
trigger_completion_after_insert: false,
sort_text: None,
})
}));
}));
}
Some(resolve::TableSource::CreateMaterializedView(create_materialized_view)) => {
let columns = resolve::collect_materialized_view_columns_with_types(
&create_materialized_view,
);
completions.extend(columns.into_iter().map(|(name, ty)| CompletionItem {
label: name.to_string(),
kind: CompletionItemKind::Column,
detail: ty.map(|t| t.to_string()),
insert_text: None,
insert_text_format: None,
trigger_completion_after_insert: false,
sort_text: None,
}));
}
None => {}
}
}
}
Expand Down Expand Up @@ -262,10 +303,11 @@ fn delete_expr_completions(
let columns = resolve::collect_table_columns(&binder, file.syntax(), &create_table);
completions.extend(columns.into_iter().filter_map(|column| {
let name = column.name()?;
let detail = column.ty().map(|t| t.syntax().text().to_string());
Some(CompletionItem {
label: Name::from_node(&name).to_string(),
kind: CompletionItemKind::Column,
detail: None,
detail,
insert_text: None,
insert_text_format: None,
trigger_completion_after_insert: false,
Expand Down Expand Up @@ -604,8 +646,8 @@ select $0 from t;
"), @r"
label | kind | detail | insert_text
--------------------+----------+-------------------------+-------------
a | Column | |
b | Column | |
a | Column | text |
b | Column | int |
t | Table | |
f() | Function | public.f() returns text |
public | Schema | |
Expand All @@ -616,6 +658,42 @@ select $0 from t;
");
}

#[test]
fn completion_after_select_with_cte() {
assert_snapshot!(completions("
with t as (select 1 a)
select $0 from t;
"), @r"
label | kind | detail | insert_text
--------------------+--------+---------+-------------
a | Column | integer |
public | Schema | |
pg_catalog | Schema | |
pg_temp | Schema | |
pg_toast | Schema | |
information_schema | Schema | |
");
}

#[test]
fn completion_values_cte() {
assert_snapshot!(completions("
with t as (values (1, 'foo', false))
select $0 from t;
"), @r"
label | kind | detail | insert_text
--------------------+--------+---------+-------------
column1 | Column | integer |
column2 | Column | text |
column3 | Column | boolean |
public | Schema | |
pg_catalog | Schema | |
pg_temp | Schema | |
pg_toast | Schema | |
information_schema | Schema | |
");
}

#[test]
fn completion_with_schema_qualifier() {
assert_snapshot!(completions("
Expand Down Expand Up @@ -681,8 +759,8 @@ delete from t where $0;
"), @r"
label | kind | detail | insert_text
-------------+----------+---------------------------------+-------------
id | Column | |
name | Column | |
id | Column | int |
name | Column | text |
t | Table | |
is_active() | Function | public.is_active() returns bool |
")
Expand All @@ -696,8 +774,8 @@ delete from t returning $0;
"), @r"
label | kind | detail | insert_text
-------+--------+--------+-------------
id | Column | |
name | Column | |
id | Column | int |
name | Column | text |
t | Table | |
");
}
Expand All @@ -717,8 +795,8 @@ delete from t where t.$0;
"), @r"
label | kind | detail | insert_text
-------+----------+--------+-------------
a | Column | |
b | Column | |
a | Column | int |
b | Column | text |
f | Function | |
");
}
Expand Down
185 changes: 185 additions & 0 deletions crates/squawk_ide/src/infer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
use std::fmt;

use squawk_syntax::{
SyntaxKind,
ast::{self, AstNode},
};

#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum Type {
Integer,
Numeric,
Text,
Bit,
Boolean,
Unknown,
Record,
Array(Box<Type>),
Other(String),
}

impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Type::Integer => write!(f, "integer"),
Type::Numeric => write!(f, "numeric"),
Type::Text => write!(f, "text"),
Type::Bit => write!(f, "bit"),
Type::Boolean => write!(f, "boolean"),
Type::Unknown => write!(f, "unknown"),
Type::Record => write!(f, "record"),
Type::Array(inner) => write!(f, "{inner}[]"),
Type::Other(s) => write!(f, "{s}"),
}
}
}

pub(crate) fn infer_type_from_expr(expr: &ast::Expr) -> Option<Type> {
match expr {
ast::Expr::CastExpr(cast_expr) => infer_type_from_ty(&cast_expr.ty()?),
ast::Expr::ArrayExpr(array_expr) => {
let first_elem = array_expr.exprs().next()?;
let elem_ty = infer_type_from_expr(&first_elem)?;
Some(Type::Array(Box::new(elem_ty)))
}
ast::Expr::BinExpr(_bin_expr) => todo!(),
ast::Expr::Literal(literal) => infer_type_from_literal(literal),
ast::Expr::ParenExpr(paren) => paren.expr().and_then(|e| infer_type_from_expr(&e)),
ast::Expr::TupleExpr(_) => Some(Type::Record),
_ => None,
}
}

fn infer_type_from_ty(ty: &ast::Type) -> Option<Type> {
match ty {
ast::Type::CharType(_) => Some(Type::Text),
ast::Type::BitType(_) => Some(Type::Bit),
ast::Type::PathType(path_type) => {
let name = path_type.path()?.segment()?.name_ref()?;
Some(Type::Other(name.syntax().text().to_string()))
}
_ => None,
}
}

fn infer_type_from_literal(literal: &ast::Literal) -> Option<Type> {
let token = literal.syntax().first_token()?;
match token.kind() {
SyntaxKind::INT_NUMBER => Some(Type::Integer),
SyntaxKind::FLOAT_NUMBER => Some(Type::Numeric),
SyntaxKind::STRING
| SyntaxKind::DOLLAR_QUOTED_STRING
| SyntaxKind::ESC_STRING
| SyntaxKind::UNICODE_ESC_STRING => Some(Type::Text),
SyntaxKind::BIT_STRING | SyntaxKind::BYTE_STRING => Some(Type::Bit),
SyntaxKind::TRUE_KW | SyntaxKind::FALSE_KW => Some(Type::Boolean),
SyntaxKind::NULL_KW => Some(Type::Unknown),
_ => None,
}
}

#[cfg(test)]
mod tests {
use super::*;
use insta::assert_snapshot;

fn infer(sql: &str) -> String {
let parse = ast::SourceFile::parse(sql);
for stmt in parse.tree().stmts() {
match stmt {
ast::Stmt::Select(select) => {
let select_clause = select.select_clause().expect("expected select clause");
let target_list = select_clause.target_list().expect("expected target list");

if let Some(target) = target_list.targets().next() {
let expr = target.expr().expect("expected expr");
let ty = infer_type_from_expr(&expr).expect("expected type");
return format!("{ty}");
}
}
_ => unreachable!("unexpected stmt type"),
}
}
unreachable!("should always have at least one target")
}

#[test]
fn integer_literal() {
assert_snapshot!(infer("select 1"), @"integer");
}

#[test]
fn float_literal() {
assert_snapshot!(infer("select 1.5"), @"numeric");
}

#[test]
fn string_literal() {
assert_snapshot!(infer("select 'hello'"), @"text");
}

#[test]
fn dollar_quoted_string() {
assert_snapshot!(infer("select $$hello$$"), @"text");
}

#[test]
fn escape_string() {
assert_snapshot!(infer("select E'hello'"), @"text");
}

#[test]
fn boolean_true() {
assert_snapshot!(infer("select true"), @"boolean");
}

#[test]
fn boolean_false() {
assert_snapshot!(infer("select false"), @"boolean");
}

#[test]
fn null_literal() {
assert_snapshot!(infer("select null"), @"unknown");
}

#[test]
fn cast_expr() {
assert_snapshot!(infer("select 1::text"), @"text");
}

#[test]
fn cast_expr_varchar() {
assert_snapshot!(infer("select 1::varchar(255)"), @"text");
}

#[test]
fn bit_string() {
assert_snapshot!(infer("select b'100'"), @"bit");
}

#[test]
fn bit_varying() {
assert_snapshot!(infer("select b'100'::bit varying"), @"bit");
}

#[test]
fn array() {
assert_snapshot!(infer("select array['foo', 'bar']"), @"text[]");
}

#[test]
fn record() {
assert_snapshot!(infer("select (1, 2)"), @"record");
}

#[test]
fn paren_expr() {
assert_snapshot!(infer("select (1)"), @"integer");
}

#[test]
fn nested_paren_expr() {
assert_snapshot!(infer("select ((1.5))"), @"numeric");
}
}
1 change: 1 addition & 0 deletions crates/squawk_ide/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod find_references;
mod generated;
pub mod goto_definition;
pub mod hover;
mod infer;
pub mod inlay_hints;
mod offsets;
mod quote;
Expand Down
Loading
Loading