Skip to content

Commit 3b9352b

Browse files
committed
feat: mimic dbt nuanced on_schema_change behavior
1 parent 1e2760f commit 3b9352b

23 files changed

+318
-143
lines changed

sqlmesh/core/config/connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class ConnectionConfig(abc.ABC, BaseConfig):
100100
register_comments: bool
101101
pre_ping: bool
102102
pretty_sql: bool = False
103+
schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None
103104

104105
# Whether to share a single connection across threads or create a new connection per thread.
105106
shared_connection: t.ClassVar[bool] = False
@@ -174,6 +175,7 @@ def create_engine_adapter(
174175
pre_ping=self.pre_ping,
175176
pretty_sql=self.pretty_sql,
176177
shared_connection=self.shared_connection,
178+
schema_differ_overrides=self.schema_differ_overrides,
177179
**self._extra_engine_config,
178180
)
179181

sqlmesh/core/engine_adapter/athena.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin, RowDiffMixin):
3939
# CTAS, Views: No comment support at all
4040
COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED
4141
COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
42-
SCHEMA_DIFFER = TrinoEngineAdapter.SCHEMA_DIFFER
42+
SCHEMA_DIFFER_KWARGS = TrinoEngineAdapter.SCHEMA_DIFFER_KWARGS
4343
MAX_TIMESTAMP_PRECISION = 3 # copied from Trino
4444
# Athena does not deal with comments well, e.g:
4545
# >>> self._execute('/* test */ DESCRIBE foo')

sqlmesh/core/engine_adapter/base.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
import sys
1616
import typing as t
17-
from functools import partial
17+
from functools import cached_property, partial
1818

