Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions crates/squawk_ide/src/binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,13 +470,85 @@ fn bind_create_type(b: &mut Binder, create_type: ast::CreateType) {
let type_id = b.symbols.alloc(Symbol {
kind: SymbolKind::Type,
ptr: name_ptr,
schema: Some(schema),
schema: Some(schema.clone()),
params: None,
table: None,
});

let root = b.root_scope();
b.scopes[root].insert(type_name, type_id);
b.scopes[root].insert(type_name.clone(), type_id);

if create_type.range_token().is_some() {
if let Some((multirange_name, multirange_ptr, multirange_schema)) =
multirange_type_from_range(b, create_type, type_name, schema, name_ptr)
{
let multirange_id = b.symbols.alloc(Symbol {
kind: SymbolKind::Type,
ptr: multirange_ptr,
schema: Some(multirange_schema),
params: None,
table: None,
});
b.scopes[root].insert(multirange_name, multirange_id);
}
}
}

fn multirange_type_from_range(
b: &Binder,
create_type: ast::CreateType,
type_name: Name,
schema: Schema,
fallback_ptr: SyntaxNodePtr,
) -> Option<(Name, SyntaxNodePtr, Schema)> {
if let Some(attribute_list) = create_type.attribute_list() {
let multirange_key = Name::from_string("multirange_type_name");
for option in attribute_list.attribute_options() {
let Some(name) = option.name() else {
continue;
};
if Name::from_node(&name) != multirange_key {
continue;
}
if let Some(attribute_value) = option.attribute_value() {
if let Some(literal) = attribute_value.literal()
&& let Some(string_value) = extract_string_literal(&literal)
{
let multirange_name = Name::from_string(string_value);
return Some((multirange_name, fallback_ptr, schema));
}
if let Some(ast::Type::PathType(path_type)) = attribute_value.ty()
&& let Some(path) = path_type.path()
&& let Some(multirange_name) = item_name(&path)
{
let multirange_schema = if path.qualifier().is_some() {
schema_name(b, &path, false)?
} else {
schema
};
return Some((multirange_name, fallback_ptr, multirange_schema));
}
}
}
}

let multirange_name = derive_multirange_name(type_name);
Some((multirange_name, fallback_ptr, schema))
}

// from postgres docs:
// > If the range type name contains the substring range, then the multirange type
// > name is formed by replacement of the range substring with multirange in the
// > range type name.
// > Otherwise, the multirange type name is formed by appending a
// > _multirange suffix to the range type name.
fn derive_multirange_name(range_name: Name) -> Name {
let range_text = range_name.0.as_str();
if range_text.contains("range") {
Name::from_string(range_text.replacen("range", "multirange", 1))
} else {
Name::from_string(format!("{range_text}_multirange"))
}
}

fn bind_create_view(b: &mut Binder, create_view: ast::CreateView) {
Expand Down
175 changes: 175 additions & 0 deletions crates/squawk_ide/src/goto_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ pub fn goto_definition(file: ast::SourceFile, offset: TextSize) -> SmallVec<[Tex
}
}

let type_node = ast::Type::cast(parent.clone()).or_else(|| {
// special case if we're at the timezone clause inside a timezone type
if ast::Timezone::can_cast(parent.kind()) {
parent.parent().and_then(ast::Type::cast)
} else {
None
}
});
if let Some(ty) = type_node {
let binder_output = binder::bind(&file);
let position = token.text_range().start();
if let Some(ptr) = resolve::resolve_type_ptr_from_type(&binder_output, &ty, position) {
return smallvec![ptr.to_node(file.syntax()).text_range()];
}
}

smallvec![]
}

Expand Down Expand Up @@ -1745,6 +1761,34 @@ create function b(t$0) returns int as 'select 1' language sql;
");
}

