diff --git a/crates/squawk_ide/src/code_actions.rs b/crates/squawk_ide/src/code_actions.rs index eae5e85f..7d9b96a1 100644 --- a/crates/squawk_ide/src/code_actions.rs +++ b/crates/squawk_ide/src/code_actions.rs @@ -8,6 +8,7 @@ use squawk_syntax::{ use std::iter; use crate::{ + binder, column_name::ColumnName, offsets::token_from_offset, quote::{quote_column_alias, unquote_ident}, @@ -36,6 +37,7 @@ pub fn code_actions(file: ast::SourceFile, offset: TextSize) -> Option, + file: &ast::SourceFile, + offset: TextSize, +) -> Option<()> { + let token = token_from_offset(file, offset)?; + let (range, has_qualifier) = token.parent_ancestors().find_map(|node| { + if let Some(create_table) = ast::CreateTableLike::cast(node.clone()) { + let path = create_table.path()?; + return Some((path.syntax().text_range(), path.qualifier().is_some())); + } + if let Some(create_function) = ast::CreateFunction::cast(node.clone()) { + let path = create_function.path()?; + return Some((path.syntax().text_range(), path.qualifier().is_some())); + } + if let Some(table) = ast::Table::cast(node.clone()) { + let path = table.relation_name()?.path()?; + return Some((path.syntax().text_range(), path.qualifier().is_some())); + } + if let Some(field_expr) = ast::FieldExpr::cast(node.clone()) { + let ast::Expr::NameRef(name_ref) = field_expr.base()? else { + return None; + }; + return Some((name_ref.syntax().text_range(), false)); + } + if let Some(from_item) = ast::FromItem::cast(node) { + let name_ref = from_item.name_ref()?; + return Some((name_ref.syntax().text_range(), false)); + } + None + })?; + + // Already have a schema (or maybe table) set + // + // TODO: we'll need to change this when we want to support things like: + // `select t.c from t; -> select public.t.c from t;` + if has_qualifier { + return None; + } + + if !range.contains(offset) { + return None; + } + let position = token.text_range().start(); + + let binder = binder::bind(file); + let schema = binder.search_path_at(position).first()?.to_string(); + + let replacement = format!("{}.", schema); + + actions.push(CodeAction { + title: "Add schema".to_owned(), + edits: vec![Edit::insert(replacement, position)], + kind: ActionKind::RefactorRewrite, + }); + + Some(()) +} + fn rewrite_cast_to_double_colon( actions: &mut Vec, file: &ast::SourceFile, @@ -981,6 +1042,89 @@ mod test { )); } + #[test] + fn add_schema_simple() { + assert_snapshot!(apply_code_action( + add_schema, + "create table t$0(a text, b int);"), + @"create table public.t(a text, b int);" + ); + } + + #[test] + fn add_schema_create_foreign_table() { + assert_snapshot!(apply_code_action( + add_schema, + "create foreign table t$0(a text, b int) server foo;"), + @"create foreign table public.t(a text, b int) server foo;" + ); + } + + #[test] + fn add_schema_create_function() { + assert_snapshot!(apply_code_action( + add_schema, + "create function f$0() returns int8\n as 'select 1'\n language sql;"), + @"create function public.f() returns int8 + as 'select 1' + language sql;" + ); + } + + #[test] + fn add_schema_table_stmt() { + assert_snapshot!(apply_code_action( + add_schema, + "table t$0;"), + @"table public.t;" + ); + } + + #[test] + fn add_schema_select_from() { + assert_snapshot!(apply_code_action( + add_schema, + "create table t(a text, b int); + select t from t$0;"), + @"create table t(a text, b int); + select t from public.t;" + ); + } + + #[test] + fn add_schema_select_qualified_column() { + assert_snapshot!(apply_code_action( + add_schema, + "create table t(c text); + select t$0.c from t;"), + @"create table t(c text); + select public.t.c from t;" + ); + } + + #[test] + fn add_schema_with_search_path() { + assert_snapshot!( + apply_code_action( + add_schema, + " +set search_path to myschema; +create table t$0(a text, b int);" + ), + @" +set search_path to myschema; +create table myschema.t(a text, b int);" + ); + } + + #[test] + fn add_schema_not_applicable_with_schema() { + assert!(code_action_not_applicable( + add_schema, + "create table myschema.t$0(a text, b int);" + )); + } + #[test] fn rewrite_select_as_table_not_applicable_with_distinct() { assert!(code_action_not_applicable(