Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 121 additions & 17 deletions crates/squawk_ide/src/classify.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::symbols::Name;
use squawk_syntax::{
SyntaxKind,
SyntaxKind, SyntaxNode,
ast::{self, AstNode},
};

#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub(crate) enum NameRefClass {
Aggregate,
AlterColumn,
Expand Down Expand Up @@ -91,7 +91,7 @@ fn is_special_fn(kind: SyntaxKind) -> bool {
)
}

pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass> {
pub(crate) fn classify_name_ref(node: &SyntaxNode) -> Option<NameRefClass> {
let mut in_function_name = false;
let mut in_arg_list = false;
let mut in_column_list = false;
Expand All @@ -111,13 +111,13 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass>
let mut in_conflict_target = false;

// TODO: can we combine this if and the one that follows?
if let Some(parent) = name_ref.syntax().parent()
if let Some(parent) = node.parent()
&& let Some(field_expr) = ast::FieldExpr::cast(parent.clone())
&& let Some(base) = field_expr.base()
&& let ast::Expr::NameRef(base_name_ref) = base
// check that the name_ref we're looking at in the field expr is the
// base name_ref, i.e., the schema, rather than the item
&& base_name_ref.syntax() == name_ref.syntax()
&& base_name_ref.syntax() == node
{
let is_function_call = field_expr
.syntax()
Expand Down Expand Up @@ -217,13 +217,13 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass>
return Some(NameRefClass::Schema);
}

if let Some(parent) = name_ref.syntax().parent()
if let Some(parent) = node.parent()
&& let Some(field_expr) = ast::FieldExpr::cast(parent.clone())
&& field_expr
.field()
// we're at the field in a FieldExpr, i.e., foo.bar
// ^^^
.is_some_and(|field_name_ref| field_name_ref.syntax() == name_ref.syntax())
.is_some_and(|field_name_ref| field_name_ref.syntax() == node)
// we're not inside a call expr
&& field_expr.star_token().is_none()
&& field_expr
Expand Down Expand Up @@ -296,7 +296,7 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass>
}
}

if let Some(parent) = name_ref.syntax().parent()
if let Some(parent) = node.parent()
&& let Some(inner_path) = ast::PathSegment::cast(parent)
.and_then(|p| p.syntax().parent().and_then(ast::Path::cast))
&& let Some(outer_path) = inner_path
Expand All @@ -309,7 +309,7 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass>
}

