Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions crates/squawk_ide/src/binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ fn bind_create_table(b: &mut Binder, create_table: ast::CreateTable) {
return;
};
let name_ptr = path_to_ptr(&path);
let schema = schema_name(&path);
let is_temp = create_table.temp_token().is_some() || create_table.temporary_token().is_some();
let schema = schema_name(&path, is_temp);

let table_id = b.symbols.alloc(Symbol {
kind: SymbolKind::Table,
Expand Down Expand Up @@ -95,21 +96,20 @@ fn path_to_ptr(path: &ast::Path) -> SyntaxNodePtr {
SyntaxNodePtr::new(path.syntax())
}

fn schema_name(path: &ast::Path) -> Schema {
let Some(qualifier) = path.qualifier() else {
return Schema::Public;
};
let Some(segment) = qualifier.segment() else {
return Schema::Public;
fn schema_name(path: &ast::Path, is_temp: bool) -> Schema {
let default_schema = if is_temp { "pg_temp" } else { "public" };

let Some(segment) = path.qualifier().and_then(|q| q.segment()) else {
return Schema::new(default_schema);
};

let schema_name = if let Some(name) = segment.name() {
Name::new(name.syntax().text().to_string())
} else if let Some(name_ref) = segment.name_ref() {
Name::new(name_ref.syntax().text().to_string())
} else {
return Schema::Public;
return Schema::new(default_schema);
};

Schema::from_name(schema_name)
Schema(schema_name)
}
75 changes: 74 additions & 1 deletion crates/squawk_ide/src/goto_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,81 @@ create table t();
drop table foo.t$0;
",
);
}

#[test]
fn goto_drop_temp_table() {
assert_snapshot!(goto("
create temp table t();
drop table t$0;
"), @r"
╭▸
2 │ create temp table t();
│ ─ 2. destination
3 │ drop table t;
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_temporary_table() {
assert_snapshot!(goto("
create temporary table t();
drop table t$0;
"), @r"
╭▸
2 │ create temporary table t();
│ ─ 2. destination
3 │ drop table t;
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_temp_table_with_pg_temp_schema() {
assert_snapshot!(goto("
create temp table t();
drop table pg_temp.t$0;
"), @r"
╭▸
2 │ create temp table t();
│ ─ 2. destination
3 │ drop table pg_temp.t;
╰╴ ─ 1. source
");
}

// todo: temp tables
#[test]
fn goto_drop_temp_table_shadows_public() {
// temp tables shadow public tables when no schema is specified
assert_snapshot!(goto("
create table t();
create temp table t();
drop table t$0;
"), @r"
╭▸
3 │ create temp table t();
│ ─ 2. destination
4 │ drop table t;
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_public_table_when_temp_exists() {
// can still access public table explicitly
assert_snapshot!(goto("
create table t();
create temp table t();
drop table public.t$0;
"), @r"
╭▸
2 │ create table t();
│ ─ 2. destination
3 │ create temp table t();
4 │ drop table public.t;
╰╴ ─ 1. source
");
}

#[test]
Expand Down
44 changes: 26 additions & 18 deletions crates/squawk_ide/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,30 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option<NameRefContext>
None
}

fn resolve_table(binder: &Binder, table_name: &Name, schema: &Schema) -> Option<SyntaxNodePtr> {
let symbol_id = binder.scopes[binder.root_scope()]
.get(table_name)?
.iter()
.copied()
.find(|id| {
fn resolve_table(
binder: &Binder,
table_name: &Name,
schema: &Option<Schema>,
) -> Option<SyntaxNodePtr> {
let symbols = binder.scopes[binder.root_scope()].get(table_name)?;

if let Some(schema) = schema {
let symbol_id = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Table && &symbol.schema == schema
})?;
Some(binder.symbols[symbol_id].ptr)
return Some(binder.symbols[symbol_id].ptr);
} else {
for search_schema in [Schema::new("pg_temp"), Schema::new("public")] {
if let Some(symbol_id) = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Table && symbol.schema == search_schema
}) {
return Some(binder.symbols[symbol_id].ptr);
}
}
}
None
}

fn find_containing_path(name_ref: &ast::NameRef) -> Option<ast::Path> {
Expand All @@ -61,15 +75,9 @@ fn extract_table_name(path: &ast::Path) -> Option<Name> {
Some(Name::new(name_ref.syntax().text().to_string()))
}

fn extract_schema_name(path: &ast::Path) -> Schema {
let Some(qualifier) = path.qualifier() else {
return Schema::Public;
};
let Some(segment) = qualifier.segment() else {
return Schema::Public;
};
let Some(name_ref) = segment.name_ref() else {
return Schema::Public;
};
Schema::from_name(Name::new(name_ref.syntax().text().to_string()))
fn extract_schema_name(path: &ast::Path) -> Option<Schema> {
path.qualifier()
.and_then(|q| q.segment())
.and_then(|s| s.name_ref())
.map(|name_ref| Schema(Name::new(name_ref.syntax().text().to_string())))
}
13 changes: 3 additions & 10 deletions crates/squawk_ide/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,11 @@ use squawk_syntax::SyntaxNodePtr;
pub(crate) struct Name(pub(crate) SmolStr);

#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum Schema {
Public,
Custom(Name),
}
pub(crate) struct Schema(pub(crate) Name);

impl Schema {
pub(crate) fn from_name(name: Name) -> Self {
if name == Name::new("public") {
Schema::Public
} else {
Schema::Custom(name)
}
pub(crate) fn new(name: impl Into<SmolStr>) -> Self {
Schema(Name::new(name))
}
}

Expand Down
Loading