Skip to content

Commit 9076edb

Browse files
authored
fix: add support for untyped snowflake arrays (#2608)
* fix: add support for untyped snowflake arrays * remove pandasnamedtuple
1 parent c932559 commit 9076edb

File tree

5 files changed

+87
-44
lines changed

5 files changed

+87
-44
lines changed

sqlmesh/core/dialect.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
if t.TYPE_CHECKING:
2828
from sqlglot._typing import E
2929

30-
from sqlmesh.utils.pandas import PandasNamedTuple
3130

3231
SQLMESH_MACRO_PREFIX = "@"
3332

@@ -837,7 +836,7 @@ def extend_sqlglot() -> None:
837836

838837

839838
def select_from_values(
840-
values: t.List[PandasNamedTuple],
839+
values: t.List[t.Tuple[t.Any, ...]],
841840
columns_to_types: t.Dict[str, exp.DataType],
842841
batch_size: int = 0,
843842
alias: str = "t",
@@ -867,7 +866,7 @@ def select_from_values(
867866

868867

869868
def select_from_values_for_batch_range(
870-
values: t.List[PandasNamedTuple],
869+
values: t.List[t.Tuple[t.Any, ...]],
871870
columns_to_types: t.Dict[str, exp.DataType],
872871
batch_start: int,
873872
batch_end: int,

sqlmesh/core/engine_adapter/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
QueryOrDF,
5656
)
5757
from sqlmesh.core.node import IntervalUnit
58-
from sqlmesh.utils.pandas import PandasNamedTuple
5958

6059
logger = logging.getLogger(__name__)
6160

@@ -1061,7 +1060,7 @@ def insert_overwrite_by_time_partition(
10611060

10621061
def _values_to_sql(
10631062
self,
1064-
values: t.List[PandasNamedTuple],
1063+
values: t.List[t.Tuple[t.Any, ...]],
10651064
columns_to_types: t.Dict[str, exp.DataType],
10661065
batch_start: int,
10671066
batch_end: int,

sqlmesh/core/schema_diff.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ class TableAlterColumn(PydanticModel):
4343
quoted: bool = False
4444

4545
@classmethod
46-
def primitive(self, name: str, quoted: bool = False) -> TableAlterColumn:
47-
return self(
46+
def primitive(cls, name: str, quoted: bool = False) -> TableAlterColumn:
47+
return cls(
4848
name=name,
4949
is_struct=False,
5050
is_array_of_struct=False,
@@ -53,8 +53,8 @@ def primitive(self, name: str, quoted: bool = False) -> TableAlterColumn:
5353
)
5454

5555
@classmethod
56-
def struct(self, name: str, quoted: bool = False) -> TableAlterColumn:
57-
return self(
56+
def struct(cls, name: str, quoted: bool = False) -> TableAlterColumn:
57+
return cls(
5858
name=name,
5959
is_struct=True,
6060
is_array_of_struct=False,
@@ -63,8 +63,8 @@ def struct(self, name: str, quoted: bool = False) -> TableAlterColumn:
6363
)
6464

6565
@classmethod
66-
def array_of_struct(self, name: str, quoted: bool = False) -> TableAlterColumn:
67-
return self(
66+
def array_of_struct(cls, name: str, quoted: bool = False) -> TableAlterColumn:
67+
return cls(
6868
name=name,
6969
is_struct=False,
7070
is_array_of_struct=True,
@@ -73,8 +73,8 @@ def array_of_struct(self, name: str, quoted: bool = False) -> TableAlterColumn:
7373
)
7474

7575
@classmethod
76-
def array_of_primitive(self, name: str, quoted: bool = False) -> TableAlterColumn:
77-
return self(
76+
def array_of_primitive(cls, name: str, quoted: bool = False) -> TableAlterColumn:
77+
return cls(
7878
name=name,
7979
is_struct=False,
8080
is_array_of_struct=False,
@@ -83,20 +83,22 @@ def array_of_primitive(self, name: str, quoted: bool = False) -> TableAlterColum
8383
)
8484

8585
@classmethod
86-
def from_struct_kwarg(self, struct: exp.ColumnDef) -> TableAlterColumn:
86+
def from_struct_kwarg(cls, struct: exp.ColumnDef) -> TableAlterColumn:
8787
name = struct.alias_or_name
8888
quoted = struct.this.quoted
8989
kwarg_type = struct.args["kind"]
9090

9191
if kwarg_type.is_type(exp.DataType.Type.STRUCT):
92-
return self.struct(name, quoted=quoted)
92+
return cls.struct(name, quoted=quoted)
9393
elif kwarg_type.is_type(exp.DataType.Type.ARRAY):
94-
if kwarg_type.expressions[0].is_type(exp.DataType.Type.STRUCT):
95-
return self.array_of_struct(name, quoted=quoted)
94+
if kwarg_type.expressions and kwarg_type.expressions[0].is_type(
95+
exp.DataType.Type.STRUCT
96+
):
97+
return cls.array_of_struct(name, quoted=quoted)
9698
else:
97-
return self.array_of_primitive(name, quoted=quoted)
99+
return cls.array_of_primitive(name, quoted=quoted)
98100
else:
99-
return self.primitive(name, quoted=quoted)
101+
return cls.primitive(name, quoted=quoted)
100102

101103
@property
102104
def is_array(self) -> bool:
@@ -121,22 +123,22 @@ class TableAlterColumnPosition(PydanticModel):
121123
after: t.Optional[exp.Identifier] = None
122124

123125
@classmethod
124-
def first(self) -> TableAlterColumnPosition:
125-
return self(is_first=True, is_last=False, after=None)
126+
def first(cls) -> TableAlterColumnPosition:
127+
return cls(is_first=True, is_last=False, after=None)
126128

127129
@classmethod
128130
def last(
129-
self, after: t.Optional[t.Union[str, exp.Identifier]] = None
131+
cls, after: t.Optional[t.Union[str, exp.Identifier]] = None
130132
) -> TableAlterColumnPosition:
131-
return self(is_first=False, is_last=True, after=exp.to_identifier(after) if after else None)
133+
return cls(is_first=False, is_last=True, after=exp.to_identifier(after) if after else None)
132134

133135
@classmethod
134-
def middle(self, after: t.Union[str, exp.Identifier]) -> TableAlterColumnPosition:
135-
return self(is_first=False, is_last=False, after=exp.to_identifier(after))
136+
def middle(cls, after: t.Union[str, exp.Identifier]) -> TableAlterColumnPosition:
137+
return cls(is_first=False, is_last=False, after=exp.to_identifier(after))
136138

137139
@classmethod
138140
def create(
139-
self,
141+
cls,
140142
pos: int,
141143
current_kwargs: t.List[exp.ColumnDef],
142144
replacing_col: bool = False,
@@ -147,7 +149,7 @@ def create(
147149
if not is_first:
148150
prior_kwarg = current_kwargs[pos - 1]
149151
after, _ = _get_name_and_type(prior_kwarg)
150-
return self(is_first=is_first, is_last=is_last, after=after)
152+
return cls(is_first=is_first, is_last=is_last, after=after)
151153

152154
@property
153155
def column_position_node(self) -> t.Optional[exp.ColumnPosition]:
@@ -170,13 +172,13 @@ class TableAlterOperation(PydanticModel):
170172

171173
@classmethod
172174
def add(
173-
self,
175+
cls,
174176
columns: t.Union[TableAlterColumn, t.List[TableAlterColumn]],
175177
column_type: t.Union[str, exp.DataType],
176178
expected_table_struct: t.Union[str, exp.DataType],
177179
position: t.Optional[TableAlterColumnPosition] = None,
178180
) -> TableAlterOperation:
179-
return self(
181+
return cls(
180182
op=TableAlterOperationType.ADD,
181183
columns=ensure_list(columns),
182184
column_type=exp.DataType.build(column_type),
@@ -186,13 +188,13 @@ def add(
186188

187189
@classmethod
188190
def drop(
189-
self,
191+
cls,
190192
columns: t.Union[TableAlterColumn, t.List[TableAlterColumn]],
191193
expected_table_struct: t.Union[str, exp.DataType],
192194
column_type: t.Optional[t.Union[str, exp.DataType]] = None,
193195
) -> TableAlterOperation:
194196
column_type = exp.DataType.build(column_type) if column_type else exp.DataType.build("INT")
195-
return self(
197+
return cls(
196198
op=TableAlterOperationType.DROP,
197199
columns=ensure_list(columns),
198200
column_type=column_type,
@@ -201,14 +203,14 @@ def drop(
201203

202204
@classmethod
203205
def alter_type(
204-
self,
206+
cls,
205207
columns: t.Union[TableAlterColumn, t.List[TableAlterColumn]],
206208
column_type: t.Union[str, exp.DataType],
207209
current_type: t.Union[str, exp.DataType],
208210
expected_table_struct: t.Union[str, exp.DataType],
209211
position: t.Optional[TableAlterColumnPosition] = None,
210212
) -> TableAlterOperation:
211-
return self(
213+
return cls(
212214
op=TableAlterOperationType.ALTER_TYPE,
213215
columns=ensure_list(columns),
214216
column_type=exp.DataType.build(column_type),
@@ -456,6 +458,9 @@ def _alter_operation(
456458
root_struct,
457459
)
458460
if new_type.this == current_type.this == exp.DataType.Type.ARRAY:
461+
# Some engines (i.e. Snowflake) don't support defining types on arrays
462+
if not new_type.expressions or not current_type.expressions:
463+
return []
459464
new_array_type = new_type.expressions[0]
460465
current_array_type = current_type.expressions[0]
461466
if new_array_type.this == current_array_type.this == exp.DataType.Type.STRUCT:

sqlmesh/utils/pandas.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,11 @@
11
from __future__ import annotations
22

3-
import sys
43
import typing as t
54

65
import numpy as np
76
import pandas as pd
87
from sqlglot import exp
98

10-
if t.TYPE_CHECKING:
11-
# https://github.com/python/mypy/issues/1153
12-
if sys.version_info >= (3, 9):
13-
try:
14-
from pandas.core.frame import _PandasNamedTuple as PandasNamedTuple
15-
except ImportError:
16-
PandasNamedTuple = t.Tuple[t.Any, ...] # type: ignore
17-
else:
18-
PandasNamedTuple = t.Tuple[t.Any, ...]
19-
209

2110
PANDAS_TYPE_MAPPINGS = {
2211
np.dtype("int8"): exp.DataType.build("tinyint"),

tests/core/test_schema_diff.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,57 @@ def test_schema_diff_calculate_type_transitions():
659659
],
660660
dict(support_positional_add=True, support_nested_operations=True),
661661
),
662+
# untyped array to support Snowflake
663+
(
664+
"STRUCT<id INT, ids ARRAY>",
665+
"STRUCT<id INT, ids ARRAY>",
666+
[],
667+
{},
668+
),
669+
# Primitive to untyped array
670+
(
671+
"STRUCT<id INT, ids INT>",
672+
"STRUCT<id INT, ids ARRAY>",
673+
[
674+
TableAlterOperation.drop(
675+
[
676+
TableAlterColumn.primitive("ids"),
677+
],
678+
"STRUCT<id INT>",
679+
"INT",
680+
),
681+
TableAlterOperation.add(
682+
[
683+
TableAlterColumn.primitive("ids"),
684+
],
685+
"ARRAY",
686+
expected_table_struct="STRUCT<id INT, ids ARRAY>",
687+
),
688+
],
689+
{},
690+
),
691+
# untyped array to primitive
692+
(
693+
"STRUCT<id INT, ids ARRAY>",
694+
"STRUCT<id INT, ids INT>",
695+
[
696+
TableAlterOperation.drop(
697+
[
698+
TableAlterColumn.array_of_primitive("ids"),
699+
],
700+
"STRUCT<id INT>",
701+
"ARRAY",
702+
),
703+
TableAlterOperation.add(
704+
[
705+
TableAlterColumn.array_of_primitive("ids"),
706+
],
707+
"INT",
708+
expected_table_struct="STRUCT<id INT, ids INT>",
709+
),
710+
],
711+
{},
712+
),
662713
# Precision VARCHAR is a no-op with no changes
663714
(
664715
"STRUCT<id INT, address VARCHAR(120)>",

0 commit comments

Comments
 (0)