Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/squawk_ide/src/binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ fn bind_create_table(b: &mut Binder, create_table: ast::CreateTable) {
b.scopes[root].insert(table_name, table_id);
}

// TODO: combine with bind_create_table
fn bind_create_foreign_table(b: &mut Binder, create_foreign_table: ast::CreateForeignTable) {
let Some(path) = create_foreign_table.path() else {
return;
Expand Down Expand Up @@ -337,6 +338,7 @@ fn bind_create_view(b: &mut Binder, create_view: ast::CreateView) {
b.scopes[root].insert(view_name, view_id);
}

// TODO: combine with create_view
fn bind_create_materialized_view(b: &mut Binder, create_view: ast::CreateMaterializedView) {
let Some(path) = create_view.path() else {
return;
Expand Down
180 changes: 180 additions & 0 deletions crates/squawk_ide/src/code_actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ pub fn code_actions(file: ast::SourceFile, offset: TextSize) -> Option<Vec<CodeA
unquote_identifier(&mut actions, &file, offset);
add_explicit_alias(&mut actions, &file, offset);
remove_redundant_alias(&mut actions, &file, offset);
rewrite_cast_to_double_colon(&mut actions, &file, offset);
rewrite_double_colon_to_cast(&mut actions, &file, offset);
Some(actions)
}

Expand Down Expand Up @@ -434,6 +436,62 @@ fn remove_redundant_alias(
Some(())
}

fn rewrite_cast_to_double_colon(
actions: &mut Vec<CodeAction>,
file: &ast::SourceFile,
offset: TextSize,
) -> Option<()> {
let token = token_from_offset(file, offset)?;
let cast_expr = token.parent_ancestors().find_map(ast::CastExpr::cast)?;

if cast_expr.colon_colon().is_some() {
return None;
}

let expr = cast_expr.expr()?;
let ty = cast_expr.ty()?;

let expr_text = expr.syntax().text();
let type_text = ty.syntax().text();

let replacement = format!("{}::{}", expr_text, type_text);

actions.push(CodeAction {
title: "Rewrite cast function".to_owned(),
edits: vec![Edit::replace(cast_expr.syntax().text_range(), replacement)],
kind: ActionKind::RefactorRewrite,
});

Some(())
}

fn rewrite_double_colon_to_cast(
actions: &mut Vec<CodeAction>,
file: &ast::SourceFile,
offset: TextSize,
) -> Option<()> {
let token = token_from_offset(file, offset)?;
let cast_expr = token.parent_ancestors().find_map(ast::CastExpr::cast)?;

cast_expr.colon_colon()?;

let expr = cast_expr.expr()?;
let ty = cast_expr.ty()?;

let expr_text = expr.syntax().text();
let type_text = ty.syntax().text();

let replacement = format!("cast({} as {})", expr_text, type_text);

actions.push(CodeAction {
title: "Rewrite as cast operator".to_owned(),
edits: vec![Edit::replace(cast_expr.syntax().text_range(), replacement)],
kind: ActionKind::RefactorRewrite,
});

Some(())
}

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -1145,4 +1203,126 @@ mod test {
"select col_name$0 from t;"
));
}

#[test]
fn rewrite_cast_to_double_colon_simple() {
assert_snapshot!(apply_code_action(
rewrite_cast_to_double_colon,
"select ca$0st(foo as text) from t;"),
@"select foo::text from t;"
);
}

#[test]
fn rewrite_cast_to_double_colon_on_column() {
assert_snapshot!(apply_code_action(
rewrite_cast_to_double_colon,
"select cast(col_na$0me as int) from t;"),
@"select col_name::int from t;"
);
}

#[test]
fn rewrite_cast_to_double_colon_on_type() {
assert_snapshot!(apply_code_action(
rewrite_cast_to_double_colon,
"select cast(x as bigi$0nt) from t;"),
@"select x::bigint from t;"
);
}

#[test]
fn rewrite_cast_to_double_colon_qualified_type() {
assert_snapshot!(apply_code_action(
rewrite_cast_to_double_colon,
"select cast(x as pg_cata$0log.text) from t;"),
@"select x::pg_catalog.text from t;"
);
}

#[test]
fn rewrite_cast_to_double_colon_expression() {
assert_snapshot!(apply_code_action(
rewrite_cast_to_double_colon,
"select ca$0st(1 + 2 as bigint) from t;"),
@"select 1 + 2::bigint from t;"
);
}

#[test]
fn rewrite_cast_to_double_colon_not_applicable_already_double_colon() {
assert!(code_action_not_applicable(
rewrite_cast_to_double_colon,
"select foo::te$0xt from t;"
));
}

#[test]
fn rewrite_cast_to_double_colon_not_applicable_outside_cast() {
assert!(code_action_not_applicable(
rewrite_cast_to_double_colon,
"select fo$0o from t;"
));
}

#[test]
fn rewrite_double_colon_to_cast_simple() {
assert_snapshot!(apply_code_action(
rewrite_double_colon_to_cast,
"select foo::te$0xt from t;"),
@"select cast(foo as text) from t;"
);
}

#[test]
fn rewrite_double_colon_to_cast_on_column() {
assert_snapshot!(apply_code_action(
rewrite_double_colon_to_cast,
"select col_na$0me::int from t;"),
@"select cast(col_name as int) from t;"
);
}

#[test]
fn rewrite_double_colon_to_cast_on_type() {
assert_snapshot!(apply_code_action(
rewrite_double_colon_to_cast,
"select x::bigi$0nt from t;"),
@"select cast(x as bigint) from t;"
);
}

#[test]
fn rewrite_double_colon_to_cast_qualified_type() {
assert_snapshot!(apply_code_action(
rewrite_double_colon_to_cast,
"select x::pg_cata$0log.text from t;"),
@"select cast(x as pg_catalog.text) from t;"
);
}

#[test]
fn rewrite_double_colon_to_cast_expression() {
assert_snapshot!(apply_code_action(
rewrite_double_colon_to_cast,
"select 1 + 2::bigi$0nt from t;"),
@"select 1 + cast(2 as bigint) from t;"
);
}

#[test]
fn rewrite_double_colon_to_cast_not_applicable_already_cast() {
assert!(code_action_not_applicable(
rewrite_double_colon_to_cast,
"select ca$0st(foo as text) from t;"
));
}

#[test]
fn rewrite_double_colon_to_cast_not_applicable_outside_cast() {
assert!(code_action_not_applicable(
rewrite_double_colon_to_cast,
"select fo$0o from t;"
));
}
}
63 changes: 30 additions & 33 deletions crates/squawk_ide/src/document_symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use squawk_syntax::ast::{self, AstNode};