#[test]
fn goto_function_param_time_type() {
assert_snapshot!(goto("
create type timestamp;
create function f(timestamp$0 without time zone) returns text language internal;
"), @r"
╭▸
2 │ create type timestamp;
│ ───────── 2. destination
3 │ create function f(timestamp without time zone) returns text language internal;
╰╴ ─ 1. source
");
}

#[test]
fn goto_function_param_time_type_no_timezone() {
assert_snapshot!(goto("
create type time;
create function f(time$0) returns text language internal;
"), @r"
╭▸
2 │ create type time;
│ ──── 2. destination
3 │ create function f(time) returns text language internal;
╰╴ ─ 1. source
");
}

#[test]
fn goto_create_table_type_reference_enum() {
assert_snapshot!(goto("
Expand Down Expand Up @@ -2090,6 +2134,109 @@ select x::public.baz$0;
");
}

#[test]
fn goto_cast_timestamp_without_time_zone() {
assert_snapshot!(goto("
create type pg_catalog.timestamp;
select ''::timestamp without$0 time zone;
"), @r"
╭▸
2 │ create type pg_catalog.timestamp;
│ ───────── 2. destination
3 │ select ''::timestamp without time zone;
╰╴ ─ 1. source
");
}

#[test]
fn goto_cast_timestamp_with_time_zone() {
assert_snapshot!(goto("
create type pg_catalog.timestamptz;
select ''::timestamp with$0 time zone;
"), @r"
╭▸
2 │ create type pg_catalog.timestamptz;
│ ─────────── 2. destination
3 │ select ''::timestamp with time zone;
╰╴ ─ 1. source
");
}

#[test]
fn goto_cast_multirange_type_from_range() {
assert_snapshot!(goto("
create type floatrange as range (
subtype = float8,
subtype_diff = float8mi
);
select '{[1.234, 5.678]}'::floatmultirange$0;
"), @r"
╭▸
2 │ create type floatrange as range (
│ ────────── 2. destination
6 │ select '{[1.234, 5.678]}'::floatmultirange;
╰╴ ─ 1. source
");
}

#[test]
fn goto_cast_multirange_special_type_name_string() {
assert_snapshot!(goto("
create type floatrange as range (
subtype = float8,
subtype_diff = float8mi,
multirange_type_name = 'floatmulirangething'
);
select '{[1.234, 5.678]}'::floatmulirangething$0;
"), @r"
╭▸
2 │ create type floatrange as range (
│ ────────── 2. destination
7 │ select '{[1.234, 5.678]}'::floatmulirangething;
╰╴ ─ 1. source
");
}

#[test]
fn goto_cast_multirange_special_type_name_ident() {
assert_snapshot!(goto("
create type floatrange as range (
subtype = float8,
subtype_diff = float8mi,
multirange_type_name = floatrangemutirange
);
select '{[1.234, 5.678]}'::floatrangemutirange$0;
"), @r"
╭▸
2 │ create type floatrange as range (
│ ────────── 2. destination
7 │ select '{[1.234, 5.678]}'::floatrangemutirange;
╰╴ ─ 1. source
");
}

#[test]
fn goto_cast_multirange_edge_case_type_from_range() {
// make sure we're calculating the multirange correctly
assert_snapshot!(goto("
create type floatrangerange as range (
subtype = float8,
subtype_diff = float8mi
);
select '{[1.234, 5.678]}'::floatmultirangerange$0;
"), @r"
╭▸
2 │ create type floatrangerange as range (
│ ─────────────── 2. destination
6 │ select '{[1.234, 5.678]}'::floatmultirangerange;
╰╴ ─ 1. source
");
}

#[test]
fn goto_cast_bigint_falls_back_to_int8() {
assert_snapshot!(goto("
Expand Down Expand Up @@ -2246,6 +2393,34 @@ select 1::smallint$0;
");
}

#[test]
fn goto_cast_double_precision_falls_back_to_float8() {
assert_snapshot!(goto("
create type pg_catalog.float8;
select '1'::double precision[]$0;
"), @r"
╭▸
2 │ create type pg_catalog.float8;
│ ────── 2. destination
3 │ select '1'::double precision[];
╰╴ ─ 1. source
");
}

#[test]
fn goto_cast_varchar_with_modifier() {
assert_snapshot!(goto("
create type pg_catalog.varchar;
select '1'::varchar$0(1);
"), @r"
╭▸
2 │ create type pg_catalog.varchar;
│ ─────── 2. destination
3 │ select '1'::varchar(1);
╰╴ ─ 1. source
");
}

#[test]
fn goto_cast_composite_type() {
assert_snapshot!(goto("
Expand Down
68 changes: 68 additions & 0 deletions crates/squawk_ide/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,74 @@ fn resolve_type_name_ptr(
None
}

pub(crate) fn resolve_type_ptr_from_type(
binder: &Binder,
ty: &ast::Type,
position: TextSize,
) -> Option<SyntaxNodePtr> {
let (type_name, schema) = type_name_and_schema_from_type(ty)?;
resolve_type_name_ptr(binder, &type_name, &schema, position)
}

fn type_name_and_schema_from_type(ty: &ast::Type) -> Option<(Name, Option<Schema>)> {
match ty {
ast::Type::ArrayType(array_type) => {
let inner = array_type.ty()?;
type_name_and_schema_from_type(&inner)
}
ast::Type::BitType(bit_type) => {
let name = if bit_type.varying_token().is_some() {
"varbit"
} else {
"bit"
};
Some((Name::from_string(name), None))
}
ast::Type::IntervalType(_) => Some((Name::from_string("interval"), None)),
ast::Type::PathType(path_type) => {
let path = path_type.path()?;
let type_name = extract_table_name(&path)?;
let schema = extract_schema_name(&path);
Some((type_name, schema))
}
ast::Type::ExprType(expr_type) => {
let expr = expr_type.expr()?;
if let ast::Expr::FieldExpr(field_expr) = expr
&& let Some(field) = field_expr.field()
&& let Some(ast::Expr::NameRef(schema_name_ref)) = field_expr.base()
{
let type_name = Name::from_node(&field);
let schema = Some(Schema(Name::from_node(&schema_name_ref)));
Some((type_name, schema))
} else {
None
}
}
ast::Type::CharType(char_type) => {
let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some()
{
"varchar"
} else {
"bpchar"
};
Some((Name::from_string(name), None))
}
ast::Type::DoubleType(_) => Some((Name::from_string("float8"), None)),
ast::Type::TimeType(time_type) => {
let mut name = if time_type.timestamp_token().is_some() {
"timestamp".to_string()
} else {
"time".to_string()
};
if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() {
name.push_str("tz");
}
Some((Name::from_string(name), None))
}
ast::Type::PercentType(_) => None,
}
}

fn fallback_type_alias(type_name: &Name) -> Option<Name> {
match type_name.0.as_str() {
"bigint" | "bigserial" | "serial8" => Some(Name::from_string("int8")),
Expand Down
Loading