diff --git a/crates/squawk_ide/src/binder.rs b/crates/squawk_ide/src/binder.rs index 8c0e8e96..073fb854 100644 --- a/crates/squawk_ide/src/binder.rs +++ b/crates/squawk_ide/src/binder.rs @@ -58,7 +58,8 @@ fn bind_create_table(b: &mut Binder, create_table: ast::CreateTable) { return; }; let name_ptr = path_to_ptr(&path); - let schema = schema_name(&path); + let is_temp = create_table.temp_token().is_some() || create_table.temporary_token().is_some(); + let schema = schema_name(&path, is_temp); let table_id = b.symbols.alloc(Symbol { kind: SymbolKind::Table, @@ -95,12 +96,11 @@ fn path_to_ptr(path: &ast::Path) -> SyntaxNodePtr { SyntaxNodePtr::new(path.syntax()) } -fn schema_name(path: &ast::Path) -> Schema { - let Some(qualifier) = path.qualifier() else { - return Schema::Public; - }; - let Some(segment) = qualifier.segment() else { - return Schema::Public; +fn schema_name(path: &ast::Path, is_temp: bool) -> Schema { + let default_schema = if is_temp { "pg_temp" } else { "public" }; + + let Some(segment) = path.qualifier().and_then(|q| q.segment()) else { + return Schema::new(default_schema); }; let schema_name = if let Some(name) = segment.name() { @@ -108,8 +108,8 @@ fn schema_name(path: &ast::Path) -> Schema { } else if let Some(name_ref) = segment.name_ref() { Name::new(name_ref.syntax().text().to_string()) } else { - return Schema::Public; + return Schema::new(default_schema); }; - Schema::from_name(schema_name) + Schema(schema_name) } diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index d5b2ae00..afe83d47 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -259,8 +259,81 @@ create table t(); drop table foo.t$0; ", ); + } + + #[test] + fn goto_drop_temp_table() { + assert_snapshot!(goto(" +create temp table t(); +drop table t$0; +"), @r" + ╭▸ + 2 │ create temp table t(); + │ ─ 2. destination + 3 │ drop table t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_temporary_table() { + assert_snapshot!(goto(" +create temporary table t(); +drop table t$0; +"), @r" + ╭▸ + 2 │ create temporary table t(); + │ ─ 2. destination + 3 │ drop table t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_temp_table_with_pg_temp_schema() { + assert_snapshot!(goto(" +create temp table t(); +drop table pg_temp.t$0; +"), @r" + ╭▸ + 2 │ create temp table t(); + │ ─ 2. destination + 3 │ drop table pg_temp.t; + ╰╴ ─ 1. source + "); + } - // todo: temp tables + #[test] + fn goto_drop_temp_table_shadows_public() { + // temp tables shadow public tables when no schema is specified + assert_snapshot!(goto(" +create table t(); +create temp table t(); +drop table t$0; +"), @r" + ╭▸ + 3 │ create temp table t(); + │ ─ 2. destination + 4 │ drop table t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_public_table_when_temp_exists() { + // can still access public table explicitly + assert_snapshot!(goto(" +create table t(); +create temp table t(); +drop table public.t$0; +"), @r" + ╭▸ + 2 │ create table t(); + │ ─ 2. destination + 3 │ create temp table t(); + 4 │ drop table public.t; + ╰╴ ─ 1. source + "); } #[test] diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 18be8280..ab576269 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -34,16 +34,30 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option None } -fn resolve_table(binder: &Binder, table_name: &Name, schema: &Schema) -> Option { - let symbol_id = binder.scopes[binder.root_scope()] - .get(table_name)? - .iter() - .copied() - .find(|id| { +fn resolve_table( + binder: &Binder, + table_name: &Name, + schema: &Option, +) -> Option { + let symbols = binder.scopes[binder.root_scope()].get(table_name)?; + + if let Some(schema) = schema { + let symbol_id = symbols.iter().copied().find(|id| { let symbol = &binder.symbols[*id]; symbol.kind == SymbolKind::Table && &symbol.schema == schema })?; - Some(binder.symbols[symbol_id].ptr) + return Some(binder.symbols[symbol_id].ptr); + } else { + for search_schema in [Schema::new("pg_temp"), Schema::new("public")] { + if let Some(symbol_id) = symbols.iter().copied().find(|id| { + let symbol = &binder.symbols[*id]; + symbol.kind == SymbolKind::Table && symbol.schema == search_schema + }) { + return Some(binder.symbols[symbol_id].ptr); + } + } + } + None } fn find_containing_path(name_ref: &ast::NameRef) -> Option { @@ -61,15 +75,9 @@ fn extract_table_name(path: &ast::Path) -> Option { Some(Name::new(name_ref.syntax().text().to_string())) } -fn extract_schema_name(path: &ast::Path) -> Schema { - let Some(qualifier) = path.qualifier() else { - return Schema::Public; - }; - let Some(segment) = qualifier.segment() else { - return Schema::Public; - }; - let Some(name_ref) = segment.name_ref() else { - return Schema::Public; - }; - Schema::from_name(Name::new(name_ref.syntax().text().to_string())) +fn extract_schema_name(path: &ast::Path) -> Option { + path.qualifier() + .and_then(|q| q.segment()) + .and_then(|s| s.name_ref()) + .map(|name_ref| Schema(Name::new(name_ref.syntax().text().to_string()))) } diff --git a/crates/squawk_ide/src/symbols.rs b/crates/squawk_ide/src/symbols.rs index ac408556..4287597b 100644 --- a/crates/squawk_ide/src/symbols.rs +++ b/crates/squawk_ide/src/symbols.rs @@ -6,18 +6,11 @@ use squawk_syntax::SyntaxNodePtr; pub(crate) struct Name(pub(crate) SmolStr); #[derive(Clone, Debug, PartialEq, Eq)] -pub(crate) enum Schema { - Public, - Custom(Name), -} +pub(crate) struct Schema(pub(crate) Name); impl Schema { - pub(crate) fn from_name(name: Name) -> Self { - if name == Name::new("public") { - Schema::Public - } else { - Schema::Custom(name) - } + pub(crate) fn new(name: impl Into) -> Self { + Schema(Name::new(name)) } }