diff --git a/crates/squawk_ide/src/binder.rs b/crates/squawk_ide/src/binder.rs index 89e0167d..e925bc0b 100644 --- a/crates/squawk_ide/src/binder.rs +++ b/crates/squawk_ide/src/binder.rs @@ -126,6 +126,7 @@ fn bind_create_table(b: &mut Binder, create_table: ast::CreateTable) { b.scopes[root].insert(table_name, table_id); } +// TODO: combine with bind_create_table fn bind_create_foreign_table(b: &mut Binder, create_foreign_table: ast::CreateForeignTable) { let Some(path) = create_foreign_table.path() else { return; @@ -337,6 +338,7 @@ fn bind_create_view(b: &mut Binder, create_view: ast::CreateView) { b.scopes[root].insert(view_name, view_id); } +// TODO: combine with create_view fn bind_create_materialized_view(b: &mut Binder, create_view: ast::CreateMaterializedView) { let Some(path) = create_view.path() else { return; diff --git a/crates/squawk_ide/src/code_actions.rs b/crates/squawk_ide/src/code_actions.rs index 769a1f33..8b1d511b 100644 --- a/crates/squawk_ide/src/code_actions.rs +++ b/crates/squawk_ide/src/code_actions.rs @@ -36,6 +36,8 @@ pub fn code_actions(file: ast::SourceFile, offset: TextSize) -> Option, + file: &ast::SourceFile, + offset: TextSize, +) -> Option<()> { + let token = token_from_offset(file, offset)?; + let cast_expr = token.parent_ancestors().find_map(ast::CastExpr::cast)?; + + if cast_expr.colon_colon().is_some() { + return None; + } + + let expr = cast_expr.expr()?; + let ty = cast_expr.ty()?; + + let expr_text = expr.syntax().text(); + let type_text = ty.syntax().text(); + + let replacement = format!("{}::{}", expr_text, type_text); + + actions.push(CodeAction { + title: "Rewrite cast function".to_owned(), + edits: vec![Edit::replace(cast_expr.syntax().text_range(), replacement)], + kind: ActionKind::RefactorRewrite, + }); + + Some(()) +} + +fn rewrite_double_colon_to_cast( + actions: &mut Vec, + file: &ast::SourceFile, + offset: TextSize, +) -> Option<()> { + let token = token_from_offset(file, offset)?; + let cast_expr = token.parent_ancestors().find_map(ast::CastExpr::cast)?; + + cast_expr.colon_colon()?; + + let expr = cast_expr.expr()?; + let ty = cast_expr.ty()?; + + let expr_text = expr.syntax().text(); + let type_text = ty.syntax().text(); + + let replacement = format!("cast({} as {})", expr_text, type_text); + + actions.push(CodeAction { + title: "Rewrite as cast operator".to_owned(), + edits: vec![Edit::replace(cast_expr.syntax().text_range(), replacement)], + kind: ActionKind::RefactorRewrite, + }); + + Some(()) +} + #[cfg(test)] mod test { use super::*; @@ -1145,4 +1203,126 @@ mod test { "select col_name$0 from t;" )); } + + #[test] + fn rewrite_cast_to_double_colon_simple() { + assert_snapshot!(apply_code_action( + rewrite_cast_to_double_colon, + "select ca$0st(foo as text) from t;"), + @"select foo::text from t;" + ); + } + + #[test] + fn rewrite_cast_to_double_colon_on_column() { + assert_snapshot!(apply_code_action( + rewrite_cast_to_double_colon, + "select cast(col_na$0me as int) from t;"), + @"select col_name::int from t;" + ); + } + + #[test] + fn rewrite_cast_to_double_colon_on_type() { + assert_snapshot!(apply_code_action( + rewrite_cast_to_double_colon, + "select cast(x as bigi$0nt) from t;"), + @"select x::bigint from t;" + ); + } + + #[test] + fn rewrite_cast_to_double_colon_qualified_type() { + assert_snapshot!(apply_code_action( + rewrite_cast_to_double_colon, + "select cast(x as pg_cata$0log.text) from t;"), + @"select x::pg_catalog.text from t;" + ); + } + + #[test] + fn rewrite_cast_to_double_colon_expression() { + assert_snapshot!(apply_code_action( + rewrite_cast_to_double_colon, + "select ca$0st(1 + 2 as bigint) from t;"), + @"select 1 + 2::bigint from t;" + ); + } + + #[test] + fn rewrite_cast_to_double_colon_not_applicable_already_double_colon() { + assert!(code_action_not_applicable( + rewrite_cast_to_double_colon, + "select foo::te$0xt from t;" + )); + } + + #[test] + fn rewrite_cast_to_double_colon_not_applicable_outside_cast() { + assert!(code_action_not_applicable( + rewrite_cast_to_double_colon, + "select fo$0o from t;" + )); + } + + #[test] + fn rewrite_double_colon_to_cast_simple() { + assert_snapshot!(apply_code_action( + rewrite_double_colon_to_cast, + "select foo::te$0xt from t;"), + @"select cast(foo as text) from t;" + ); + } + + #[test] + fn rewrite_double_colon_to_cast_on_column() { + assert_snapshot!(apply_code_action( + rewrite_double_colon_to_cast, + "select col_na$0me::int from t;"), + @"select cast(col_name as int) from t;" + ); + } + + #[test] + fn rewrite_double_colon_to_cast_on_type() { + assert_snapshot!(apply_code_action( + rewrite_double_colon_to_cast, + "select x::bigi$0nt from t;"), + @"select cast(x as bigint) from t;" + ); + } + + #[test] + fn rewrite_double_colon_to_cast_qualified_type() { + assert_snapshot!(apply_code_action( + rewrite_double_colon_to_cast, + "select x::pg_cata$0log.text from t;"), + @"select cast(x as pg_catalog.text) from t;" + ); + } + + #[test] + fn rewrite_double_colon_to_cast_expression() { + assert_snapshot!(apply_code_action( + rewrite_double_colon_to_cast, + "select 1 + 2::bigi$0nt from t;"), + @"select 1 + cast(2 as bigint) from t;" + ); + } + + #[test] + fn rewrite_double_colon_to_cast_not_applicable_already_cast() { + assert!(code_action_not_applicable( + rewrite_double_colon_to_cast, + "select ca$0st(foo as text) from t;" + )); + } + + #[test] + fn rewrite_double_colon_to_cast_not_applicable_outside_cast() { + assert!(code_action_not_applicable( + rewrite_double_colon_to_cast, + "select fo$0o from t;" + )); + } } diff --git a/crates/squawk_ide/src/document_symbols.rs b/crates/squawk_ide/src/document_symbols.rs index 4b25d32c..b7fbcb21 100644 --- a/crates/squawk_ide/src/document_symbols.rs +++ b/crates/squawk_ide/src/document_symbols.rs @@ -3,8 +3,8 @@ use squawk_syntax::ast::{self, AstNode}; use crate::binder::{self, extract_string_literal}; use crate::resolve::{ - resolve_aggregate_info, resolve_function_info, resolve_materialized_view_info, - resolve_procedure_info, resolve_table_info, resolve_type_info, resolve_view_info, + resolve_aggregate_info, resolve_function_info, resolve_procedure_info, resolve_table_info, + resolve_type_info, resolve_view_info, }; #[derive(Debug)] @@ -122,23 +122,13 @@ fn create_cte_table_symbol(with_table: ast::WithTable) -> Option let full_range = with_table.syntax().text_range(); let focus_range = name_node.syntax().text_range(); - let mut children = vec![]; - if let Some(column_list) = with_table.column_list() { - for column in column_list.columns() { - if let Some(column_symbol) = create_column_symbol(column) { - children.push(column_symbol); - } - } - } - - Some(DocumentSymbol { + symbols_from_column_list( + with_table.column_list(), name, - detail: None, - kind: DocumentSymbolKind::Table, full_range, focus_range, - children, - }) + DocumentSymbolKind::Table, + ) } fn create_schema_symbol(create_schema: ast::CreateSchema) -> Option { @@ -221,8 +211,24 @@ fn create_view_symbol( let full_range = create_view.syntax().text_range(); let focus_range = name_node.syntax().text_range(); + symbols_from_column_list( + create_view.column_list(), + name, + full_range, + focus_range, + DocumentSymbolKind::View, + ) +} + +fn symbols_from_column_list( + column_list: Option, + name: String, + full_range: TextRange, + focus_range: TextRange, + kind: DocumentSymbolKind, +) -> Option { let mut children = vec![]; - if let Some(column_list) = create_view.column_list() { + if let Some(column_list) = column_list { for column in column_list.columns() { if let Some(column_symbol) = create_column_symbol(column) { children.push(column_symbol); @@ -233,13 +239,14 @@ fn create_view_symbol( Some(DocumentSymbol { name, detail: None, - kind: DocumentSymbolKind::View, + kind, full_range, focus_range, children, }) } +// TODO: combine with create_view_symbol fn create_materialized_view_symbol( binder: &binder::Binder, create_view: ast::CreateMaterializedView, @@ -248,29 +255,19 @@ fn create_materialized_view_symbol( let segment = path.segment()?; let name_node = segment.name()?; - let (schema, view_name) = resolve_materialized_view_info(binder, &path)?; + let (schema, view_name) = resolve_view_info(binder, &path)?; let name = format!("{}.{}", schema.0, view_name); let full_range = create_view.syntax().text_range(); let focus_range = name_node.syntax().text_range(); - let mut children = vec![]; - if let Some(column_list) = create_view.column_list() { - for column in column_list.columns() { - if let Some(column_symbol) = create_column_symbol(column) { - children.push(column_symbol); - } - } - } - - Some(DocumentSymbol { + symbols_from_column_list( + create_view.column_list(), name, - detail: None, - kind: DocumentSymbolKind::MaterializedView, full_range, focus_range, - children, - }) + DocumentSymbolKind::MaterializedView, + ) } fn create_function_symbol( diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index bb7a551f..5985d40e 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -2625,13 +2625,6 @@ pub(crate) fn resolve_view_info(binder: &Binder, path: &ast::Path) -> Option<(Sc resolve_symbol_info(binder, path, SymbolKind::View) } -pub(crate) fn resolve_materialized_view_info( - binder: &Binder, - path: &ast::Path, -) -> Option<(Schema, String)> { - resolve_symbol_info(binder, path, SymbolKind::View) -} - pub(crate) fn resolve_sequence_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> { resolve_symbol_info(binder, path, SymbolKind::Sequence) }