use crate::binder::{self, extract_string_literal};
use crate::resolve::{
resolve_aggregate_info, resolve_function_info, resolve_materialized_view_info,
resolve_procedure_info, resolve_table_info, resolve_type_info, resolve_view_info,
resolve_aggregate_info, resolve_function_info, resolve_procedure_info, resolve_table_info,
resolve_type_info, resolve_view_info,
};

#[derive(Debug)]
Expand Down Expand Up @@ -122,23 +122,13 @@ fn create_cte_table_symbol(with_table: ast::WithTable) -> Option<DocumentSymbol>
let full_range = with_table.syntax().text_range();
let focus_range = name_node.syntax().text_range();

let mut children = vec![];
if let Some(column_list) = with_table.column_list() {
for column in column_list.columns() {
if let Some(column_symbol) = create_column_symbol(column) {
children.push(column_symbol);
}
}
}

Some(DocumentSymbol {
symbols_from_column_list(
with_table.column_list(),
name,
detail: None,
kind: DocumentSymbolKind::Table,
full_range,
focus_range,
children,
})
DocumentSymbolKind::Table,
)
}

fn create_schema_symbol(create_schema: ast::CreateSchema) -> Option<DocumentSymbol> {
Expand Down Expand Up @@ -221,8 +211,24 @@ fn create_view_symbol(
let full_range = create_view.syntax().text_range();
let focus_range = name_node.syntax().text_range();

symbols_from_column_list(
create_view.column_list(),
name,
full_range,
focus_range,
DocumentSymbolKind::View,
)
}

fn symbols_from_column_list(
column_list: Option<ast::ColumnList>,
name: String,
full_range: TextRange,
focus_range: TextRange,
kind: DocumentSymbolKind,
) -> Option<DocumentSymbol> {
let mut children = vec![];
if let Some(column_list) = create_view.column_list() {
if let Some(column_list) = column_list {
for column in column_list.columns() {
if let Some(column_symbol) = create_column_symbol(column) {
children.push(column_symbol);
Expand All @@ -233,13 +239,14 @@ fn create_view_symbol(
Some(DocumentSymbol {
name,
detail: None,
kind: DocumentSymbolKind::View,
kind,
full_range,
focus_range,
children,
})
}

// TODO: combine with create_view_symbol
fn create_materialized_view_symbol(
binder: &binder::Binder,
create_view: ast::CreateMaterializedView,
Expand All @@ -248,29 +255,19 @@ fn create_materialized_view_symbol(
let segment = path.segment()?;
let name_node = segment.name()?;

let (schema, view_name) = resolve_materialized_view_info(binder, &path)?;
let (schema, view_name) = resolve_view_info(binder, &path)?;
let name = format!("{}.{}", schema.0, view_name);

let full_range = create_view.syntax().text_range();
let focus_range = name_node.syntax().text_range();

let mut children = vec![];
if let Some(column_list) = create_view.column_list() {
for column in column_list.columns() {
if let Some(column_symbol) = create_column_symbol(column) {
children.push(column_symbol);
}
}
}

Some(DocumentSymbol {
symbols_from_column_list(
create_view.column_list(),
name,
detail: None,
kind: DocumentSymbolKind::MaterializedView,
full_range,
focus_range,
children,
})
DocumentSymbolKind::MaterializedView,
)
}

fn create_function_symbol(
Expand Down
7 changes: 0 additions & 7 deletions crates/squawk_ide/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2625,13 +2625,6 @@ pub(crate) fn resolve_view_info(binder: &Binder, path: &ast::Path) -> Option<(Sc
resolve_symbol_info(binder, path, SymbolKind::View)
}

pub(crate) fn resolve_materialized_view_info(
binder: &Binder,
path: &ast::Path,
) -> Option<(Schema, String)> {
resolve_symbol_info(binder, path, SymbolKind::View)
}

pub(crate) fn resolve_sequence_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> {
resolve_symbol_info(binder, path, SymbolKind::Sequence)
}
Expand Down
Loading