Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class ConnectionConfig(abc.ABC, BaseConfig):
register_comments: bool
pre_ping: bool
pretty_sql: bool = False
schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None

# Whether to share a single connection across threads or create a new connection per thread.
shared_connection: t.ClassVar[bool] = False
Expand Down Expand Up @@ -174,6 +175,7 @@ def create_engine_adapter(
pre_ping=self.pre_ping,
pretty_sql=self.pretty_sql,
shared_connection=self.shared_connection,
schema_differ_overrides=self.schema_differ_overrides,
**self._extra_engine_config,
)

Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/engine_adapter/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin, RowDiffMixin):
# CTAS, Views: No comment support at all
COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED
COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
SCHEMA_DIFFER = TrinoEngineAdapter.SCHEMA_DIFFER
SCHEMA_DIFFER_KWARGS = TrinoEngineAdapter.SCHEMA_DIFFER_KWARGS
MAX_TIMESTAMP_PRECISION = 3 # copied from Trino
# Athena does not deal with comments well, e.g:
# >>> self._execute('/* test */ DESCRIBE foo')
Expand Down
17 changes: 14 additions & 3 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import sys
import typing as t
from functools import partial
from functools import cached_property, partial

from sqlglot import Dialect, exp
from sqlglot.errors import ErrorLevel
Expand Down Expand Up @@ -109,7 +109,7 @@ class EngineAdapter:
SUPPORTS_MANAGED_MODELS = False
SUPPORTS_CREATE_DROP_CATALOG = False
SUPPORTED_DROP_CASCADE_OBJECT_KINDS: t.List[str] = []
SCHEMA_DIFFER = SchemaDiffer()
SCHEMA_DIFFER_KWARGS: t.Dict[str, t.Any] = {}
SUPPORTS_TUPLE_IN = True
HAS_VIEW_BINDING = False
SUPPORTS_REPLACE_TABLE = True
Expand All @@ -132,6 +132,7 @@ def __init__(
pretty_sql: bool = False,
shared_connection: bool = False,
correlation_id: t.Optional[CorrelationId] = None,
schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None,
**kwargs: t.Any,
):
self.dialect = dialect.lower() or self.DIALECT
Expand All @@ -154,6 +155,7 @@ def __init__(
self._pretty_sql = pretty_sql
self._multithreaded = multithreaded
self.correlation_id = correlation_id
self._schema_differ_overrides = schema_differ_overrides

def with_settings(self, **kwargs: t.Any) -> EngineAdapter:
extra_kwargs = {
Expand Down Expand Up @@ -204,6 +206,15 @@ def comments_enabled(self) -> bool:
def catalog_support(self) -> CatalogSupport:
return CatalogSupport.UNSUPPORTED

@cached_property
def schema_differ(self) -> SchemaDiffer:
return SchemaDiffer(
**{
**self.SCHEMA_DIFFER_KWARGS,
**(self._schema_differ_overrides or {}),
}
)

@classmethod
def _casted_columns(
cls,
Expand Down Expand Up @@ -1101,7 +1112,7 @@ def get_alter_operations(
"""
return t.cast(
t.List[TableAlterOperation],
self.SCHEMA_DIFFER.compare_columns(
self.schema_differ.compare_columns(
current_table_name,
self.columns(current_table_name),
self.columns(target_table_name),
Expand Down
19 changes: 9 additions & 10 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
set_catalog,
)
from sqlmesh.core.node import IntervalUnit
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation
from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport
from sqlmesh.utils import optional_import, get_source_columns_to_types
from sqlmesh.utils.date import to_datetime
from sqlmesh.utils.errors import SQLMeshError
Expand Down Expand Up @@ -68,8 +68,8 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
MAX_COLUMN_COMMENT_LENGTH = 1024
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"]

SCHEMA_DIFFER = SchemaDiffer(
compatible_types={
SCHEMA_DIFFER_KWARGS = {
"compatible_types": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (and I guess this is also a matter of opinion): did you consider using the dict() constructor so you can keep treating the kwarg names as identifiers instead of strings?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a preference either way. Just leaving how it is to save some time.

exp.DataType.build("INT64", dialect=DIALECT): {
exp.DataType.build("NUMERIC", dialect=DIALECT),
exp.DataType.build("FLOAT64", dialect=DIALECT),
Expand All @@ -83,17 +83,17 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
exp.DataType.build("DATETIME", dialect=DIALECT),
},
},
coerceable_types={
"coerceable_types": {
exp.DataType.build("FLOAT64", dialect=DIALECT): {
exp.DataType.build("BIGNUMERIC", dialect=DIALECT),
},
},
support_coercing_compatible_types=True,
parameterized_type_defaults={
"support_coercing_compatible_types": True,
"parameterized_type_defaults": {
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 9), (0,)],
exp.DataType.build("BIGDECIMAL", dialect=DIALECT).this: [(76.76, 38), (0,)],
},
types_with_unlimited_length={
"types_with_unlimited_length": {
# parameterized `STRING(n)` can ALTER to unparameterized `STRING`
exp.DataType.build("STRING", dialect=DIALECT).this: {
exp.DataType.build("STRING", dialect=DIALECT).this,
Expand All @@ -103,9 +103,8 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
exp.DataType.build("BYTES", dialect=DIALECT).this,
},
},
support_nested_operations=True,
support_nested_drop=False,
)
"nested_support": NestedSupport.ALL_BUT_DROP,
}

@property
def client(self) -> BigQueryClient:
Expand Down
4 changes: 2 additions & 2 deletions sqlmesh/core/engine_adapter/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
CommentCreationView,
InsertOverwriteStrategy,
)
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation
from sqlmesh.core.schema_diff import TableAlterOperation
from sqlmesh.utils import get_source_columns_to_types

if t.TYPE_CHECKING:
Expand All @@ -37,7 +37,7 @@ class ClickhouseEngineAdapter(EngineAdapterWithIndexSupport, LogicalMergeMixin):
SUPPORTS_REPLACE_TABLE = False
COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY

SCHEMA_DIFFER = SchemaDiffer()
SCHEMA_DIFFER_KWARGS = {}

DEFAULT_TABLE_ENGINE = "MergeTree"
ORDER_BY_TABLE_ENGINE_REGEX = "^.*?MergeTree.*$"
Expand Down
15 changes: 7 additions & 8 deletions sqlmesh/core/engine_adapter/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
from sqlmesh.core.node import IntervalUnit
from sqlmesh.core.schema_diff import SchemaDiffer
from sqlmesh.core.schema_diff import NestedSupport
from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection
from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError

Expand All @@ -34,15 +34,14 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
SUPPORTS_CLONING = True
SUPPORTS_MATERIALIZED_VIEWS = True
SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
SCHEMA_DIFFER = SchemaDiffer(
support_positional_add=True,
support_nested_operations=True,
support_nested_drop=True,
array_element_selector="element",
parameterized_type_defaults={
SCHEMA_DIFFER_KWARGS = {
"support_positional_add": True,
"nested_support": NestedSupport.ALL,
"array_element_selector": "element",
"parameterized_type_defaults": {
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)],
},
)
}

def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
Expand Down
7 changes: 3 additions & 4 deletions sqlmesh/core/engine_adapter/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
SourceQuery,
set_catalog,
)
from sqlmesh.core.schema_diff import SchemaDiffer

if t.TYPE_CHECKING:
from sqlmesh.core._typing import SchemaName, TableName
Expand All @@ -29,11 +28,11 @@
class DuckDBEngineAdapter(LogicalMergeMixin, GetCurrentCatalogFromFunctionMixin, RowDiffMixin):
DIALECT = "duckdb"
SUPPORTS_TRANSACTIONS = False
SCHEMA_DIFFER = SchemaDiffer(
parameterized_type_defaults={
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 3), (0,)],
},
)
}
COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY
COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY
SUPPORTS_CREATE_DROP_CATALOG = True
Expand Down
8 changes: 4 additions & 4 deletions sqlmesh/core/engine_adapter/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ def _default_precision_to_max(
) -> t.Dict[str, exp.DataType]:
# get default lengths for types that support "max" length
types_with_max_default_param = {
k: [self.SCHEMA_DIFFER.parameterized_type_defaults[k][0][0]]
for k in self.SCHEMA_DIFFER.max_parameter_length
if k in self.SCHEMA_DIFFER.parameterized_type_defaults
k: [self.schema_differ.parameterized_type_defaults[k][0][0]]
for k in self.schema_differ.max_parameter_length
if k in self.schema_differ.parameterized_type_defaults
}

# Redshift and MSSQL have a bug where CTAS statements have non-deterministic types. If a LIMIT
Expand All @@ -270,7 +270,7 @@ def _default_precision_to_max(
# and supports "max" length, we convert it to "max" length to prevent inadvertent data truncation.
for col_name, col_type in columns_to_types.items():
if col_type.this in types_with_max_default_param and col_type.expressions:
parameter = self.SCHEMA_DIFFER.get_type_parameters(col_type)
parameter = self.schema_differ.get_type_parameters(col_type)
type_default = types_with_max_default_param[col_type.this]
if parameter == type_default:
col_type.set("expressions", [exp.DataTypeParam(this=exp.var("max"))])
Expand Down
9 changes: 4 additions & 5 deletions sqlmesh/core/engine_adapter/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
SourceQuery,
set_catalog,
)
from sqlmesh.core.schema_diff import SchemaDiffer
from sqlmesh.utils import get_source_columns_to_types

if t.TYPE_CHECKING:
Expand All @@ -54,8 +53,8 @@ class MSSQLEngineAdapter(
COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED
COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
SUPPORTS_REPLACE_TABLE = False
SCHEMA_DIFFER = SchemaDiffer(
parameterized_type_defaults={
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)],
exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)],
exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(1,)],
Expand All @@ -67,12 +66,12 @@ class MSSQLEngineAdapter(
exp.DataType.build("DATETIME2", dialect=DIALECT).this: [(7,)],
exp.DataType.build("DATETIMEOFFSET", dialect=DIALECT).this: [(7,)],
},
max_parameter_length={
"max_parameter_length": {
exp.DataType.build("VARBINARY", dialect=DIALECT).this: 2147483647, # 2 GB
exp.DataType.build("VARCHAR", dialect=DIALECT).this: 2147483647,
exp.DataType.build("NVARCHAR", dialect=DIALECT).this: 2147483647,
},
)
}
VARIABLE_LENGTH_DATA_TYPES = {"binary", "varbinary", "char", "varchar", "nchar", "nvarchar"}

@property
Expand Down
7 changes: 3 additions & 4 deletions sqlmesh/core/engine_adapter/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
DataObjectType,
set_catalog,
)
from sqlmesh.core.schema_diff import SchemaDiffer

if t.TYPE_CHECKING:
from sqlmesh.core._typing import SchemaName, TableName
Expand All @@ -40,8 +39,8 @@ class MySQLEngineAdapter(
MAX_COLUMN_COMMENT_LENGTH = 1024
SUPPORTS_REPLACE_TABLE = False
MAX_IDENTIFIER_LENGTH = 64
SCHEMA_DIFFER = SchemaDiffer(
parameterized_type_defaults={
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
exp.DataType.build("BIT", dialect=DIALECT).this: [(1,)],
exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)],
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)],
Expand All @@ -52,7 +51,7 @@ class MySQLEngineAdapter(
exp.DataType.build("DATETIME", dialect=DIALECT).this: [(0,)],
exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(0,)],
},
)
}

def get_current_catalog(self) -> t.Optional[str]:
"""Returns the catalog name of the current connection."""
Expand Down
11 changes: 5 additions & 6 deletions sqlmesh/core/engine_adapter/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
logical_merge,
)
from sqlmesh.core.engine_adapter.shared import set_catalog
from sqlmesh.core.schema_diff import SchemaDiffer

if t.TYPE_CHECKING:
from sqlmesh.core._typing import TableName
Expand All @@ -36,15 +35,15 @@ class PostgresEngineAdapter(
CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog")
SUPPORTS_REPLACE_TABLE = False
MAX_IDENTIFIER_LENGTH = 63
SCHEMA_DIFFER = SchemaDiffer(
parameterized_type_defaults={
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
# DECIMAL without precision is "up to 131072 digits before the decimal point; up to 16383 digits after the decimal point"
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(131072 + 16383, 16383), (0,)],
exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)],
exp.DataType.build("TIME", dialect=DIALECT).this: [(6,)],
exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(6,)],
},
types_with_unlimited_length={
"types_with_unlimited_length": {
# all can ALTER to `TEXT`
exp.DataType.build("TEXT", dialect=DIALECT).this: {
exp.DataType.build("VARCHAR", dialect=DIALECT).this,
Expand All @@ -63,8 +62,8 @@ class PostgresEngineAdapter(
exp.DataType.build("BPCHAR", dialect=DIALECT).this
},
},
drop_cascade=True,
)
"drop_cascade": True,
}

def _fetch_native_df(
self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False
Expand Down
13 changes: 6 additions & 7 deletions sqlmesh/core/engine_adapter/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
SourceQuery,
set_catalog,
)
from sqlmesh.core.schema_diff import SchemaDiffer
from sqlmesh.utils.errors import SQLMeshError

if t.TYPE_CHECKING:
Expand All @@ -48,22 +47,22 @@ class RedshiftEngineAdapter(
COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
SUPPORTS_REPLACE_TABLE = False

SCHEMA_DIFFER = SchemaDiffer(
parameterized_type_defaults={
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
exp.DataType.build("VARBYTE", dialect=DIALECT).this: [(64000,)],
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)],
exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)],
exp.DataType.build("VARCHAR", dialect=DIALECT).this: [(256,)],
exp.DataType.build("NCHAR", dialect=DIALECT).this: [(1,)],
exp.DataType.build("NVARCHAR", dialect=DIALECT).this: [(256,)],
},
max_parameter_length={
"max_parameter_length": {
exp.DataType.build("CHAR", dialect=DIALECT).this: 4096,
exp.DataType.build("VARCHAR", dialect=DIALECT).this: 65535,
},
precision_increase_allowed_types={exp.DataType.build("VARCHAR", dialect=DIALECT).this},
drop_cascade=True,
)
"precision_increase_allowed_types": {exp.DataType.build("VARCHAR", dialect=DIALECT).this},
"drop_cascade": True,
}
VARIABLE_LENGTH_DATA_TYPES = {
"char",
"character",
Expand Down
7 changes: 3 additions & 4 deletions sqlmesh/core/engine_adapter/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
SourceQuery,
set_catalog,
)
from sqlmesh.core.schema_diff import SchemaDiffer
from sqlmesh.utils import optional_import, get_source_columns_to_types
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.pandas import columns_to_types_from_dtypes
Expand Down Expand Up @@ -56,8 +55,8 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
CURRENT_CATALOG_EXPRESSION = exp.func("current_database")
SUPPORTS_CREATE_DROP_CATALOG = True
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA", "TABLE"]
SCHEMA_DIFFER = SchemaDiffer(
parameterized_type_defaults={
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
exp.DataType.build("BINARY", dialect=DIALECT).this: [(8388608,)],
exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(8388608,)],
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 0), (0,)],
Expand All @@ -70,7 +69,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
exp.DataType.build("TIMESTAMP_NTZ", dialect=DIALECT).this: [(9,)],
exp.DataType.build("TIMESTAMP_TZ", dialect=DIALECT).this: [(9,)],
},
)
}
MANAGED_TABLE_KIND = "DYNAMIC TABLE"
SNOWPARK = "snowpark"

Expand Down
Loading