diff --git a/crates/squawk_ide/src/binder.rs b/crates/squawk_ide/src/binder.rs index 95fa0391..d0b4a645 100644 --- a/crates/squawk_ide/src/binder.rs +++ b/crates/squawk_ide/src/binder.rs @@ -470,13 +470,85 @@ fn bind_create_type(b: &mut Binder, create_type: ast::CreateType) { let type_id = b.symbols.alloc(Symbol { kind: SymbolKind::Type, ptr: name_ptr, - schema: Some(schema), + schema: Some(schema.clone()), params: None, table: None, }); let root = b.root_scope(); - b.scopes[root].insert(type_name, type_id); + b.scopes[root].insert(type_name.clone(), type_id); + + if create_type.range_token().is_some() { + if let Some((multirange_name, multirange_ptr, multirange_schema)) = + multirange_type_from_range(b, create_type, type_name, schema, name_ptr) + { + let multirange_id = b.symbols.alloc(Symbol { + kind: SymbolKind::Type, + ptr: multirange_ptr, + schema: Some(multirange_schema), + params: None, + table: None, + }); + b.scopes[root].insert(multirange_name, multirange_id); + } + } +} + +fn multirange_type_from_range( + b: &Binder, + create_type: ast::CreateType, + type_name: Name, + schema: Schema, + fallback_ptr: SyntaxNodePtr, +) -> Option<(Name, SyntaxNodePtr, Schema)> { + if let Some(attribute_list) = create_type.attribute_list() { + let multirange_key = Name::from_string("multirange_type_name"); + for option in attribute_list.attribute_options() { + let Some(name) = option.name() else { + continue; + }; + if Name::from_node(&name) != multirange_key { + continue; + } + if let Some(attribute_value) = option.attribute_value() { + if let Some(literal) = attribute_value.literal() + && let Some(string_value) = extract_string_literal(&literal) + { + let multirange_name = Name::from_string(string_value); + return Some((multirange_name, fallback_ptr, schema)); + } + if let Some(ast::Type::PathType(path_type)) = attribute_value.ty() + && let Some(path) = path_type.path() + && let Some(multirange_name) = item_name(&path) + { + let multirange_schema = if path.qualifier().is_some() { + schema_name(b, &path, false)? + } else { + schema + }; + return Some((multirange_name, fallback_ptr, multirange_schema)); + } + } + } + } + + let multirange_name = derive_multirange_name(type_name); + Some((multirange_name, fallback_ptr, schema)) +} + +// from postgres docs: +// > If the range type name contains the substring range, then the multirange type +// > name is formed by replacement of the range substring with multirange in the +// > range type name. +// > Otherwise, the multirange type name is formed by appending a +// > _multirange suffix to the range type name. +fn derive_multirange_name(range_name: Name) -> Name { + let range_text = range_name.0.as_str(); + if range_text.contains("range") { + Name::from_string(range_text.replacen("range", "multirange", 1)) + } else { + Name::from_string(format!("{range_text}_multirange")) + } } fn bind_create_view(b: &mut Binder, create_view: ast::CreateView) { diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 4a44accc..d333b820 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -67,6 +67,22 @@ pub fn goto_definition(file: ast::SourceFile, offset: TextSize) -> SmallVec<[Tex } } + let type_node = ast::Type::cast(parent.clone()).or_else(|| { + // special case if we're at the timezone clause inside a timezone type + if ast::Timezone::can_cast(parent.kind()) { + parent.parent().and_then(ast::Type::cast) + } else { + None + } + }); + if let Some(ty) = type_node { + let binder_output = binder::bind(&file); + let position = token.text_range().start(); + if let Some(ptr) = resolve::resolve_type_ptr_from_type(&binder_output, &ty, position) { + return smallvec![ptr.to_node(file.syntax()).text_range()]; + } + } + smallvec![] } @@ -1745,6 +1761,34 @@ create function b(t$0) returns int as 'select 1' language sql; "); } + #[test] + fn goto_function_param_time_type() { + assert_snapshot!(goto(" +create type timestamp; +create function f(timestamp$0 without time zone) returns text language internal; +"), @r" + ╭▸ + 2 │ create type timestamp; + │ ───────── 2. destination + 3 │ create function f(timestamp without time zone) returns text language internal; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_function_param_time_type_no_timezone() { + assert_snapshot!(goto(" +create type time; +create function f(time$0) returns text language internal; +"), @r" + ╭▸ +2 │ create type time; + │ ──── 2. destination +3 │ create function f(time) returns text language internal; + ╰╴ ─ 1. source +"); + } + #[test] fn goto_create_table_type_reference_enum() { assert_snapshot!(goto(" @@ -2090,6 +2134,109 @@ select x::public.baz$0; "); } + #[test] + fn goto_cast_timestamp_without_time_zone() { + assert_snapshot!(goto(" +create type pg_catalog.timestamp; +select ''::timestamp without$0 time zone; +"), @r" + ╭▸ + 2 │ create type pg_catalog.timestamp; + │ ───────── 2. destination + 3 │ select ''::timestamp without time zone; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_cast_timestamp_with_time_zone() { + assert_snapshot!(goto(" +create type pg_catalog.timestamptz; +select ''::timestamp with$0 time zone; +"), @r" + ╭▸ + 2 │ create type pg_catalog.timestamptz; + │ ─────────── 2. destination + 3 │ select ''::timestamp with time zone; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_cast_multirange_type_from_range() { + assert_snapshot!(goto(" +create type floatrange as range ( + subtype = float8, + subtype_diff = float8mi +); +select '{[1.234, 5.678]}'::floatmultirange$0; +"), @r" + ╭▸ + 2 │ create type floatrange as range ( + │ ────────── 2. destination + ‡ + 6 │ select '{[1.234, 5.678]}'::floatmultirange; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_cast_multirange_special_type_name_string() { + assert_snapshot!(goto(" +create type floatrange as range ( + subtype = float8, + subtype_diff = float8mi, + multirange_type_name = 'floatmulirangething' +); +select '{[1.234, 5.678]}'::floatmulirangething$0; +"), @r" + ╭▸ + 2 │ create type floatrange as range ( + │ ────────── 2. destination + ‡ + 7 │ select '{[1.234, 5.678]}'::floatmulirangething; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_cast_multirange_special_type_name_ident() { + assert_snapshot!(goto(" +create type floatrange as range ( + subtype = float8, + subtype_diff = float8mi, + multirange_type_name = floatrangemutirange +); +select '{[1.234, 5.678]}'::floatrangemutirange$0; +"), @r" + ╭▸ + 2 │ create type floatrange as range ( + │ ────────── 2. destination + ‡ + 7 │ select '{[1.234, 5.678]}'::floatrangemutirange; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_cast_multirange_edge_case_type_from_range() { + // make sure we're calculating the multirange correctly + assert_snapshot!(goto(" +create type floatrangerange as range ( + subtype = float8, + subtype_diff = float8mi +); +select '{[1.234, 5.678]}'::floatmultirangerange$0; +"), @r" + ╭▸ + 2 │ create type floatrangerange as range ( + │ ─────────────── 2. destination + ‡ + 6 │ select '{[1.234, 5.678]}'::floatmultirangerange; + ╰╴ ─ 1. source + "); + } + #[test] fn goto_cast_bigint_falls_back_to_int8() { assert_snapshot!(goto(" @@ -2246,6 +2393,34 @@ select 1::smallint$0; "); } + #[test] + fn goto_cast_double_precision_falls_back_to_float8() { + assert_snapshot!(goto(" +create type pg_catalog.float8; +select '1'::double precision[]$0; +"), @r" + ╭▸ + 2 │ create type pg_catalog.float8; + │ ────── 2. destination + 3 │ select '1'::double precision[]; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_cast_varchar_with_modifier() { + assert_snapshot!(goto(" +create type pg_catalog.varchar; +select '1'::varchar$0(1); +"), @r" + ╭▸ + 2 │ create type pg_catalog.varchar; + │ ─────── 2. destination + 3 │ select '1'::varchar(1); + ╰╴ ─ 1. source + "); + } + #[test] fn goto_cast_composite_type() { assert_snapshot!(goto(" diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 5031772e..36193150 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -535,6 +535,74 @@ fn resolve_type_name_ptr( None } +pub(crate) fn resolve_type_ptr_from_type( + binder: &Binder, + ty: &ast::Type, + position: TextSize, +) -> Option { + let (type_name, schema) = type_name_and_schema_from_type(ty)?; + resolve_type_name_ptr(binder, &type_name, &schema, position) +} + +fn type_name_and_schema_from_type(ty: &ast::Type) -> Option<(Name, Option)> { + match ty { + ast::Type::ArrayType(array_type) => { + let inner = array_type.ty()?; + type_name_and_schema_from_type(&inner) + } + ast::Type::BitType(bit_type) => { + let name = if bit_type.varying_token().is_some() { + "varbit" + } else { + "bit" + }; + Some((Name::from_string(name), None)) + } + ast::Type::IntervalType(_) => Some((Name::from_string("interval"), None)), + ast::Type::PathType(path_type) => { + let path = path_type.path()?; + let type_name = extract_table_name(&path)?; + let schema = extract_schema_name(&path); + Some((type_name, schema)) + } + ast::Type::ExprType(expr_type) => { + let expr = expr_type.expr()?; + if let ast::Expr::FieldExpr(field_expr) = expr + && let Some(field) = field_expr.field() + && let Some(ast::Expr::NameRef(schema_name_ref)) = field_expr.base() + { + let type_name = Name::from_node(&field); + let schema = Some(Schema(Name::from_node(&schema_name_ref))); + Some((type_name, schema)) + } else { + None + } + } + ast::Type::CharType(char_type) => { + let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some() + { + "varchar" + } else { + "bpchar" + }; + Some((Name::from_string(name), None)) + } + ast::Type::DoubleType(_) => Some((Name::from_string("float8"), None)), + ast::Type::TimeType(time_type) => { + let mut name = if time_type.timestamp_token().is_some() { + "timestamp".to_string() + } else { + "time".to_string() + }; + if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() { + name.push_str("tz"); + } + Some((Name::from_string(name), None)) + } + ast::Type::PercentType(_) => None, + } +} + fn fallback_type_alias(type_name: &Name) -> Option { match type_name.0.as_str() { "bigint" | "bigserial" | "serial8" => Some(Name::from_string("int8")),