From b8a4ba4216095e08e6b0e260d47f373bd529d900 Mon Sep 17 00:00:00 2001 From: Steve Dignam Date: Mon, 26 Jan 2026 22:58:50 -0500 Subject: [PATCH] ide: create table as --- crates/squawk_ide/src/binder.rs | 36 +++ crates/squawk_ide/src/document_symbols.rs | 50 ++++ crates/squawk_ide/src/goto_definition.rs | 28 +++ crates/squawk_ide/src/hover.rs | 28 +++ crates/squawk_ide/src/resolve.rs | 291 +++++++++++----------- 5 files changed, 293 insertions(+), 140 deletions(-) diff --git a/crates/squawk_ide/src/binder.rs b/crates/squawk_ide/src/binder.rs index 4ee11834..7c23cafc 100644 --- a/crates/squawk_ide/src/binder.rs +++ b/crates/squawk_ide/src/binder.rs @@ -249,6 +249,7 @@ fn bind_file(b: &mut Binder, file: &ast::SourceFile) { fn bind_stmt(b: &mut Binder, stmt: ast::Stmt) { match stmt { ast::Stmt::CreateTable(create_table) => bind_create_table(b, create_table), + ast::Stmt::CreateTableAs(create_table_as) => bind_create_table_as(b, create_table_as), ast::Stmt::CreateForeignTable(create_foreign_table) => { bind_create_table(b, create_foreign_table) } @@ -318,6 +319,41 @@ fn bind_create_table(b: &mut Binder, create_table: impl ast::HasCreateTable) { b.scopes[root].insert(table_name, type_id); } +fn bind_create_table_as(b: &mut Binder, create_table_as: ast::CreateTableAs) { + let Some(path) = create_table_as.path() else { + return; + }; + let Some(table_name) = item_name(&path) else { + return; + }; + let name_ptr = path_to_ptr(&path); + let is_temp = + create_table_as.temp_token().is_some() || create_table_as.temporary_token().is_some(); + let Some(schema) = schema_name(b, &path, is_temp) else { + return; + }; + + let table_id = b.symbols.alloc(Symbol { + kind: SymbolKind::Table, + ptr: name_ptr, + schema: Some(schema.clone()), + params: None, + table: None, + }); + + let type_id = b.symbols.alloc(Symbol { + kind: SymbolKind::Type, + ptr: name_ptr, + schema: Some(schema), + params: None, + table: None, + }); + + let root = b.root_scope(); + b.scopes[root].insert(table_name.clone(), table_id); + b.scopes[root].insert(table_name, type_id); +} + fn bind_create_index(b: &mut Binder, create_index: ast::CreateIndex) { let Some(name) = create_index.name() else { return; diff --git a/crates/squawk_ide/src/document_symbols.rs b/crates/squawk_ide/src/document_symbols.rs index 365126ef..974fe3df 100644 --- a/crates/squawk_ide/src/document_symbols.rs +++ b/crates/squawk_ide/src/document_symbols.rs @@ -65,6 +65,11 @@ pub fn document_symbols(file: &ast::SourceFile) -> Vec { symbols.push(symbol); } } + ast::Stmt::CreateTableAs(create_table_as) => { + if let Some(symbol) = create_table_as_symbol(&binder, create_table_as) { + symbols.push(symbol); + } + } ast::Stmt::CreateForeignTable(create_foreign_table) => { if let Some(symbol) = create_table_symbol(&binder, create_foreign_table) { symbols.push(symbol); @@ -289,6 +294,36 @@ fn create_table_symbol( }) } +fn create_table_as_symbol( + binder: &binder::Binder, + create_table_as: ast::CreateTableAs, +) -> Option { + let path = create_table_as.path()?; + let segment = path.segment()?; + let name_node = if let Some(name) = segment.name() { + name.syntax().clone() + } else { + return None; + }; + + let (schema, table_name) = resolve_table_info(binder, &path)?; + let name = format!("{}.{}", schema.0, table_name); + + let full_range = create_table_as.syntax().text_range(); + let focus_range = name_node.text_range(); + + Some(DocumentSymbol { + name, + detail: None, + kind: DocumentSymbolKind::Table, + full_range, + focus_range, + // TODO: infer the column names, we need the same for views without + // explicit column lists + children: vec![], + }) +} + fn create_view_symbol( binder: &binder::Binder, create_view: ast::CreateView, @@ -959,6 +994,21 @@ create table users ( "); } + #[test] + fn create_table_as() { + assert_snapshot!(symbols(" +create table t as select 1 a; +"), @r" + info: table: public.t + ╭▸ + 2 │ create table t as select 1 a; + │ ┬────────────┯────────────── + │ │ │ + │ │ focus range + ╰╴full range + "); + } + #[test] fn create_schema() { assert_snapshot!(symbols(" diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index dad33262..4d4b2fc8 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -2247,6 +2247,34 @@ select v.a$0 from v; "); } + #[test] + fn goto_create_table_as_column() { + assert_snapshot!(goto(" +create table t as select 1 a; +select a$0 from t; +"), @r" + ╭▸ + 2 │ create table t as select 1 a; + │ ─ 2. destination + 3 │ select a from t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_select_from_create_table_as() { + assert_snapshot!(goto(" +create table t as select 1 a; +select a from t$0; +"), @r" + ╭▸ + 2 │ create table t as select 1 a; + │ ─ 2. destination + 3 │ select a from t; + ╰╴ ─ 1. source + "); + } + #[test] fn goto_view_with_explicit_column_list() { assert_snapshot!(goto(" diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index 3bd2cb95..0d3fc7c3 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -385,6 +385,21 @@ fn format_hover_for_column_node( { return format_view_column(&create_view, column_name, binder); } + + if let Some(create_table_as) = ast::CreateTableAs::cast(a.clone()) { + let column_name = if let Some(name) = ast::Name::cast(column_name_node.clone()) { + Name::from_node(&name) + } else { + continue; + }; + let path = create_table_as.path()?; + let (schema, table_name) = resolve::resolve_table_info(binder, &path)?; + return Some(ColumnHover::schema_table_column( + &schema.to_string(), + &table_name, + &column_name.to_string(), + )); + } } let column = column_name_node.ancestors().find_map(ast::Column::cast)?; @@ -3614,6 +3629,19 @@ select a$0, b from v; "); } + #[test] + fn hover_on_select_column_from_create_table_as() { + assert_snapshot!(check_hover(" +create table t as select 1 a; +select a$0 from t; +"), @r" + hover: column public.t.a + ╭▸ + 3 │ select a from t; + ╰╴ ─ hover + "); + } + #[test] fn hover_on_select_from_view_with_schema() { assert_snapshot!(check_hover(" diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index df147341..274b9611 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -2,7 +2,7 @@ use rowan::TextSize; use smallvec::{SmallVec, smallvec}; use squawk_syntax::{ SyntaxNode, SyntaxNodePtr, - ast::{self, AstNode}, + ast::{self, AstNode, SelectVariant}, }; use crate::binder::Binder; @@ -37,8 +37,7 @@ pub(crate) fn resolve_name_ref_ptrs( | NameRefClass::MergeTable | NameRefClass::AttachPartition => { let path = find_containing_path(name_ref)?; - let table_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (table_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) } @@ -107,8 +106,7 @@ pub(crate) fn resolve_name_ref_ptrs( } NameRefClass::DropIndex | NameRefClass::ReindexIndex => { let path = find_containing_path(name_ref)?; - let index_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (index_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); resolve_index_name_ptr(binder, &index_name, &schema, position).map(|ptr| smallvec![ptr]) } @@ -130,8 +128,7 @@ pub(crate) fn resolve_name_ref_ptrs( (type_name, schema) } else { let path = find_containing_path(name_ref)?; - let type_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (type_name, schema) = extract_table_schema_from_path(&path)?; (type_name, schema) }; let position = name_ref.syntax().text_range().start(); @@ -141,15 +138,13 @@ pub(crate) fn resolve_name_ref_ptrs( | NameRefClass::DropMaterializedView | NameRefClass::RefreshMaterializedView => { let path = find_containing_path(name_ref)?; - let view_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (view_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); resolve_view_name_ptr(binder, &view_name, &schema, position).map(|ptr| smallvec![ptr]) } NameRefClass::DropSequence => { let path = find_containing_path(name_ref)?; - let sequence_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (sequence_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); resolve_sequence_name_ptr(binder, &sequence_name, &schema, position) .map(|ptr| smallvec![ptr]) @@ -160,8 +155,7 @@ pub(crate) fn resolve_name_ref_ptrs( .ancestors() .find_map(ast::DropTrigger::cast)?; let path = drop_trigger.path()?; - let trigger_name = extract_table_name(&path)?; - let mut schema = extract_schema_name(&path); + let (trigger_name, mut schema) = extract_table_schema_from_path(&path)?; let on_table_path = drop_trigger .on_table() .and_then(|on_table| on_table.path())?; @@ -183,8 +177,7 @@ pub(crate) fn resolve_name_ref_ptrs( })?; let policy_name = Name::from_node(&policy_name?); let on_table_path = on_table.and_then(|on_table| on_table.path())?; - let schema = extract_schema_name(&on_table_path); - let table_name = extract_table_name(&on_table_path)?; + let (table_name, schema) = extract_table_schema_from_path(&on_table_path)?; let position = name_ref.syntax().text_range().start(); resolve_policy_name_ptr(binder, &policy_name, &schema, position, table_name) .map(|ptr| smallvec![ptr]) @@ -234,8 +227,7 @@ pub(crate) fn resolve_name_ref_ptrs( } NameRefClass::ForeignKeyTable => { let path = name_ref.syntax().ancestors().find_map(ast::Path::cast)?; - let table_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (table_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) } @@ -285,7 +277,8 @@ pub(crate) fn resolve_name_ref_ptrs( None } })?; - resolve_policy_column_ptr(binder, root, &on_table_path, name_ref) + let column_name = Name::from_node(name_ref); + resolve_column_for_path(binder, root, &on_table_path, column_name) .map(|ptr| smallvec![ptr]) } NameRefClass::PolicyQualifiedColumnTable => { @@ -298,8 +291,7 @@ pub(crate) fn resolve_name_ref_ptrs( None } })?; - let table_name = extract_table_name(&on_table_path)?; - let schema = extract_schema_name(&on_table_path); + let (table_name, schema) = extract_table_schema_from_path(&on_table_path)?; let position = name_ref.syntax().text_range().start(); resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) } @@ -309,8 +301,7 @@ pub(crate) fn resolve_name_ref_ptrs( .ancestors() .find_map(ast::LikeClause::cast)?; let path = like_clause.path()?; - let table_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (table_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); resolve_table_name_ptr(binder, &table_name, &schema, position).map(|ptr| smallvec![ptr]) } @@ -320,8 +311,7 @@ pub(crate) fn resolve_name_ref_ptrs( .ancestors() .find_map(ast::FunctionSig::cast)?; let path = function_sig.path()?; - let function_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (function_name, schema) = extract_table_schema_from_path(&path)?; let params = extract_param_signature(&function_sig); let position = name_ref.syntax().text_range().start(); resolve_function(binder, &function_name, &schema, params.as_deref(), position) @@ -333,8 +323,7 @@ pub(crate) fn resolve_name_ref_ptrs( .ancestors() .find_map(ast::Aggregate::cast)?; let path = aggregate.path()?; - let aggregate_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (aggregate_name, schema) = extract_table_schema_from_path(&path)?; let params = extract_param_signature(&aggregate); let position = name_ref.syntax().text_range().start(); resolve_aggregate( @@ -352,8 +341,7 @@ pub(crate) fn resolve_name_ref_ptrs( .ancestors() .find_map(ast::FunctionSig::cast)?; let path = function_sig.path()?; - let procedure_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (procedure_name, schema) = extract_table_schema_from_path(&path)?; let params = extract_param_signature(&function_sig); let position = name_ref.syntax().text_range().start(); resolve_procedure( @@ -371,8 +359,7 @@ pub(crate) fn resolve_name_ref_ptrs( .ancestors() .find_map(ast::FunctionSig::cast)?; let path = function_sig.path()?; - let routine_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (routine_name, schema) = extract_table_schema_from_path(&path)?; let params = extract_param_signature(&function_sig); let position = name_ref.syntax().text_range().start(); @@ -394,8 +381,7 @@ pub(crate) fn resolve_name_ref_ptrs( NameRefClass::CallProcedure => { let call = name_ref.syntax().ancestors().find_map(ast::Call::cast)?; let path = call.path()?; - let procedure_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (procedure_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); resolve_procedure(binder, &procedure_name, &schema, None, position) .map(|ptr| smallvec![ptr]) @@ -441,8 +427,7 @@ pub(crate) fn resolve_name_ref_ptrs( .ancestors() .find_map(ast::PathType::cast)?; let path = path_type.path()?; - let function_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (function_name, schema) = extract_table_schema_from_path(&path)?; let position = name_ref.syntax().text_range().start(); resolve_function(binder, &function_name, &schema, None, position) .map(|ptr| smallvec![ptr]) @@ -618,9 +603,7 @@ fn type_name_and_schema_from_type(ty: &ast::Type) -> Option<(Name, Option Some((Name::from_string("interval"), None)), ast::Type::PathType(path_type) => { let path = path_type.path()?; - let type_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); - Some((type_name, schema)) + extract_table_schema_from_path(&path) } ast::Type::ExprType(expr_type) => { let expr = expr_type.expr()?; @@ -839,48 +822,10 @@ fn resolve_column_for_path( path: &ast::Path, column_name: Name, ) -> Option { - let table_name = extract_table_name(path)?; - let schema = extract_schema_name(path); + let (table_name, schema) = extract_table_schema_from_path(path)?; let position = path.syntax().text_range().start(); - let resolved = resolve_table_name(binder, root, &table_name, &schema, position)?; - match resolved { - ResolvedTableName::View(create_view) => { - find_column_in_create_view(&create_view, &column_name) - } - ResolvedTableName::Table(create_table_like) => { - find_column_in_create_table(binder, root, &create_table_like, &column_name) - } - } -} - -fn resolve_policy_column_ptr( - binder: &Binder, - root: &SyntaxNode, - on_table_path: &ast::Path, - column_name_ref: &ast::NameRef, -) -> Option { - let column_name = Name::from_node(column_name_ref); - let (table_name, schema) = extract_table_schema_from_path(on_table_path)?; - let position = column_name_ref.syntax().text_range().start(); - - let resolved = resolve_table_name(binder, root, &table_name, &schema, position)?; - match resolved { - ResolvedTableName::View(create_view) => { - if let Some(ptr) = find_column_in_create_view(&create_view, &column_name) { - return Some(ptr); - } - resolve_function(binder, &column_name, &schema, None, position) - } - ResolvedTableName::Table(create_table_like) => { - if let Some(ptr) = - find_column_in_create_table(binder, root, &create_table_like, &column_name) - { - return Some(ptr); - } - resolve_function(binder, &column_name, &schema, None, position) - } - } + resolve_column_for_table(binder, root, &table_name, &schema, &column_name, position) } fn resolve_insert_column_ptr( @@ -1030,7 +975,6 @@ fn extract_table_schema_from_path(path: &ast::Path) -> Option<(Name, Option, + column_name: &Name, + position: TextSize, +) -> Option { + let resolved = resolve_table_name(binder, root, table_name, schema, position)?; match resolved { ResolvedTableName::View(create_view) => { - if let Some(ptr) = find_column_in_create_view(&create_view, &column_name) { + if let Some(ptr) = find_column_in_create_view(&create_view, column_name) { return Some(ptr); } - return resolve_function(binder, &column_name, &schema, None, position); + return resolve_function(binder, column_name, schema, None, position); } ResolvedTableName::Table(create_table_like) => { // 1. Try to find a matching column (columns take precedence) if let Some(ptr) = - find_column_in_create_table(binder, root, &create_table_like, &column_name) + find_column_in_create_table(binder, root, &create_table_like, column_name) { return Some(ptr); } // 2. No column found, check for field-style function call // e.g., select t.b from t where b is a function that takes t as an argument - return resolve_function(binder, &column_name, &schema, None, position); + return resolve_function(binder, column_name, schema, None, position); + } + ResolvedTableName::TableAs(create_table_as) => { + if let Some(ptr) = find_column_in_create_table_as(&create_table_as, column_name) { + return Some(ptr); + } + return resolve_function(binder, column_name, schema, None, position); } } } @@ -1249,6 +1210,7 @@ fn resolve_select_qualified_column_ptr( enum ResolvedTableName { View(ast::CreateView), Table(ast::CreateTableLike), + TableAs(ast::CreateTableAs), } fn resolve_table_name( binder: &Binder, @@ -1266,6 +1228,12 @@ fn resolve_table_name( { return Some(Table(create_table)); } + if let Some(create_table_as) = table_name_node + .ancestors() + .find_map(ast::CreateTableAs::cast) + { + return Some(TableAs(create_table_as)); + } } if let Some(view_name_ptr) = resolve_view_name_ptr(binder, table_name, schema, position) { @@ -1415,8 +1383,8 @@ fn resolve_column_from_table_or_view_impl( && let Some(partition_of) = create_table_node.partition_of() && let Some(parent_path) = partition_of.path() { - let parent_table_name = extract_table_name(&parent_path)?; - let parent_schema = extract_schema_name(&parent_path); + let (parent_table_name, parent_schema) = + extract_table_schema_from_path(&parent_path)?; return resolve_column_from_table_or_view_impl( binder, root, @@ -1438,6 +1406,19 @@ fn resolve_column_from_table_or_view_impl( return Some(table_name_ptr); } } + + if let Some(create_table_as) = table_name_node + .ancestors() + .find_map(ast::CreateTableAs::cast) + { + if let Some(ptr) = find_column_in_create_table_as(&create_table_as, column_name) { + return Some(ptr); + } + + if column_name == table_name { + return Some(table_name_ptr); + } + } } // ditto as above but with view @@ -1840,21 +1821,27 @@ fn find_column_in_create_table_impl( } ast::TableArg::LikeClause(like_clause) => { let path = like_clause.path()?; - let table_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (table_name, schema) = extract_table_schema_from_path(&path)?; let position = path.syntax().text_range().start(); - if let Some(ResolvedTableName::Table(source_table)) = + if let Some(resolved) = resolve_table_name(binder, root, &table_name, &schema, position) - && let Some(ptr) = find_column_in_create_table_impl( - binder, - root, - &source_table, - column_name, - depth + 1, - ) { - return Some(ptr); + if let Some(ptr) = match resolved { + ResolvedTableName::Table(source_table) => find_column_in_create_table_impl( + binder, + root, + &source_table, + column_name, + depth + 1, + ), + ResolvedTableName::TableAs(create_table_as) => { + find_column_in_create_table_as(&create_table_as, column_name) + } + ResolvedTableName::View(_) => None, + } { + return Some(ptr); + } } } ast::TableArg::TableConstraint(_) => (), @@ -1863,21 +1850,26 @@ fn find_column_in_create_table_impl( if let Some(inherits) = create_table.inherits() { for path in inherits.paths() { - let table_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (table_name, schema) = extract_table_schema_from_path(&path)?; let position = path.syntax().text_range().start(); - if let Some(ResolvedTableName::Table(parent_table)) = - resolve_table_name(binder, root, &table_name, &schema, position) - && let Some(ptr) = find_column_in_create_table_impl( - binder, - root, - &parent_table, - column_name, - depth + 1, - ) + if let Some(resolved) = resolve_table_name(binder, root, &table_name, &schema, position) { - return Some(ptr); + if let Some(ptr) = match resolved { + ResolvedTableName::Table(parent_table) => find_column_in_create_table_impl( + binder, + root, + &parent_table, + column_name, + depth + 1, + ), + ResolvedTableName::TableAs(create_table_as) => { + find_column_in_create_table_as(&create_table_as, column_name) + } + ResolvedTableName::View(_) => None, + } { + return Some(ptr); + } } } } @@ -1903,14 +1895,7 @@ fn find_column_in_create_view( 0 }; - let select = match create_view.query()? { - ast::SelectVariant::Select(s) => s, - ast::SelectVariant::ParenSelect(ps) => match ps.select()? { - ast::SelectVariant::Select(s) => s, - _ => return None, - }, - _ => return None, - }; + let select = resolve_select_clause(create_view.query()?)?; let select_clause = select.select_clause()?; let target_list = select_clause.target_list()?; @@ -1932,6 +1917,25 @@ fn find_column_in_create_view( None } +fn find_column_in_create_table_as( + create_table_as: &ast::CreateTableAs, + column_name: &Name, +) -> Option { + let select = resolve_select_clause(create_table_as.query()?)?; + + for target in select.select_clause()?.target_list()?.targets() { + if let Some((col_name, node)) = ColumnName::from_target(target.clone()) { + if let Some(col_name_str) = col_name.to_string() + && Name::from_string(col_name_str) == *column_name + { + return Some(SyntaxNodePtr::new(&node)); + } + } + } + + None +} + fn resolve_cte_table(name_ref: &ast::NameRef, cte_name: &Name) -> Option { let with_clause = find_parent_with_clause(name_ref.syntax())?; for with_table in with_clause.with_tables() { @@ -1963,8 +1967,7 @@ fn find_parent_with_clause(node: &SyntaxNode) -> Option { } fn count_columns_for_path(binder: &Binder, root: &SyntaxNode, path: &ast::Path) -> Option { - let table_name = extract_table_name(path)?; - let schema = extract_schema_name(path); + let (table_name, schema) = extract_table_schema_from_path(path)?; let position = path.syntax().text_range().start(); count_columns_for_table_name(binder, root, &table_name, &schema, position) @@ -1994,6 +1997,17 @@ fn count_columns_for_table_name( } return Some(count); } + + if let Some(create_table_as) = table_name_node + .ancestors() + .find_map(ast::CreateTableAs::cast) + { + let select = resolve_select_clause(create_table_as.query()?)?; + + if let Some(target_list) = select.select_clause().and_then(|c| c.target_list()) { + return Some(target_list.targets().count()); + } + } } if let Some(view_name_ptr) = resolve_view_name_ptr(binder, table_name, schema, position) { @@ -2004,14 +2018,7 @@ fn count_columns_for_table_name( return Some(column_list.columns().count()); } - let select = match create_view.query()? { - ast::SelectVariant::Select(s) => s, - ast::SelectVariant::ParenSelect(ps) => match ps.select()? { - ast::SelectVariant::Select(s) => s, - _ => return None, - }, - _ => return None, - }; + let select = resolve_select_clause(create_view.query()?)?; if let Some(target_list) = select.select_clause().and_then(|c| c.target_list()) { // This is not quite right if there's a `*` in the view definition. @@ -2024,6 +2031,17 @@ fn count_columns_for_table_name( None } +fn resolve_select_clause(query: SelectVariant) -> Option { + match query { + ast::SelectVariant::Select(s) => Some(s), + ast::SelectVariant::ParenSelect(ps) => match ps.select()? { + ast::SelectVariant::Select(s) => Some(s), + _ => return None, + }, + _ => return None, + } +} + fn resolve_cte_column( binder: &Binder, root: &SyntaxNode, @@ -2914,8 +2932,7 @@ pub(crate) fn resolve_insert_create_table( insert: &ast::Insert, ) -> Option { let path = insert.path()?; - let table_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (table_name, schema) = extract_table_schema_from_path(&path)?; let position = insert.syntax().text_range().start(); let table_name_ptr = resolve_table_name_ptr(binder, &table_name, &schema, position)?; @@ -2984,8 +3001,7 @@ fn collect_table_columns_impl( if let Some(inherits) = create_table.inherits() { for path in inherits.paths() { - if let Some(table_name) = extract_table_name(&path) { - let schema = extract_schema_name(&path); + if let Some((table_name, schema)) = extract_table_schema_from_path(&path) { let position = path.syntax().text_range().start(); if let Some(ResolvedTableName::Table(parent_table)) = resolve_table_name(binder, root, &table_name, &schema, position) @@ -3006,9 +3022,8 @@ fn collect_table_columns_impl( } ast::TableArg::LikeClause(like_clause) => { if let Some(path) = like_clause.path() - && let Some(table_name) = extract_table_name(&path) + && let Some((table_name, schema)) = extract_table_schema_from_path(&path) { - let schema = extract_schema_name(&path); let position = path.syntax().text_range().start(); if let Some(ResolvedTableName::Table(source_table)) = resolve_table_name(binder, root, &table_name, &schema, position) @@ -3318,10 +3333,9 @@ fn collect_select_variant_columns_with_types( let Some(path) = table.relation_name().and_then(|r| r.path()) else { return vec![]; }; - let Some(table_name) = extract_table_name(&path) else { + let Some((table_name, schema)) = extract_table_schema_from_path(&path) else { return vec![]; }; - let schema = extract_schema_name(&path); let position = table.syntax().text_range().start(); let Some(table_ptr) = binder.lookup_with(&table_name, SymbolKind::Table, position, &schema) @@ -3477,8 +3491,7 @@ fn extract_type_name_and_schema(ty: &ast::Type) -> Option<(Name, Option) match ty { ast::Type::PathType(path_type) => { let path = path_type.path()?; - let type_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (type_name, schema) = extract_table_schema_from_path(&path)?; Some((type_name, schema)) } ast::Type::ExprType(expr_type) => { @@ -3531,12 +3544,11 @@ fn resolve_table_in_returning_clause( returning_clause: Option, ) -> Option { let table_name = Name::from_node(table_name_ref); - let stmt_table_name = extract_table_name(path)?; + let (stmt_table_name, schema) = extract_table_schema_from_path(path)?; let matched = match_table_in_returning_clause(&table_name, &stmt_table_name, alias, returning_clause)?; - let schema = extract_schema_name(path); let position = table_name_ref.syntax().text_range().start(); match matched { @@ -3677,8 +3689,7 @@ fn find_func_call_from_named_arg(name_ref: &ast::NameRef) -> Option<(Name, Optio }; } else if let Some(call) = ast::Call::cast(a) { let path = call.path()?; - let function_name = extract_table_name(&path)?; - let schema = extract_schema_name(&path); + let (function_name, schema) = extract_table_schema_from_path(&path)?; return Some((function_name, schema)); } }