// Check for function/procedure reference in CREATE OPERATOR before the type check
for ancestor in name_ref.syntax().ancestors() {
for ancestor in node.ancestors() {
if let Some(attr_option) = ast::AttributeOption::cast(ancestor.clone())
&& let Some(name) = attr_option.name()
{
Expand All @@ -327,7 +327,7 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass>
}

let mut in_type = false;
for ancestor in name_ref.syntax().ancestors() {
for ancestor in node.ancestors() {
if ast::PathType::can_cast(ancestor.kind()) || ast::ExprType::can_cast(ancestor.kind()) {
in_type = true;
}
Expand Down Expand Up @@ -462,15 +462,15 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass>
&& to_columns
.syntax()
.text_range()
.contains_range(name_ref.syntax().text_range())
.contains_range(node.text_range())
{
return Some(NameRefClass::ForeignKeyColumn);
}
if let Some(from_columns) = foreign_key.from_columns()
&& from_columns
.syntax()
.text_range()
.contains_range(name_ref.syntax().text_range())
.contains_range(node.text_range())
{
return Some(NameRefClass::ConstraintColumn);
}
Expand All @@ -487,15 +487,12 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option<NameRefClass>
&& column_ref
.syntax()
.text_range()
.contains_range(name_ref.syntax().text_range())
.contains_range(node.text_range())
{
return Some(NameRefClass::ForeignKeyColumn);
}
if let Some(path) = references_constraint.table()
&& path
.syntax()
.text_range()
.contains_range(name_ref.syntax().text_range())
&& path.syntax().text_range().contains_range(node.text_range())
{
return Some(NameRefClass::ForeignKeyTable);
}
Expand Down Expand Up @@ -826,6 +823,113 @@ pub(crate) fn classify_name(name: &ast::Name) -> Option<NameClass> {
None
}

pub(crate) fn classify_def_node(def_node: &SyntaxNode) -> Option<NameRefClass> {
let mut in_column = false;
let mut in_column_list = false;
for ancestor in def_node.ancestors() {
if ast::Column::can_cast(ancestor.kind()) {
in_column = true;
}
if ast::ColumnList::can_cast(ancestor.kind()) {
in_column_list = true;
}
if ast::Param::can_cast(ancestor.kind()) {
return Some(NameRefClass::NamedArgParameter);
}
if ast::CreateTableLike::can_cast(ancestor.kind()) {
if in_column {
return Some(NameRefClass::SelectColumn);
}
return Some(NameRefClass::Table);
}
if ast::CreateType::can_cast(ancestor.kind()) {
if in_column {
return Some(NameRefClass::CompositeTypeField);
}
return Some(NameRefClass::Type);
}
if ast::CreateFunction::can_cast(ancestor.kind()) {
return Some(NameRefClass::Function);
}
if ast::CreateProcedure::can_cast(ancestor.kind()) {
return Some(NameRefClass::Procedure);
}
if ast::WithTable::can_cast(ancestor.kind()) {
if in_column_list {
return Some(NameRefClass::SelectColumn);
}
return Some(NameRefClass::Table);
}
if ast::CreateTableAs::can_cast(ancestor.kind()) {
return Some(NameRefClass::Table);
}
if ast::CreateIndex::can_cast(ancestor.kind()) {
return Some(NameRefClass::Index);
}
if ast::CreateSequence::can_cast(ancestor.kind()) {
return Some(NameRefClass::Sequence);
}
if ast::CreateTrigger::can_cast(ancestor.kind()) {
return Some(NameRefClass::Trigger);
}
if ast::CreateEventTrigger::can_cast(ancestor.kind()) {
return Some(NameRefClass::EventTrigger);
}
if ast::CreateTablespace::can_cast(ancestor.kind()) {
return Some(NameRefClass::Tablespace);
}
if ast::CreateDatabase::can_cast(ancestor.kind()) {
return Some(NameRefClass::Database);
}
if ast::CreateServer::can_cast(ancestor.kind()) {
return Some(NameRefClass::Server);
}
if ast::CreateExtension::can_cast(ancestor.kind()) {
return Some(NameRefClass::Extension);
}
if ast::CreateRole::can_cast(ancestor.kind()) {
return Some(NameRefClass::Role);
}
if ast::CreateAggregate::can_cast(ancestor.kind()) {
return Some(NameRefClass::Aggregate);
}
if ast::CreateSchema::can_cast(ancestor.kind()) {
return Some(NameRefClass::Schema);
}
if ast::CreateView::can_cast(ancestor.kind())
|| ast::CreateMaterializedView::can_cast(ancestor.kind())
{
if in_column_list {
return Some(NameRefClass::SelectColumn);
}
return Some(NameRefClass::View);
}
if ast::CreatePolicy::can_cast(ancestor.kind()) {
return Some(NameRefClass::Policy);
}
if ast::Declare::can_cast(ancestor.kind()) {
return Some(NameRefClass::Cursor);
}
if ast::Prepare::can_cast(ancestor.kind()) {
return Some(NameRefClass::PreparedStatement);
}
if ast::Listen::can_cast(ancestor.kind()) {
return Some(NameRefClass::Channel);
}
if ast::Alias::can_cast(ancestor.kind()) {
return Some(NameRefClass::FromTable);
}
if ast::AsName::can_cast(ancestor.kind())
|| ast::ParenSelect::can_cast(ancestor.kind())
|| ast::Values::can_cast(ancestor.kind())
|| ast::Select::can_cast(ancestor.kind())
{
return Some(NameRefClass::SelectColumn);
}
}
None
}

#[test]
fn special_function() {
for kind in (0..SyntaxKind::__LAST as u16)
Expand Down
1 change: 1 addition & 0 deletions crates/squawk_ide/src/code_actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ fn add_schema(
}

let position = token.text_range().start();
// TODO: we should salsa this
let binder = binder::bind(file);
let schema = binder.search_path_at(position).first()?.to_string();
let replacement = format!("{}.", schema);
Expand Down
7 changes: 7 additions & 0 deletions crates/squawk_ide/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ fn select_completions(
select_clause: ast::SelectClause,
token: &SyntaxToken,
) -> Vec<CompletionItem> {
// TODO: we should salsa this
let binder = binder::bind(file);
let mut completions = vec![];
let schema = schema_qualifier_at_token(token);
Expand Down Expand Up @@ -246,6 +247,7 @@ fn select_clauses_completions(select: &ast::Select) -> Vec<CompletionItem> {
}

fn limit_completions(file: &ast::SourceFile, token: &SyntaxToken) -> Vec<CompletionItem> {
// TODO: we should salsa this
let binder = binder::bind(file);
let schema = schema_qualifier_at_token(token);
let position = token.text_range().start();
Expand All @@ -265,6 +267,7 @@ fn limit_completions(file: &ast::SourceFile, token: &SyntaxToken) -> Vec<Complet
}

fn offset_completions(file: &ast::SourceFile, token: &SyntaxToken) -> Vec<CompletionItem> {
// TODO: we should salsa this
let binder = binder::bind(file);
let schema = schema_qualifier_at_token(token);
let position = token.text_range().start();
Expand All @@ -277,6 +280,7 @@ fn select_expr_completions(
select: &ast::Select,
token: &SyntaxToken,
) -> Vec<CompletionItem> {
// TODO: we should salsa this
let binder = binder::bind(file);
let mut completions = vec![];
let schema = schema_qualifier_at_token(token);
Expand Down Expand Up @@ -449,6 +453,7 @@ fn schema_completions(binder: &binder::Binder) -> Vec<CompletionItem> {
}

fn table_completions(file: &ast::SourceFile, token: &SyntaxToken) -> Vec<CompletionItem> {
// TODO: we should salsa this
let binder = binder::bind(file);
let schema = schema_qualifier_at_token(token);
let tables = binder.all_symbols_by_kind(SymbolKind::Table, schema.as_ref());
Expand Down Expand Up @@ -528,6 +533,7 @@ fn delete_expr_completions(
delete: &ast::Delete,
token: &SyntaxToken,
) -> Vec<CompletionItem> {
// TODO: we should salsa this
let binder = binder::bind(file);
let mut completions = vec![];

Expand Down Expand Up @@ -738,6 +744,7 @@ fn file_with_completion_marker(file: &ast::SourceFile, offset: TextSize) -> ast:
let offset = u32::from(offset) as usize;
let offset = offset.min(sql.len());
sql.insert_str(offset, COMPLETION_MARKER);
// TODO: should this be cached
ast::SourceFile::parse(&sql).tree()
}

Expand Down
1 change: 1 addition & 0 deletions crates/squawk_ide/src/document_symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub struct DocumentSymbol {
}

pub fn document_symbols(file: &ast::SourceFile) -> Vec<DocumentSymbol> {
// TODO: we should salsa this
let binder = binder::bind(file);
let mut symbols = vec![];

Expand Down
3 changes: 3 additions & 0 deletions crates/squawk_ide/src/find_references.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ use squawk_syntax::{
};

pub fn find_references(file: &ast::SourceFile, offset: TextSize) -> Vec<Location> {
// TODO: we should salsa this
let current_binder = binder::bind(file);

// TODO: we should salsa this
let builtins_tree = ast::SourceFile::parse(BUILTINS_SQL).tree();
// TODO: we should salsa this
let builtins_binder = binder::bind(&builtins_tree);

let Some((target_file, target_defs)) = find_target_defs(
Expand Down
4 changes: 4 additions & 0 deletions crates/squawk_ide/src/goto_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ pub fn goto_definition(file: &ast::SourceFile, offset: TextSize) -> SmallVec<[Lo
for file_id in [FileId::Current, FileId::Builtins] {
let file = match file_id {
FileId::Current => file,
// TODO: we should salsa this
FileId::Builtins => &ast::SourceFile::parse(BUILTINS_SQL).tree(),
};
// TODO: we should salsa this
let binder_output = binder::bind(file);
let root = file.syntax();
if let Some(ptrs) = resolve::resolve_name_ref_ptrs(&binder_output, root, &name_ref) {
Expand Down Expand Up @@ -89,8 +91,10 @@ pub fn goto_definition(file: &ast::SourceFile, offset: TextSize) -> SmallVec<[Lo
for file_id in [FileId::Current, FileId::Builtins] {
let file = match file_id {
FileId::Current => file,
// TODO: we should salsa this
FileId::Builtins => &ast::SourceFile::parse(BUILTINS_SQL).tree(),
};
// TODO: we should salsa this
let binder_output = binder::bind(file);
let position = token.text_range().start();
if let Some(ptr) = resolve::resolve_type_ptr_from_type(&binder_output, &ty, position) {
Expand Down
Loading
Loading