Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/squawk_ide/src/binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ fn bind_create_function(b: &mut Binder, create_function: ast::CreateFunction) {

let name_ptr = path_to_ptr(&path);

let Some(schema) = b.current_search_path().first().cloned() else {
let Some(schema) = schema_name(b, &path, false) else {
return;
};

Expand Down
138 changes: 138 additions & 0 deletions crates/squawk_ide/src/hover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option<String> {
if is_index_ref(&name_ref) {
return hover_index(file, &name_ref, &binder);
}

if is_function_ref(&name_ref) {
return hover_function(file, &name_ref, &binder);
}
}

if let Some(name) = ast::Name::cast(parent) {
Expand All @@ -38,6 +42,14 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option<String> {
if let Some(create_index) = name.syntax().ancestors().find_map(ast::CreateIndex::cast) {
return format_create_index(&create_index, &binder);
}

if let Some(create_function) = name
.syntax()
.ancestors()
.find_map(ast::CreateFunction::cast)
{
return format_create_function(&create_function, &binder);
}
}

None
Expand Down Expand Up @@ -278,6 +290,68 @@ fn is_index_ref(name_ref: &ast::NameRef) -> bool {
false
}

fn is_function_ref(name_ref: &ast::NameRef) -> bool {
for ancestor in name_ref.syntax().ancestors() {
if ast::DropFunction::can_cast(ancestor.kind()) {
return true;
}
}
false
}

fn hover_function(
file: &ast::SourceFile,
name_ref: &ast::NameRef,
binder: &binder::Binder,
) -> Option<String> {
let function_ptr = resolve::resolve_name_ref(binder, name_ref)?;

let root = file.syntax();
let function_name_node = function_ptr.to_node(root);

let create_function = function_name_node
.ancestors()
.find_map(ast::CreateFunction::cast)?;

format_create_function(&create_function, binder)
}

fn format_create_function(
create_function: &ast::CreateFunction,
binder: &binder::Binder,
) -> Option<String> {
let path = create_function.path()?;
let segment = path.segment()?;
let name = segment.name()?;
let function_name = name.syntax().text().to_string();

let schema = if let Some(qualifier) = path.qualifier() {
qualifier.syntax().text().to_string()
} else {
function_schema(create_function, binder)?
};

let param_list = create_function.param_list()?;
let params = param_list.syntax().text().to_string();

let ret_type = create_function.ret_type()?;
let return_type = ret_type.syntax().text().to_string();

Some(format!(
"function {}.{}{} {}",
schema, function_name, params, return_type
))
}

fn function_schema(
create_function: &ast::CreateFunction,
binder: &binder::Binder,
) -> Option<String> {
let position = create_function.syntax().text_range().start();
let search_path = binder.search_path_at(position);
search_path.first().map(|s| s.to_string())
}

#[cfg(test)]
mod test {
use crate::hover::hover;
Expand Down Expand Up @@ -718,4 +792,68 @@ drop index idx_x$0;
╰╴ ─ hover
");
}

#[test]
fn hover_on_drop_function() {
assert_snapshot!(check_hover("
create function foo() returns int as $$ select 1 $$ language sql;
drop function foo$0();
"), @r"
hover: function public.foo() returns int
╭▸
3 │ drop function foo();
╰╴ ─ hover
");
}

#[test]
fn hover_on_drop_function_with_schema() {
assert_snapshot!(check_hover("
create function myschema.foo() returns int as $$ select 1 $$ language sql;
drop function myschema.foo$0();
"), @r"
hover: function myschema.foo() returns int
╭▸
3 │ drop function myschema.foo();
╰╴ ─ hover
");
}

#[test]
fn hover_on_create_function_definition() {
assert_snapshot!(check_hover("
create function foo$0() returns int as $$ select 1 $$ language sql;
"), @r"
hover: function public.foo() returns int
╭▸
2 │ create function foo() returns int as $$ select 1 $$ language sql;
╰╴ ─ hover
");
}

#[test]
fn hover_on_create_function_with_explicit_schema() {
assert_snapshot!(check_hover("
create function myschema.foo$0() returns int as $$ select 1 $$ language sql;
"), @r"
hover: function myschema.foo() returns int
╭▸
2 │ create function myschema.foo() returns int as $$ select 1 $$ language sql;
╰╴ ─ hover
");
}

#[test]
fn hover_on_drop_function_with_search_path() {
assert_snapshot!(check_hover(r#"
set search_path to myschema;
create function foo() returns int as $$ select 1 $$ language sql;
drop function foo$0();
"#), @r"
hover: function myschema.foo() returns int
╭▸
4 │ drop function foo();
╰╴ ─ hover
");
}
}
64 changes: 21 additions & 43 deletions crates/squawk_ide/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,26 +84,7 @@ fn resolve_table(
schema: &Option<Schema>,
position: TextSize,
) -> Option<SyntaxNodePtr> {
let symbols = binder.scopes[binder.root_scope()].get(table_name)?;

if let Some(schema) = schema {
let symbol_id = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Table && &symbol.schema == schema
})?;
return Some(binder.symbols[symbol_id].ptr);
} else {
let search_path = binder.search_path_at(position);
for search_schema in search_path {
if let Some(symbol_id) = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Table && &symbol.schema == search_schema
}) {
return Some(binder.symbols[symbol_id].ptr);
}
}
}
None
resolve_for_kind(binder, table_name, schema, position, SymbolKind::Table)
}

fn resolve_index(
Expand All @@ -112,20 +93,30 @@ fn resolve_index(
schema: &Option<Schema>,
position: TextSize,
) -> Option<SyntaxNodePtr> {
let symbols = binder.scopes[binder.root_scope()].get(index_name)?;
resolve_for_kind(binder, index_name, schema, position, SymbolKind::Index)
}

fn resolve_for_kind(
binder: &Binder,
name: &Name,
schema: &Option<Schema>,
position: TextSize,
kind: SymbolKind,
) -> Option<SyntaxNodePtr> {
let symbols = binder.scopes[binder.root_scope()].get(name)?;

if let Some(schema) = schema {
let symbol_id = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Index && &symbol.schema == schema
symbol.kind == kind && &symbol.schema == schema
})?;
return Some(binder.symbols[symbol_id].ptr);
} else {
let search_path = binder.search_path_at(position);
for search_schema in search_path {
if let Some(symbol_id) = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Index && &symbol.schema == search_schema
symbol.kind == kind && &symbol.schema == search_schema
}) {
return Some(binder.symbols[symbol_id].ptr);
}
Expand All @@ -140,26 +131,13 @@ fn resolve_function(
schema: &Option<Schema>,
position: TextSize,
) -> Option<SyntaxNodePtr> {
let symbols = binder.scopes[binder.root_scope()].get(function_name)?;

if let Some(schema) = schema {
let symbol_id = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Function && &symbol.schema == schema
})?;
return Some(binder.symbols[symbol_id].ptr);
} else {
let search_path = binder.search_path_at(position);
for search_schema in search_path {
if let Some(symbol_id) = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Function && &symbol.schema == search_schema
}) {
return Some(binder.symbols[symbol_id].ptr);
}
}
}
None
resolve_for_kind(
binder,
function_name,
schema,
position,
SymbolKind::Function,
)
}

fn resolve_create_index_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
Expand Down
Loading