diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index b530af36da..b68b83a39a 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -100,6 +100,7 @@ class ConnectionConfig(abc.ABC, BaseConfig): register_comments: bool pre_ping: bool pretty_sql: bool = False + schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None # Whether to share a single connection across threads or create a new connection per thread. shared_connection: t.ClassVar[bool] = False @@ -174,6 +175,7 @@ def create_engine_adapter( pre_ping=self.pre_ping, pretty_sql=self.pretty_sql, shared_connection=self.shared_connection, + schema_differ_overrides=self.schema_differ_overrides, **self._extra_engine_config, ) diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index 3ed34067d2..48b9e4ad4e 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -39,7 +39,7 @@ class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin, RowDiffMixin): # CTAS, Views: No comment support at all COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED - SCHEMA_DIFFER = TrinoEngineAdapter.SCHEMA_DIFFER + SCHEMA_DIFFER_KWARGS = TrinoEngineAdapter.SCHEMA_DIFFER_KWARGS MAX_TIMESTAMP_PRECISION = 3 # copied from Trino # Athena does not deal with comments well, e.g: # >>> self._execute('/* test */ DESCRIBE foo') diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 920a5aff3d..fe19f7df0f 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -14,7 +14,7 @@ import logging import sys import typing as t -from functools import partial +from functools import cached_property, partial from sqlglot import Dialect, exp from sqlglot.errors import ErrorLevel @@ -109,7 +109,7 @@ class EngineAdapter: SUPPORTS_MANAGED_MODELS = False SUPPORTS_CREATE_DROP_CATALOG = False SUPPORTED_DROP_CASCADE_OBJECT_KINDS: t.List[str] = [] - SCHEMA_DIFFER = SchemaDiffer() + SCHEMA_DIFFER_KWARGS: t.Dict[str, t.Any] = {} SUPPORTS_TUPLE_IN = True HAS_VIEW_BINDING = False SUPPORTS_REPLACE_TABLE = True @@ -132,6 +132,7 @@ def __init__( pretty_sql: bool = False, shared_connection: bool = False, correlation_id: t.Optional[CorrelationId] = None, + schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None, **kwargs: t.Any, ): self.dialect = dialect.lower() or self.DIALECT @@ -154,6 +155,7 @@ def __init__( self._pretty_sql = pretty_sql self._multithreaded = multithreaded self.correlation_id = correlation_id + self._schema_differ_overrides = schema_differ_overrides def with_settings(self, **kwargs: t.Any) -> EngineAdapter: extra_kwargs = { @@ -204,6 +206,15 @@ def comments_enabled(self) -> bool: def catalog_support(self) -> CatalogSupport: return CatalogSupport.UNSUPPORTED + @cached_property + def schema_differ(self) -> SchemaDiffer: + return SchemaDiffer( + **{ + **self.SCHEMA_DIFFER_KWARGS, + **(self._schema_differ_overrides or {}), + } + ) + @classmethod def _casted_columns( cls, @@ -1101,7 +1112,7 @@ def get_alter_operations( """ return t.cast( t.List[TableAlterOperation], - self.SCHEMA_DIFFER.compare_columns( + self.schema_differ.compare_columns( current_table_name, self.columns(current_table_name), self.columns(target_table_name), diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index f90506c5a1..4c8a125fa3 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -22,7 +22,7 @@ set_catalog, ) from sqlmesh.core.node import IntervalUnit -from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation +from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport from sqlmesh.utils import optional_import, get_source_columns_to_types from sqlmesh.utils.date import to_datetime from sqlmesh.utils.errors import SQLMeshError @@ -68,8 +68,8 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row MAX_COLUMN_COMMENT_LENGTH = 1024 SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] - SCHEMA_DIFFER = SchemaDiffer( - compatible_types={ + SCHEMA_DIFFER_KWARGS = { + "compatible_types": { exp.DataType.build("INT64", dialect=DIALECT): { exp.DataType.build("NUMERIC", dialect=DIALECT), exp.DataType.build("FLOAT64", dialect=DIALECT), @@ -83,17 +83,17 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row exp.DataType.build("DATETIME", dialect=DIALECT), }, }, - coerceable_types={ + "coerceable_types": { exp.DataType.build("FLOAT64", dialect=DIALECT): { exp.DataType.build("BIGNUMERIC", dialect=DIALECT), }, }, - support_coercing_compatible_types=True, - parameterized_type_defaults={ + "support_coercing_compatible_types": True, + "parameterized_type_defaults": { exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 9), (0,)], exp.DataType.build("BIGDECIMAL", dialect=DIALECT).this: [(76.76, 38), (0,)], }, - types_with_unlimited_length={ + "types_with_unlimited_length": { # parameterized `STRING(n)` can ALTER to unparameterized `STRING` exp.DataType.build("STRING", dialect=DIALECT).this: { exp.DataType.build("STRING", dialect=DIALECT).this, @@ -103,9 +103,8 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row exp.DataType.build("BYTES", dialect=DIALECT).this, }, }, - support_nested_operations=True, - support_nested_drop=False, - ) + "nested_support": NestedSupport.ALL_BUT_DROP, + } @property def client(self) -> BigQueryClient: diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py index 37b1f20721..635e6f369b 100644 --- a/sqlmesh/core/engine_adapter/clickhouse.py +++ b/sqlmesh/core/engine_adapter/clickhouse.py @@ -15,7 +15,7 @@ CommentCreationView, InsertOverwriteStrategy, ) -from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation +from sqlmesh.core.schema_diff import TableAlterOperation from sqlmesh.utils import get_source_columns_to_types if t.TYPE_CHECKING: @@ -37,7 +37,7 @@ class ClickhouseEngineAdapter(EngineAdapterWithIndexSupport, LogicalMergeMixin): SUPPORTS_REPLACE_TABLE = False COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY - SCHEMA_DIFFER = SchemaDiffer() + SCHEMA_DIFFER_KWARGS = {} DEFAULT_TABLE_ENGINE = "MergeTree" ORDER_BY_TABLE_ENGINE_REGEX = "^.*?MergeTree.*$" diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 4e352b27ef..da70163db4 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -15,7 +15,7 @@ ) from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter from sqlmesh.core.node import IntervalUnit -from sqlmesh.core.schema_diff import SchemaDiffer +from sqlmesh.core.schema_diff import NestedSupport from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError @@ -34,15 +34,14 @@ class DatabricksEngineAdapter(SparkEngineAdapter): SUPPORTS_CLONING = True SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True - SCHEMA_DIFFER = SchemaDiffer( - support_positional_add=True, - support_nested_operations=True, - support_nested_drop=True, - array_element_selector="element", - parameterized_type_defaults={ + SCHEMA_DIFFER_KWARGS = { + "support_positional_add": True, + "nested_support": NestedSupport.ALL, + "array_element_selector": "element", + "parameterized_type_defaults": { exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)], }, - ) + } def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) diff --git a/sqlmesh/core/engine_adapter/duckdb.py b/sqlmesh/core/engine_adapter/duckdb.py index a3bebadbe9..8fbe40a575 100644 --- a/sqlmesh/core/engine_adapter/duckdb.py +++ b/sqlmesh/core/engine_adapter/duckdb.py @@ -18,7 +18,6 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer if t.TYPE_CHECKING: from sqlmesh.core._typing import SchemaName, TableName @@ -29,11 +28,11 @@ class DuckDBEngineAdapter(LogicalMergeMixin, GetCurrentCatalogFromFunctionMixin, RowDiffMixin): DIALECT = "duckdb" SUPPORTS_TRANSACTIONS = False - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 3), (0,)], }, - ) + } COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY SUPPORTS_CREATE_DROP_CATALOG = True diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index e2d7915cf3..bc83beb3d4 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -259,9 +259,9 @@ def _default_precision_to_max( ) -> t.Dict[str, exp.DataType]: # get default lengths for types that support "max" length types_with_max_default_param = { - k: [self.SCHEMA_DIFFER.parameterized_type_defaults[k][0][0]] - for k in self.SCHEMA_DIFFER.max_parameter_length - if k in self.SCHEMA_DIFFER.parameterized_type_defaults + k: [self.schema_differ.parameterized_type_defaults[k][0][0]] + for k in self.schema_differ.max_parameter_length + if k in self.schema_differ.parameterized_type_defaults } # Redshift and MSSQL have a bug where CTAS statements have non-deterministic types. If a LIMIT @@ -270,7 +270,7 @@ def _default_precision_to_max( # and supports "max" length, we convert it to "max" length to prevent inadvertent data truncation. for col_name, col_type in columns_to_types.items(): if col_type.this in types_with_max_default_param and col_type.expressions: - parameter = self.SCHEMA_DIFFER.get_type_parameters(col_type) + parameter = self.schema_differ.get_type_parameters(col_type) type_default = types_with_max_default_param[col_type.this] if parameter == type_default: col_type.set("expressions", [exp.DataTypeParam(this=exp.var("max"))]) diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index 3a43d539a9..6aefd51fc0 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -30,7 +30,6 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer from sqlmesh.utils import get_source_columns_to_types if t.TYPE_CHECKING: @@ -54,8 +53,8 @@ class MSSQLEngineAdapter( COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED SUPPORTS_REPLACE_TABLE = False - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)], exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)], exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(1,)], @@ -67,12 +66,12 @@ class MSSQLEngineAdapter( exp.DataType.build("DATETIME2", dialect=DIALECT).this: [(7,)], exp.DataType.build("DATETIMEOFFSET", dialect=DIALECT).this: [(7,)], }, - max_parameter_length={ + "max_parameter_length": { exp.DataType.build("VARBINARY", dialect=DIALECT).this: 2147483647, # 2 GB exp.DataType.build("VARCHAR", dialect=DIALECT).this: 2147483647, exp.DataType.build("NVARCHAR", dialect=DIALECT).this: 2147483647, }, - ) + } VARIABLE_LENGTH_DATA_TYPES = {"binary", "varbinary", "char", "varchar", "nchar", "nvarchar"} @property diff --git a/sqlmesh/core/engine_adapter/mysql.py b/sqlmesh/core/engine_adapter/mysql.py index 298dc18903..e81b30e25e 100644 --- a/sqlmesh/core/engine_adapter/mysql.py +++ b/sqlmesh/core/engine_adapter/mysql.py @@ -19,7 +19,6 @@ DataObjectType, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer if t.TYPE_CHECKING: from sqlmesh.core._typing import SchemaName, TableName @@ -40,8 +39,8 @@ class MySQLEngineAdapter( MAX_COLUMN_COMMENT_LENGTH = 1024 SUPPORTS_REPLACE_TABLE = False MAX_IDENTIFIER_LENGTH = 64 - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { exp.DataType.build("BIT", dialect=DIALECT).this: [(1,)], exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)], exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)], @@ -52,7 +51,7 @@ class MySQLEngineAdapter( exp.DataType.build("DATETIME", dialect=DIALECT).this: [(0,)], exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(0,)], }, - ) + } def get_current_catalog(self) -> t.Optional[str]: """Returns the catalog name of the current connection.""" diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index a1ff46e9ad..faeb52b207 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -14,7 +14,6 @@ logical_merge, ) from sqlmesh.core.engine_adapter.shared import set_catalog -from sqlmesh.core.schema_diff import SchemaDiffer if t.TYPE_CHECKING: from sqlmesh.core._typing import TableName @@ -36,15 +35,15 @@ class PostgresEngineAdapter( CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") SUPPORTS_REPLACE_TABLE = False MAX_IDENTIFIER_LENGTH = 63 - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { # DECIMAL without precision is "up to 131072 digits before the decimal point; up to 16383 digits after the decimal point" exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(131072 + 16383, 16383), (0,)], exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], exp.DataType.build("TIME", dialect=DIALECT).this: [(6,)], exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(6,)], }, - types_with_unlimited_length={ + "types_with_unlimited_length": { # all can ALTER to `TEXT` exp.DataType.build("TEXT", dialect=DIALECT).this: { exp.DataType.build("VARCHAR", dialect=DIALECT).this, @@ -63,8 +62,8 @@ class PostgresEngineAdapter( exp.DataType.build("BPCHAR", dialect=DIALECT).this }, }, - drop_cascade=True, - ) + "drop_cascade": True, + } def _fetch_native_df( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 2589ef960e..30ebc8e30d 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -22,7 +22,6 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer from sqlmesh.utils.errors import SQLMeshError if t.TYPE_CHECKING: @@ -48,8 +47,8 @@ class RedshiftEngineAdapter( COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED SUPPORTS_REPLACE_TABLE = False - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { exp.DataType.build("VARBYTE", dialect=DIALECT).this: [(64000,)], exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)], exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], @@ -57,13 +56,13 @@ class RedshiftEngineAdapter( exp.DataType.build("NCHAR", dialect=DIALECT).this: [(1,)], exp.DataType.build("NVARCHAR", dialect=DIALECT).this: [(256,)], }, - max_parameter_length={ + "max_parameter_length": { exp.DataType.build("CHAR", dialect=DIALECT).this: 4096, exp.DataType.build("VARCHAR", dialect=DIALECT).this: 65535, }, - precision_increase_allowed_types={exp.DataType.build("VARCHAR", dialect=DIALECT).this}, - drop_cascade=True, - ) + "precision_increase_allowed_types": {exp.DataType.build("VARCHAR", dialect=DIALECT).this}, + "drop_cascade": True, + } VARIABLE_LENGTH_DATA_TYPES = { "char", "character", diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 69ff33b5a8..c5fa8540b0 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -23,7 +23,6 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer from sqlmesh.utils import optional_import, get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pandas import columns_to_types_from_dtypes @@ -56,8 +55,8 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi CURRENT_CATALOG_EXPRESSION = exp.func("current_database") SUPPORTS_CREATE_DROP_CATALOG = True SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA", "TABLE"] - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { exp.DataType.build("BINARY", dialect=DIALECT).this: [(8388608,)], exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(8388608,)], exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 0), (0,)], @@ -70,7 +69,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi exp.DataType.build("TIMESTAMP_NTZ", dialect=DIALECT).this: [(9,)], exp.DataType.build("TIMESTAMP_TZ", dialect=DIALECT).this: [(9,)], }, - ) + } MANAGED_TABLE_KIND = "DYNAMIC TABLE" SNOWPARK = "snowpark" diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 4f6e9a984f..8a529390c1 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -22,7 +22,6 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer from sqlmesh.utils import classproperty, get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError @@ -61,12 +60,12 @@ class SparkEngineAdapter( WAP_PREFIX = "wap_" BRANCH_PREFIX = "branch_" - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { # default decimal precision varies across backends exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)], }, - ) + } @property def connection(self) -> SparkSessionConnection: diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index c62f7bef45..fc08dd10af 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -26,7 +26,6 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer from sqlmesh.utils import get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.date import TimeLike @@ -56,14 +55,14 @@ class TrinoEngineAdapter( SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] DEFAULT_CATALOG_TYPE = "hive" QUOTE_IDENTIFIERS_IN_VIEWS = False - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { # default decimal precision varies across backends exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)], exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(3,)], }, - ) + } # some catalogs support microsecond (precision 6) but it has to be specifically enabled (Hive) or just isnt available (Delta / TIMESTAMP WITH TIME ZONE) # and even if you have a TIMESTAMP(6) the date formatting functions still only support millisecond precision MAX_TIMESTAMP_PRECISION = 3 diff --git a/sqlmesh/core/schema_diff.py b/sqlmesh/core/schema_diff.py index 0bbc146c17..7b8c7f16f7 100644 --- a/sqlmesh/core/schema_diff.py +++ b/sqlmesh/core/schema_diff.py @@ -5,6 +5,8 @@ import typing as t from dataclasses import dataclass from collections import defaultdict +from enum import Enum + from pydantic import Field from sqlglot import exp from sqlglot.helper import ensure_list, seq_get @@ -132,14 +134,15 @@ def _alter_actions(self) -> t.List[exp.Expression]: @dataclass(frozen=True) class TableAlterChangeColumnTypeOperation(TableAlterTypedColumnOperation): current_type: exp.DataType + is_part_of_destructive_change: bool = False @property def is_additive(self) -> bool: - return True + return not self.is_part_of_destructive_change @property def is_destructive(self) -> bool: - return False + return self.is_part_of_destructive_change @property def _alter_actions(self) -> t.List[exp.Expression]: @@ -278,6 +281,33 @@ def column_position_node(self) -> t.Optional[exp.ColumnPosition]: return exp.ColumnPosition(this=column, position=position) +class NestedSupport(str, Enum): + # Supports all nested data type operations + ALL = "ALL" + # Does not support any nested data type operations + NONE = "NONE" + # Supports nested data type operations except for those that require dropping a nested field + ALL_BUT_DROP = "ALL_BUT_DROP" + # Ignores all nested data type operations + IGNORE = "IGNORE" + + @property + def is_all(self) -> bool: + return self == NestedSupport.ALL + + @property + def is_none(self) -> bool: + return self == NestedSupport.NONE + + @property + def is_all_but_drop(self) -> bool: + return self == NestedSupport.ALL_BUT_DROP + + @property + def is_ignore(self) -> bool: + return self == NestedSupport.IGNORE + + class SchemaDiffer(PydanticModel): """ Compares a source schema against a target schema and returns a list of alter statements to have the source @@ -297,10 +327,7 @@ class SchemaDiffer(PydanticModel): Args: support_positional_add: Whether the engine for which the diff is being computed supports adding columns in a specific position in the set of existing columns. - support_nested_operations: Whether the engine for which the diff is being computed supports modifications to - nested data types like STRUCTs and ARRAYs. - support_nested_drop: Whether the engine for which the diff is being computed supports removing individual - columns of nested STRUCTs. + nested_support: How the engine for which the diff is being computed supports nested types. compatible_types: Types that are compatible and automatically coerced in actions like UNION ALL. Dict key is data type, and value is the set of types that are compatible with it. coerceable_types: The mapping from a current type to all types that can be safely coerced to the current one without @@ -323,11 +350,14 @@ class SchemaDiffer(PydanticModel): max_parameter_length: Numeric parameter values corresponding to "max". Example: `VARCHAR(max)` -> `VARCHAR(65535)`. types_with_unlimited_length: Data types that accept values of any length up to system limits. Any explicitly parameterized type can ALTER to its unlimited length version, along with different types in some engines. + treat_alter_data_type_as_destructive: The SchemaDiffer will only output change data type operations if it + concludes the change is compatible and won't result in data loss. If this flag is set to True, it will + flag these data type changes as destructive. This was added for dbt adapter support and likely shouldn't + be set outside of that context. """ support_positional_add: bool = False - support_nested_operations: bool = False - support_nested_drop: bool = False + nested_support: NestedSupport = NestedSupport.NONE array_element_selector: str = "" compatible_types: t.Dict[exp.DataType, t.Set[exp.DataType]] = {} coerceable_types_: t.Dict[exp.DataType, t.Set[exp.DataType]] = Field( @@ -341,6 +371,7 @@ class SchemaDiffer(PydanticModel): ] = {} max_parameter_length: t.Dict[exp.DataType.Type, t.Union[int, float]] = {} types_with_unlimited_length: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} + treat_alter_data_type_as_destructive: bool = False _coerceable_types: t.Dict[exp.DataType, t.Set[exp.DataType]] = {} @@ -575,9 +606,11 @@ def _alter_operation( # We don't copy on purpose here because current_type may need to be mutated inside # _get_operations (struct.expressions.pop and struct.expressions.insert) current_type = exp.DataType.build(current_type, copy=False) - if self.support_nested_operations: + if not self.nested_support.is_none: if new_type.this == current_type.this == exp.DataType.Type.STRUCT: - if self.support_nested_drop or not self._requires_drop_alteration( + if self.nested_support.is_ignore: + return [] + if self.nested_support.is_all or not self._requires_drop_alteration( current_type, new_type ): return self._get_operations( @@ -597,7 +630,9 @@ def _alter_operation( new_array_type = new_type.expressions[0] current_array_type = current_type.expressions[0] if new_array_type.this == current_array_type.this == exp.DataType.Type.STRUCT: - if self.support_nested_drop or not self._requires_drop_alteration( + if self.nested_support.is_ignore: + return [] + if self.nested_support.is_all or not self._requires_drop_alteration( current_array_type, new_array_type ): return self._get_operations( @@ -624,6 +659,7 @@ def _alter_operation( current_type=current_type, expected_table_struct=root_struct.copy(), array_element_selector=self.array_element_selector, + is_part_of_destructive_change=self.treat_alter_data_type_as_destructive, ) ] if ignore_destructive: @@ -806,12 +842,15 @@ def get_additive_column_names(alter_expressions: t.List[TableAlterOperation]) -> ] -def get_schema_differ(dialect: str) -> SchemaDiffer: +def get_schema_differ( + dialect: str, overrides: t.Optional[t.Dict[str, t.Any]] = None +) -> SchemaDiffer: """ Returns the appropriate SchemaDiffer for a given dialect without initializing the engine adapter. Args: dialect: The dialect for which to get the schema differ. + overrides: Optional dictionary of overrides to apply to the SchemaDiffer instance. Returns: The SchemaDiffer instance configured for the given dialect. @@ -825,7 +864,12 @@ def get_schema_differ(dialect: str) -> SchemaDiffer: dialect = dialect.lower() dialect = DIALECT_ALIASES.get(dialect, dialect) engine_adapter_class = DIALECT_TO_ENGINE_ADAPTER.get(dialect, EngineAdapter) - return getattr(engine_adapter_class, "SCHEMA_DIFFER", SchemaDiffer()) + return SchemaDiffer( + **{ + **getattr(engine_adapter_class, "SCHEMA_DIFFER_KWARGS"), + **(overrides or {}), + } + ) def _get_name_and_type(struct: exp.ColumnDef) -> t.Tuple[exp.Identifier, exp.DataType]: diff --git a/sqlmesh/dbt/target.py b/sqlmesh/dbt/target.py index 30a8f35c50..035b5b9e93 100644 --- a/sqlmesh/dbt/target.py +++ b/sqlmesh/dbt/target.py @@ -29,6 +29,7 @@ IncrementalByUniqueKeyKind, IncrementalUnmanagedKind, ) +from sqlmesh.core.schema_diff import NestedSupport from sqlmesh.dbt.common import DbtConfig from sqlmesh.dbt.relation import Policy from sqlmesh.dbt.util import DBT_VERSION @@ -50,6 +51,25 @@ "schema_", } +SCHEMA_DIFFER_OVERRIDES = { + "schema_differ_overrides": { + "treat_alter_data_type_as_destructive": True, + "nested_support": NestedSupport.IGNORE, + } +} + + +def with_schema_differ_overrides( + func: t.Callable[..., ConnectionConfig], +) -> t.Callable[..., ConnectionConfig]: + """Decorator that merges default config with kwargs.""" + + def wrapper(self: TargetConfig, **kwargs: t.Any) -> ConnectionConfig: + merged_kwargs = {**SCHEMA_DIFFER_OVERRIDES, **kwargs} + return func(self, **merged_kwargs) + + return wrapper + class TargetConfig(abc.ABC, DbtConfig): """ @@ -92,6 +112,7 @@ def default_incremental_strategy(self, kind: IncrementalKind) -> str: """The default incremental strategy for the db""" raise NotImplementedError + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: """Converts target config to SQLMesh connection config""" raise NotImplementedError @@ -177,6 +198,7 @@ def relation_class(cls) -> t.Type[BaseRelation]: return DuckDBRelation + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: if self.extensions is not None: kwargs["extensions"] = self.extensions @@ -286,6 +308,7 @@ def column_class(cls) -> t.Type[Column]: return SnowflakeColumn + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return SnowflakeConnectionConfig( user=self.user, @@ -359,6 +382,7 @@ def _validate_port(cls, v: t.Union[int, str]) -> int: def default_incremental_strategy(self, kind: IncrementalKind) -> str: return "delete+insert" if kind is IncrementalByUniqueKeyKind else "append" + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return PostgresConnectionConfig( host=self.host, @@ -454,6 +478,7 @@ def column_class(cls) -> t.Type[Column]: return RedshiftColumn return super(RedshiftConfig, cls).column_class + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return RedshiftConnectionConfig( user=self.user, @@ -504,6 +529,7 @@ def column_class(cls) -> t.Type[Column]: return DatabricksColumn + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return DatabricksConnectionConfig( server_hostname=self.host, @@ -605,6 +631,7 @@ def column_class(cls) -> t.Type[Column]: return BigQueryColumn + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: job_retries = self.job_retries if self.job_retries is not None else self.retries job_execution_timeout_seconds = ( @@ -778,6 +805,7 @@ def column_class(cls) -> t.Type[Column]: def dialect(self) -> str: return "tsql" + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return MSSQLConnectionConfig( host=self.host, @@ -892,6 +920,7 @@ def column_class(cls) -> t.Type[Column]: return TrinoColumn + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return TrinoConnectionConfig( method=self._method_to_auth_enum[self.method], @@ -1002,6 +1031,7 @@ def column_class(cls) -> t.Type[Column]: return ClickHouseColumn + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return ClickhouseConnectionConfig( host=self.host, @@ -1085,6 +1115,7 @@ def column_class(cls) -> t.Type[Column]: def default_incremental_strategy(self, kind: IncrementalKind) -> str: return "insert_overwrite" + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return AthenaConnectionConfig( type="athena", diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index f283680cfb..433e2165d8 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -960,6 +960,7 @@ def test_dlt_filesystem_pipeline(tmp_path): " # register_comments: False\n" " # pre_ping: False\n" " # pretty_sql: False\n" + " # schema_differ_overrides: \n" " # aws_access_key_id: \n" " # aws_secret_access_key: \n" " # role_arn: \n" @@ -1961,11 +1962,11 @@ def test_init_dbt_template(runner: CliRunner, tmp_path: Path): @time_machine.travel(FREEZE_TIME) def test_init_project_engine_configs(tmp_path): engine_type_to_config = { - "redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: \n # enable_merge: ", - "bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ", - "snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # host: \n # port: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ", - "databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False", - "postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: \n # application_name: ", + "redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: \n # enable_merge: ", + "bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ", + "snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # host: \n # port: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ", + "databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False", + "postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # schema_differ_overrides: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: \n # application_name: ", } for engine_type, expected_config in engine_type_to_config.items(): diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index 3661df3a3b..b2dfcc7ccc 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -15,7 +15,7 @@ from sqlmesh.core.engine_adapter import EngineAdapter, EngineAdapterWithIndexSupport from sqlmesh.core.engine_adapter.mixins import InsertOverwriteWithMergeMixin from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObject -from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation +from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation, NestedSupport from sqlmesh.utils import columns_to_types_to_struct from sqlmesh.utils.date import to_ds from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError @@ -715,8 +715,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) ( { "support_positional_add": True, - "support_nested_operations": True, - "support_nested_drop": True, + "nested_support": NestedSupport.ALL, "array_element_selector": "element", }, { @@ -774,7 +773,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) ), ( { - "support_nested_operations": True, + "nested_support": NestedSupport.ALL_BUT_DROP, "array_element_selector": "element", }, { @@ -892,8 +891,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) ( { "support_positional_add": True, - "support_nested_operations": True, - "support_nested_drop": True, + "nested_support": NestedSupport.ALL, "array_element_selector": "element", }, { @@ -922,8 +920,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) ( { "support_positional_add": True, - "support_nested_operations": True, - "support_nested_drop": True, + "nested_support": NestedSupport.ALL, "array_element_selector": "element", }, { @@ -979,8 +976,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) # Test multiple operations on a column with no positional and nested features enabled ( { - "support_nested_operations": True, - "support_nested_drop": True, + "nested_support": NestedSupport.ALL, "array_element_selector": "element", }, { @@ -1037,8 +1033,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) # Test deeply nested structures ( { - "support_nested_operations": True, - "support_nested_drop": True, + "nested_support": NestedSupport.ALL, "array_element_selector": "element", }, { @@ -1067,8 +1062,8 @@ def test_alter_table( ): adapter = make_mocked_engine_adapter(EngineAdapter) - adapter.SCHEMA_DIFFER = SchemaDiffer(**schema_differ_config) - original_from_structs = adapter.SCHEMA_DIFFER._from_structs + adapter.SCHEMA_DIFFER_KWARGS = schema_differ_config + original_from_structs = adapter.schema_differ._from_structs def _from_structs(*args, **kwargs) -> t.List[TableAlterOperation]: operations = original_from_structs(*args, **kwargs) diff --git a/tests/core/engine_adapter/test_clickhouse.py b/tests/core/engine_adapter/test_clickhouse.py index b75609e759..39e317c7fa 100644 --- a/tests/core/engine_adapter/test_clickhouse.py +++ b/tests/core/engine_adapter/test_clickhouse.py @@ -7,7 +7,6 @@ from sqlmesh.core.dialect import parse from sqlglot import exp, parse_one import typing as t -from sqlmesh.core.schema_diff import SchemaDiffer from datetime import datetime from pytest_mock.plugin import MockerFixture from sqlmesh.core import dialect as d @@ -152,7 +151,7 @@ def test_alter_table( adapter: ClickhouseEngineAdapter, mocker, ): - adapter.SCHEMA_DIFFER = SchemaDiffer() + adapter.SCHEMA_DIFFER_KWARGS = {} current_table_name = "test_table" current_table = {"a": "Int8", "b": "String", "c": "Int8"} target_table_name = "target_table" diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index 22d21fcef7..907d1b70cc 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1780,3 +1780,18 @@ def test_fabric_pyodbc_connection_string_generation(): # Check autocommit parameter, should default to True for Fabric assert call_args[1]["autocommit"] is True + + +def test_schema_differ_overrides(make_config) -> None: + default_config = make_config(type="duckdb") + assert default_config.schema_differ_overrides is None + default_adapter = default_config.create_engine_adapter() + assert default_adapter._schema_differ_overrides is None + assert default_adapter.schema_differ.parameterized_type_defaults != {} + + override: t.Dict[str, t.Any] = {"parameterized_type_defaults": {}} + config = make_config(type="duckdb", schema_differ_overrides=override) + assert config.schema_differ_overrides == override + adapter = config.create_engine_adapter() + assert adapter._schema_differ_overrides == override + assert adapter.schema_differ.parameterized_type_defaults == {} diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 517d7c3ca1..dec7309591 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -7944,7 +7944,7 @@ def test_incremental_by_time_model_ignore_destructive_change(tmp_path: Path): context = Context(paths=[tmp_path], config=config) # Make the change compatible since that means we will attempt and alter now that is considered additive - context.engine_adapter.SCHEMA_DIFFER.compatible_types = { + context.engine_adapter.SCHEMA_DIFFER_KWARGS["compatible_types"] = { exp.DataType.build("INT"): {exp.DataType.build("STRING")} } context.plan("prod", auto_apply=True, no_prompts=True, run=True) @@ -8106,7 +8106,7 @@ def test_incremental_by_time_model_ignore_additive_change(tmp_path: Path): (models_dir / "test_model.sql").write_text(updated_model) context = Context(paths=[tmp_path], config=config) - context.engine_adapter.SCHEMA_DIFFER.compatible_types = { + context.engine_adapter.SCHEMA_DIFFER_KWARGS["compatible_types"] = { exp.DataType.build("INT"): {exp.DataType.build("STRING")} } context.plan("prod", auto_apply=True, no_prompts=True, run=True) @@ -8153,7 +8153,7 @@ def test_incremental_by_time_model_ignore_additive_change(tmp_path: Path): context = Context(paths=[tmp_path], config=config) # Make the change compatible since that means we will attempt and alter now that is considered additive - context.engine_adapter.SCHEMA_DIFFER.compatible_types = { + context.engine_adapter.SCHEMA_DIFFER_KWARGS["compatible_types"] = { exp.DataType.build("INT"): {exp.DataType.build("STRING")} } context.plan("prod", auto_apply=True, no_prompts=True, run=True) diff --git a/tests/core/test_schema_diff.py b/tests/core/test_schema_diff.py index 916bead3e6..e091dea539 100644 --- a/tests/core/test_schema_diff.py +++ b/tests/core/test_schema_diff.py @@ -13,6 +13,7 @@ TableAlterAddColumnOperation, TableAlterDropColumnOperation, TableAlterChangeColumnTypeOperation, + NestedSupport, ) @@ -20,7 +21,7 @@ def test_schema_diff_calculate(): alter_operations = SchemaDiffer( **{ "support_positional_add": False, - "support_nested_operations": False, + "nested_support": NestedSupport.NONE, "array_element_selector": "", "compatible_types": { exp.DataType.build("STRING"): {exp.DataType.build("INT")}, @@ -53,7 +54,7 @@ def test_schema_diff_drop_cascade(): alter_expressions = SchemaDiffer( **{ "support_positional_add": False, - "support_nested_operations": False, + "nested_support": NestedSupport.NONE, "array_element_selector": "", "drop_cascade": True, } @@ -79,7 +80,7 @@ def test_schema_diff_calculate_type_transitions(): alter_expressions = SchemaDiffer( **{ "support_positional_add": False, - "support_nested_operations": False, + "nested_support": NestedSupport.NONE, "array_element_selector": "", "compatible_types": { exp.DataType.build("STRING"): {exp.DataType.build("INT")}, @@ -426,7 +427,7 @@ def test_schema_diff_calculate_type_transitions(): position=TableAlterColumnPosition.first(), ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), ), # Add a column to the end of a struct ( @@ -447,7 +448,7 @@ def test_schema_diff_calculate_type_transitions(): position=TableAlterColumnPosition.last(after="col_c"), ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), ), # Add a column to the middle of a struct ( @@ -468,7 +469,7 @@ def test_schema_diff_calculate_type_transitions(): array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), ), # Add two columns at the start of a struct ( @@ -502,7 +503,7 @@ def test_schema_diff_calculate_type_transitions(): array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), ), # Add columns in different levels of nesting of structs ( @@ -533,7 +534,7 @@ def test_schema_diff_calculate_type_transitions(): array_element_selector="", ), ], - dict(support_positional_add=False, support_nested_operations=True), + dict(support_positional_add=False, nested_support=NestedSupport.ALL_BUT_DROP), ), # Remove a column from the start of a struct ( @@ -554,8 +555,7 @@ def test_schema_diff_calculate_type_transitions(): ], dict( support_positional_add=True, - support_nested_operations=True, - support_nested_drop=True, + nested_support=NestedSupport.ALL, ), ), # Remove a column from the end of a struct @@ -577,8 +577,7 @@ def test_schema_diff_calculate_type_transitions(): ], dict( support_positional_add=True, - support_nested_operations=True, - support_nested_drop=True, + nested_support=NestedSupport.ALL, ), ), # Remove a column from the middle of a struct @@ -600,8 +599,7 @@ def test_schema_diff_calculate_type_transitions(): ], dict( support_positional_add=True, - support_nested_operations=True, - support_nested_drop=True, + nested_support=NestedSupport.ALL, ), ), # Remove a column from a struct where nested drop is not supported @@ -631,8 +629,7 @@ def test_schema_diff_calculate_type_transitions(): ), ], dict( - support_nested_operations=True, - support_nested_drop=False, + nested_support=NestedSupport.ALL_BUT_DROP, ), ), # Remove two columns from the start of a struct @@ -665,8 +662,7 @@ def test_schema_diff_calculate_type_transitions(): ], dict( support_positional_add=True, - support_nested_operations=True, - support_nested_drop=True, + nested_support=NestedSupport.ALL, ), ), # Change a column type in a struct @@ -690,7 +686,7 @@ def test_schema_diff_calculate_type_transitions(): ], dict( support_positional_add=True, - support_nested_operations=True, + nested_support=NestedSupport.ALL_BUT_DROP, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("TEXT")}, }, @@ -754,8 +750,7 @@ def test_schema_diff_calculate_type_transitions(): ], dict( support_positional_add=True, - support_nested_operations=True, - support_nested_drop=True, + nested_support=NestedSupport.ALL, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("TEXT")}, }, @@ -788,8 +783,7 @@ def test_schema_diff_calculate_type_transitions(): ), ], dict( - support_nested_operations=True, - support_nested_drop=False, + nested_support=NestedSupport.ALL_BUT_DROP, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("TEXT")}, }, @@ -853,8 +847,7 @@ def test_schema_diff_calculate_type_transitions(): ], dict( support_positional_add=True, - support_nested_operations=True, - support_nested_drop=True, + nested_support=NestedSupport.ALL, ), ), # ##################### @@ -879,7 +872,7 @@ def test_schema_diff_calculate_type_transitions(): array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), ), # Remove column from array of structs ( @@ -900,8 +893,7 @@ def test_schema_diff_calculate_type_transitions(): ], dict( support_positional_add=True, - support_nested_operations=True, - support_nested_drop=True, + nested_support=NestedSupport.ALL, ), ), # Alter column type in array of structs @@ -925,7 +917,7 @@ def test_schema_diff_calculate_type_transitions(): ], dict( support_positional_add=True, - support_nested_operations=True, + nested_support=NestedSupport.ALL_BUT_DROP, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("TEXT")}, }, @@ -958,7 +950,7 @@ def test_schema_diff_calculate_type_transitions(): array_element_selector="", ), ], - dict(support_positional_add=False, support_nested_operations=True), + dict(support_positional_add=False, nested_support=NestedSupport.ALL_BUT_DROP), ), # Add an array of primitives ( @@ -978,7 +970,7 @@ def test_schema_diff_calculate_type_transitions(): array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), ), # untyped array to support Snowflake ( @@ -1134,8 +1126,7 @@ def test_schema_diff_calculate_type_transitions(): ], dict( support_positional_add=True, - support_nested_operations=True, - support_nested_drop=True, + nested_support=NestedSupport.ALL, ), ), # Type with precision to same type with no precision and no default is DROP/ADD @@ -1398,7 +1389,7 @@ def test_schema_diff_calculate_type_transitions(): [], dict( support_positional_add=True, - support_nested_operations=True, + nested_support=NestedSupport.ALL_BUT_DROP, support_coercing_compatible_types=True, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("FLOAT")}, @@ -1411,7 +1402,7 @@ def test_schema_diff_calculate_type_transitions(): [], dict( support_positional_add=True, - support_nested_operations=True, + nested_support=NestedSupport.ALL_BUT_DROP, coerceable_types={ exp.DataType.build("FLOAT"): {exp.DataType.build("INT")}, }, @@ -1423,7 +1414,7 @@ def test_schema_diff_calculate_type_transitions(): [], dict( support_positional_add=True, - support_nested_operations=True, + nested_support=NestedSupport.ALL_BUT_DROP, support_coercing_compatible_types=True, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("FLOAT")}, @@ -1453,13 +1444,108 @@ def test_schema_diff_calculate_type_transitions(): ], dict( support_positional_add=False, - support_nested_operations=True, + nested_support=NestedSupport.ALL_BUT_DROP, support_coercing_compatible_types=True, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("FLOAT")}, }, ), ), + # ################### + # Ignore Nested Tests + # ################### + # Remove nested col_c + ( + "STRUCT>", + "STRUCT>", + [], + dict(nested_support=NestedSupport.IGNORE), + ), + # Add nested col_d + ( + "STRUCT>", + "STRUCT>", + [], + dict(nested_support=NestedSupport.IGNORE), + ), + # Change nested col_c to incompatible type + ( + "STRUCT>", + "STRUCT>", + [], + dict(nested_support=NestedSupport.IGNORE), + ), + # Change nested col_c to compatible type + ( + "STRUCT>", + "STRUCT>", + [], + dict( + nested_support=NestedSupport.IGNORE, + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + }, + ), + ), + # Mix of ignored nested and non-nested changes + ( + "STRUCT, age INT>", + "STRUCT, age STRING, new_col INT>", + [ + # `col_c` change is ignored + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("new_col")], + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT, age INT, new_col INT>" + ), + position=TableAlterColumnPosition.last("age"), + array_element_selector="", + ), + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("age")], + column_type=exp.DataType.build("STRING"), + current_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT, age STRING, new_col INT>" + ), + array_element_selector="", + ), + ], + dict( + nested_support=NestedSupport.IGNORE, + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + }, + support_positional_add=True, + ), + ), + # ############################ + # Change Data Type Destructive + # ############################ + ( + "STRUCT", + "STRUCT", + [ + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("age")], + column_type=exp.DataType.build("STRING"), + current_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + is_part_of_destructive_change=True, + ), + ], + dict( + treat_alter_data_type_as_destructive=True, + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + }, + ), + ), ], ) def test_struct_diff( @@ -1750,7 +1836,7 @@ def test_ignore_destructive_compare_columns(): """Test ignore_destructive behavior in compare_columns method.""" schema_differ = SchemaDiffer( support_positional_add=True, - support_nested_operations=False, + nested_support=NestedSupport.NONE, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("STRING")}, }, @@ -1796,8 +1882,7 @@ def test_ignore_destructive_compare_columns(): def test_ignore_destructive_nested_struct_without_support(): """Test ignore_destructive with nested structs when nested_drop is not supported.""" schema_differ = SchemaDiffer( - support_nested_operations=True, - support_nested_drop=False, # This forces DROP+ADD for nested changes + nested_support=NestedSupport.ALL_BUT_DROP, # This forces DROP+ADD for nested changes ) current_struct = "STRUCT>" @@ -1834,8 +1919,7 @@ def test_get_schema_differ(): # Databricks should support positional add and nested operations databricks_differ = get_schema_differ("databricks") assert databricks_differ.support_positional_add is True - assert databricks_differ.support_nested_operations is True - assert databricks_differ.support_nested_drop is True + assert databricks_differ.nested_support == NestedSupport.ALL # BigQuery should have specific compatible types configured bigquery_differ = get_schema_differ("bigquery") @@ -1860,7 +1944,7 @@ def test_get_schema_differ(): schema_differ_unknown = get_schema_differ("unknown_dialect") assert isinstance(schema_differ_unknown, SchemaDiffer) assert schema_differ_unknown.support_positional_add is False - assert schema_differ_unknown.support_nested_operations is False + assert schema_differ_unknown.nested_support == NestedSupport.NONE # Test case insensitivity schema_differ_upper = get_schema_differ("BIGQUERY") @@ -1870,6 +1954,10 @@ def test_get_schema_differ(): == schema_differ_lower.support_coercing_compatible_types ) + # Test override + schema_differ_with_override = get_schema_differ("postgres", {"drop_cascade": False}) + assert schema_differ_with_override.drop_cascade is False + def test_ignore_destructive_edge_cases(): """Test edge cases for ignore_destructive behavior.""" @@ -2116,7 +2204,7 @@ def test_ignore_destructive_edge_cases(): ), ], [], # No operations when ignoring additive - dict(support_nested_operations=True), + dict(nested_support=NestedSupport.ALL_BUT_DROP), ), ], ) @@ -2228,7 +2316,7 @@ def test_ignore_both_destructive_and_additive(): def test_ignore_additive_array_operations(): """Test ignore_additive with array of struct operations.""" schema_differ = SchemaDiffer( - support_nested_operations=True, + nested_support=NestedSupport.ALL, support_positional_add=True, ) diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index 72994fe33c..bc6f878801 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -35,6 +35,7 @@ TrinoConfig, AthenaConfig, ClickhouseConfig, + SCHEMA_DIFFER_OVERRIDES, ) from sqlmesh.dbt.test import TestConfig from sqlmesh.utils.errors import ConfigError @@ -542,6 +543,9 @@ def test_snowflake_config(): ) sqlmesh_config = config.to_sqlmesh() assert sqlmesh_config.application == "Tobiko_SQLMesh" + assert ( + sqlmesh_config.schema_differ_overrides == SCHEMA_DIFFER_OVERRIDES["schema_differ_overrides"] + ) def test_snowflake_config_private_key_path(): @@ -771,6 +775,7 @@ def test_databricks_config_oauth(): assert as_sqlmesh.auth_type == "databricks-oauth" assert as_sqlmesh.oauth_client_id == "client-id" assert as_sqlmesh.oauth_client_secret == "client-secret" + assert as_sqlmesh.schema_differ_overrides == SCHEMA_DIFFER_OVERRIDES["schema_differ_overrides"] def test_bigquery_config():