Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions crates/squawk_ide/src/classify.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use squawk_syntax::ast::{self, AstNode};
use squawk_syntax::{
SyntaxKind,
ast::{self, AstNode},
};

#[derive(Debug)]
pub(crate) enum NameRefClass {
Expand Down Expand Up @@ -88,6 +91,39 @@ pub(crate) enum NameRefClass {
ReindexDatabase,
ReindexSystem,
AttachPartition,
NamedArgParameter,
}

fn is_special_fn(kind: SyntaxKind) -> bool {
matches!(
kind,
SyntaxKind::EXTRACT_FN
| SyntaxKind::JSON_EXISTS_FN
| SyntaxKind::JSON_ARRAY_FN
| SyntaxKind::JSON_OBJECT_FN
| SyntaxKind::JSON_OBJECT_AGG_FN
| SyntaxKind::JSON_ARRAY_AGG_FN
| SyntaxKind::JSON_QUERY_FN
| SyntaxKind::JSON_SCALAR_FN
| SyntaxKind::JSON_SERIALIZE_FN
| SyntaxKind::JSON_VALUE_FN
| SyntaxKind::JSON_FN
| SyntaxKind::SUBSTRING_FN
| SyntaxKind::POSITION_FN
| SyntaxKind::OVERLAY_FN
| SyntaxKind::TRIM_FN
| SyntaxKind::XML_ROOT_FN
| SyntaxKind::XML_SERIALIZE_FN
| SyntaxKind::XML_ELEMENT_FN
| SyntaxKind::XML_FOREST_FN
| SyntaxKind::XML_EXISTS_FN
| SyntaxKind::XML_PARSE_FN
| SyntaxKind::XML_PI_FN
| SyntaxKind::SOME_FN
| SyntaxKind::ANY_FN
| SyntaxKind::ALL_FN
| SyntaxKind::EXISTS_FN
)
}

pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass> {
Expand All @@ -106,6 +142,7 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass>
let mut in_using_clause = false;
let mut in_returning_clause = false;
let mut in_when_clause = false;
let mut in_special_sql_fn = false;

// TODO: can we combine this if and the one that follows?
if let Some(parent) = name_ref.syntax().parent()
Expand Down Expand Up @@ -530,6 +567,12 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass>
if ast::PartitionOf::can_cast(ancestor.kind()) {
return Some(NameRefClass::PartitionOfTable);
}
if is_special_fn(ancestor.kind()) {
in_special_sql_fn = true;
}
if ast::NamedArg::can_cast(ancestor.kind()) {
return Some(NameRefClass::NamedArgParameter);
}
if ast::ArgList::can_cast(ancestor.kind()) {
in_arg_list = true;
}
Expand All @@ -546,7 +589,7 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass>
in_from_clause = true;
}
if ast::Select::can_cast(ancestor.kind()) {
if in_call_expr && !in_arg_list {
if in_call_expr && !in_arg_list && !in_special_sql_fn {
return Some(NameRefClass::SelectFunctionCall);
}
if in_from_clause && !in_on_clause {
Expand Down Expand Up @@ -743,3 +786,17 @@ pub(crate) fn classify_name(name: &ast::Name) -> Option<NameClass> {

None
}

#[test]
fn special_function() {
for kind in (0..SyntaxKind::__LAST as u16)
.map(SyntaxKind::from)
.filter(|kind| format!("{:?}", kind).ends_with("_FN"))
{
assert!(
is_special_fn(kind),
"unhandled special function kind: {:?}. Please update is_special_fn",
kind
)
}
}
128 changes: 128 additions & 0 deletions crates/squawk_ide/src/goto_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6693,4 +6693,132 @@ returning x, u.y$0;
╰╴ ─ 1. source
");
}

#[test]
fn goto_overlay_with_cte_column() {
assert_snapshot!(goto("
with t as (
select '1' a, '2' b, 3 start
)
select overlay(a placing b$0 from start) from t;
"), @r"
╭▸
3 │ select '1' a, '2' b, 3 start
│ ─ 2. destination
4 │ )
5 │ select overlay(a placing b from start) from t;
╰╴ ─ 1. source
");
}

#[test]
fn goto_overlay_with_cte_column_first_arg() {
assert_snapshot!(goto("
with t as (
select '1' a, '2' b, 3 start
)
select overlay(a$0 placing b from start) from t;
"), @r"
╭▸
3 │ select '1' a, '2' b, 3 start
│ ─ 2. destination
4 │ )
5 │ select overlay(a placing b from start) from t;
╰╴ ─ 1. source
");
}

#[test]
fn goto_overlay_with_cte_column_from_arg() {
assert_snapshot!(goto("
with t as (
select '1' a, '2' b, 3 start
)
select overlay(a placing b from start$0) from t;
"), @r"
╭▸
3 │ select '1' a, '2' b, 3 start
│ ───── 2. destination
4 │ )
5 │ select overlay(a placing b from start) from t;
╰╴ ─ 1. source
");
}

#[test]
fn goto_named_arg_to_param() {
assert_snapshot!(goto("
create function foo(bar_param int) returns int as 'select 1' language sql;
select foo(bar_param$0 := 5);
"), @r"
╭▸
2 │ create function foo(bar_param int) returns int as 'select 1' language sql;
│ ───────── 2. destination
3 │ select foo(bar_param := 5);
╰╴ ─ 1. source
");
}

#[test]
fn goto_named_arg_schema_qualified() {
assert_snapshot!(goto("
create schema s;
create function s.foo(my_param int) returns int as 'select 1' language sql;
select s.foo(my_param$0 := 10);
"), @r"
╭▸
3 │ create function s.foo(my_param int) returns int as 'select 1' language sql;
│ ──────── 2. destination
4 │ select s.foo(my_param := 10);
╰╴ ─ 1. source
");
}

#[test]
fn goto_named_arg_multiple_params() {
assert_snapshot!(goto("
create function foo(a int, b int, c int) returns int as 'select 1' language sql;
select foo(b$0 := 2, a := 1);
"), @r"
╭▸
2 │ create function foo(a int, b int, c int) returns int as 'select 1' language sql;
│ ─ 2. destination
3 │ select foo(b := 2, a := 1);
╰╴ ─ 1. source
");
}

#[test]
fn goto_named_arg_procedure() {
assert_snapshot!(goto("
create procedure proc(param_x int) as 'select 1' language sql;
call proc(param_x$0 := 42);
"), @r"
╭▸
2 │ create procedure proc(param_x int) as 'select 1' language sql;
│ ─────── 2. destination
3 │ call proc(param_x := 42);
╰╴ ─ 1. source
");
}

#[test]
fn goto_named_arg_not_found_unnamed_param() {
goto_not_found(
"
create function foo(int) returns int as 'select 1' language sql;
select foo(bar$0 := 5);
",
);
}

#[test]
fn goto_named_arg_not_found_wrong_name() {
goto_not_found(
"
create function foo(correct_param int) returns int as 'select 1' language sql;
select foo(wrong_param$0 := 5);
",
);
}
}
113 changes: 112 additions & 1 deletion crates/squawk_ide/src/hover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use crate::classify::{NameClass, NameRefClass, classify_name, classify_name_ref}
use crate::column_name::ColumnName;
use crate::offsets::token_from_offset;
use crate::resolve;
use crate::{binder, symbols::Name};
use crate::{
binder,
symbols::{Name, Schema},
};
use rowan::TextSize;
use squawk_syntax::SyntaxNode;
use squawk_syntax::{
Expand Down Expand Up @@ -160,6 +163,9 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option<String> {
| NameRefClass::ReindexSchema => {
return hover_schema(root, &name_ref, &binder);
}
NameRefClass::NamedArgParameter => {
return hover_named_arg_parameter(root, &name_ref, &binder);
}
}
}

Expand Down Expand Up @@ -998,6 +1004,71 @@ fn hover_function(
format_create_function(&create_function, binder)
}

fn hover_named_arg_parameter(
root: &SyntaxNode,
name_ref: &ast::NameRef,
binder: &binder::Binder,
) -> Option<String> {
let param_ptr = resolve::resolve_name_ref(binder, root, name_ref)?
.into_iter()
.next()?;
let param_name_node = param_ptr.to_node(root);
let param = param_name_node.ancestors().find_map(ast::Param::cast)?;
let param_name = param.name().map(|name| Name::from_node(&name))?;
let param_type = param.ty().map(|ty| ty.syntax().text().to_string());

for ancestor in param_name_node.ancestors() {
if let Some(create_function) = ast::CreateFunction::cast(ancestor.clone()) {
let path = create_function.path()?;
let (schema, function_name) = resolve::resolve_function_info(binder, &path)?;
return Some(format_param_hover(
schema,
function_name,
param_name,
param_type,
));
}
if let Some(create_procedure) = ast::CreateProcedure::cast(ancestor.clone()) {
let path = create_procedure.path()?;
let (schema, procedure_name) = resolve::resolve_procedure_info(binder, &path)?;
return Some(format_param_hover(
schema,
procedure_name,
param_name,
param_type,
));
}
if let Some(create_aggregate) = ast::CreateAggregate::cast(ancestor) {
let path = create_aggregate.path()?;
let (schema, aggregate_name) = resolve::resolve_aggregate_info(binder, &path)?;
return Some(format_param_hover(
schema,
aggregate_name,
param_name,
param_type,
));
}
}

None
}

fn format_param_hover(
schema: Schema,
routine_name: String,
param_name: Name,
param_type: Option<String>,
) -> String {
if let Some(param_type) = param_type {
return format!(
"parameter {}.{}.{} {}",
schema, routine_name, param_name, param_type
);
}

format!("parameter {}.{}.{}", schema, routine_name, param_name)
}

fn format_create_function(
create_function: &ast::CreateFunction,
binder: &binder::Binder,
Expand Down Expand Up @@ -1909,6 +1980,46 @@ select add$0(1, 2);
");
}

#[test]
fn hover_on_named_arg_param() {
assert_snapshot!(check_hover("
create function foo(bar_param int) returns int as $$ select 1 $$ language sql;
select foo(bar_param$0 := 5);
"), @r"
hover: parameter public.foo.bar_param int
╭▸
3 │ select foo(bar_param := 5);
╰╴ ─ hover
");
}

#[test]
fn hover_on_named_arg_param_schema_qualified() {
assert_snapshot!(check_hover("
create schema s;
create function s.foo(my_param int) returns int as $$ select 1 $$ language sql;
select s.foo(my_param$0 := 10);
"), @r"
hover: parameter s.foo.my_param int
╭▸
4 │ select s.foo(my_param := 10);
╰╴ ─ hover
");
}

#[test]
fn hover_on_named_arg_param_procedure() {
assert_snapshot!(check_hover("
create procedure proc(param_x int) as 'select 1' language sql;
call proc(param_x$0 := 42);
"), @r"
hover: parameter public.proc.param_x int
╭▸
3 │ call proc(param_x := 42);
╰╴ ─ hover
");
}

#[test]
fn hover_on_function_call_style_column_access() {
assert_snapshot!(check_hover("
Expand Down
Loading
Loading