From e4c30b27fa1e61aca7a8276986d2a0a6fead5df7 Mon Sep 17 00:00:00 2001 From: Steve Dignam Date: Sun, 11 Jan 2026 16:58:09 -0700 Subject: [PATCH] ide: refactor binder to be more private --- crates/squawk_ide/src/binder.rs | 104 ++++++++++++++++++++++-- crates/squawk_ide/src/resolve.rs | 134 +++---------------------------- 2 files changed, 111 insertions(+), 127 deletions(-) diff --git a/crates/squawk_ide/src/binder.rs b/crates/squawk_ide/src/binder.rs index e925bc0b..b658e048 100644 --- a/crates/squawk_ide/src/binder.rs +++ b/crates/squawk_ide/src/binder.rs @@ -7,15 +7,16 @@ use squawk_syntax::{SyntaxNodePtr, ast, ast::AstNode}; use crate::scope::{Scope, ScopeId}; use crate::symbols::{Name, Schema, Symbol, SymbolKind}; -pub(crate) struct SearchPathChange { +struct SearchPathChange { position: TextSize, search_path: Vec, } pub(crate) struct Binder { - pub(crate) scopes: Arena, - pub(crate) symbols: Arena, - pub(crate) search_path_changes: Vec, + // TODO: doesn't seem like we need this with our resolve setup + scopes: Arena, + symbols: Arena, + search_path_changes: Vec, } impl Binder { @@ -32,7 +33,7 @@ impl Binder { } } - pub(crate) fn root_scope(&self) -> ScopeId { + fn root_scope(&self) -> ScopeId { self.scopes .iter() .next() @@ -40,6 +41,99 @@ impl Binder { .expect("root scope must exist") } + pub(crate) fn lookup(&self, name: &Name, kind: SymbolKind) -> Option { + let symbols = self.scopes[self.root_scope()].get(name)?; + let symbol_id = symbols.iter().copied().find(|id| { + let symbol = &self.symbols[*id]; + symbol.kind == kind + })?; + Some(self.symbols[symbol_id].ptr) + } + + pub(crate) fn lookup_with( + &self, + name: &Name, + kind: SymbolKind, + position: TextSize, + schema: &Option, + ) -> Option { + let symbols = self.scopes[self.root_scope()].get(name)?; + + let search_paths = match schema { + Some(s) => std::slice::from_ref(s), + None => self.search_path_at(position), + }; + + for search_schema in search_paths { + if let Some(symbol_id) = symbols.iter().copied().find(|id| { + let symbol = &self.symbols[*id]; + symbol.kind == kind && symbol.schema.as_ref() == Some(search_schema) + }) { + return Some(self.symbols[symbol_id].ptr); + } + } + None + } + + pub(crate) fn lookup_with_params( + &self, + name: &Name, + kind: SymbolKind, + position: TextSize, + schema: &Option, + params: Option<&[Name]>, + ) -> Option { + let symbols = self.scopes[self.root_scope()].get(name)?; + + let search_paths = match schema { + Some(s) => std::slice::from_ref(s), + None => self.search_path_at(position), + }; + + for search_schema in search_paths { + if let Some(symbol_id) = symbols.iter().copied().find(|id| { + let symbol = &self.symbols[*id]; + let params_match = match (&symbol.params, params) { + (Some(sym_params), Some(req_params)) => sym_params.as_slice() == req_params, + (None, None) => true, + (_, None) => true, + _ => false, + }; + symbol.kind == kind && symbol.schema.as_ref() == Some(search_schema) && params_match + }) { + return Some(self.symbols[symbol_id].ptr); + } + } + None + } + + pub(crate) fn lookup_info( + &self, + name_str: String, + schema: &Option, + kind: SymbolKind, + position: TextSize, + ) -> Option<(Schema, String)> { + let name_normalized = Name::from_string(name_str.clone()); + let symbols = self.scopes[self.root_scope()].get(&name_normalized)?; + + let search_paths = match schema { + Some(schema_name) => &[Schema::new(schema_name)], + None => self.search_path_at(position), + }; + + for search_schema in search_paths { + if let Some(symbol_id) = symbols.iter().copied().find(|id| { + let symbol = &self.symbols[*id]; + symbol.kind == kind && symbol.schema.as_ref() == Some(search_schema) + }) { + let symbol = &self.symbols[symbol_id]; + return Some((symbol.schema.clone()?, name_str)); + } + } + None + } + fn current_search_path(&self) -> &[Schema] { &self .search_path_changes diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 5985d40e..c6b0372e 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -421,7 +421,7 @@ fn resolve_table_name_ptr( schema: &Option, position: TextSize, ) -> Option { - resolve_for_kind(binder, table_name, schema, position, SymbolKind::Table) + binder.lookup_with(table_name, SymbolKind::Table, position, schema) } fn resolve_index_name_ptr( @@ -430,7 +430,7 @@ fn resolve_index_name_ptr( schema: &Option, position: TextSize, ) -> Option { - resolve_for_kind(binder, index_name, schema, position, SymbolKind::Index) + binder.lookup_with(index_name, SymbolKind::Index, position, schema) } fn resolve_type_name_ptr( @@ -439,7 +439,7 @@ fn resolve_type_name_ptr( schema: &Option, position: TextSize, ) -> Option { - resolve_for_kind(binder, type_name, schema, position, SymbolKind::Type) + binder.lookup_with(type_name, SymbolKind::Type, position, schema) } fn resolve_view_name_ptr( @@ -448,7 +448,7 @@ fn resolve_view_name_ptr( schema: &Option, position: TextSize, ) -> Option { - resolve_for_kind(binder, view_name, schema, position, SymbolKind::View) + binder.lookup_with(view_name, SymbolKind::View, position, schema) } fn resolve_sequence_name_ptr( @@ -457,69 +457,19 @@ fn resolve_sequence_name_ptr( schema: &Option, position: TextSize, ) -> Option { - resolve_for_kind( - binder, - sequence_name, - schema, - position, - SymbolKind::Sequence, - ) + binder.lookup_with(sequence_name, SymbolKind::Sequence, position, schema) } fn resolve_tablespace_name_ptr(binder: &Binder, tablespace_name: &Name) -> Option { - let symbols = binder.scopes[binder.root_scope()].get(tablespace_name)?; - let symbol_id = symbols.iter().copied().find(|id| { - let symbol = &binder.symbols[*id]; - symbol.kind == SymbolKind::Tablespace - })?; - Some(binder.symbols[symbol_id].ptr) + binder.lookup(tablespace_name, SymbolKind::Tablespace) } fn resolve_database_name_ptr(binder: &Binder, database_name: &Name) -> Option { - let symbols = binder.scopes[binder.root_scope()].get(database_name)?; - let symbol_id = symbols.iter().copied().find(|id| { - let symbol = &binder.symbols[*id]; - symbol.kind == SymbolKind::Database - })?; - Some(binder.symbols[symbol_id].ptr) + binder.lookup(database_name, SymbolKind::Database) } fn resolve_server_name_ptr(binder: &Binder, server_name: &Name) -> Option { - let symbols = binder.scopes[binder.root_scope()].get(server_name)?; - let symbol_id = symbols.iter().copied().find(|id| { - let symbol = &binder.symbols[*id]; - symbol.kind == SymbolKind::Server - })?; - Some(binder.symbols[symbol_id].ptr) -} - -fn resolve_for_kind( - binder: &Binder, - name: &Name, - schema: &Option, - position: TextSize, - kind: SymbolKind, -) -> Option { - let symbols = binder.scopes[binder.root_scope()].get(name)?; - - if let Some(schema) = schema { - let symbol_id = symbols.iter().copied().find(|id| { - let symbol = &binder.symbols[*id]; - symbol.kind == kind && symbol.schema.as_ref() == Some(schema) - })?; - return Some(binder.symbols[symbol_id].ptr); - } else { - let search_path = binder.search_path_at(position); - for search_schema in search_path { - if let Some(symbol_id) = symbols.iter().copied().find(|id| { - let symbol = &binder.symbols[*id]; - symbol.kind == kind && symbol.schema.as_ref() == Some(search_schema) - }) { - return Some(binder.symbols[symbol_id].ptr); - } - } - } - None + binder.lookup(server_name, SymbolKind::Server) } fn resolve_for_kind_with_params( @@ -530,38 +480,7 @@ fn resolve_for_kind_with_params( position: TextSize, kind: SymbolKind, ) -> Option { - let symbols = binder.scopes[binder.root_scope()].get(name)?; - - if let Some(schema) = schema { - let symbol_id = symbols.iter().copied().find(|id| { - let symbol = &binder.symbols[*id]; - let params_match = match (&symbol.params, params) { - (Some(sym_params), Some(req_params)) => sym_params.as_slice() == req_params, - (None, None) => true, - (_, None) => true, - _ => false, - }; - symbol.kind == kind && symbol.schema.as_ref() == Some(schema) && params_match - })?; - return Some(binder.symbols[symbol_id].ptr); - } else { - let search_path = binder.search_path_at(position); - for search_schema in search_path { - if let Some(symbol_id) = symbols.iter().copied().find(|id| { - let symbol = &binder.symbols[*id]; - let params_match = match (&symbol.params, params) { - (Some(sym_params), Some(req_params)) => sym_params.as_slice() == req_params, - (None, None) => true, - (_, None) => true, - _ => false, - }; - symbol.kind == kind && symbol.schema.as_ref() == Some(search_schema) && params_match - }) { - return Some(binder.symbols[symbol_id].ptr); - } - } - } - None + binder.lookup_with_params(name, kind, position, schema, params) } fn resolve_function( @@ -616,12 +535,7 @@ fn resolve_procedure( } fn resolve_schema(binder: &Binder, schema_name: &Name) -> Option { - let symbols = binder.scopes[binder.root_scope()].get(schema_name)?; - let symbol_id = symbols.iter().copied().find(|id| { - let symbol = &binder.symbols[*id]; - symbol.kind == SymbolKind::Schema - })?; - Some(binder.symbols[symbol_id].ptr) + binder.lookup(schema_name, SymbolKind::Schema) } fn resolve_create_index_column_ptr( @@ -2819,32 +2733,8 @@ fn resolve_symbol_info( ) -> Option<(Schema, String)> { let name_str = extract_table_name_from_path(path)?; let schema = extract_schema_from_path(path); - - let name_normalized = Name::from_string(name_str.clone()); - let symbols = binder.scopes[binder.root_scope()].get(&name_normalized)?; - - if let Some(schema_name) = schema { - let schema_normalized = Schema::new(schema_name); - let symbol_id = symbols.iter().copied().find(|id| { - let symbol = &binder.symbols[*id]; - symbol.kind == kind && symbol.schema.as_ref() == Some(&schema_normalized) - })?; - let symbol = &binder.symbols[symbol_id]; - return Some((symbol.schema.clone()?, name_str)); - } else { - let position = path.syntax().text_range().start(); - let search_path = binder.search_path_at(position); - for search_schema in search_path { - if let Some(symbol_id) = symbols.iter().copied().find(|id| { - let symbol = &binder.symbols[*id]; - symbol.kind == kind && symbol.schema.as_ref() == Some(search_schema) - }) { - let symbol = &binder.symbols[symbol_id]; - return Some((symbol.schema.clone()?, name_str)); - } - } - } - None + let position = path.syntax().text_range().start(); + binder.lookup_info(name_str, &schema, kind, position) } fn collect_column_names_from_column_list(column_list: &ast::ColumnList) -> Vec {