1919
from sqlglot import Dialect, exp
2020
from sqlglot.errors import ErrorLevel
@@ -108,7 +108,7 @@ class EngineAdapter:
108108
SUPPORTS_CLONING = False
109109
SUPPORTS_MANAGED_MODELS = False
110110
SUPPORTS_CREATE_DROP_CATALOG = False
111-
SCHEMA_DIFFER = SchemaDiffer()
111+
SCHEMA_DIFFER_KWARGS: t.Dict[str, t.Any] = {}
112112
SUPPORTS_TUPLE_IN = True
113113
HAS_VIEW_BINDING = False
114114
SUPPORTS_REPLACE_TABLE = True
@@ -131,6 +131,7 @@ def __init__(
131131
pretty_sql: bool = False,
132132
shared_connection: bool = False,
133133
correlation_id: t.Optional[CorrelationId] = None,
134+
schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None,
134135
**kwargs: t.Any,
135136
):
136137
self.dialect = dialect.lower() or self.DIALECT
@@ -153,6 +154,7 @@ def __init__(
153154
self._pretty_sql = pretty_sql
154155
self._multithreaded = multithreaded
155156
self.correlation_id = correlation_id
157+
self._schema_differ_overrides = schema_differ_overrides
156158

157159
def with_settings(self, **kwargs: t.Any) -> EngineAdapter:
158160
extra_kwargs = {
@@ -203,6 +205,15 @@ def comments_enabled(self) -> bool:
203205
def catalog_support(self) -> CatalogSupport:
204206
return CatalogSupport.UNSUPPORTED
205207

208+
@cached_property
209+
def schema_differ(self) -> SchemaDiffer:
210+
return SchemaDiffer(
211+
**{
212+
**self.SCHEMA_DIFFER_KWARGS,
213+
**(self._schema_differ_overrides or {}),
214+
}
215+
)
216+
206217
@classmethod
207218
def _casted_columns(
208219
cls,
@@ -1094,7 +1105,7 @@ def get_alter_operations(
10941105
"""
10951106
return t.cast(
10961107
t.List[TableAlterOperation],
1097-
self.SCHEMA_DIFFER.compare_columns(
1108+
self.schema_differ.compare_columns(
10981109
current_table_name,
10991110
self.columns(current_table_name),
11001111
self.columns(target_table_name),

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
set_catalog,
2323
)
2424
from sqlmesh.core.node import IntervalUnit
25-
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation
25+
from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport
2626
from sqlmesh.utils import optional_import, get_source_columns_to_types
2727
from sqlmesh.utils.date import to_datetime
2828
from sqlmesh.utils.errors import SQLMeshError
@@ -67,8 +67,8 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
6767
MAX_TABLE_COMMENT_LENGTH = 1024
6868
MAX_COLUMN_COMMENT_LENGTH = 1024
6969

70-
SCHEMA_DIFFER = SchemaDiffer(
71-
compatible_types={
70+
SCHEMA_DIFFER_KWARGS = {
71+
"compatible_types": {
7272
exp.DataType.build("INT64", dialect=DIALECT): {
7373
exp.DataType.build("NUMERIC", dialect=DIALECT),
7474
exp.DataType.build("FLOAT64", dialect=DIALECT),
@@ -82,17 +82,17 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
8282
exp.DataType.build("DATETIME", dialect=DIALECT),
8383
},
8484
},
85-
coerceable_types={
85+
"coerceable_types": {
8686
exp.DataType.build("FLOAT64", dialect=DIALECT): {
8787
exp.DataType.build("BIGNUMERIC", dialect=DIALECT),
8888
},
8989
},
90-
support_coercing_compatible_types=True,
91-
parameterized_type_defaults={
90+
"support_coercing_compatible_types": True,
91+
"parameterized_type_defaults": {
9292
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 9), (0,)],
9393
exp.DataType.build("BIGDECIMAL", dialect=DIALECT).this: [(76.76, 38), (0,)],
9494
},
95-
types_with_unlimited_length={
95+
"types_with_unlimited_length": {
9696
# parameterized `STRING(n)` can ALTER to unparameterized `STRING`
9797
exp.DataType.build("STRING", dialect=DIALECT).this: {
9898
exp.DataType.build("STRING", dialect=DIALECT).this,
@@ -102,9 +102,8 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
102102
exp.DataType.build("BYTES", dialect=DIALECT).this,
103103
},
104104
},
105-
support_nested_operations=True,
106-
support_nested_drop=False,
107-
)
105+
"nested_support": NestedSupport.ALL_BUT_DROP,
106+
}
108107

109108
@property
110109
def client(self) -> BigQueryClient:

sqlmesh/core/engine_adapter/clickhouse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
CommentCreationView,
1616
InsertOverwriteStrategy,
1717
)
18-
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation
18+
from sqlmesh.core.schema_diff import TableAlterOperation
1919
from sqlmesh.utils import get_source_columns_to_types
2020

2121
if t.TYPE_CHECKING:
@@ -37,7 +37,7 @@ class ClickhouseEngineAdapter(EngineAdapterWithIndexSupport, LogicalMergeMixin):
3737
SUPPORTS_REPLACE_TABLE = False
3838
COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY
3939

40-
SCHEMA_DIFFER = SchemaDiffer()
40+
SCHEMA_DIFFER_KWARGS = {}
4141

4242
DEFAULT_TABLE_ENGINE = "MergeTree"
4343
ORDER_BY_TABLE_ENGINE_REGEX = "^.*?MergeTree.*$"

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
1717
from sqlmesh.core.node import IntervalUnit
18-
from sqlmesh.core.schema_diff import SchemaDiffer
18+
from sqlmesh.core.schema_diff import NestedSupport
1919
from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection
2020
from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError
2121

@@ -34,15 +34,14 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
3434
SUPPORTS_CLONING = True
3535
SUPPORTS_MATERIALIZED_VIEWS = True
3636
SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
37-
SCHEMA_DIFFER = SchemaDiffer(
38-
support_positional_add=True,
39-
support_nested_operations=True,
40-
support_nested_drop=True,
41-
array_element_selector="element",
42-
parameterized_type_defaults={
37+
SCHEMA_DIFFER_KWARGS = {
38+
"support_positional_add": True,
39+
"nested_support": NestedSupport.ALL,
40+
"array_element_selector": "element",
41+
"parameterized_type_defaults": {
4342
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)],
4443
},
45-
)
44+
}
4645

4746
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
4847
super().__init__(*args, **kwargs)

sqlmesh/core/engine_adapter/duckdb.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
SourceQuery,
1919
set_catalog,
2020
)
21-
from sqlmesh.core.schema_diff import SchemaDiffer
2221

2322
if t.TYPE_CHECKING:
2423
from sqlmesh.core._typing import SchemaName, TableName
@@ -29,11 +28,11 @@
2928
class DuckDBEngineAdapter(LogicalMergeMixin, GetCurrentCatalogFromFunctionMixin, RowDiffMixin):
3029
DIALECT = "duckdb"
3130
SUPPORTS_TRANSACTIONS = False
32-
SCHEMA_DIFFER = SchemaDiffer(
33-
parameterized_type_defaults={
31+
SCHEMA_DIFFER_KWARGS = {
32+
"parameterized_type_defaults": {
3433
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 3), (0,)],
3534
},
36-
)
35+
}
3736
COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY
3837
COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY
3938
SUPPORTS_CREATE_DROP_CATALOG = True

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,9 @@ def _default_precision_to_max(
259259
) -> t.Dict[str, exp.DataType]:
260260
# get default lengths for types that support "max" length
261261
types_with_max_default_param = {
262-
k: [self.SCHEMA_DIFFER.parameterized_type_defaults[k][0][0]]
263-
for k in self.SCHEMA_DIFFER.max_parameter_length
264-
if k in self.SCHEMA_DIFFER.parameterized_type_defaults
262+
k: [self.schema_differ.parameterized_type_defaults[k][0][0]]
263+
for k in self.schema_differ.max_parameter_length
264+
if k in self.schema_differ.parameterized_type_defaults
265265
}
266266

267267
# Redshift and MSSQL have a bug where CTAS statements have non-deterministic types. If a LIMIT
@@ -270,7 +270,7 @@ def _default_precision_to_max(
270270
# and supports "max" length, we convert it to "max" length to prevent inadvertent data truncation.
271271
for col_name, col_type in columns_to_types.items():
272272
if col_type.this in types_with_max_default_param and col_type.expressions:
273-
parameter = self.SCHEMA_DIFFER.get_type_parameters(col_type)
273+
parameter = self.schema_differ.get_type_parameters(col_type)
274274
type_default = types_with_max_default_param[col_type.this]
275275
if parameter == type_default:
276276
col_type.set("expressions", [exp.DataTypeParam(this=exp.var("max"))])

sqlmesh/core/engine_adapter/mssql.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
SourceQuery,
3131
set_catalog,
3232
)
33-
from sqlmesh.core.schema_diff import SchemaDiffer
3433
from sqlmesh.utils import get_source_columns_to_types
3534

3635
if t.TYPE_CHECKING:
@@ -54,8 +53,8 @@ class MSSQLEngineAdapter(
5453
COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED
5554
COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
5655
SUPPORTS_REPLACE_TABLE = False
57-
SCHEMA_DIFFER = SchemaDiffer(
58-
parameterized_type_defaults={
56+
SCHEMA_DIFFER_KWARGS = {
57+
"parameterized_type_defaults": {
5958
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)],
6059
exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)],
6160
exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(1,)],
@@ -67,12 +66,12 @@ class MSSQLEngineAdapter(
6766
exp.DataType.build("DATETIME2", dialect=DIALECT).this: [(7,)],
6867
exp.DataType.build("DATETIMEOFFSET", dialect=DIALECT).this: [(7,)],
6968
},
70-
max_parameter_length={
69+
"max_parameter_length": {
7170
exp.DataType.build("VARBINARY", dialect=DIALECT).this: 2147483647, # 2 GB
7271
exp.DataType.build("VARCHAR", dialect=DIALECT).this: 2147483647,
7372
exp.DataType.build("NVARCHAR", dialect=DIALECT).this: 2147483647,
7473
},
75-
)
74+
}
7675
VARIABLE_LENGTH_DATA_TYPES = {"binary", "varbinary", "char", "varchar", "nchar", "nvarchar"}
7776

7877
@property

sqlmesh/core/engine_adapter/mysql.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
DataObjectType,
2020
set_catalog,
2121
)
22-
from sqlmesh.core.schema_diff import SchemaDiffer
2322

2423
if t.TYPE_CHECKING:
2524
from sqlmesh.core._typing import SchemaName, TableName
@@ -40,8 +39,8 @@ class MySQLEngineAdapter(
4039
MAX_COLUMN_COMMENT_LENGTH = 1024
4140
SUPPORTS_REPLACE_TABLE = False
4241
MAX_IDENTIFIER_LENGTH = 64
43-
SCHEMA_DIFFER = SchemaDiffer(
44-
parameterized_type_defaults={
42+
SCHEMA_DIFFER_KWARGS = {
43+
"parameterized_type_defaults": {
4544
exp.DataType.build("BIT", dialect=DIALECT).this: [(1,)],
4645
exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)],
4746
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)],
@@ -52,7 +51,7 @@ class MySQLEngineAdapter(
5251
exp.DataType.build("DATETIME", dialect=DIALECT).this: [(0,)],
5352
exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(0,)],
5453
},
55-
)
54+
}
5655

5756
def get_current_catalog(self) -> t.Optional[str]:
5857
"""Returns the catalog name of the current connection."""

0 commit comments

Comments
 (0)