diff --git a/crates/squawk_ide/src/code_actions.rs b/crates/squawk_ide/src/code_actions.rs index 7d9b96a1..2de3c04a 100644 --- a/crates/squawk_ide/src/code_actions.rs +++ b/crates/squawk_ide/src/code_actions.rs @@ -448,48 +448,33 @@ fn add_schema( offset: TextSize, ) -> Option<()> { let token = token_from_offset(file, offset)?; - let (range, has_qualifier) = token.parent_ancestors().find_map(|node| { - if let Some(create_table) = ast::CreateTableLike::cast(node.clone()) { - let path = create_table.path()?; - return Some((path.syntax().text_range(), path.qualifier().is_some())); - } - if let Some(create_function) = ast::CreateFunction::cast(node.clone()) { - let path = create_function.path()?; - return Some((path.syntax().text_range(), path.qualifier().is_some())); + let range = token.parent_ancestors().find_map(|node| { + if let Some(path) = ast::Path::cast(node.clone()) { + if path.qualifier().is_some() { + return None; + } + return Some(path.syntax().text_range()); } - if let Some(table) = ast::Table::cast(node.clone()) { - let path = table.relation_name()?.path()?; - return Some((path.syntax().text_range(), path.qualifier().is_some())); + if let Some(from_item) = ast::FromItem::cast(node.clone()) { + let name_ref = from_item.name_ref()?; + return Some(name_ref.syntax().text_range()); } - if let Some(field_expr) = ast::FieldExpr::cast(node.clone()) { - let ast::Expr::NameRef(name_ref) = field_expr.base()? else { + if let Some(call_expr) = ast::CallExpr::cast(node) { + let ast::Expr::NameRef(name_ref) = call_expr.expr()? else { return None; }; - return Some((name_ref.syntax().text_range(), false)); - } - if let Some(from_item) = ast::FromItem::cast(node) { - let name_ref = from_item.name_ref()?; - return Some((name_ref.syntax().text_range(), false)); + return Some(name_ref.syntax().text_range()); } None })?; - // Already have a schema (or maybe table) set - // - // TODO: we'll need to change this when we want to support things like: - // `select t.c from t; -> select public.t.c from t;` - if has_qualifier { - return None; - } - if !range.contains(offset) { return None; } - let position = token.text_range().start(); + let position = token.text_range().start(); let binder = binder::bind(file); let schema = binder.search_path_at(position).first()?.to_string(); - let replacement = format!("{}.", schema); actions.push(CodeAction { @@ -1071,6 +1056,15 @@ mod test { ); } + #[test] + fn add_schema_create_type() { + assert_snapshot!(apply_code_action( + add_schema, + "create type t$0 as enum ();"), + @"create type public.t as enum ();" + ); + } + #[test] fn add_schema_table_stmt() { assert_snapshot!(apply_code_action( @@ -1091,15 +1085,37 @@ mod test { ); } + #[test] + fn add_schema_select_table_value() { + // we can't insert the schema here because: + // `select public.t from t` isn't valid + assert!(code_action_not_applicable( + add_schema, + "create table t(a text, b int); + select t$0 from t;" + )); + } + + #[test] + fn add_schema_select_unqualified_column() { + // not applicable since we don't have the table name set + // we'll have another quick action to insert table names + assert!(code_action_not_applicable( + add_schema, + "create table t(a text, b int); + select a$0 from t;" + )); + } + #[test] fn add_schema_select_qualified_column() { - assert_snapshot!(apply_code_action( + // not valid because we haven't specified the schema on the table name + // `select public.t.c from t` isn't valid sql + assert!(code_action_not_applicable( add_schema, "create table t(c text); - select t$0.c from t;"), - @"create table t(c text); - select public.t.c from t;" - ); + select t$0.c from t;" + )); } #[test] @@ -1125,6 +1141,35 @@ create table myschema.t(a text, b int);" )); } + #[test] + fn add_schema_function_call() { + assert_snapshot!(apply_code_action( + add_schema, + " +create function f() returns int8 + as 'select 1' + language sql; + +select f$0();"), + @" +create function f() returns int8 + as 'select 1' + language sql; + +select public.f();" + ); + } + + #[test] + fn add_schema_function_call_not_applicable_with_schema() { + assert!(code_action_not_applicable( + add_schema, + " +create function f() returns int8 as 'select 1' language sql; +select myschema.f$0();" + )); + } + #[test] fn rewrite_select_as_table_not_applicable_with_distinct() { assert!(code_action_not_applicable( diff --git a/crates/squawk_syntax/src/ast/generated/nodes.rs b/crates/squawk_syntax/src/ast/generated/nodes.rs index 4266f75a..c728be6f 100644 --- a/crates/squawk_syntax/src/ast/generated/nodes.rs +++ b/crates/squawk_syntax/src/ast/generated/nodes.rs @@ -16583,6 +16583,8 @@ pub enum AlterColumnOption { DropExpression(DropExpression), DropIdentity(DropIdentity), DropNotNull(DropNotNull), + Inherit(Inherit), + NoInherit(NoInherit), ResetOptions(ResetOptions), Restart(Restart), SetCompression(SetCompression), @@ -28587,6 +28589,8 @@ impl AstNode for AlterColumnOption { | SyntaxKind::DROP_EXPRESSION | SyntaxKind::DROP_IDENTITY | SyntaxKind::DROP_NOT_NULL + | SyntaxKind::INHERIT + | SyntaxKind::NO_INHERIT | SyntaxKind::RESET_OPTIONS | SyntaxKind::RESTART | SyntaxKind::SET_COMPRESSION @@ -28613,6 +28617,8 @@ impl AstNode for AlterColumnOption { } SyntaxKind::DROP_IDENTITY => AlterColumnOption::DropIdentity(DropIdentity { syntax }), SyntaxKind::DROP_NOT_NULL => AlterColumnOption::DropNotNull(DropNotNull { syntax }), + SyntaxKind::INHERIT => AlterColumnOption::Inherit(Inherit { syntax }), + SyntaxKind::NO_INHERIT => AlterColumnOption::NoInherit(NoInherit { syntax }), SyntaxKind::RESET_OPTIONS => AlterColumnOption::ResetOptions(ResetOptions { syntax }), SyntaxKind::RESTART => AlterColumnOption::Restart(Restart { syntax }), SyntaxKind::SET_COMPRESSION => { @@ -28653,6 +28659,8 @@ impl AstNode for AlterColumnOption { AlterColumnOption::DropExpression(it) => &it.syntax, AlterColumnOption::DropIdentity(it) => &it.syntax, AlterColumnOption::DropNotNull(it) => &it.syntax, + AlterColumnOption::Inherit(it) => &it.syntax, + AlterColumnOption::NoInherit(it) => &it.syntax, AlterColumnOption::ResetOptions(it) => &it.syntax, AlterColumnOption::Restart(it) => &it.syntax, AlterColumnOption::SetCompression(it) => &it.syntax, @@ -28700,6 +28708,18 @@ impl From for AlterColumnOption { AlterColumnOption::DropNotNull(node) } } +impl From for AlterColumnOption { + #[inline] + fn from(node: Inherit) -> AlterColumnOption { + AlterColumnOption::Inherit(node) + } +} +impl From for AlterColumnOption { + #[inline] + fn from(node: NoInherit) -> AlterColumnOption { + AlterColumnOption::NoInherit(node) + } +} impl From for AlterColumnOption { #[inline] fn from(node: ResetOptions) -> AlterColumnOption { diff --git a/crates/squawk_syntax/src/postgresql.ungram b/crates/squawk_syntax/src/postgresql.ungram index e489737f..e534be42 100644 --- a/crates/squawk_syntax/src/postgresql.ungram +++ b/crates/squawk_syntax/src/postgresql.ungram @@ -2995,6 +2995,8 @@ AlterColumnOption = | SetStorage | SetCompression | SetNotNull +| Inherit +| NoInherit AlterConstraint = 'alter' 'constraint' option:AlterColumnOption