From 05ad6ef465177d60248471d8597a7e8d3897a9e7 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Sat, 9 Aug 2025 13:24:34 -0700 Subject: [PATCH 1/4] feat: add ignore destructive support --- .circleci/continue_config.yml | 8 +- docs/concepts/models/overview.md | 8 +- docs/guides/custom_materializations.md | 4 + docs/guides/incremental_time.md | 7 +- .../custom_materializations/custom_kind.py | 3 +- .../custom_materializations/full.py | 3 +- sqlmesh/core/dialect.py | 29 +- sqlmesh/core/engine_adapter/_typing.py | 2 +- sqlmesh/core/engine_adapter/athena.py | 2 + sqlmesh/core/engine_adapter/base.py | 220 ++++- sqlmesh/core/engine_adapter/base_postgres.py | 2 + sqlmesh/core/engine_adapter/bigquery.py | 47 +- sqlmesh/core/engine_adapter/clickhouse.py | 18 +- sqlmesh/core/engine_adapter/databricks.py | 13 +- sqlmesh/core/engine_adapter/duckdb.py | 3 +- sqlmesh/core/engine_adapter/mixins.py | 20 +- sqlmesh/core/engine_adapter/mssql.py | 23 +- sqlmesh/core/engine_adapter/postgres.py | 2 + sqlmesh/core/engine_adapter/redshift.py | 12 +- sqlmesh/core/engine_adapter/snowflake.py | 50 +- sqlmesh/core/engine_adapter/spark.py | 82 +- sqlmesh/core/engine_adapter/trino.py | 11 +- sqlmesh/core/model/kind.py | 5 + sqlmesh/core/plan/builder.py | 1 + sqlmesh/core/schema_diff.py | 83 +- sqlmesh/core/snapshot/evaluator.py | 251 +++-- sqlmesh/core/table_diff.py | 2 +- sqlmesh/utils/__init__.py | 7 + .../engine_adapter/integration/conftest.py | 2 +- .../integration/test_integration.py | 761 +++++++++++++- tests/core/engine_adapter/test_base.py | 539 +++++++++- tests/core/engine_adapter/test_bigquery.py | 6 +- tests/core/test_integration.py | 929 +++++++++++++++++- tests/core/test_schema_diff.py | 235 +++++ tests/core/test_snapshot_evaluator.py | 26 +- 35 files changed, 3159 insertions(+), 257 deletions(-) diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index 04135574a9..e651a1e80b 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -305,10 +305,10 @@ workflows: - clickhouse-cloud - athena - gcp-postgres - filters: - branches: - only: - - main +# filters: +# branches: +# only: +# - main - ui_style - ui_test - vscode_test diff --git a/docs/concepts/models/overview.md b/docs/concepts/models/overview.md index dd9fd0d767..cf57678607 100644 --- a/docs/concepts/models/overview.md +++ b/docs/concepts/models/overview.md @@ -507,11 +507,15 @@ Some properties are only available in specific model kinds - see the [model conf : Set this to true to indicate that all changes to this model should be [forward-only](../plans.md#forward-only-plans). ### on_destructive_change -: What should happen when a change to a [forward-only model](../../guides/incremental_time.md#forward-only-models) or incremental model in a [forward-only plan](../plans.md#forward-only-plans) causes a destructive modification to the table schema (i.e., requires dropping an existing column). +: What should happen when a change to a [forward-only model](../../guides/incremental_time.md#forward-only-models) or incremental model in a [forward-only plan](../plans.md#forward-only-plans) causes a destructive modification to the table schema (i.e., requires dropping an existing column or modifying column constraints in ways that could cause data loss). SQLMesh checks for destructive changes at plan time based on the model definition and run time based on the model's underlying physical tables. - Must be one of the following values: `allow`, `warn`, or `error` (default). + Must be one of the following values: `allow`, `warn`, `error` (default), or `ignore`. + +!!! warning "Ignore is Dangerous" + + `ignore` is dangerous since it can result in error or data loss. It likely should never be used but could be useful as an "escape-hatch" or a way to workaround unexpected behavior. ### disable_restatement : Set this to true to indicate that [data restatement](../plans.md#restatement-plans) is disabled for this model. diff --git a/docs/guides/custom_materializations.md b/docs/guides/custom_materializations.md index b11d9004a9..58eb64026d 100644 --- a/docs/guides/custom_materializations.md +++ b/docs/guides/custom_materializations.md @@ -64,6 +64,7 @@ class CustomFullMaterialization(CustomMaterialization): query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: self.adapter.replace_query(table_name, query_or_df) @@ -78,6 +79,7 @@ Let's unpack this materialization: * `query_or_df` - a query (of SQLGlot expression type) or DataFrame (Pandas, PySpark, or Snowpark) instance to be inserted * `model` - the model definition object used to access model parameters and user-specified materialization arguments * `is_first_insert` - whether this is the first insert for the current version of the model (used with batched or multi-step inserts) + * `render_kwargs` - a dictionary of arguments used to render the model query * `kwargs` - additional and future arguments * The `self.adapter` instance is used to interact with the target engine. It comes with a set of useful high-level APIs like `replace_query`, `columns`, and `table_exists`, but also supports executing arbitrary SQL expressions with its `execute` method. @@ -150,6 +152,7 @@ class CustomFullMaterialization(CustomMaterialization): query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: config_value = model.custom_materialization_properties["config_key"] @@ -232,6 +235,7 @@ class CustomFullMaterialization(CustomMaterialization[MyCustomKind]): query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: assert isinstance(model.kind, MyCustomKind) diff --git a/docs/guides/incremental_time.md b/docs/guides/incremental_time.md index 7c773f7edc..2f54516ec4 100644 --- a/docs/guides/incremental_time.md +++ b/docs/guides/incremental_time.md @@ -171,7 +171,12 @@ The check is performed at plan time based on the model definition. SQLMesh may n A model's `on_destructive_change` [configuration setting](../reference/model_configuration.md#incremental-models) determines what happens when SQLMesh detects a destructive change. -By default, SQLMesh will error so no data is lost. You can set `on_destructive_change` to `warn` or `allow` in the model's `MODEL` block to allow destructive changes. +By default, SQLMesh will error so no data is lost. You can set `on_destructive_change` to `warn` or `allow` in the model's `MODEL` block to allow destructive changes. +`ignore` can be used to not perform the schema change and allow the table's definition to diverge from the model definition. + +!!! warning "Ignore is Dangerous" + + `ignore` is dangerous since it can result in error or data loss. It likely should never be used but could be useful as an "escape-hatch" or a way to workaround unexpected behavior. This example configures a model to silently `allow` destructive changes: diff --git a/examples/custom_materializations/custom_materializations/custom_kind.py b/examples/custom_materializations/custom_materializations/custom_kind.py index a8330febad..8a0eabcfa7 100644 --- a/examples/custom_materializations/custom_materializations/custom_kind.py +++ b/examples/custom_materializations/custom_materializations/custom_kind.py @@ -24,8 +24,9 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: assert type(model.kind).__name__ == "ExtendedCustomKind" - self._replace_query_for_model(model, table_name, query_or_df) + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs) diff --git a/examples/custom_materializations/custom_materializations/full.py b/examples/custom_materializations/custom_materializations/full.py index 79aa50232a..d2a7c64993 100644 --- a/examples/custom_materializations/custom_materializations/full.py +++ b/examples/custom_materializations/custom_materializations/full.py @@ -17,6 +17,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - self._replace_query_for_model(model, table_name, query_or_df) + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs) diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 568d9f5f73..33ec55b7a7 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -23,6 +23,7 @@ from sqlglot.tokens import Token from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE +from sqlmesh.utils import get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError, ConfigError from sqlmesh.utils.pandas import columns_to_types_from_df @@ -1134,31 +1135,43 @@ def select_from_values_for_batch_range( batch_start: int, batch_end: int, alias: str = "t", + source_columns: t.Optional[t.List[str]] = None, ) -> exp.Select: - casted_columns = [ - exp.alias_(exp.cast(exp.column(column), to=kind), column, copy=False) - for column, kind in columns_to_types.items() - ] + source_columns = source_columns or list(columns_to_types) + source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) if not values: # Ensures we don't generate an empty VALUES clause & forces a zero-row output where = exp.false() - expressions = [tuple(exp.cast(exp.null(), to=kind) for kind in columns_to_types.values())] + expressions = [ + tuple(exp.cast(exp.null(), to=kind) for kind in source_columns_to_types.values()) + ] else: where = None expressions = [ - tuple(transform_values(v, columns_to_types)) for v in values[batch_start:batch_end] + tuple(transform_values(v, source_columns_to_types)) + for v in values[batch_start:batch_end] ] - values_exp = exp.values(expressions, alias=alias, columns=columns_to_types) + values_exp = exp.values(expressions, alias=alias, columns=source_columns_to_types) if values: # BigQuery crashes on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([NULL]) AS x`, but not # on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([CAST(NULL AS TIMESTAMP)]) AS x`. This # ensures nulls under the `Values` expression are cast to avoid similar issues. - for value, kind in zip(values_exp.expressions[0].expressions, columns_to_types.values()): + for value, kind in zip( + values_exp.expressions[0].expressions, source_columns_to_types.values() + ): if isinstance(value, exp.Null): value.replace(exp.cast(value, to=kind)) + casted_columns = [ + exp.alias_( + exp.cast(exp.column(column) if column in source_columns else exp.Null(), to=kind), + column, + copy=False, + ) + for column, kind in columns_to_types.items() + ] return exp.select(*casted_columns).from_(values_exp, copy=False).where(where, copy=False) diff --git a/sqlmesh/core/engine_adapter/_typing.py b/sqlmesh/core/engine_adapter/_typing.py index 143fcf6ab6..98821bb2d4 100644 --- a/sqlmesh/core/engine_adapter/_typing.py +++ b/sqlmesh/core/engine_adapter/_typing.py @@ -13,7 +13,7 @@ snowpark = optional_import("snowflake.snowpark") - Query = t.Union[exp.Query, exp.DerivedTable] + Query = exp.Query PySparkSession = t.Union[pyspark.sql.SparkSession, pyspark.sql.connect.dataframe.SparkSession] PySparkDataFrame = t.Union[pyspark.sql.DataFrame, pyspark.sql.connect.dataframe.DataFrame] diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index abaf7ba281..59642b6e16 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -434,6 +434,7 @@ def replace_query( columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: table = exp.to_table(table_name) @@ -447,6 +448,7 @@ def replace_query( columns_to_types=columns_to_types, table_description=table_description, column_descriptions=column_descriptions, + source_columns=source_columns, **kwargs, ) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 4651caa6ec..327a6fbee7 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -40,7 +40,12 @@ ) from sqlmesh.core.model.kind import TimeColumn from sqlmesh.core.schema_diff import SchemaDiffer -from sqlmesh.utils import CorrelationId, columns_to_types_all_known, random_id +from sqlmesh.utils import ( + CorrelationId, + columns_to_types_all_known, + random_id, + get_source_columns_to_types, +) from sqlmesh.utils.connection_pool import ConnectionPool, create_connection_pool from sqlmesh.utils.date import TimeLike, make_inclusive, to_time_column from sqlmesh.utils.errors import ( @@ -199,9 +204,22 @@ def catalog_support(self) -> CatalogSupport: return CatalogSupport.UNSUPPORTED @classmethod - def _casted_columns(cls, columns_to_types: t.Dict[str, exp.DataType]) -> t.List[exp.Alias]: + def _casted_columns( + cls, + columns_to_types: t.Dict[str, exp.DataType], + source_columns: t.Optional[t.List[str]] = None, + ) -> t.List[exp.Alias]: + source_columns = source_columns or list(columns_to_types) return [ - exp.alias_(exp.cast(exp.column(column), to=kind), column, copy=False) + exp.alias_( + exp.cast( + exp.column(column, quoted=True) if column in source_columns else exp.Null(), + to=kind, + ), + column, + copy=False, + quoted=True, + ) for column, kind in columns_to_types.items() ] @@ -227,12 +245,31 @@ def _get_source_queries( target_table: TableName, *, batch_size: t.Optional[int] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: import pandas as pd batch_size = self.DEFAULT_BATCH_SIZE if batch_size is None else batch_size - if isinstance(query_or_df, (exp.Query, exp.DerivedTable)): - return [SourceQuery(query_factory=lambda: query_or_df)] # type: ignore + if isinstance(query_or_df, exp.Query): + query_factory = lambda: query_or_df + if source_columns: + if not columns_to_types: + raise SQLMeshError("columns_to_types must be set if source_columns is set") + if not set(columns_to_types).issubset(set(source_columns)): + select_columns = [ + exp.column(c, quoted=True) + if c in source_columns + else exp.cast(exp.Null(), columns_to_types[c], copy=False).as_( + c, copy=False, quoted=True + ) + for c in columns_to_types + ] + query_factory = ( + lambda: exp.Select() + .select(*select_columns) + .from_(query_or_df.subquery("select_source_columns")) + ) + return [SourceQuery(query_factory=query_factory)] # type: ignore if not columns_to_types: raise SQLMeshError( @@ -247,7 +284,11 @@ def _get_source_queries( ) return self._df_to_source_queries( - query_or_df, columns_to_types, batch_size, target_table=target_table + query_or_df, + columns_to_types, + batch_size, + target_table=target_table, + source_columns=source_columns, ) def _df_to_source_queries( @@ -256,6 +297,7 @@ def _df_to_source_queries( columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: import pandas as pd @@ -265,7 +307,7 @@ def _df_to_source_queries( # we need to ensure that the order of the columns in columns_to_types columns matches the order of the values # they can differ if a user specifies columns() on a python model in a different order than what's in the DataFrame's emitted by that model - df = df[list(columns_to_types)] + df = df[list(source_columns or columns_to_types)] values = list(df.itertuples(index=False, name=None)) return [ @@ -276,6 +318,7 @@ def _df_to_source_queries( columns_to_types=columns_to_types, batch_start=i, batch_end=min(i + batch_size, num_rows), + source_columns=source_columns, ), ) for i in range(0, num_rows, batch_size) @@ -288,35 +331,49 @@ def _get_source_queries_and_columns_to_types( target_table: TableName, *, batch_size: t.Optional[int] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.List[SourceQuery], t.Optional[t.Dict[str, exp.DataType]]]: - columns_to_types = self._columns_to_types(query_or_df, columns_to_types) - return ( - self._get_source_queries( - query_or_df, columns_to_types, target_table=target_table, batch_size=batch_size - ), + columns_to_types, source_columns = self._columns_to_types( + query_or_df, columns_to_types, source_columns + ) + source_queries = self._get_source_queries( + query_or_df, columns_to_types, + target_table=target_table, + batch_size=batch_size, + source_columns=source_columns, ) + return source_queries, columns_to_types @t.overload def _columns_to_types( - self, query_or_df: DF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Dict[str, exp.DataType]: ... + self, + query_or_df: DF, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... @t.overload def _columns_to_types( - self, query_or_df: Query, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: ... + self, + query_or_df: Query, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... def _columns_to_types( - self, query_or_df: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: + self, + query_or_df: QueryOrDF, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: import pandas as pd - if columns_to_types: - return columns_to_types - if isinstance(query_or_df, pd.DataFrame): - return columns_to_types_from_df(t.cast(pd.DataFrame, query_or_df)) - return columns_to_types + if not columns_to_types and isinstance(query_or_df, pd.DataFrame): + columns_to_types = columns_to_types_from_df(t.cast(pd.DataFrame, query_or_df)) + if not source_columns and columns_to_types: + source_columns = list(columns_to_types) + return columns_to_types, source_columns def recycle(self) -> None: """Closes all open connections and releases all allocated resources associated with any thread @@ -356,6 +413,7 @@ def replace_query( columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: """Replaces an existing table with a query. @@ -377,10 +435,13 @@ def replace_query( table_exists = False source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=target_table + query_or_df, + columns_to_types, + target_table=target_table, + source_columns=source_columns, ) - columns_to_types = columns_to_types or self.columns(target_table) query = source_queries[0].query_factory() + columns_to_types = columns_to_types or self.columns(target_table) self_referencing = any( quote_identifiers(table) == quote_identifiers(target_table) for table in query.find_all(exp.Table) @@ -531,6 +592,7 @@ def create_managed_table( table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: """Create a managed table using a query. @@ -558,6 +620,7 @@ def ctas( exists: bool = True, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: """Create a table using a CTAS statement @@ -572,7 +635,7 @@ def ctas( kwargs: Optional create table properties. """ source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name + query_or_df, columns_to_types, target_table=table_name, source_columns=source_columns ) return self._create_table_from_source_queries( table_name, @@ -1000,6 +1063,8 @@ def get_alter_expressions( self, current_table_name: TableName, target_table_name: TableName, + *, + ignore_destructive: bool = False, ) -> t.List[exp.Alter]: """ Determines the alter statements needed to change the current table into the structure of the target table. @@ -1008,6 +1073,7 @@ def get_alter_expressions( current_table_name, self.columns(current_table_name), self.columns(target_table_name), + ignore_destructive=ignore_destructive, ) def alter_table( @@ -1032,6 +1098,7 @@ def create_view( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + source_columns: t.Optional[t.List[str]] = None, **create_kwargs: t.Any, ) -> None: """Create a view with a query or dataframe. @@ -1062,18 +1129,25 @@ def create_view( values: t.List[t.Tuple[t.Any, ...]] = list( query_or_df.itertuples(index=False, name=None) ) - columns_to_types = columns_to_types or self._columns_to_types(query_or_df) + columns_to_types, source_columns = self._columns_to_types( + query_or_df, columns_to_types, source_columns + ) if not columns_to_types: raise SQLMeshError("columns_to_types must be provided for dataframes") + source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) query_or_df = self._values_to_sql( values, - columns_to_types, + source_columns_to_types, batch_start=0, batch_end=len(values), ) source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, batch_size=0, target_table=view_name + query_or_df, + columns_to_types, + batch_size=0, + target_table=view_name, + source_columns=source_columns, ) if len(source_queries) != 1: raise SQLMeshError("Only one source query is supported for creating views") @@ -1308,9 +1382,10 @@ def insert_append( table_name: TableName, query_or_df: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> None: source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name + query_or_df, columns_to_types, target_table=table_name, source_columns=source_columns ) self._insert_append_source_queries(table_name, source_queries, columns_to_types) @@ -1343,18 +1418,27 @@ def insert_overwrite_by_partition( query_or_df: QueryOrDF, partitioned_by: t.List[exp.Expression], columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> None: if self.INSERT_OVERWRITE_STRATEGY.is_insert_overwrite: target_table = exp.to_table(table_name) source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=target_table + query_or_df, + columns_to_types, + target_table=target_table, + source_columns=source_columns, ) self._insert_overwrite_by_condition( table_name, source_queries, columns_to_types=columns_to_types ) else: self._replace_by_key( - table_name, query_or_df, columns_to_types, partitioned_by, is_unique_key=False + table_name, + query_or_df, + columns_to_types, + partitioned_by, + is_unique_key=False, + source_columns=source_columns, ) def insert_overwrite_by_time_partition( @@ -1368,10 +1452,11 @@ def insert_overwrite_by_time_partition( ], time_column: TimeColumn | exp.Expression | str, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name + query_or_df, columns_to_types, target_table=table_name, source_columns=source_columns ) if not columns_to_types or not columns_to_types_all_known(columns_to_types): columns_to_types = self.columns(table_name) @@ -1408,6 +1493,7 @@ def _values_to_sql( batch_start: int, batch_end: int, alias: str = "t", + source_columns: t.Optional[t.List[str]] = None, ) -> Query: return select_from_values_for_batch_range( values=values, @@ -1415,6 +1501,7 @@ def _values_to_sql( batch_start=batch_start, batch_end=batch_end, alias=alias, + source_columns=source_columns, ) def _insert_overwrite_by_condition( @@ -1497,6 +1584,7 @@ def scd_type_2_by_time( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: self._scd_type_2( @@ -1513,6 +1601,7 @@ def scd_type_2_by_time( table_description=table_description, column_descriptions=column_descriptions, truncate=truncate, + source_columns=source_columns, **kwargs, ) @@ -1531,6 +1620,7 @@ def scd_type_2_by_column( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: self._scd_type_2( @@ -1547,6 +1637,7 @@ def scd_type_2_by_column( table_description=table_description, column_descriptions=column_descriptions, truncate=truncate, + source_columns=source_columns, **kwargs, ) @@ -1567,6 +1658,7 @@ def _scd_type_2( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: def remove_managed_columns( @@ -1578,20 +1670,24 @@ def remove_managed_columns( valid_from_name = valid_from_col.name valid_to_name = valid_to_col.name - unmanaged_columns_to_types = ( - remove_managed_columns(columns_to_types) if columns_to_types else None - ) - source_queries, unmanaged_columns_to_types = self._get_source_queries_and_columns_to_types( - source_table, unmanaged_columns_to_types, target_table=target_table, batch_size=0 - ) columns_to_types = columns_to_types or self.columns(target_table) - updated_at_name = updated_at_col.name if updated_at_col else None if ( valid_from_name not in columns_to_types or valid_to_name not in columns_to_types or not columns_to_types_all_known(columns_to_types) ): columns_to_types = self.columns(target_table) + unmanaged_columns_to_types = ( + remove_managed_columns(columns_to_types) if columns_to_types else None + ) + source_queries, unmanaged_columns_to_types = self._get_source_queries_and_columns_to_types( + source_table, + unmanaged_columns_to_types, + target_table=target_table, + batch_size=0, + source_columns=source_columns, + ) + updated_at_name = updated_at_col.name if updated_at_col else None if not columns_to_types: raise SQLMeshError(f"Could not get columns_to_types. Does {target_table} exist?") unmanaged_columns_to_types = unmanaged_columns_to_types or remove_managed_columns( @@ -1943,10 +2039,11 @@ def merge( unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - source_table, columns_to_types, target_table=target_table + source_table, columns_to_types, target_table=target_table, source_columns=source_columns ) columns_to_types = columns_to_types or self.columns(target_table) on = exp.and_( @@ -2093,7 +2190,7 @@ def _native_df_to_pandas_df( """ import pandas as pd - if isinstance(query_or_df, (exp.Query, exp.DerivedTable, pd.DataFrame)): + if isinstance(query_or_df, (exp.Query, pd.DataFrame)): return query_or_df # EngineAdapter subclasses that have native DataFrame types should override this @@ -2269,6 +2366,7 @@ def temp_table( query_or_df: QueryOrDF, name: TableName = "diff", columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> t.Iterator[exp.Table]: """A context manager for working a temp table. @@ -2289,7 +2387,10 @@ def temp_table( name.set("catalog", exp.parse_identifier(self.default_catalog)) source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types=columns_to_types, target_table=name + query_or_df, + columns_to_types=columns_to_types, + target_table=name, + source_columns=source_columns, ) with self.transaction(): @@ -2515,6 +2616,7 @@ def _replace_by_key( columns_to_types: t.Optional[t.Dict[str, exp.DataType]], key: t.Sequence[exp.Expression], is_unique_key: bool, + source_columns: t.Optional[t.List[str]] = None, ) -> None: if columns_to_types is None: columns_to_types = self.columns(target_table) @@ -2524,7 +2626,13 @@ def _replace_by_key( column_names = list(columns_to_types or []) with self.transaction(): - self.ctas(temp_table, source_table, columns_to_types=columns_to_types, exists=False) + self.ctas( + temp_table, + source_table, + columns_to_types=columns_to_types, + exists=False, + source_columns=source_columns, + ) try: delete_query = exp.select(key_exp).from_(temp_table) @@ -2626,8 +2734,17 @@ def ping(self) -> None: self._connection_pool.close_cursor() @classmethod - def _select_columns(cls, columns: t.Iterable[str]) -> exp.Select: - return exp.select(*(exp.column(c, quoted=True) for c in columns)) + def _select_columns( + cls, columns: t.Iterable[str], source_columns: t.Optional[t.List[str]] = None + ) -> exp.Select: + return exp.select( + *( + exp.column(c, quoted=True) + if c in (source_columns or columns) + else exp.alias_(exp.Null(), c, quoted=True) + for c in columns + ) + ) def _check_identifier_length(self, expression: exp.Expression) -> None: if self.MAX_IDENTIFIER_LENGTH is None or not isinstance(expression, exp.DDL): @@ -2641,6 +2758,17 @@ def _check_identifier_length(self, expression: exp.Expression) -> None: f"Identifier name '{name}' (length {name_length}) exceeds {self.dialect.capitalize()}'s max identifier limit of {self.MAX_IDENTIFIER_LENGTH} characters" ) + @classmethod + def get_source_columns_to_types( + cls, + columns_to_types: t.Dict[str, exp.DataType], + source_columns: t.Optional[t.List[str]], + ) -> t.Dict[str, exp.DataType]: + """Returns the source columns to types mapping.""" + return { + k: v for k, v in columns_to_types.items() if not source_columns or k in source_columns + } + class EngineAdapterWithIndexSupport(EngineAdapter): SUPPORTS_INDEXES = True diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py index c1026c91df..10083dcb91 100644 --- a/sqlmesh/core/engine_adapter/base_postgres.py +++ b/sqlmesh/core/engine_adapter/base_postgres.py @@ -98,6 +98,7 @@ def create_view( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + source_columns: t.Optional[t.List[str]] = None, **create_kwargs: t.Any, ) -> None: """ @@ -120,6 +121,7 @@ def create_view( table_description=table_description, column_descriptions=column_descriptions, view_properties=view_properties, + source_columns=source_columns, **create_kwargs, ) diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 2de3f18c99..fb88a707fd 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -22,7 +22,7 @@ ) from sqlmesh.core.node import IntervalUnit from sqlmesh.core.schema_diff import SchemaDiffer -from sqlmesh.utils import optional_import +from sqlmesh.utils import optional_import, get_source_columns_to_types from sqlmesh.utils.date import to_datetime from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pandas import columns_to_types_from_dtypes @@ -151,11 +151,14 @@ def _df_to_source_queries( columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: import pandas as pd + source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) + temp_bq_table = self.__get_temp_bq_table( - self._get_temp_table(target_table or "pandas"), columns_to_types + self._get_temp_table(target_table or "pandas"), source_columns_to_types ) temp_table = exp.table_( temp_bq_table.table_id, @@ -174,11 +177,13 @@ def query_factory() -> Query: assert isinstance(df, pd.DataFrame) self._db_call(self.client.create_table, table=temp_bq_table, exists_ok=False) result = self.__load_pandas_to_table( - temp_bq_table, df, columns_to_types, replace=False + temp_bq_table, df, source_columns_to_types, replace=False ) if result.errors: raise SQLMeshError(result.errors) - return self._select_columns(columns_to_types).from_(temp_table) + return exp.select( + *self._casted_columns(columns_to_types, source_columns=source_columns) + ).from_(temp_table) return [ SourceQuery( @@ -674,6 +679,7 @@ def insert_overwrite_by_partition( query_or_df: QueryOrDF, partitioned_by: t.List[exp.Expression], columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> None: if len(partitioned_by) != 1: raise SQLMeshError( @@ -695,7 +701,10 @@ def insert_overwrite_by_partition( with ( self.session({}), self.temp_table( - query_or_df, name=table_name, partitioned_by=partitioned_by + query_or_df, + name=table_name, + partitioned_by=partitioned_by, + source_columns=source_columns, ) as temp_table_name, ): if columns_to_types is None or columns_to_types[ @@ -1204,17 +1213,26 @@ def _normalize_nested_value(self, col: exp.Expression) -> exp.Expression: @t.overload def _columns_to_types( - self, query_or_df: DF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Dict[str, exp.DataType]: ... + self, + query_or_df: DF, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... @t.overload def _columns_to_types( - self, query_or_df: Query, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: ... + self, + query_or_df: Query, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... def _columns_to_types( - self, query_or_df: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: + self, + query_or_df: QueryOrDF, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: if ( not columns_to_types and bigframes @@ -1222,9 +1240,12 @@ def _columns_to_types( ): # using dry_run=True attempts to prevent the DataFrame from being materialized just to read the column types from it dtypes = query_or_df.to_pandas(dry_run=True).columnDtypes - return columns_to_types_from_dtypes(dtypes.items()) + columns_to_types = columns_to_types_from_dtypes(dtypes.items()) + return columns_to_types, list(source_columns or columns_to_types) - return super()._columns_to_types(query_or_df, columns_to_types) + return super()._columns_to_types( + query_or_df, columns_to_types, source_columns=source_columns + ) def _native_df_to_pandas_df( self, diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py index fb515b7291..26be50aba0 100644 --- a/sqlmesh/core/engine_adapter/clickhouse.py +++ b/sqlmesh/core/engine_adapter/clickhouse.py @@ -16,6 +16,7 @@ InsertOverwriteStrategy, ) from sqlmesh.core.schema_diff import SchemaDiffer +from sqlmesh.utils import get_source_columns_to_types if t.TYPE_CHECKING: import pandas as pd @@ -92,9 +93,11 @@ def _df_to_source_queries( columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> t.List[SourceQuery]: temp_table = self._get_temp_table(target_table, **kwargs) + source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) def query_factory() -> Query: # It is possible for the factory to be called multiple times and if so then the temp table will already @@ -102,12 +105,17 @@ def query_factory() -> Query: # as later calls. if not self.table_exists(temp_table): self.create_table( - temp_table, columns_to_types, storage_format=exp.var("MergeTree"), **kwargs + temp_table, + source_columns_to_types, + storage_format=exp.var("MergeTree"), + **kwargs, ) self.cursor.client.insert_df(temp_table.sql(dialect=self.dialect), df=df) - return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) + return exp.select(*self._casted_columns(columns_to_types, source_columns)).from_( + temp_table + ) return [ SourceQuery( @@ -403,9 +411,10 @@ def _replace_by_key( columns_to_types: t.Optional[t.Dict[str, exp.DataType]], key: t.Sequence[exp.Expression], is_unique_key: bool, + source_columns: t.Optional[t.List[str]] = None, ) -> None: source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - source_table, columns_to_types, target_table=target_table + source_table, columns_to_types, target_table=target_table, source_columns=source_columns ) key_exp = exp.func("CONCAT_WS", "'__SQLMESH_DELIM__'", *key) if len(key) > 1 else key[0] @@ -425,9 +434,10 @@ def insert_overwrite_by_partition( query_or_df: QueryOrDF, partitioned_by: t.List[exp.Expression], columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> None: source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name + query_or_df, columns_to_types, target_table=table_name, source_columns=source_columns ) self._insert_overwrite_by_condition( diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 9d35726a32..dba7e58834 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -164,22 +164,21 @@ def _df_to_source_queries( columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: if not self._use_spark_session: return super(SparkEngineAdapter, self)._df_to_source_queries( - df, columns_to_types, batch_size, target_table + df, columns_to_types, batch_size, target_table, source_columns=source_columns ) - df = self._ensure_pyspark_df(df, columns_to_types) + pyspark_df = self._ensure_pyspark_df(df, columns_to_types, source_columns=source_columns) def query_factory() -> Query: temp_table = self._get_temp_table(target_table or "spark", table_only=True) - df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) + pyspark_df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) self._connection_pool.set_attribute("use_spark_engine_adapter", True) - return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) + return exp.select(*self._select_columns(columns_to_types)).from_(temp_table) - if self._use_spark_session: - return [SourceQuery(query_factory=query_factory)] - return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table) + return [SourceQuery(query_factory=query_factory)] def _fetch_native_df( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False diff --git a/sqlmesh/core/engine_adapter/duckdb.py b/sqlmesh/core/engine_adapter/duckdb.py index 00be5f426a..49231fcf87 100644 --- a/sqlmesh/core/engine_adapter/duckdb.py +++ b/sqlmesh/core/engine_adapter/duckdb.py @@ -64,10 +64,11 @@ def _df_to_source_queries( columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: temp_table = self._get_temp_table(target_table) temp_table_sql = ( - exp.select(*self._casted_columns(columns_to_types)) + exp.select(*self._casted_columns(columns_to_types, source_columns)) .from_("df") .sql(dialect=self.dialect) ) diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index 5ca1f200d9..81c092c517 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -32,6 +32,7 @@ def merge( unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: logical_merge( @@ -42,6 +43,7 @@ def merge( unique_key, when_matched=when_matched, merge_filter=merge_filter, + source_columns=source_columns, ) @@ -357,9 +359,15 @@ def _parse_clustering_key(self, clustering_key: t.Optional[str]) -> t.List[exp.E return parsed_cluster_key.expressions or [parsed_cluster_key.this] def get_alter_expressions( - self, current_table_name: TableName, target_table_name: TableName + self, + current_table_name: TableName, + target_table_name: TableName, + *, + ignore_destructive: bool = False, ) -> t.List[exp.Alter]: - expressions = super().get_alter_expressions(current_table_name, target_table_name) + expressions = super().get_alter_expressions( + current_table_name, target_table_name, ignore_destructive=ignore_destructive + ) # check for a change in clustering current_table = exp.to_table(current_table_name) @@ -416,6 +424,7 @@ def logical_merge( unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> None: """ Merge implementation for engine adapters that do not support merge natively. @@ -434,7 +443,12 @@ def logical_merge( ) engine_adapter._replace_by_key( - target_table, source_table, columns_to_types, unique_key, is_unique_key=True + target_table, + source_table, + columns_to_types, + unique_key, + is_unique_key=True, + source_columns=source_columns, ) diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index 112193073d..0edfb4f1f2 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -31,6 +31,7 @@ set_catalog, ) from sqlmesh.core.schema_diff import SchemaDiffer +from sqlmesh.utils import get_source_columns_to_types if t.TYPE_CHECKING: from sqlmesh.core._typing import SchemaName, TableName @@ -198,12 +199,13 @@ def merge( unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: mssql_merge_exists = kwargs.get("physical_properties", {}).get("mssql_merge_exists") source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - source_table, columns_to_types, target_table=target_table + source_table, columns_to_types, target_table=target_table, source_columns=source_columns ) columns_to_types = columns_to_types or self.columns(target_table) on = exp.and_( @@ -306,6 +308,7 @@ def _df_to_source_queries( columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: import pandas as pd import numpy as np @@ -315,25 +318,31 @@ def _df_to_source_queries( # Return the superclass implementation if the connection pool doesn't support bulk_copy if not hasattr(self._connection_pool.get(), "bulk_copy"): - return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table) + return super()._df_to_source_queries( + df, columns_to_types, batch_size, target_table, source_columns=source_columns + ) def query_factory() -> Query: # It is possible for the factory to be called multiple times and if so then the temp table will already # be created so we skip creating again. This means we are assuming the first call is the same result # as later calls. if not self.table_exists(temp_table): - columns_to_types_create = columns_to_types.copy() + source_columns_to_types = get_source_columns_to_types( + columns_to_types, source_columns + ) ordered_df = df[ - list(columns_to_types_create) + list(source_columns_to_types) ] # reorder DataFrame so it matches columns_to_types - self._convert_df_datetime(ordered_df, columns_to_types_create) - self.create_table(temp_table, columns_to_types_create) + self._convert_df_datetime(ordered_df, source_columns_to_types) + self.create_table(temp_table, source_columns_to_types) rows: t.List[t.Tuple[t.Any, ...]] = list( ordered_df.replace({np.nan: None}).itertuples(index=False, name=None) # type: ignore ) conn = self._connection_pool.get() conn.bulk_copy(temp_table.sql(dialect=self.dialect), rows) - return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) # type: ignore + return exp.select( + *self._casted_columns(columns_to_types, source_columns=source_columns) + ).from_(temp_table) # type: ignore return [ SourceQuery( diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index a736f5553b..dd58b4949b 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -110,6 +110,7 @@ def merge( unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: # Merge isn't supported until Postgres 15 @@ -122,6 +123,7 @@ def merge( unique_key, when_matched=when_matched, merge_filter=merge_filter, + source_columns=source_columns, ) @cached_property diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 829cdf3686..7b6b477d60 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -214,6 +214,7 @@ def create_view( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + source_columns: t.Optional[t.List[str]] = None, **create_kwargs: t.Any, ) -> None: """ @@ -240,6 +241,7 @@ def create_view( column_descriptions=column_descriptions, no_schema_binding=no_schema_binding, view_properties=view_properties, + source_columns=source_columns, **create_kwargs, ) @@ -250,6 +252,7 @@ def replace_query( columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: """ @@ -274,10 +277,14 @@ def replace_query( columns_to_types, table_description, column_descriptions, + source_columns=source_columns, **kwargs, ) source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name + query_or_df, + columns_to_types, + target_table=table_name, + source_columns=source_columns, ) columns_to_types = columns_to_types or self.columns(table_name) target_table = exp.to_table(table_name) @@ -358,6 +365,7 @@ def merge( unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: if self.enable_merge: @@ -369,6 +377,7 @@ def merge( unique_key=unique_key, when_matched=when_matched, merge_filter=merge_filter, + source_columns=source_columns, ) else: logical_merge( @@ -379,6 +388,7 @@ def merge( unique_key, when_matched=when_matched, merge_filter=merge_filter, + source_columns=source_columns, ) def _merge( diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 71ffc10f48..7bf6f3d303 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -24,7 +24,7 @@ set_catalog, ) from sqlmesh.core.schema_diff import SchemaDiffer -from sqlmesh.utils import optional_import +from sqlmesh.utils import optional_import, get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pandas import columns_to_types_from_dtypes @@ -198,6 +198,7 @@ def create_managed_table( table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: target_table = exp.to_table(table_name) @@ -218,7 +219,7 @@ def create_managed_table( ) source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query, columns_to_types, target_table=target_table + query, columns_to_types, target_table=target_table, source_columns=source_columns ) self._create_table_from_source_queries( @@ -246,6 +247,7 @@ def create_view( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + source_columns: t.Optional[t.List[str]] = None, **create_kwargs: t.Any, ) -> None: properties = create_kwargs.pop("properties", None) @@ -265,6 +267,7 @@ def create_view( column_descriptions=column_descriptions, view_properties=view_properties, properties=properties, + source_columns=source_columns, **create_kwargs, ) @@ -324,10 +327,13 @@ def _df_to_source_queries( columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: import pandas as pd from pandas.api.types import is_datetime64_any_dtype + source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) + temp_table = self._get_temp_table( target_table or "pandas", quoted=False ) # write_pandas() re-quotes everything without checking if its already quoted @@ -358,7 +364,7 @@ def query_factory() -> Query: local_df = df.rename( { col: exp.to_identifier(col).sql(dialect=self.dialect, identify=True) - for col in columns_to_types + for col in source_columns_to_types } ) # type: ignore local_df.createOrReplaceTempView( @@ -376,7 +382,7 @@ def query_factory() -> Query: self.set_current_schema(schema) # See: https://stackoverflow.com/a/75627721 - for column, kind in columns_to_types.items(): + for column, kind in source_columns_to_types.items(): if is_datetime64_any_dtype(df.dtypes[column]): if kind.is_type("date"): # type: ignore df[column] = pd.to_datetime(df[column]).dt.date # type: ignore @@ -392,7 +398,7 @@ def query_factory() -> Query: # create the table first using our usual method ensure the column datatypes match what we parsed with sqlglot # otherwise we would be trusting `write_pandas()` from the snowflake lib to do this correctly - self.create_table(temp_table, columns_to_types, table_kind="TEMPORARY TABLE") + self.create_table(temp_table, source_columns_to_types, table_kind="TEMPORARY TABLE") write_pandas( self._connection_pool.get(), @@ -409,7 +415,9 @@ def query_factory() -> Query: f"Unknown dataframe type: {type(df)} for {target_table}. Expecting pandas or snowpark." ) - return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) + return exp.select( + *self._casted_columns(columns_to_types, source_columns=source_columns) + ).from_(temp_table) def cleanup() -> None: if is_snowpark_dataframe: @@ -616,21 +624,35 @@ def clone_table( @t.overload def _columns_to_types( - self, query_or_df: DF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Dict[str, exp.DataType]: ... + self, + query_or_df: DF, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... @t.overload def _columns_to_types( - self, query_or_df: Query, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: ... + self, + query_or_df: Query, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... def _columns_to_types( - self, query_or_df: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: + self, + query_or_df: QueryOrDF, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: if not columns_to_types and snowpark and isinstance(query_or_df, snowpark.DataFrame): - return columns_to_types_from_dtypes(query_or_df.sample(n=1).to_pandas().dtypes.items()) + columns_to_types = columns_to_types_from_dtypes( + query_or_df.sample(n=1).to_pandas().dtypes.items() + ) + return columns_to_types, list(source_columns or columns_to_types) - return super()._columns_to_types(query_or_df, columns_to_types) + return super()._columns_to_types( + query_or_df, columns_to_types, source_columns=source_columns + ) def close(self) -> t.Any: if snowpark_session := self._connection_pool.get_attribute(self.SNOWPARK): diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 799d46a9c5..f015c0f158 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -240,24 +240,36 @@ def try_get_pandas_df(cls, value: t.Any) -> t.Optional[pd.DataFrame]: @t.overload def _columns_to_types( - self, query_or_df: DF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Dict[str, exp.DataType]: ... + self, + query_or_df: DF, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... @t.overload def _columns_to_types( - self, query_or_df: Query, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: ... + self, + query_or_df: Query, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... def _columns_to_types( - self, query_or_df: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: + self, + query_or_df: QueryOrDF, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: if columns_to_types: - return columns_to_types + return columns_to_types, list(source_columns or columns_to_types) if self.is_pyspark_df(query_or_df): from pyspark.sql import DataFrame - return self.spark_to_sqlglot_types(t.cast(DataFrame, query_or_df).schema) - return super()._columns_to_types(query_or_df, columns_to_types) + columns_to_types = self.spark_to_sqlglot_types(t.cast(DataFrame, query_or_df).schema) + return columns_to_types, list(source_columns or columns_to_types) + return super()._columns_to_types( + query_or_df, columns_to_types, source_columns=source_columns + ) def _df_to_source_queries( self, @@ -265,36 +277,54 @@ def _df_to_source_queries( columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: - df = self._ensure_pyspark_df(df, columns_to_types) + df = self._ensure_pyspark_df(df, columns_to_types, source_columns=source_columns) def query_factory() -> Query: temp_table = self._get_temp_table(target_table or "spark", table_only=True) df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore temp_table.set("db", "global_temp") - return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) + return exp.select(*self._select_columns(columns_to_types)).from_(temp_table) return [SourceQuery(query_factory=query_factory)] def _ensure_pyspark_df( - self, generic_df: DF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None + self, + generic_df: DF, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> PySparkDataFrame: - pyspark_df = self.try_get_pyspark_df(generic_df) - if pyspark_df: + def _get_pyspark_df() -> PySparkDataFrame: + pyspark_df = self.try_get_pyspark_df(generic_df) + if pyspark_df: + return pyspark_df + df = self.try_get_pandas_df(generic_df) + if df is None: + raise SQLMeshError( + "Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame" + ) + if columns_to_types: - # ensure Spark dataframe column order matches columns_to_types - pyspark_df = pyspark_df.select(*columns_to_types) - return pyspark_df - df = self.try_get_pandas_df(generic_df) - if df is None: - raise SQLMeshError("Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame") + source_columns_to_types = self.get_source_columns_to_types( + columns_to_types, source_columns + ) + # ensure Pandas dataframe column order matches columns_to_types + df = df[list(source_columns_to_types)] + else: + source_columns_to_types = None + kwargs = ( + dict(schema=self.sqlglot_to_spark_types(source_columns_to_types)) + if source_columns_to_types + else {} + ) + return self.spark.createDataFrame(df, **kwargs) # type: ignore + + df_result = _get_pyspark_df() if columns_to_types: - # ensure Pandas dataframe column order matches columns_to_types - df = df[list(columns_to_types)] - kwargs = ( - dict(schema=self.sqlglot_to_spark_types(columns_to_types)) if columns_to_types else {} - ) - return self.spark.createDataFrame(df, **kwargs) # type: ignore + select_columns = self._casted_columns(columns_to_types, source_columns=source_columns) + df_result = df_result.selectExpr(*[x.sql(self.dialect) for x in select_columns]) # type: ignore + return df_result def _get_temp_table( self, table: TableName, table_only: bool = False, quoted: bool = True diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index df8e45b520..592fe41109 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -27,6 +27,7 @@ set_catalog, ) from sqlmesh.core.schema_diff import SchemaDiffer +from sqlmesh.utils import get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.date import TimeLike @@ -218,21 +219,25 @@ def _df_to_source_queries( columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: import pandas as pd from pandas.api.types import is_datetime64_any_dtype # type: ignore assert isinstance(df, pd.DataFrame) + source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) # Trino does not accept timestamps in ISOFORMAT that include the "T". `execution_time` is stored in # Pandas with that format, so we convert the column to a string with the proper format and CAST to # timestamp in Trino. - for column, kind in (columns_to_types or {}).items(): + for column, kind in source_columns_to_types.items(): dtype = df.dtypes[column] if is_datetime64_any_dtype(dtype) and getattr(dtype, "tz", None) is not None: df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" ")) - return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table) + return super()._df_to_source_queries( + df, columns_to_types, batch_size, target_table, source_columns=source_columns + ) def _build_schema_exp( self, @@ -266,6 +271,7 @@ def _scd_type_2( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: if columns_to_types and self.current_catalog_type == "delta_lake": @@ -287,6 +293,7 @@ def _scd_type_2( table_description, column_descriptions, truncate, + source_columns, **kwargs, ) diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index 185556fc8f..470ce92c20 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -189,6 +189,7 @@ class OnDestructiveChange(str, Enum): ERROR = "ERROR" WARN = "WARN" ALLOW = "ALLOW" + IGNORE = "IGNORE" @property def is_error(self) -> bool: @@ -202,6 +203,10 @@ def is_warn(self) -> bool: def is_allow(self) -> bool: return self == OnDestructiveChange.ALLOW + @property + def is_ignore(self) -> bool: + return self == OnDestructiveChange.IGNORE + def _on_destructive_change_validator( cls: t.Type, v: t.Union[OnDestructiveChange, str, exp.Identifier] diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index db2b43345a..a1adca56fb 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -558,6 +558,7 @@ def _check_destructive_changes(self, directly_modified: t.Set[SnapshotId]) -> No new.name, old_columns_to_types, new_columns_to_types, + ignore_destructive=new.model.on_destructive_change.is_ignore, ) if has_drop_alteration(schema_diff): diff --git a/sqlmesh/core/schema_diff.py b/sqlmesh/core/schema_diff.py index 70b4f72163..1bf2f76672 100644 --- a/sqlmesh/core/schema_diff.py +++ b/sqlmesh/core/schema_diff.py @@ -556,6 +556,8 @@ def _alter_operation( current_type: t.Union[str, exp.DataType], root_struct: exp.DataType, new_kwarg: exp.ColumnDef, + *, + ignore_destructive: bool = False, ) -> t.List[TableAlterOperation]: # We don't copy on purpose here because current_type may need to be mutated inside # _get_operations (struct.expressions.pop and struct.expressions.insert) @@ -570,6 +572,7 @@ def _alter_operation( current_type, new_type, root_struct, + ignore_destructive=ignore_destructive, ) if new_type.this == current_type.this == exp.DataType.Type.ARRAY: @@ -587,6 +590,7 @@ def _alter_operation( current_array_type, new_array_type, root_struct, + ignore_destructive=ignore_destructive, ) if self._is_coerceable_type(current_type, new_type): return [] @@ -607,6 +611,8 @@ def _alter_operation( col_pos, ) ] + if ignore_destructive: + return [] return self._drop_operation(columns, root_struct, pos, root_struct) + self._add_operation( columns, pos, new_kwarg, struct, root_struct ) @@ -617,11 +623,16 @@ def _resolve_alter_operations( current_struct: exp.DataType, new_struct: exp.DataType, root_struct: exp.DataType, + *, + ignore_destructive: bool = False, ) -> t.List[TableAlterOperation]: operations = [] for current_pos, current_kwarg in enumerate(current_struct.expressions.copy()): _, new_kwarg = self._get_matching_kwarg(current_kwarg, new_struct, current_pos) - assert new_kwarg + if new_kwarg is None: + if ignore_destructive: + continue + raise ValueError("Cannot alter a column that is being dropped") _, new_type = _get_name_and_type(new_kwarg) _, current_type = _get_name_and_type(current_kwarg) columns = parent_columns + [TableAlterColumn.from_struct_kwarg(current_kwarg)] @@ -636,6 +647,7 @@ def _resolve_alter_operations( current_type, root_struct, new_kwarg, + ignore_destructive=ignore_destructive, ) ) return operations @@ -646,42 +658,54 @@ def _get_operations( current_struct: exp.DataType, new_struct: exp.DataType, root_struct: exp.DataType, + *, + ignore_destructive: bool = False, ) -> t.List[TableAlterOperation]: root_struct = root_struct or current_struct parent_columns = parent_columns or [] operations = [] - operations.extend( - self._resolve_drop_operation(parent_columns, current_struct, new_struct, root_struct) - ) + if not ignore_destructive: + operations.extend( + self._resolve_drop_operation( + parent_columns, current_struct, new_struct, root_struct + ) + ) operations.extend( self._resolve_add_operations(parent_columns, current_struct, new_struct, root_struct) ) operations.extend( - self._resolve_alter_operations(parent_columns, current_struct, new_struct, root_struct) + self._resolve_alter_operations( + parent_columns, + current_struct, + new_struct, + root_struct, + ignore_destructive=ignore_destructive, + ) ) return operations def _from_structs( - self, current_struct: exp.DataType, new_struct: exp.DataType + self, + current_struct: exp.DataType, + new_struct: exp.DataType, + *, + ignore_destructive: bool = False, ) -> t.List[TableAlterOperation]: - return self._get_operations([], current_struct, new_struct, current_struct) + return self._get_operations( + [], current_struct, new_struct, current_struct, ignore_destructive=ignore_destructive + ) - def compare_structs( - self, table_name: t.Union[str, exp.Table], current: exp.DataType, new: exp.DataType + def _compare_structs( + self, + table_name: t.Union[str, exp.Table], + current: exp.DataType, + new: exp.DataType, + *, + ignore_destructive: bool = False, ) -> t.List[exp.Alter]: - """ - Compares two schemas represented as structs. - - Args: - current: The current schema. - new: The new schema. - - Returns: - The list of table alter operations. - """ return [ op.expression(table_name, self.array_element_selector) - for op in self._from_structs(current, new) + for op in self._from_structs(current, new, ignore_destructive=ignore_destructive) ] def compare_columns( @@ -689,19 +713,14 @@ def compare_columns( table_name: TableName, current: t.Dict[str, exp.DataType], new: t.Dict[str, exp.DataType], + *, + ignore_destructive: bool = False, ) -> t.List[exp.Alter]: - """ - Compares two schemas represented as dictionaries of column names and types. - - Args: - current: The current schema. - new: The new schema. - - Returns: - The list of schema deltas. - """ - return self.compare_structs( - table_name, columns_to_types_to_struct(current), columns_to_types_to_struct(new) + return self._compare_structs( + table_name, + columns_to_types_to_struct(current), + columns_to_types_to_struct(new), + ignore_destructive=ignore_destructive, ) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 3937f37fba..0c15cdf26d 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -694,6 +694,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: end=end, execution_time=execution_time, physical_properties=rendered_physical_properties, + render_kwargs=render_statements_kwargs, ) else: logger.info( @@ -715,6 +716,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: end=end, execution_time=execution_time, physical_properties=rendered_physical_properties, + render_kwargs=render_statements_kwargs, ) with ( @@ -865,7 +867,9 @@ def _create_snapshot( rendered_physical_properties=rendered_physical_properties, ) alter_expressions = adapter.get_alter_expressions( - target_table_name, tmp_table_name + target_table_name, + tmp_table_name, + ignore_destructive=snapshot.model.on_destructive_change.is_ignore, ) _check_destructive_schema_change( snapshot, alter_expressions, allow_destructive_snapshots @@ -940,9 +944,9 @@ def _migrate_snapshot( evaluation_strategy = _evaluation_strategy(snapshot, adapter) tmp_table_name = snapshot.table_name(is_deployable=False) logger.info( - "Migrating table schema from '%s' to '%s'", - tmp_table_name, + "Migrating table schema '%s' to match '%s'", target_table_name, + tmp_table_name, ) evaluation_strategy.migrate( target_table_name=target_table_name, @@ -950,6 +954,7 @@ def _migrate_snapshot( snapshot=snapshot, snapshots=parent_snapshots_by_name(snapshot, snapshots), allow_destructive_snapshots=allow_destructive_snapshots, + ignore_destructive=snapshot.model.on_destructive_change.is_ignore, ) else: logger.info( @@ -1291,6 +1296,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: """Inserts the given query or a DataFrame into the target table or a view. @@ -1303,6 +1309,7 @@ def insert( if no data has been previously inserted into the target table, or when the entire history of the target model has been restated. Note that in the latter case, the table might contain data from previous executions, and it is the responsibility of a specific evaluation strategy to handle the truncation of the table if necessary. + render_kwargs: Additional key-value arguments to pass when rendering the model's query. """ @abc.abstractmethod @@ -1311,6 +1318,7 @@ def append( table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: """Appends the given query or a DataFrame to the existing table. @@ -1319,6 +1327,7 @@ def append( table_name: The target table name. query_or_df: A query or a DataFrame to insert. model: The target model. + render_kwargs: Additional key-value arguments to pass when rendering the model's query. """ @abc.abstractmethod @@ -1349,6 +1358,8 @@ def migrate( target_table_name: str, source_table_name: str, snapshot: Snapshot, + *, + ignore_destructive: bool, **kwargs: t.Any, ) -> None: """Migrates the target table schema so that it corresponds to the source table schema. @@ -1357,6 +1368,8 @@ def migrate( target_table_name: The target table name. source_table_name: The source table name. snapshot: The target snapshot. + ignore_destructive: If True, destructive changes are not created when migrating. + This is used for forward-only models that are being migrated to a new version. """ @abc.abstractmethod @@ -1393,36 +1406,6 @@ def demote(self, view_name: str, **kwargs: t.Any) -> None: view_name: The name of the target view in the virtual layer. """ - def _replace_query_for_model( - self, model: Model, name: str, query_or_df: QueryOrDF, **kwargs: t.Any - ) -> None: - """Replaces the table for the given model. - - Args: - model: The target model. - name: The name of the target table. - query_or_df: The query or DataFrame to replace the target table with. - """ - # Source columns from the underlying table to prevent unintentional table schema changes during restatement of incremental models. - columns_to_types = ( - model.columns_to_types - if (model.is_seed or model.kind.is_full) and model.annotated - else self.adapter.columns(name) - ) - self.adapter.replace_query( - name, - query_or_df, - table_format=model.table_format, - storage_format=model.storage_format, - partitioned_by=model.partitioned_by, - partition_interval_unit=model.partition_interval_unit, - clustered_by=model.clustered_by, - table_properties=kwargs.get("physical_properties", model.physical_properties), - table_description=model.description, - column_descriptions=model.column_descriptions, - columns_to_types=columns_to_types, - ) - class SymbolicStrategy(EvaluationStrategy): def insert( @@ -1431,6 +1414,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: pass @@ -1440,6 +1424,7 @@ def append( table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: pass @@ -1459,6 +1444,8 @@ def migrate( target_table_name: str, source_table_name: str, snapshot: Snapshot, + *, + ignore_destructive: bool, **kwarg: t.Any, ) -> None: pass @@ -1493,7 +1480,7 @@ def promote( self.adapter.drop_view(view_name, cascade=False) -class PromotableStrategy(EvaluationStrategy): +class PromotableStrategy(EvaluationStrategy, abc.ABC): def promote( self, table_name: str, @@ -1527,15 +1514,23 @@ def demote(self, view_name: str, **kwargs: t.Any) -> None: self.adapter.drop_view(view_name, cascade=False) -class MaterializableStrategy(PromotableStrategy): +class MaterializableStrategy(PromotableStrategy, abc.ABC): def append( self, table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - self.adapter.insert_append(table_name, query_or_df, columns_to_types=model.columns_to_types) + columns_to_types = kwargs.pop("columns_to_types", model.columns_to_types) + source_columns = kwargs.pop("source_columns", None) + self.adapter.insert_append( + table_name, + query_or_df, + columns_to_types=columns_to_types, + source_columns=source_columns, + ) def create( self, @@ -1592,10 +1587,14 @@ def migrate( target_table_name: str, source_table_name: str, snapshot: Snapshot, + *, + ignore_destructive: bool, **kwargs: t.Any, ) -> None: logger.info(f"Altering table '{target_table_name}'") - alter_expressions = self.adapter.get_alter_expressions(target_table_name, source_table_name) + alter_expressions = self.adapter.get_alter_expressions( + target_table_name, source_table_name, ignore_destructive=ignore_destructive + ) _check_destructive_schema_change( snapshot, alter_expressions, kwargs["allow_destructive_snapshots"] ) @@ -1606,6 +1605,66 @@ def delete(self, name: str, **kwargs: t.Any) -> None: self.adapter.drop_table(name) logger.info("Dropped table '%s'", name) + def _replace_query_for_model( + self, + model: Model, + name: str, + query_or_df: QueryOrDF, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + """Replaces the table for the given model. + + Args: + model: The target model. + name: The name of the target table. + query_or_df: The query or DataFrame to replace the target table with. + """ + if (model.is_seed or model.kind.is_full) and model.annotated: + columns_to_types = model.columns_to_types_or_raise + source_columns: t.Optional[t.List[str]] = list(columns_to_types) + else: + # Source columns from the underlying table to prevent unintentional table schema changes during restatement of incremental models. + columns_to_types, source_columns = self._get_target_and_source_columns( + model, name, render_kwargs, columns_to_types=self.adapter.columns(name) + ) + + self.adapter.replace_query( + name, + query_or_df, + table_format=model.table_format, + storage_format=model.storage_format, + partitioned_by=model.partitioned_by, + partition_interval_unit=model.partition_interval_unit, + clustered_by=model.clustered_by, + table_properties=kwargs.get("physical_properties", model.physical_properties), + table_description=model.description, + column_descriptions=model.column_descriptions, + columns_to_types=columns_to_types, + source_columns=source_columns, + ) + + def _get_target_and_source_columns( + self, + model: Model, + table_name: str, + render_kwargs: t.Dict[str, t.Any], + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + ) -> t.Tuple[t.Dict[str, exp.DataType], t.Optional[t.List[str]]]: + if not columns_to_types: + columns_to_types = ( + model.columns_to_types if model.annotated else self.adapter.columns(table_name) + ) + assert columns_to_types is not None + if model.on_destructive_change.is_ignore: + # We need to identify the columns that are only in the source so we create an empty table with + # the user query to determine that + with self.adapter.temp_table(model.ctas_query(**render_kwargs)) as temp_table: + source_columns = list(self.adapter.columns(temp_table)) + else: + source_columns = None + return columns_to_types, source_columns + class IncrementalByPartitionStrategy(MaterializableStrategy): def insert( @@ -1614,16 +1673,21 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: if is_first_insert: - self._replace_query_for_model(model, table_name, query_or_df, **kwargs) + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs, **kwargs) else: + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, render_kwargs=render_kwargs + ) self.adapter.insert_overwrite_by_partition( table_name, query_or_df, partitioned_by=model.partitioned_by, - columns_to_types=model.columns_to_types, + columns_to_types=columns_to_types, + source_columns=source_columns, ) @@ -1634,15 +1698,20 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: assert model.time_column + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, render_kwargs=render_kwargs + ) self.adapter.insert_overwrite_by_time_partition( table_name, query_or_df, time_formatter=model.convert_to_time_column, time_column=model.time_column, - columns_to_types=model.columns_to_types, + columns_to_types=columns_to_types, + source_columns=source_columns, **kwargs, ) @@ -1654,15 +1723,21 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: if is_first_insert: - self._replace_query_for_model(model, table_name, query_or_df, **kwargs) + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs, **kwargs) else: + columns_to_types, source_columns = self._get_target_and_source_columns( + model, + table_name, + render_kwargs=render_kwargs, + ) self.adapter.merge( table_name, query_or_df, - columns_to_types=model.columns_to_types, + columns_to_types=columns_to_types, unique_key=model.unique_key, when_matched=model.when_matched, merge_filter=model.render_merge_filter( @@ -1671,6 +1746,7 @@ def insert( execution_time=kwargs.get("execution_time"), ), physical_properties=kwargs.get("physical_properties", model.physical_properties), + source_columns=source_columns, ) def append( @@ -1678,12 +1754,16 @@ def append( table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, render_kwargs=render_kwargs + ) self.adapter.merge( table_name, query_or_df, - columns_to_types=model.columns_to_types, + columns_to_types=columns_to_types, unique_key=model.unique_key, when_matched=model.when_matched, merge_filter=model.render_merge_filter( @@ -1692,6 +1772,7 @@ def append( execution_time=kwargs.get("execution_time"), ), physical_properties=kwargs.get("physical_properties", model.physical_properties), + source_columns=source_columns, ) @@ -1702,24 +1783,36 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: if is_first_insert: - self._replace_query_for_model(model, table_name, query_or_df, **kwargs) - elif isinstance(model.kind, IncrementalUnmanagedKind) and model.kind.insert_overwrite: - self.adapter.insert_overwrite_by_partition( - table_name, - query_or_df, - model.partitioned_by, - columns_to_types=model.columns_to_types, + return self._replace_query_for_model( + model, table_name, query_or_df, render_kwargs, **kwargs ) - else: - self.append( + columns_to_types, source_columns = self._get_target_and_source_columns( + model, + table_name, + render_kwargs=render_kwargs, + columns_to_types=kwargs.pop("columns_to_types", None), + ) + if isinstance(model.kind, IncrementalUnmanagedKind) and model.kind.insert_overwrite: + return self.adapter.insert_overwrite_by_partition( table_name, query_or_df, - model, - **kwargs, + model.partitioned_by, + columns_to_types=columns_to_types, + source_columns=source_columns, ) + return self.append( + table_name, + query_or_df, + model, + render_kwargs=render_kwargs, + columns_to_types=columns_to_types, + source_columns=source_columns, + **kwargs, + ) class FullRefreshStrategy(MaterializableStrategy): @@ -1729,9 +1822,10 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - self._replace_query_for_model(model, table_name, query_or_df, **kwargs) + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs, **kwargs) class SeedStrategy(MaterializableStrategy): @@ -1760,7 +1854,9 @@ def create( try: for index, df in enumerate(model.render_seed()): if index == 0: - self._replace_query_for_model(model, table_name, df, **kwargs) + self._replace_query_for_model( + model, table_name, df, render_kwargs, **kwargs + ) else: self.adapter.insert_append( table_name, df, columns_to_types=model.columns_to_types @@ -1775,6 +1871,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: # Data has already been inserted at the time of table creation. @@ -1826,10 +1923,16 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - # Source columns from the underlying table to prevent unintentional table schema changes during the insert. - columns_to_types = self.adapter.columns(table_name) + # Source columns from the underlying table to prevent unintentional table schema changes during restatement of incremental models. + columns_to_types, source_columns = self._get_target_and_source_columns( + model, + table_name, + render_kwargs=render_kwargs, + columns_to_types=self.adapter.columns(table_name), + ) if isinstance(model.kind, SCDType2ByTimeKind): self.adapter.scd_type_2_by_time( target_table=table_name, @@ -1846,6 +1949,7 @@ def insert( table_description=model.description, column_descriptions=model.column_descriptions, truncate=is_first_insert, + source_columns=source_columns, ) elif isinstance(model.kind, SCDType2ByColumnKind): self.adapter.scd_type_2_by_column( @@ -1863,6 +1967,7 @@ def insert( table_description=model.description, column_descriptions=model.column_descriptions, truncate=is_first_insert, + source_columns=source_columns, ) else: raise SQLMeshError( @@ -1874,10 +1979,16 @@ def append( table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - # Source columns from the underlying table to prevent unintentional table schema changes during the insert. - columns_to_types = self.adapter.columns(table_name) + # Source columns from the underlying table to prevent unintentional table schema changes during restatement of incremental models. + columns_to_types, source_columns = self._get_target_and_source_columns( + model, + table_name, + render_kwargs=render_kwargs, + columns_to_types=self.adapter.columns(table_name), + ) if isinstance(model.kind, SCDType2ByTimeKind): self.adapter.scd_type_2_by_time( target_table=table_name, @@ -1892,6 +2003,7 @@ def append( table_format=model.table_format, table_description=model.description, column_descriptions=model.column_descriptions, + source_columns=source_columns, **kwargs, ) elif isinstance(model.kind, SCDType2ByColumnKind): @@ -1908,6 +2020,7 @@ def append( execution_time_as_valid_from=model.kind.execution_time_as_valid_from, table_description=model.description, column_descriptions=model.column_descriptions, + source_columns=source_columns, **kwargs, ) else: @@ -1923,6 +2036,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: deployability_index = ( @@ -1956,6 +2070,7 @@ def append( table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: raise ConfigError(f"Cannot append to a view '{table_name}'.") @@ -2011,6 +2126,8 @@ def migrate( target_table_name: str, source_table_name: str, snapshot: Snapshot, + *, + ignore_destructive: bool, **kwargs: t.Any, ) -> None: logger.info("Migrating view '%s'", target_table_name) @@ -2057,6 +2174,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: """Inserts the given query or a DataFrame into the target table or a view. @@ -2069,6 +2187,7 @@ def insert( if no data has been previously inserted into the target table, or when the entire history of the target model has been restated. Note that in the latter case, the table might contain data from previous executions, and it is the responsibility of a specific evaluation strategy to handle the truncation of the table if necessary. + render_kwargs: Additional key-value arguments to pass when rendering the model's query. """ raise NotImplementedError( "Custom materialization strategies must implement the 'insert' method." @@ -2204,6 +2323,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: deployability_index: DeployabilityIndex = kwargs["deployability_index"] @@ -2231,7 +2351,11 @@ def insert( model.name, ) self._replace_query_for_model( - model=model, name=table_name, query_or_df=query_or_df, **kwargs + model=model, + name=table_name, + query_or_df=query_or_df, + render_kwargs=render_kwargs, + **kwargs, ) def append( @@ -2239,6 +2363,7 @@ def append( table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: raise ConfigError(f"Cannot append to a managed table '{table_name}'.") @@ -2248,10 +2373,12 @@ def migrate( target_table_name: str, source_table_name: str, snapshot: Snapshot, + *, + ignore_destructive: bool, **kwargs: t.Any, ) -> None: potential_alter_expressions = self.adapter.get_alter_expressions( - target_table_name, source_table_name + target_table_name, source_table_name, ignore_destructive=ignore_destructive ) if len(potential_alter_expressions) > 0: # this can happen if a user changes a managed model and deliberately overrides a plan to be forward only, eg `sqlmesh plan --forward-only` diff --git a/sqlmesh/core/table_diff.py b/sqlmesh/core/table_diff.py index 126fa64b1e..b48b852f8a 100644 --- a/sqlmesh/core/table_diff.py +++ b/sqlmesh/core/table_diff.py @@ -494,7 +494,7 @@ def _column_expr(name: str, table: str) -> exp.Expression: schema = to_schema(temp_schema, dialect=self.dialect) temp_table = exp.table_("diff", db=schema.db, catalog=schema.catalog, quoted=True) - temp_table_kwargs = {} + temp_table_kwargs: t.Dict[str, t.Any] = {} if isinstance(self.adapter, AthenaEngineAdapter): # Athena has two table formats: Hive (the default) and Iceberg. TableDiff requires that # the formats be the same for the source, target, and temp tables. diff --git a/sqlmesh/utils/__init__.py b/sqlmesh/utils/__init__.py index 80e4fa5934..b0a3b566d5 100644 --- a/sqlmesh/utils/__init__.py +++ b/sqlmesh/utils/__init__.py @@ -403,3 +403,10 @@ def __str__(self) -> str: @classmethod def from_plan_id(cls, plan_id: str) -> CorrelationId: return CorrelationId(JobType.PLAN, plan_id) + + +def get_source_columns_to_types( + columns_to_types: t.Dict[str, exp.DataType], + source_columns: t.Optional[t.List[str]], +) -> t.Dict[str, exp.DataType]: + return {k: v for k, v in columns_to_types.items() if not source_columns or k in source_columns} diff --git a/tests/core/engine_adapter/integration/conftest.py b/tests/core/engine_adapter/integration/conftest.py index f072ca77f5..4d374cfdbc 100644 --- a/tests/core/engine_adapter/integration/conftest.py +++ b/tests/core/engine_adapter/integration/conftest.py @@ -145,7 +145,7 @@ def ctx_df( yield from create_test_context(*request.param) -@pytest.fixture(params=list(generate_pytest_params(ENGINES, query=True, df=True))) +@pytest.fixture(params=list(generate_pytest_params(ENGINES, query=True, df=False))) def ctx_query_and_df( request: FixtureRequest, create_test_context: t.Callable[[IntegrationTestEngine, str], t.Iterable[TestContext]], diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 039159825b..f97298cf2d 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -263,6 +263,52 @@ def test_ctas(ctx_query_and_df: TestContext): ctx.engine_adapter.ctas(table, exp.select("1").limit(0)) +def test_ctas_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + table = ctx.table("test_table") + + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + ctx.engine_adapter.ctas( + table, + ctx.input_data(input_data), + table_description="test table description", + column_descriptions={"id": "test id column description"}, + table_format=ctx.default_table_format, + columns_to_types=columns_to_types, + source_columns=["id", "ds"], + ) + + expected_data = input_data.copy() + expected_data["ignored_column"] = pd.Series() + + results = ctx.get_metadata_results(schema=table.db) + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, expected_data) + + if ctx.engine_adapter.COMMENT_CREATION_TABLE.is_supported: + table_description = ctx.get_table_comment(table.db, table.name) + column_comments = ctx.get_column_comments(table.db, table.name) + + assert table_description == "test table description" + assert column_comments == {"id": "test id column description"} + + # ensure we don't hit clickhouse INSERT with LIMIT 0 bug on CTAS + if ctx.dialect == "clickhouse": + ctx.engine_adapter.ctas(table, exp.select("1").limit(0)) + + def test_create_view(ctx_query_and_df: TestContext): ctx = ctx_query_and_df input_data = pd.DataFrame( @@ -306,6 +352,47 @@ def test_create_view(ctx_query_and_df: TestContext): ) +def test_create_view_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + view = ctx.table("test_view") + ctx.engine_adapter.create_view( + view, + ctx.input_data(input_data), + table_description="test view description", + column_descriptions={"id": "test id column description"}, + source_columns=["id", "ds"], + columns_to_types=columns_to_types, + ) + + expected_data = input_data.copy() + expected_data["ignored_column"] = pd.Series() + + results = ctx.get_metadata_results() + assert len(results.tables) == 0 + assert len(results.views) == 1 + assert len(results.materialized_views) == 0 + assert results.views[0] == view.name + ctx.compare_with_current(view, expected_data) + + if ctx.engine_adapter.COMMENT_CREATION_VIEW.is_supported: + table_description = ctx.get_table_comment(view.db, "test_view", table_kind="VIEW") + column_comments = ctx.get_column_comments(view.db, "test_view", table_kind="VIEW") + + assert table_description == "test view description" + assert column_comments == {"id": "test id column description"} + + def test_materialized_view(ctx_query_and_df: TestContext): ctx = ctx_query_and_df if not ctx.engine_adapter.SUPPORTS_MATERIALIZED_VIEWS: @@ -450,6 +537,67 @@ def test_replace_query(ctx_query_and_df: TestContext): ctx.compare_with_current(table, replace_data) +def test_replace_query_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + ctx.engine_adapter.DEFAULT_BATCH_SIZE = sys.maxsize + table = ctx.table("test_table") + + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + # Initial Load + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + ctx.engine_adapter.create_table(table, columns_to_types, table_format=ctx.default_table_format) + ctx.engine_adapter.replace_query( + table, + ctx.input_data(input_data), + table_format=ctx.default_table_format, + source_columns=["id", "ds"], + columns_to_types=columns_to_types, + ) + expected_data = input_data.copy() + expected_data["ignored_column"] = pd.Series() + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, expected_data) + + # Replace that we only need to run once + if type == "df": + replace_data = pd.DataFrame( + [ + {"id": 4, "ds": "2022-01-04"}, + {"id": 5, "ds": "2022-01-05"}, + {"id": 6, "ds": "2022-01-06"}, + ] + ) + ctx.engine_adapter.replace_query( + table, + ctx.input_data(replace_data), + table_format=ctx.default_table_format, + source_columns=["id", "ds"], + columns_to_types=columns_to_types, + ) + expected_data = replace_data.copy() + expected_data["ignored_column"] = pd.Series() + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, expected_data) + + def test_replace_query_batched(ctx_query_and_df: TestContext): ctx = ctx_query_and_df ctx.engine_adapter.DEFAULT_BATCH_SIZE = 1 @@ -548,6 +696,61 @@ def test_insert_append(ctx_query_and_df: TestContext): ctx.compare_with_current(table, pd.concat([input_data, append_data])) +def test_insert_append_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + table = ctx.table("test_table") + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + ctx.engine_adapter.create_table(table, columns_to_types, table_format=ctx.default_table_format) + # Initial Load + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + ctx.engine_adapter.insert_append( + table, + ctx.input_data(input_data), + source_columns=["id", "ds"], + columns_to_types=columns_to_types, + ) + expected_data = input_data.copy() + expected_data["ignored_column"] = pd.Series() + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, expected_data) + + # Replace that we only need to run once + if ctx.test_type == "df": + append_data = pd.DataFrame( + [ + {"id": 4, "ds": "2022-01-04"}, + {"id": 5, "ds": "2022-01-05"}, + {"id": 6, "ds": "2022-01-06"}, + ] + ) + ctx.engine_adapter.insert_append( + table, + ctx.input_data(append_data), + source_columns=["id", "ds"], + columns_to_types=columns_to_types, + ) + append_expected_data = append_data.copy() + append_expected_data["ignored_column"] = pd.Series() + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) in [1, 2, 3] + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, pd.concat([expected_data, append_expected_data])) + + def test_insert_overwrite_by_time_partition(ctx_query_and_df: TestContext): ctx = ctx_query_and_df ds_type = "string" @@ -636,6 +839,105 @@ def test_insert_overwrite_by_time_partition(ctx_query_and_df: TestContext): ) +def test_insert_overwrite_by_time_partition_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + ds_type = "string" + if ctx.dialect == "bigquery": + ds_type = "datetime" + if ctx.dialect == "tsql": + ds_type = "varchar(max)" + + ctx.columns_to_types = {"id": "int", "ds": ds_type} + columns_to_types = { + "id": exp.DataType.build("int"), + "ignored_column": exp.DataType.build("int"), + "ds": exp.DataType.build(ds_type), + } + table = ctx.table("test_table") + if ctx.dialect == "bigquery": + partitioned_by = ["DATE(ds)"] + else: + partitioned_by = ctx.partitioned_by # type: ignore + ctx.engine_adapter.create_table( + table, + columns_to_types, + partitioned_by=partitioned_by, + partition_interval_unit="DAY", + table_format=ctx.default_table_format, + ) + input_data = pd.DataFrame( + [ + {"id": 1, ctx.time_column: "2022-01-01"}, + {"id": 2, ctx.time_column: "2022-01-02"}, + {"id": 3, ctx.time_column: "2022-01-03"}, + ] + ) + ctx.engine_adapter.insert_overwrite_by_time_partition( + table, + ctx.input_data(input_data), + start="2022-01-02", + end="2022-01-03", + time_formatter=ctx.time_formatter, + time_column=ctx.time_column, + columns_to_types=columns_to_types, + source_columns=["id", "ds"], + ) + + expected_data = input_data.copy() + expected_data.insert(len(expected_data.columns) - 1, "ignored_column", pd.Series()) + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + + if ctx.dialect == "trino": + # trino has some lag between partitions being registered and data showing up + wait_until(lambda: len(ctx.get_current_data(table)) > 0) + + ctx.compare_with_current(table, expected_data.iloc[1:]) + + if ctx.test_type == "df": + overwrite_data = pd.DataFrame( + [ + {"id": 10, ctx.time_column: "2022-01-03"}, + {"id": 4, ctx.time_column: "2022-01-04"}, + {"id": 5, ctx.time_column: "2022-01-05"}, + ] + ) + ctx.engine_adapter.insert_overwrite_by_time_partition( + table, + ctx.input_data(overwrite_data), + start="2022-01-03", + end="2022-01-05", + time_formatter=ctx.time_formatter, + time_column=ctx.time_column, + columns_to_types=columns_to_types, + source_columns=["id", "ds"], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + + if ctx.dialect == "trino": + wait_until(lambda: len(ctx.get_current_data(table)) > 2) + + ctx.compare_with_current( + table, + pd.DataFrame( + [ + {"id": 2, "ignored_column": None, ctx.time_column: "2022-01-02"}, + {"id": 10, "ignored_column": None, ctx.time_column: "2022-01-03"}, + {"id": 4, "ignored_column": None, ctx.time_column: "2022-01-04"}, + {"id": 5, "ignored_column": None, ctx.time_column: "2022-01-05"}, + ] + ), + ) + + def test_merge(ctx_query_and_df: TestContext): ctx = ctx_query_and_df if not ctx.supports_merge: @@ -702,20 +1004,96 @@ def test_merge(ctx_query_and_df: TestContext): ) -def test_scd_type_2_by_time(ctx_query_and_df: TestContext): +def test_merge_source_columns(ctx_query_and_df: TestContext): ctx = ctx_query_and_df - # Athena only supports the operations required for SCD models on Iceberg tables - if ctx.mark == "athena_hive": - pytest.skip("SCD Type 2 is only supported on Athena / Iceberg") + if not ctx.supports_merge: + pytest.skip(f"{ctx.dialect} doesn't support merge") - time_type = exp.DataType.build("timestamp") + table = ctx.table("test_table") - ctx.columns_to_types = { - "id": "int", - "name": "string", - "updated_at": time_type, - "valid_from": time_type, - "valid_to": time_type, + # Athena only supports MERGE on Iceberg tables + # And it cant fall back to a logical merge on Hive tables because it cant delete records + table_format = "iceberg" if ctx.dialect == "athena" else None + + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + ctx.engine_adapter.create_table(table, columns_to_types, table_format=table_format) + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + ctx.engine_adapter.merge( + table, + ctx.input_data(input_data), + unique_key=[exp.to_identifier("id")], + columns_to_types=columns_to_types, + source_columns=["id", "ds"], + ) + + expected_data = input_data.copy() + expected_data["ignored_column"] = pd.Series() + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, expected_data) + + if ctx.test_type == "df": + merge_data = pd.DataFrame( + [ + {"id": 2, "ds": "2022-01-10"}, + {"id": 4, "ds": "2022-01-04"}, + {"id": 5, "ds": "2022-01-05"}, + ] + ) + ctx.engine_adapter.merge( + table, + ctx.input_data(merge_data), + unique_key=[exp.to_identifier("id")], + columns_to_types=columns_to_types, + source_columns=["id", "ds"], + ) + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01", "ignored_column": None}, + {"id": 2, "ds": "2022-01-10", "ignored_column": None}, + {"id": 3, "ds": "2022-01-03", "ignored_column": None}, + {"id": 4, "ds": "2022-01-04", "ignored_column": None}, + {"id": 5, "ds": "2022-01-05", "ignored_column": None}, + ] + ), + ) + + +def test_scd_type_2_by_time(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + # Athena only supports the operations required for SCD models on Iceberg tables + if ctx.mark == "athena_hive": + pytest.skip("SCD Type 2 is only supported on Athena / Iceberg") + + time_type = exp.DataType.build("timestamp") + + ctx.columns_to_types = { + "id": "int", + "name": "string", + "updated_at": time_type, + "valid_from": time_type, + "valid_to": time_type, } table = ctx.table("test_table") input_schema = { @@ -857,6 +1235,174 @@ def test_scd_type_2_by_time(ctx_query_and_df: TestContext): ) +def test_scd_type_2_by_time_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + # Athena only supports the operations required for SCD models on Iceberg tables + if ctx.mark == "athena_hive": + pytest.skip("SCD Type 2 is only supported on Athena / Iceberg") + + time_type = exp.DataType.build("timestamp") + + ctx.columns_to_types = { + "id": "int", + "name": "string", + "updated_at": time_type, + "valid_from": time_type, + "valid_to": time_type, + } + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + table = ctx.table("test_table") + input_schema = { + k: v for k, v in ctx.columns_to_types.items() if k not in ("valid_from", "valid_to") + } + + ctx.engine_adapter.create_table(table, columns_to_types, table_format=ctx.default_table_format) + input_data = pd.DataFrame( + [ + {"id": 1, "name": "a", "updated_at": "2022-01-01 00:00:00"}, + {"id": 2, "name": "b", "updated_at": "2022-01-02 00:00:00"}, + {"id": 3, "name": "c", "updated_at": "2022-01-03 00:00:00"}, + ] + ) + ctx.engine_adapter.scd_type_2_by_time( + table, + ctx.input_data(input_data, input_schema), + unique_key=[parse_one("COALESCE(id, -1)")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + updated_at_col=exp.column("updated_at", quoted=True), + execution_time="2023-01-01 00:00:00", + updated_at_as_valid_from=False, + table_format=ctx.default_table_format, + truncate=True, + start="2022-01-01 00:00:00", + columns_to_types=columns_to_types, + source_columns=["id", "name", "updated_at"], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "updated_at": "2022-01-01 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 2, + "name": "b", + "updated_at": "2022-01-02 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 3, + "name": "c", + "updated_at": "2022-01-03 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + ] + ), + ) + + if ctx.test_type == "query": + return + + current_data = pd.DataFrame( + [ + # Change `a` to `x` + {"id": 1, "name": "x", "updated_at": "2022-01-04 00:00:00"}, + # Delete + # {"id": 2, "name": "b", "updated_at": "2022-01-02 00:00:00"}, + # No change + {"id": 3, "name": "c", "updated_at": "2022-01-03 00:00:00"}, + # Add + {"id": 4, "name": "d", "updated_at": "2022-01-04 00:00:00"}, + ] + ) + ctx.engine_adapter.scd_type_2_by_time( + table, + ctx.input_data(current_data, input_schema), + unique_key=[exp.to_column("id")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + updated_at_col=exp.column("updated_at", quoted=True), + execution_time="2023-01-05 00:00:00", + updated_at_as_valid_from=False, + table_format=ctx.default_table_format, + truncate=False, + start="2022-01-01 00:00:00", + columns_to_types=columns_to_types, + source_columns=["id", "name", "updated_at"], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "updated_at": "2022-01-01 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2022-01-04 00:00:00", + "ignored_column": None, + }, + { + "id": 1, + "name": "x", + "updated_at": "2022-01-04 00:00:00", + "valid_from": "2022-01-04 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 2, + "name": "b", + "updated_at": "2022-01-02 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + "ignored_column": None, + }, + { + "id": 3, + "name": "c", + "updated_at": "2022-01-03 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 4, + "name": "d", + "updated_at": "2022-01-04 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + ] + ), + ) + + def test_scd_type_2_by_column(ctx_query_and_df: TestContext): ctx = ctx_query_and_df # Athena only supports the operations required for SCD models on Iceberg tables @@ -1034,6 +1580,199 @@ def test_scd_type_2_by_column(ctx_query_and_df: TestContext): ) +def test_scd_type_2_by_column_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + # Athena only supports the operations required for SCD models on Iceberg tables + if ctx.mark == "athena_hive": + pytest.skip("SCD Type 2 is only supported on Athena / Iceberg") + + time_type = exp.DataType.build("timestamp") + + ctx.columns_to_types = { + "id": "int", + "name": "string", + "status": "string", + "valid_from": time_type, + "valid_to": time_type, + } + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + table = ctx.table("test_table") + input_schema = { + k: v for k, v in ctx.columns_to_types.items() if k not in ("valid_from", "valid_to") + } + + ctx.engine_adapter.create_table(table, columns_to_types, table_format=ctx.default_table_format) + input_data = pd.DataFrame( + [ + {"id": 1, "name": "a", "status": "active"}, + {"id": 2, "name": "b", "status": "inactive"}, + {"id": 3, "name": "c", "status": "active"}, + {"id": 4, "name": "d", "status": "active"}, + ] + ) + ctx.engine_adapter.scd_type_2_by_column( + table, + ctx.input_data(input_data, input_schema), + unique_key=[exp.to_column("id")], + check_columns=[exp.to_column("name"), exp.to_column("status")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + execution_time="2023-01-01", + execution_time_as_valid_from=False, + truncate=True, + start="2023-01-01", + columns_to_types=columns_to_types, + source_columns=["id", "name", "status"], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 2, + "name": "b", + "status": "inactive", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 3, + "name": "c", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 4, + "name": "d", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + ] + ), + ) + + if ctx.test_type == "query": + return + + current_data = pd.DataFrame( + [ + # Change `a` to `x` + {"id": 1, "name": "x", "status": "active"}, + # Delete + # {"id": 2, "name": "b", status: "inactive"}, + # No change + {"id": 3, "name": "c", "status": "active"}, + # Change status to inactive + {"id": 4, "name": "d", "status": "inactive"}, + # Add + {"id": 5, "name": "e", "status": "inactive"}, + ] + ) + ctx.engine_adapter.scd_type_2_by_column( + table, + ctx.input_data(current_data, input_schema), + unique_key=[exp.to_column("id")], + check_columns=[exp.to_column("name"), exp.to_column("status")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + execution_time="2023-01-05 00:00:00", + execution_time_as_valid_from=False, + truncate=False, + start="2023-01-01", + columns_to_types=columns_to_types, + source_columns=["id", "name", "status"], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + "ignored_column": None, + }, + { + "id": 1, + "name": "x", + "status": "active", + "valid_from": "2023-01-05 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 2, + "name": "b", + "status": "inactive", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + "ignored_column": None, + }, + { + "id": 3, + "name": "c", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 4, + "name": "d", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + "ignored_column": None, + }, + { + "id": 4, + "name": "d", + "status": "inactive", + "valid_from": "2023-01-05 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 5, + "name": "e", + "status": "inactive", + "valid_from": "2023-01-05 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + ] + ), + ) + + def test_get_data_objects(ctx_query_and_df: TestContext): ctx = ctx_query_and_df table = ctx.table("test_table") diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index 618d89a445..afe506143f 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -76,6 +76,36 @@ def test_create_view_pandas(make_mocked_engine_adapter: t.Callable): ] +def test_create_view_pandas_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + bigint_dtype = exp.DataType.build("BIGINT") + adapter.create_view( + "test_view", + pd.DataFrame({"a": [1, 2, 3]}), + columns_to_types={"a": bigint_dtype, "b": bigint_dtype}, + replace=False, + source_columns=["a"], + ) + + assert to_sql_calls(adapter) == [ + 'CREATE VIEW "test_view" ("a", "b") AS SELECT "a", CAST(NULL AS BIGINT) AS "b" FROM (SELECT CAST("a" AS BIGINT) AS "a" FROM (VALUES (1), (2), (3)) AS "t"("a")) AS "select_source_columns"', + ] + + +def test_create_view_query_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.create_view( + "test_view", + parse_one("SELECT a FROM tbl"), + columns_to_types={"a": exp.DataType.build("BIGINT"), "b": exp.DataType.build("BIGINT")}, + replace=False, + source_columns=["a"], + ) + assert to_sql_calls(adapter) == [ + 'CREATE VIEW "test_view" ("a", "b") AS SELECT "a", CAST(NULL AS BIGINT) AS "b" FROM (SELECT "a" FROM "tbl") AS "select_source_columns"', + ] + + def test_create_materialized_view(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.SUPPORTS_MATERIALIZED_VIEWS = True @@ -274,6 +304,47 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas( ] +def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_source_columns( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INSERT_OVERWRITE + df = pd.DataFrame({"a": [1, 2]}) + adapter.insert_overwrite_by_time_partition( + "test_table", + df, + start="2022-01-01", + end="2022-01-02", + time_column="ds", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + source_columns=["a"], + ) + assert to_sql_calls(adapter) == [ + """INSERT OVERWRITE TABLE "test_table" ("a", "ds") SELECT "a", "ds" FROM (SELECT CAST("a" AS INT) AS "a", CAST(NULL AS TEXT) AS "ds" FROM (VALUES (1), (2)) AS "t"("a")) AS "_subquery" WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02'""" + ] + + +def test_insert_overwrite_by_time_partition_supports_insert_overwrite_query_source_columns( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INSERT_OVERWRITE + adapter.insert_overwrite_by_time_partition( + "test_table", + parse_one("SELECT a FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="ds", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + source_columns=["a"], + ) + assert to_sql_calls(adapter) == [ + """INSERT OVERWRITE TABLE "test_table" ("a", "ds") SELECT "a", "ds" FROM (SELECT "a", CAST(NULL AS TEXT) AS "ds" FROM (SELECT "a" FROM "tbl") AS "select_source_columns") AS "_subquery" WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02'""" + ] + + def test_insert_overwrite_by_time_partition_replace_where(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE @@ -316,6 +387,47 @@ def test_insert_overwrite_by_time_partition_replace_where_pandas( ] +def test_insert_overwrite_by_time_partition_replace_where_pandas_source_columns( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE + df = pd.DataFrame({"a": [1, 2]}) + adapter.insert_overwrite_by_time_partition( + "test_table", + df, + start="2022-01-01", + end="2022-01-02", + time_column="ds", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + source_columns=["a"], + ) + assert to_sql_calls(adapter) == [ + """INSERT INTO "test_table" REPLACE WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02' SELECT "a", "ds" FROM (SELECT CAST("a" AS INT) AS "a", CAST(NULL AS TEXT) AS "ds" FROM (VALUES (1), (2)) AS "t"("a")) AS "_subquery" WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02'""" + ] + + +def test_insert_overwrite_by_time_partition_replace_where_query_source_columns( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE + adapter.insert_overwrite_by_time_partition( + "test_table", + parse_one("SELECT a FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="ds", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + source_columns=["a"], + ) + assert to_sql_calls(adapter) == [ + """INSERT INTO "test_table" REPLACE WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02' SELECT "a", "ds" FROM (SELECT "a", CAST(NULL AS TEXT) AS "ds" FROM (SELECT "a" FROM "tbl") AS "select_source_columns") AS "_subquery" WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02'""" + ] + + def test_insert_overwrite_no_where(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) @@ -434,6 +546,39 @@ def test_insert_append_pandas_batches(make_mocked_engine_adapter: t.Callable): ] +def test_insert_append_pandas_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + df = pd.DataFrame({"a": [1, 2, 3]}) + adapter.insert_append( + "test_table", + df, + columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a"], + ) + assert to_sql_calls(adapter) == [ + 'INSERT INTO "test_table" ("a", "b") SELECT CAST("a" AS INT) AS "a", CAST(NULL AS INT) AS "b" FROM (VALUES (1), (2), (3)) AS "t"("a")', + ] + + +def test_insert_append_query_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.insert_append( + "test_table", + parse_one("SELECT a FROM tbl"), + columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a"], + ) + assert to_sql_calls(adapter) == [ + 'INSERT INTO "test_table" ("a", "b") SELECT "a", CAST(NULL AS INT) AS "b" FROM (SELECT "a" FROM "tbl") AS "select_source_columns"', + ] + + def test_create_table(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) @@ -899,9 +1044,11 @@ def test_alter_table( original_from_structs = adapter.SCHEMA_DIFFER._from_structs def _from_structs( - current_struct: exp.DataType, new_struct: exp.DataType + current_struct: exp.DataType, new_struct: exp.DataType, *, ignore_destructive: bool = False ) -> t.List[TableAlterOperation]: - operations = original_from_structs(current_struct, new_struct) + operations = original_from_structs( + current_struct, new_struct, ignore_destructive=ignore_destructive + ) if not operations: return operations assert ( @@ -1018,6 +1165,47 @@ def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable): ) +def test_merge_upsert_pandas_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + df = pd.DataFrame({"id": [1, 2, 3], "ts": [4, 5, 6]}) + adapter.merge( + target_table="target", + source_table=df, + columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("id")], + source_columns=["id", "ts"], + ) + adapter.cursor.execute.assert_called_once_with( + 'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT CAST("id" AS INT) AS "id", CAST("ts" AS TIMESTAMP) AS "ts", CAST(NULL AS INT) AS "val" FROM (VALUES (1, 4), (2, 5), (3, 6)) AS "t"("id", "ts")) AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" ' + 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id", "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val" ' + 'WHEN NOT MATCHED THEN INSERT ("id", "ts", "val") VALUES ("__MERGE_SOURCE__"."id", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")' + ) + + +def test_merge_upsert_query_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.merge( + target_table="target", + source_table=parse_one("SELECT id, ts FROM source"), + columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("id")], + source_columns=["id", "ts"], + ) + adapter.cursor.execute.assert_called_once_with( + 'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT "id", "ts", CAST(NULL AS INT) AS "val" FROM (SELECT "id", "ts" FROM "source") AS "select_source_columns") AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" ' + 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id", "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val" ' + 'WHEN NOT MATCHED THEN INSERT ("id", "ts", "val") VALUES ("__MERGE_SOURCE__"."id", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")' + ) + + def test_merge_when_matched(make_mocked_engine_adapter: t.Callable, assert_exp_eq): adapter = make_mocked_engine_adapter(EngineAdapter) @@ -1416,6 +1604,219 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): ) +def test_scd_type_2_by_time_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + df = pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["a", "b", "c"], + "test_UPDATED_at": [ + "2020-01-01 10:00:00", + "2020-01-02 15:00:00", + "2020-01-03 12:00:00", + ], + } + ) + adapter.scd_type_2_by_time( + target_table="target", + source_table=df, + unique_key=[exp.column("id")], + valid_from_col=exp.column("test_valid_from", quoted=True), + valid_to_col=exp.column("test_valid_to", quoted=True), + updated_at_col=exp.column("test_UPDATED_at", quoted=True), + columns_to_types={ + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("VARCHAR"), + "price": exp.DataType.build("DOUBLE"), + "test_UPDATED_at": exp.DataType.build("TIMESTAMP"), + "test_valid_from": exp.DataType.build("TIMESTAMP"), + "test_valid_to": exp.DataType.build("TIMESTAMP"), + }, + source_columns=["id", "name", "test_UPDATED_at"], + execution_time=datetime(2020, 1, 1, 0, 0, 0), + start=datetime(2020, 1, 1, 0, 0, 0), + is_restatement=True, + ) + sql_calls = to_sql_calls(adapter) + assert ( + parse_one(sql_calls[1]).sql() + == parse_one(""" +CREATE OR REPLACE TABLE "target" AS +WITH "source" AS ( + SELECT DISTINCT ON ("id") + TRUE AS "_exists", + "id", + "name", + "price", + CAST("test_UPDATED_at" AS TIMESTAMP) AS "test_UPDATED_at" + FROM ( + SELECT + CAST("id" AS INT) AS "id", + CAST("name" AS VARCHAR) AS "name", + CAST(NULL AS DOUBLE) AS "price", + CAST("test_UPDATED_at" AS TIMESTAMP) AS "test_UPDATED_at" + FROM (VALUES + (1, 'a', '2020-01-01 10:00:00'), + (2, 'b', '2020-01-02 15:00:00'), + (3, 'c', '2020-01-03 12:00:00')) AS "t"("id", "name", "test_UPDATED_at") + ) AS "raw_source" +), "static" AS ( + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_valid_from", + "test_valid_to", + TRUE AS "_exists" + FROM "target" + WHERE + NOT "test_valid_to" IS NULL +), "latest" AS ( + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_valid_from", + "test_valid_to", + TRUE AS "_exists" + FROM "target" + WHERE + "test_valid_to" IS NULL +), "deleted" AS ( + SELECT + "static"."id", + "static"."name", + "static"."price", + "static"."test_UPDATED_at", + "static"."test_valid_from", + "static"."test_valid_to" + FROM "static" + LEFT JOIN "latest" + ON "static"."id" = "latest"."id" + WHERE + "latest"."test_valid_to" IS NULL +), "latest_deleted" AS ( + SELECT + TRUE AS "_exists", + "id" AS "_key0", + MAX("test_valid_to") AS "test_valid_to" + FROM "deleted" + GROUP BY + "id" +), "joined" AS ( + SELECT + "source"."_exists" AS "_exists", + "latest"."id" AS "t_id", + "latest"."name" AS "t_name", + "latest"."price" AS "t_price", + "latest"."test_UPDATED_at" AS "t_test_UPDATED_at", + "latest"."test_valid_from" AS "t_test_valid_from", + "latest"."test_valid_to" AS "t_test_valid_to", + "source"."id" AS "id", + "source"."name" AS "name", + "source"."price" AS "price", + "source"."test_UPDATED_at" AS "test_UPDATED_at" + FROM "latest" + LEFT JOIN "source" + ON "latest"."id" = "source"."id" + UNION ALL + SELECT + "source"."_exists" AS "_exists", + "latest"."id" AS "t_id", + "latest"."name" AS "t_name", + "latest"."price" AS "t_price", + "latest"."test_UPDATED_at" AS "t_test_UPDATED_at", + "latest"."test_valid_from" AS "t_test_valid_from", + "latest"."test_valid_to" AS "t_test_valid_to", + "source"."id" AS "id", + "source"."name" AS "name", + "source"."price" AS "price", + "source"."test_UPDATED_at" AS "test_UPDATED_at" + FROM "latest" + RIGHT JOIN "source" + ON "latest"."id" = "source"."id" + WHERE + "latest"."_exists" IS NULL +), "updated_rows" AS ( + SELECT + COALESCE("joined"."t_id", "joined"."id") AS "id", + COALESCE("joined"."t_name", "joined"."name") AS "name", + COALESCE("joined"."t_price", "joined"."price") AS "price", + COALESCE("joined"."t_test_UPDATED_at", "joined"."test_UPDATED_at") AS "test_UPDATED_at", + CASE + WHEN "t_test_valid_from" IS NULL AND NOT "latest_deleted"."_exists" IS NULL + THEN CASE + WHEN "latest_deleted"."test_valid_to" > "test_UPDATED_at" + THEN "latest_deleted"."test_valid_to" + ELSE "test_UPDATED_at" + END + WHEN "t_test_valid_from" IS NULL + THEN CAST('1970-01-01 00:00:00' AS TIMESTAMP) + ELSE "t_test_valid_from" + END AS "test_valid_from", + CASE + WHEN "joined"."test_UPDATED_at" > "joined"."t_test_UPDATED_at" + THEN "joined"."test_UPDATED_at" + WHEN "joined"."_exists" IS NULL + THEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) + ELSE "t_test_valid_to" + END AS "test_valid_to" + FROM "joined" + LEFT JOIN "latest_deleted" + ON "joined"."id" = "latest_deleted"."_key0" +), "inserted_rows" AS ( + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_UPDATED_at" AS "test_valid_from", + CAST(NULL AS TIMESTAMP) AS "test_valid_to" + FROM "joined" + WHERE + "joined"."test_UPDATED_at" > "joined"."t_test_UPDATED_at" +) +SELECT + CAST("id" AS INT) AS "id", + CAST("name" AS VARCHAR) AS "name", + CAST("price" AS DOUBLE) AS "price", + CAST("test_UPDATED_at" AS TIMESTAMP) AS "test_UPDATED_at", + CAST("test_valid_from" AS TIMESTAMP) AS "test_valid_from", + CAST("test_valid_to" AS TIMESTAMP) AS "test_valid_to" +FROM ( + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_valid_from", + "test_valid_to" + FROM "static" + UNION ALL + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_valid_from", + "test_valid_to" + FROM "updated_rows" + UNION ALL + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_valid_from", + "test_valid_to" + FROM "inserted_rows" +) AS "_subquery" + """).sql() + ) + + def test_scd_type_2_by_time_no_invalidate_hard_deletes(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) @@ -2774,6 +3175,39 @@ def test_replace_query_pandas(make_mocked_engine_adapter: t.Callable): ] +def test_replace_query_pandas_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + df = pd.DataFrame({"a": [1, 2, 3]}) + adapter.replace_query( + "test_table", + df, + columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a"], + ) + assert to_sql_calls(adapter) == [ + 'CREATE OR REPLACE TABLE "test_table" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT CAST("a" AS INT) AS "a", CAST(NULL AS INT) AS "b" FROM (VALUES (1), (2), (3)) AS "t"("a")) AS "_subquery"', + ] + + +def test_replace_query_query_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.replace_query( + "test_table", + parse_one("SELECT a FROM tbl"), + columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a"], + ) + assert to_sql_calls(adapter) == [ + 'CREATE OR REPLACE TABLE "test_table" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT "a", CAST(NULL AS INT) AS "b" FROM (SELECT "a" FROM "tbl") AS "select_source_columns") AS "_subquery"', + ] + + def test_replace_query_self_referencing_not_exists_unknown( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture ): @@ -2922,6 +3356,39 @@ def test_ctas_pandas(make_mocked_engine_adapter: t.Callable): ] +def test_ctas_pandas_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + df = pd.DataFrame({"a": [1, 2, 3]}) + adapter.ctas( + "test_table", + df, + columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a"], + ) + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test_table" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT CAST("a" AS INT) AS "a", CAST(NULL AS INT) AS "b" FROM (VALUES (1), (2), (3)) AS "t"("a")) AS "_subquery"', + ] + + +def test_ctas_query_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.ctas( + "test_table", + parse_one("SELECT a FROM tbl"), + columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a"], + ) + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test_table" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT "a", CAST(NULL AS INT) AS "b" FROM (SELECT "a" FROM "tbl") AS "select_source_columns") AS "_subquery"', + ] + + def test_drop_view(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) @@ -3142,3 +3609,71 @@ def test_log_sql(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): mock_logger.log.call_args_list[4][0][2] == 'CREATE OR REPLACE TABLE "test" AS SELECT CAST("id" AS BIGINT) AS "id", CAST("value" AS TEXT) AS "value" FROM (SELECT CAST("id" AS BIGINT) AS "id", CAST("value" AS TEXT) AS "value" FROM (VALUES "") AS "t"("id", "value")) AS "_subquery"' ) + + +@pytest.mark.parametrize( + "columns, source_columns, expected", + [ + (["a", "b"], None, 'SELECT "a", "b"'), + (["a", "b"], ["a"], 'SELECT "a", NULL AS "b"'), + (["a", "b"], ["a", "b"], 'SELECT "a", "b"'), + (["a", "b"], ["c", "d"], 'SELECT NULL AS "a", NULL AS "b"'), + (["a", "b"], [], 'SELECT "a", "b"'), + ], +) +def test_select_columns( + columns: t.List[str], source_columns: t.Optional[t.List[str]], expected: str +) -> None: + assert ( + EngineAdapter._select_columns( + columns, + source_columns, + ).sql() + == expected + ) + + +@pytest.mark.parametrize( + "columns_to_types, source_columns, expected", + [ + ( + { + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("TEXT"), + }, + None, + [ + 'CAST("a" AS INT) AS "a"', + 'CAST("b" AS TEXT) AS "b"', + ], + ), + ( + { + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("TEXT"), + }, + ["a"], + [ + 'CAST("a" AS INT) AS "a"', + 'CAST(NULL AS TEXT) AS "b"', + ], + ), + ( + { + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("TEXT"), + }, + ["b", "c"], + [ + 'CAST(NULL AS INT) AS "a"', + 'CAST("b" AS TEXT) AS "b"', + ], + ), + ], +) +def test_casted_columns( + columns_to_types: t.Dict[str, exp.DataType], source_columns: t.List[str], expected: t.List[str] +) -> None: + assert [ + x.sql() for x in EngineAdapter._casted_columns(columns_to_types, source_columns) + ] == expected diff --git a/tests/core/engine_adapter/test_bigquery.py b/tests/core/engine_adapter/test_bigquery.py index 32377ac1de..79d6fbf9db 100644 --- a/tests/core/engine_adapter/test_bigquery.py +++ b/tests/core/engine_adapter/test_bigquery.py @@ -205,7 +205,7 @@ def temp_table_exists(table: exp.Table) -> bool: assert load_temp_table.kwargs["job_config"].write_disposition is None assert ( merge_sql.sql(dialect="bigquery") - == "MERGE INTO test_table AS __MERGE_TARGET__ USING (SELECT `a`, `ds` FROM (SELECT `a`, `ds` FROM project.dataset.temp_table) AS _subquery WHERE ds BETWEEN '2022-01-01' AND '2022-01-05') AS __MERGE_SOURCE__ ON FALSE WHEN NOT MATCHED BY SOURCE AND ds BETWEEN '2022-01-01' AND '2022-01-05' THEN DELETE WHEN NOT MATCHED THEN INSERT (a, ds) VALUES (a, ds)" + == "MERGE INTO test_table AS __MERGE_TARGET__ USING (SELECT `a`, `ds` FROM (SELECT CAST(`a` AS INT64) AS `a`, CAST(`ds` AS STRING) AS `ds` FROM project.dataset.temp_table) AS _subquery WHERE ds BETWEEN '2022-01-01' AND '2022-01-05') AS __MERGE_SOURCE__ ON FALSE WHEN NOT MATCHED BY SOURCE AND ds BETWEEN '2022-01-01' AND '2022-01-05' THEN DELETE WHEN NOT MATCHED THEN INSERT (a, ds) VALUES (a, ds)" ) assert ( drop_temp_table_sql.sql(dialect="bigquery") @@ -295,7 +295,7 @@ def temp_table_exists(table: exp.Table) -> bool: ] sql_calls = _to_sql_calls(execute_mock) assert sql_calls == [ - "CREATE OR REPLACE TABLE `test_table` AS SELECT CAST(`a` AS INT64) AS `a`, CAST(`b` AS INT64) AS `b` FROM (SELECT `a`, `b` FROM `project`.`dataset`.`temp_table`) AS `_subquery`", + "CREATE OR REPLACE TABLE `test_table` AS SELECT CAST(`a` AS INT64) AS `a`, CAST(`b` AS INT64) AS `b` FROM (SELECT CAST(`a` AS INT64) AS `a`, CAST(`b` AS INT64) AS `b` FROM `project`.`dataset`.`temp_table`) AS `_subquery`", "DROP TABLE IF EXISTS `project`.`dataset`.`temp_table`", ] @@ -498,7 +498,7 @@ def temp_table_exists(table: exp.Table) -> bool: sql_calls = _to_sql_calls(execute_mock, identify=False) assert sql_calls == [ - "MERGE INTO target AS __MERGE_TARGET__ USING (SELECT `id`, `ts`, `val` FROM project.dataset.temp_table) AS __MERGE_SOURCE__ ON __MERGE_TARGET__.id = __MERGE_SOURCE__.id " + "MERGE INTO target AS __MERGE_TARGET__ USING (SELECT CAST(`id` AS INT64) AS `id`, CAST(`ts` AS DATETIME) AS `ts`, CAST(`val` AS INT64) AS `val` FROM project.dataset.temp_table) AS __MERGE_SOURCE__ ON __MERGE_TARGET__.id = __MERGE_SOURCE__.id " "WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.id = __MERGE_SOURCE__.id, __MERGE_TARGET__.ts = __MERGE_SOURCE__.ts, __MERGE_TARGET__.val = __MERGE_SOURCE__.val " "WHEN NOT MATCHED THEN INSERT (id, ts, val) VALUES (__MERGE_SOURCE__.id, __MERGE_SOURCE__.ts, __MERGE_SOURCE__.val)", "DROP TABLE IF EXISTS project.dataset.temp_table", diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 7248b2a724..b21f77da39 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -2059,12 +2059,13 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: nonlocal custom_insert_called custom_insert_called = True - self._replace_query_for_model(model, table_name, query_or_df) + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs) model = context.get_model("sushi.top_waiters") kwargs = { @@ -2104,6 +2105,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: assert isinstance(model.kind, TestCustomKind) @@ -2111,7 +2113,7 @@ def insert( nonlocal custom_insert_calls custom_insert_calls.append(model.kind.custom_property) - self._replace_query_for_model(model, table_name, query_or_df) + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs) model = context.get_model("sushi.top_waiters") kwargs = { @@ -7625,3 +7627,926 @@ def test_default_audits_with_custom_audit_definitions(tmp_path: Path): if audit_name == "positive_amount": assert "column" in audit_args assert audit_args["column"].name == "amount" + + +def test_incremental_by_time_model_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + assert updated_df["new_column"].dropna().tolist() == [3] + + with time_machine.travel("2023-01-11 00:00:00 UTC"): + updated_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + CAST(4 AS STRING) as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(updated_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True, run=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 3 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + # The destructive change was ignored but this change is coercable and therefore we still return ints + assert updated_df["new_column"].dropna().tolist() == [3, 4] + + with time_machine.travel("2023-01-12 00:00:00 UTC"): + updated_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + CAST(5 AS STRING) as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(updated_model) + + context = Context(paths=[tmp_path], config=config) + # Make the change compatible since that means we will attempt and alter now that is considered additive + context.engine_adapter.SCHEMA_DIFFER.compatible_types = { + exp.DataType.build("INT"): {exp.DataType.build("STRING")} + } + context.plan("prod", auto_apply=True, no_prompts=True, run=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 4 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + # The change is now reflected since an additive alter could be performed + assert updated_df["new_column"].dropna().tolist() == ["3", "4", "5"] + + context.close() + + +def test_incremental_by_unique_key_model_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + +def test_incremental_unmanaged_model_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + set_console(TerminalConsole()) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_UNMANAGED( + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_UNMANAGED( + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + +def test_scd_type_2_by_time_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + set_console(TerminalConsole()) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_TIME ( + unique_key id, + updated_at_name ds, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_dt as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_TIME ( + unique_key id, + updated_at_name ds, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 3 as new_column, + @start_dt as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + +def test_scd_type_2_by_column_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key id, + columns [name], + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key id, + columns [new_column], + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + +def test_incremental_partition_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_PARTITION ( + on_destructive_change ignore + ), + partitioned_by [ds], + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_PARTITION ( + on_destructive_change ignore + ), + partitioned_by [ds], + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + +def test_incremental_by_time_model_ignore_destructive_change_unit_test(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + test_dir = tmp_path / "tests" + test_dir.mkdir() + test_filepath = test_dir / "test_test_model.yaml" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + id, + name, + ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + initial_test = f""" + +test_test_model: + model: test_model + inputs: + source_table: + - id: 1 + name: 'test_name' + ds: '2025-01-01' + outputs: + query: + - id: 1 + name: 'test_name' + ds: '2025-01-01' +""" + + # Write initial test + test_filepath.write_text(initial_test) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute( + "CREATE TABLE source_table (id INT, name STRING, new_column INT, ds STRING)" + ) + context.engine_adapter.execute( + "INSERT INTO source_table VALUES (1, 'test_name', NULL, '2023-01-01')" + ) + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + test_result = context.test() + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + assert len(test_result.successes) == 1 + assert test_result.testsRun == len(test_result.successes) + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + id, + new_column, + ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + updated_test = f""" + + test_test_model: + model: test_model + inputs: + source_table: + - id: 1 + new_column: 3 + ds: '2025-01-01' + outputs: + query: + - id: 1 + new_column: 3 + ds: '2025-01-01' + """ + + # Write initial test + test_filepath.write_text(updated_test) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + test_result = context.test() + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 1 + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + assert len(test_result.successes) == 1 + assert test_result.testsRun == len(test_result.successes) + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("INSERT INTO source_table VALUES (2, NULL, 3, '2023-01-09')") + context.run() + test_result = context.test() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + assert len(test_result.successes) == 1 + assert test_result.testsRun == len(test_result.successes) + + context.close() diff --git a/tests/core/test_schema_diff.py b/tests/core/test_schema_diff.py index 1e57cab57c..fd14b0b9b3 100644 --- a/tests/core/test_schema_diff.py +++ b/tests/core/test_schema_diff.py @@ -1331,6 +1331,201 @@ def test_schema_diff_alter_op_column(): ) +@pytest.mark.parametrize( + "current_struct, new_struct, expected_diff_with_destructive, expected_diff_ignore_destructive, config", + [ + # Simple DROP operation - should be ignored when ignore_destructive=True + ( + "STRUCT", + "STRUCT", + [ + TableAlterOperation.drop( + TableAlterColumn.primitive("name"), + "STRUCT", + "STRING", + ) + ], + [], # No operations when ignoring destructive + {}, + ), + # DROP + ADD operation (incompatible type change) - should be ignored when ignore_destructive=True + ( + "STRUCT", + "STRUCT", + [ + TableAlterOperation.drop( + TableAlterColumn.primitive("name"), + "STRUCT", + "STRING", + ), + TableAlterOperation.add( + TableAlterColumn.primitive("name"), + "BIGINT", + "STRUCT", + ), + ], + [], # No operations when ignoring destructive + {}, + ), + # Pure ADD operation - should work same way regardless of ignore_destructive + ( + "STRUCT", + "STRUCT", + [ + TableAlterOperation.add( + TableAlterColumn.primitive("new_col"), + "STRING", + "STRUCT", + ), + ], + [ + # Same operation when ignoring destructive + TableAlterOperation.add( + TableAlterColumn.primitive("new_col"), + "STRING", + "STRUCT", + ), + ], + {}, + ), + # Mix of destructive and non-destructive operations + ( + "STRUCT", + "STRUCT", + [ + TableAlterOperation.drop( + TableAlterColumn.primitive("name"), + "STRUCT", + "STRING", + ), + TableAlterOperation.add( + TableAlterColumn.primitive("address"), + "STRING", + "STRUCT", + ), + TableAlterOperation.alter_type( + TableAlterColumn.primitive("id"), + "STRING", + current_type="INT", + expected_table_struct="STRUCT", + ), + ], + [ + # Only non-destructive operations remain + TableAlterOperation.add( + TableAlterColumn.primitive("address"), + "STRING", + "STRUCT", + ), + TableAlterOperation.alter_type( + TableAlterColumn.primitive("id"), + "STRING", + current_type="INT", + expected_table_struct="STRUCT", + ), + ], + dict( + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + } + ), + ), + ], +) +def test_ignore_destructive_operations( + current_struct, + new_struct, + expected_diff_with_destructive: t.List[TableAlterOperation], + expected_diff_ignore_destructive: t.List[TableAlterOperation], + config: t.Dict[str, t.Any], +): + resolver = SchemaDiffer(**config) + + # Test with destructive operations allowed (default behavior) + operations_with_destructive = resolver._from_structs( + exp.DataType.build(current_struct), exp.DataType.build(new_struct), ignore_destructive=False + ) + assert operations_with_destructive == expected_diff_with_destructive + + # Test with destructive operations ignored + operations_ignore_destructive = resolver._from_structs( + exp.DataType.build(current_struct), exp.DataType.build(new_struct), ignore_destructive=True + ) + assert operations_ignore_destructive == expected_diff_ignore_destructive + + +def test_ignore_destructive_compare_columns(): + """Test ignore_destructive behavior in compare_columns method.""" + schema_differ = SchemaDiffer( + support_positional_add=True, + support_nested_operations=False, + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + }, + ) + + current = { + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("STRING"), + "to_drop": exp.DataType.build("DOUBLE"), + "age": exp.DataType.build("INT"), + } + + new = { + "id": exp.DataType.build("STRING"), # Compatible type change + "name": exp.DataType.build("STRING"), + "age": exp.DataType.build("INT"), + "new_col": exp.DataType.build("DOUBLE"), # New column + } + + # With destructive operations allowed + alter_expressions_with_destructive = schema_differ.compare_columns( + "test_table", current, new, ignore_destructive=False + ) + assert len(alter_expressions_with_destructive) == 3 # DROP + ADD + ALTER + + # With destructive operations ignored + alter_expressions_ignore_destructive = schema_differ.compare_columns( + "test_table", current, new, ignore_destructive=True + ) + assert len(alter_expressions_ignore_destructive) == 2 # Only ADD + ALTER + + # Verify the operations are correct + operations_sql = [expr.sql() for expr in alter_expressions_ignore_destructive] + add_column_found = any("ADD COLUMN new_col DOUBLE" in op for op in operations_sql) + alter_column_found = any("ALTER COLUMN id SET DATA TYPE" in op for op in operations_sql) + drop_column_found = any("DROP COLUMN to_drop" in op for op in operations_sql) + + assert add_column_found, f"ADD COLUMN not found in: {operations_sql}" + assert alter_column_found, f"ALTER COLUMN not found in: {operations_sql}" + assert not drop_column_found, f"DROP COLUMN should not be present in: {operations_sql}" + + +def test_ignore_destructive_nested_struct_without_support(): + """Test ignore_destructive with nested structs when nested_drop is not supported.""" + schema_differ = SchemaDiffer( + support_nested_operations=True, + support_nested_drop=False, # This forces DROP+ADD for nested changes + ) + + current_struct = "STRUCT>" + new_struct = "STRUCT>" # Removes col_b + + # With destructive operations allowed - should do DROP+ADD of entire struct + operations_with_destructive = schema_differ._from_structs( + exp.DataType.build(current_struct), exp.DataType.build(new_struct), ignore_destructive=False + ) + assert len(operations_with_destructive) == 2 # DROP struct + ADD struct + assert operations_with_destructive[0].is_drop + assert operations_with_destructive[1].is_add + + # With destructive operations ignored - should do nothing + operations_ignore_destructive = schema_differ._from_structs( + exp.DataType.build(current_struct), exp.DataType.build(new_struct), ignore_destructive=True + ) + assert len(operations_ignore_destructive) == 0 + + def test_get_schema_differ(): # Test that known dialects return SchemaDiffer instances for dialect in ["bigquery", "snowflake", "postgres", "databricks", "spark", "duckdb"]: @@ -1376,3 +1571,43 @@ def test_get_schema_differ(): schema_differ_upper.support_coercing_compatible_types == schema_differ_lower.support_coercing_compatible_types ) + + +def test_ignore_destructive_edge_cases(): + """Test edge cases for ignore_destructive behavior.""" + schema_differ = SchemaDiffer(support_positional_add=True) + + # Test when all operations are destructive - should result in empty list + current_struct = "STRUCT" + new_struct = "STRUCT<>" # Remove all columns + + operations_ignore_destructive = schema_differ._from_structs( + exp.DataType.build(current_struct), exp.DataType.build(new_struct), ignore_destructive=True + ) + assert len(operations_ignore_destructive) == 0 + + # Test when no operations are needed - should result in empty list regardless of ignore_destructive + same_struct = "STRUCT" + + operations_same_with_destructive = schema_differ._from_structs( + exp.DataType.build(same_struct), exp.DataType.build(same_struct), ignore_destructive=False + ) + operations_same_ignore_destructive = schema_differ._from_structs( + exp.DataType.build(same_struct), exp.DataType.build(same_struct), ignore_destructive=True + ) + assert len(operations_same_with_destructive) == 0 + assert len(operations_same_ignore_destructive) == 0 + + # Test when only ADD operations are needed - should be same regardless of ignore_destructive + current_struct = "STRUCT" + new_struct = "STRUCT" + + operations_add_with_destructive = schema_differ._from_structs( + exp.DataType.build(current_struct), exp.DataType.build(new_struct), ignore_destructive=False + ) + operations_add_ignore_destructive = schema_differ._from_structs( + exp.DataType.build(current_struct), exp.DataType.build(new_struct), ignore_destructive=True + ) + assert len(operations_add_with_destructive) == 2 # ADD name, ADD age + assert len(operations_add_ignore_destructive) == 2 # Same operations + assert operations_add_with_destructive == operations_add_ignore_destructive diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index d474159b7c..6eae6376f1 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -677,6 +677,8 @@ def test_evaluate_incremental_unmanaged_with_intervals( snapshot.categorize_as(SnapshotChangeCategory.BREAKING) snapshot.intervals = [(to_timestamp("2020-01-01"), to_timestamp("2020-01-02"))] + adapter_mock.columns.return_value = model.columns_to_types + evaluator = SnapshotEvaluator(adapter_mock) evaluator.evaluate( snapshot, @@ -692,12 +694,14 @@ def test_evaluate_incremental_unmanaged_with_intervals( model.render_query(), [exp.to_column("ds", quoted=True)], columns_to_types=model.columns_to_types, + source_columns=None, ) else: adapter_mock.insert_append.assert_called_once_with( snapshot.table_name(), model.render_query(), columns_to_types=model.columns_to_types, + source_columns=None, ) @@ -738,6 +742,7 @@ def test_evaluate_incremental_unmanaged_no_intervals( storage_format=None, table_description=None, table_properties={}, + source_columns=None, ) adapter_mock.columns.assert_called_once_with(snapshot.table_name()) @@ -1627,6 +1632,7 @@ def test_create_clone_in_dev(mocker: MockerFixture, adapter_mock, make_snapshot) adapter_mock.get_alter_expressions.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev__schema_migration_source", + ignore_destructive=False, ) adapter_mock.alter_table.assert_called_once_with([]) @@ -1727,6 +1733,7 @@ def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_m adapter_mock.get_alter_expressions.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev__schema_migration_source", + ignore_destructive=False, ) adapter_mock.alter_table.assert_called_once_with([]) @@ -2105,6 +2112,7 @@ def test_insert_into_scd_type_2_by_time( column_descriptions={}, updated_at_as_valid_from=False, truncate=truncate, + source_columns=None, ) adapter_mock.columns.assert_called_once_with(snapshot.table_name()) @@ -2277,6 +2285,7 @@ def test_insert_into_scd_type_2_by_column( table_description=None, column_descriptions={}, truncate=truncate, + source_columns=None, ) adapter_mock.columns.assert_called_once_with(snapshot.table_name()) @@ -2345,6 +2354,7 @@ def test_create_incremental_by_unique_key_updated_at_exp(adapter_mock, make_snap ] ), physical_properties={}, + source_columns=None, ) @@ -2434,6 +2444,7 @@ def test_create_incremental_by_unique_key_multiple_updated_at_exp(adapter_mock, ], ), physical_properties={}, + source_columns=None, ) @@ -2484,6 +2495,7 @@ def test_create_incremental_by_unique_no_intervals(adapter_mock, make_snapshot): storage_format=None, table_description=None, table_properties={}, + source_columns=None, ) adapter_mock.columns.assert_called_once_with(snapshot.table_name()) @@ -2582,6 +2594,7 @@ def test_create_incremental_by_unique_key_merge_filter(adapter_mock, make_snapsh ), ), physical_properties={}, + source_columns=None, ) @@ -2621,6 +2634,7 @@ def test_create_seed(mocker: MockerFixture, adapter_mock, make_snapshot): f"sqlmesh__db.db__seed__{snapshot.version}", mocker.ANY, column_descriptions={}, + source_columns=["id", "name"], **common_create_kwargs, ) @@ -2692,6 +2706,7 @@ def test_create_seed_on_error(mocker: MockerFixture, adapter_mock, make_snapshot clustered_by=[], table_properties={}, table_description=None, + source_columns=["id", "name"], ) adapter_mock.drop_table.assert_called_once_with(f"sqlmesh__db.db__seed__{snapshot.version}") @@ -2748,6 +2763,7 @@ def test_create_seed_no_intervals(mocker: MockerFixture, adapter_mock, make_snap clustered_by=[], table_properties={}, table_description=None, + source_columns=["id", "name"], ) @@ -3253,6 +3269,7 @@ def test_evaluate_incremental_by_partition(mocker: MockerFixture, make_snapshot, storage_format=None, table_description=None, table_format=None, + source_columns=None, ) adapter_mock.reset_mock() @@ -3275,6 +3292,7 @@ def test_evaluate_incremental_by_partition(mocker: MockerFixture, make_snapshot, exp.to_column("b", quoted=True), ], columns_to_types=model.columns_to_types, + source_columns=None, ) @@ -3291,6 +3309,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: nonlocal custom_insert_kind @@ -3365,6 +3384,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: nonlocal custom_insert_kind @@ -3583,6 +3603,7 @@ def test_evaluate_managed(adapter_mock, make_snapshot, mocker: MockerFixture): table_properties=model.physical_properties, table_description=model.description, column_descriptions=model.column_descriptions, + source_columns=None, ) adapter_mock.columns.assert_called_once_with(snapshot.table_name(is_deployable=False)) @@ -3851,6 +3872,7 @@ def test_migrate_snapshot(snapshot: Snapshot, mocker: MockerFixture, adapter_moc adapter_mock.get_alter_expressions.assert_called_once_with( snapshot.table_name(), new_snapshot.table_name(is_deployable=False), + ignore_destructive=False, ) @@ -4117,7 +4139,9 @@ def columns(table_name): # The second mock adapter has to be called only for the gateway-specific model adapter_mock.get_alter_expressions.assert_called_once_with( - snapshot_2.table_name(True), snapshot_2.table_name(False) + snapshot_2.table_name(True), + snapshot_2.table_name(False), + ignore_destructive=False, ) From 25822fe62b5c255bfc9072bc8ee9feb34c290ad7 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Thu, 14 Aug 2025 15:18:22 -0700 Subject: [PATCH 2/4] feedback --- sqlmesh/core/dialect.py | 14 +- sqlmesh/core/engine_adapter/athena.py | 24 +- sqlmesh/core/engine_adapter/base.py | 359 +++++++++--------- sqlmesh/core/engine_adapter/base_postgres.py | 4 +- sqlmesh/core/engine_adapter/bigquery.py | 46 ++- sqlmesh/core/engine_adapter/clickhouse.py | 56 +-- sqlmesh/core/engine_adapter/databricks.py | 14 +- sqlmesh/core/engine_adapter/duckdb.py | 12 +- sqlmesh/core/engine_adapter/mixins.py | 26 +- sqlmesh/core/engine_adapter/mssql.py | 35 +- sqlmesh/core/engine_adapter/postgres.py | 4 +- sqlmesh/core/engine_adapter/redshift.py | 28 +- sqlmesh/core/engine_adapter/snowflake.py | 40 +- sqlmesh/core/engine_adapter/spark.py | 62 +-- sqlmesh/core/engine_adapter/trino.py | 32 +- sqlmesh/core/snapshot/evaluator.py | 59 +-- sqlmesh/core/state_sync/db/environment.py | 4 +- sqlmesh/core/state_sync/db/interval.py | 4 +- sqlmesh/core/state_sync/db/snapshot.py | 6 +- sqlmesh/core/state_sync/db/version.py | 2 +- sqlmesh/core/table_diff.py | 2 +- .../v0007_env_table_info_to_kind.py | 2 +- .../migrations/v0009_remove_pre_post_hooks.py | 2 +- .../migrations/v0011_add_model_kind_name.py | 2 +- .../v0012_update_jinja_expressions.py | 2 +- .../v0013_serde_using_model_dialects.py | 2 +- sqlmesh/migrations/v0016_fix_windows_path.py | 2 +- .../migrations/v0017_fix_windows_seed_path.py | 2 +- .../v0018_rename_snapshot_model_to_node.py | 2 +- ...ve_redundant_attributes_from_dbt_models.py | 2 +- .../migrations/v0021_fix_table_properties.py | 2 +- .../migrations/v0022_move_project_to_model.py | 2 +- ...replace_model_kind_name_enum_with_value.py | 2 +- ...x_intervals_and_missing_change_category.py | 4 +- .../v0026_remove_dialect_from_seed.py | 2 +- .../v0027_minute_interval_to_five.py | 2 +- ...029_generate_schema_types_using_dialect.py | 2 +- .../v0030_update_unrestorable_snapshots.py | 2 +- .../v0031_remove_dbt_target_fields.py | 2 +- .../migrations/v0034_add_default_catalog.py | 8 +- .../v0037_remove_dbt_is_incremental_macro.py | 2 +- .../v0038_add_expiration_ts_to_snapshot.py | 2 +- ...39_include_environment_in_plan_dag_spec.py | 2 +- .../v0041_remove_hash_raw_query_attribute.py | 2 +- .../v0042_trim_indirect_versions.py | 2 +- ...remove_obsolete_attributes_in_plan_dags.py | 2 +- .../migrations/v0045_move_gateway_variable.py | 2 +- .../v0048_drop_indirect_versions.py | 2 +- .../v0051_rename_column_descriptions.py | 2 +- ...used_ts_ttl_ms_unrestorable_to_snapshot.py | 2 +- .../migrations/v0056_restore_table_indexes.py | 6 +- .../migrations/v0060_move_audits_to_model.py | 2 +- sqlmesh/migrations/v0063_change_signals.py | 2 +- .../v0064_join_when_matched_strings.py | 2 +- .../v0069_update_dev_table_suffix.py | 4 +- .../v0071_add_dev_version_to_intervals.py | 4 +- ...073_remove_symbolic_disable_restatement.py | 2 +- .../migrations/v0075_remove_validate_query.py | 2 +- .../migrations/v0081_update_partitioned_by.py | 2 +- .../migrations/v0085_deterministic_repr.py | 2 +- .../v0087_normalize_blueprint_variables.py | 2 +- .../v0090_add_forward_only_column.py | 2 +- sqlmesh/utils/__init__.py | 7 +- .../engine_adapter/integration/__init__.py | 2 +- .../integration/test_integration.py | 62 +-- .../integration/test_integration_athena.py | 22 +- tests/core/engine_adapter/test_athena.py | 24 +- tests/core/engine_adapter/test_base.py | 124 +++--- tests/core/engine_adapter/test_bigquery.py | 16 +- tests/core/engine_adapter/test_clickhouse.py | 4 +- tests/core/engine_adapter/test_databricks.py | 2 +- tests/core/engine_adapter/test_mixins.py | 4 +- tests/core/engine_adapter/test_mssql.py | 31 +- tests/core/engine_adapter/test_postgres.py | 4 +- tests/core/engine_adapter/test_redshift.py | 18 +- tests/core/engine_adapter/test_snowflake.py | 27 +- tests/core/engine_adapter/test_spark.py | 16 +- tests/core/engine_adapter/test_trino.py | 10 +- tests/core/state_sync/test_state_sync.py | 8 +- tests/core/test_context.py | 2 +- tests/core/test_integration.py | 34 +- tests/core/test_snapshot_evaluator.py | 71 ++-- tests/dbt/test_adapter.py | 12 +- tests/dbt/test_integration.py | 6 +- 84 files changed, 778 insertions(+), 659 deletions(-) diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 33ec55b7a7..ed904cc4b3 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -1122,7 +1122,7 @@ def select_from_values( for i in range(0, num_rows, batch_size): yield select_from_values_for_batch_range( values=values, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, batch_start=i, batch_end=min(i + batch_size, num_rows), alias=alias, @@ -1131,14 +1131,14 @@ def select_from_values( def select_from_values_for_batch_range( values: t.List[t.Tuple[t.Any, ...]], - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_start: int, batch_end: int, alias: str = "t", source_columns: t.Optional[t.List[str]] = None, ) -> exp.Select: - source_columns = source_columns or list(columns_to_types) - source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) + source_columns = source_columns or list(target_columns_to_types) + source_columns_to_types = get_source_columns_to_types(target_columns_to_types, source_columns) if not values: # Ensures we don't generate an empty VALUES clause & forces a zero-row output @@ -1166,11 +1166,13 @@ def select_from_values_for_batch_range( casted_columns = [ exp.alias_( - exp.cast(exp.column(column) if column in source_columns else exp.Null(), to=kind), + exp.cast( + exp.column(column) if column in source_columns_to_types else exp.Null(), to=kind + ), column, copy=False, ) - for column, kind in columns_to_types.items() + for column, kind in target_columns_to_types.items() ] return exp.select(*casted_columns).from_(values_exp, copy=False).where(where, copy=False) diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index 59642b6e16..d549de3f4c 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -84,12 +84,12 @@ def catalog_support(self) -> CatalogSupport: def create_state_table( self, table_name: str, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], primary_key: t.Optional[t.Tuple[str, ...]] = None, ) -> None: self.create_table( table_name, - columns_to_types, + target_columns_to_types, primary_key=primary_key, # it's painfully slow, but it works table_format="iceberg", @@ -178,7 +178,7 @@ def _build_create_table_exp( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, partitioned_by: t.Optional[t.List[exp.Expression]] = None, @@ -198,7 +198,7 @@ def _build_create_table_exp( properties = self._build_table_properties_exp( table=table, expression=expression, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, partitioned_by=partitioned_by, table_properties=table_properties, table_description=table_description, @@ -237,7 +237,7 @@ def _build_table_properties_exp( partition_interval_unit: t.Optional[IntervalUnit] = None, clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, table: t.Optional[exp.Table] = None, @@ -265,12 +265,12 @@ def _build_table_properties_exp( if partitioned_by: schema_expressions: t.List[exp.Expression] = [] - if is_hive and columns_to_types: + if is_hive and target_columns_to_types: # For Hive-style tables, you cannot include the partitioned by columns in the main set of columns # In the PARTITIONED BY expression, you also cant just include the column names, you need to include the data type as well # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html for match_name, match_dtype in self._find_matching_columns( - partitioned_by, columns_to_types + partitioned_by, target_columns_to_types ): column_def = exp.ColumnDef(this=exp.to_identifier(match_name), kind=match_dtype) schema_expressions.append(column_def) @@ -431,7 +431,7 @@ def replace_query( self, table_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, source_columns: t.Optional[t.List[str]] = None, @@ -445,7 +445,7 @@ def replace_query( return super().replace_query( table_name=table, query_or_df=query_or_df, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, source_columns=source_columns, @@ -456,7 +456,7 @@ def _insert_overwrite_by_time_partition( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], where: exp.Condition, **kwargs: t.Any, ) -> None: @@ -467,7 +467,7 @@ def _insert_overwrite_by_time_partition( if table_type == "iceberg": # Iceberg tables work as expected, we can use the default behaviour return super()._insert_overwrite_by_time_partition( - table, source_queries, columns_to_types, where, **kwargs + table, source_queries, target_columns_to_types, where, **kwargs ) # For Hive tables, we need to drop all the partitions covered by the query and delete the data from S3 @@ -477,7 +477,7 @@ def _insert_overwrite_by_time_partition( return super()._insert_overwrite_by_time_partition( table, source_queries, - columns_to_types, + target_columns_to_types, where, insert_overwrite_strategy_override=InsertOverwriteStrategy.INTO_IS_OVERWRITE, # since we already cleared the data **kwargs, diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 327a6fbee7..94ffbe81d2 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -206,21 +206,23 @@ def catalog_support(self) -> CatalogSupport: @classmethod def _casted_columns( cls, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], source_columns: t.Optional[t.List[str]] = None, ) -> t.List[exp.Alias]: - source_columns = source_columns or list(columns_to_types) + source_columns_lookup = set(source_columns or target_columns_to_types) return [ exp.alias_( exp.cast( - exp.column(column, quoted=True) if column in source_columns else exp.Null(), + exp.column(column, quoted=True) + if column in source_columns_lookup + else exp.Null(), to=kind, ), column, copy=False, quoted=True, ) - for column, kind in columns_to_types.items() + for column, kind in target_columns_to_types.items() ] @property @@ -241,7 +243,7 @@ def engine_run_mode(self) -> EngineRunMode: def _get_source_queries( self, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], target_table: TableName, *, batch_size: t.Optional[int] = None, @@ -253,16 +255,17 @@ def _get_source_queries( if isinstance(query_or_df, exp.Query): query_factory = lambda: query_or_df if source_columns: - if not columns_to_types: + source_columns_lookup = set(source_columns) + if not target_columns_to_types: raise SQLMeshError("columns_to_types must be set if source_columns is set") - if not set(columns_to_types).issubset(set(source_columns)): + if not set(target_columns_to_types).issubset(source_columns_lookup): select_columns = [ exp.column(c, quoted=True) - if c in source_columns - else exp.cast(exp.Null(), columns_to_types[c], copy=False).as_( + if c in source_columns_lookup + else exp.cast(exp.Null(), target_columns_to_types[c], copy=False).as_( c, copy=False, quoted=True ) - for c in columns_to_types + for c in target_columns_to_types ] query_factory = ( lambda: exp.Select() @@ -271,7 +274,7 @@ def _get_source_queries( ) return [SourceQuery(query_factory=query_factory)] # type: ignore - if not columns_to_types: + if not target_columns_to_types: raise SQLMeshError( "It is expected that if a DataFrame is passed in then columns_to_types is set" ) @@ -285,7 +288,7 @@ def _get_source_queries( return self._df_to_source_queries( query_or_df, - columns_to_types, + target_columns_to_types, batch_size, target_table=target_table, source_columns=source_columns, @@ -294,7 +297,7 @@ def _get_source_queries( def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, source_columns: t.Optional[t.List[str]] = None, @@ -307,7 +310,7 @@ def _df_to_source_queries( # we need to ensure that the order of the columns in columns_to_types columns matches the order of the values # they can differ if a user specifies columns() on a python model in a different order than what's in the DataFrame's emitted by that model - df = df[list(source_columns or columns_to_types)] + df = df[list(source_columns or target_columns_to_types)] values = list(df.itertuples(index=False, name=None)) return [ @@ -315,7 +318,7 @@ def _df_to_source_queries( query_factory=partial( self._values_to_sql, values=values, # type: ignore - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, batch_start=i, batch_end=min(i + batch_size, num_rows), source_columns=source_columns, @@ -327,29 +330,29 @@ def _df_to_source_queries( def _get_source_queries_and_columns_to_types( self, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], target_table: TableName, *, batch_size: t.Optional[int] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.List[SourceQuery], t.Optional[t.Dict[str, exp.DataType]]]: - columns_to_types, source_columns = self._columns_to_types( - query_or_df, columns_to_types, source_columns + target_columns_to_types, source_columns = self._columns_to_types( + query_or_df, target_columns_to_types, source_columns ) source_queries = self._get_source_queries( query_or_df, - columns_to_types, + target_columns_to_types, target_table=target_table, batch_size=batch_size, source_columns=source_columns, ) - return source_queries, columns_to_types + return source_queries, target_columns_to_types @t.overload def _columns_to_types( self, query_or_df: DF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... @@ -357,23 +360,23 @@ def _columns_to_types( def _columns_to_types( self, query_or_df: Query, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... def _columns_to_types( self, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: import pandas as pd - if not columns_to_types and isinstance(query_or_df, pd.DataFrame): - columns_to_types = columns_to_types_from_df(t.cast(pd.DataFrame, query_or_df)) - if not source_columns and columns_to_types: - source_columns = list(columns_to_types) - return columns_to_types, source_columns + if not target_columns_to_types and isinstance(query_or_df, pd.DataFrame): + target_columns_to_types = columns_to_types_from_df(t.cast(pd.DataFrame, query_or_df)) + if not source_columns and target_columns_to_types: + source_columns = list(target_columns_to_types) + return target_columns_to_types, source_columns def recycle(self) -> None: """Closes all open connections and releases all allocated resources associated with any thread @@ -410,7 +413,7 @@ def replace_query( self, table_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, source_columns: t.Optional[t.List[str]] = None, @@ -423,7 +426,7 @@ def replace_query( Args: table_name: The name of the table (eg. prod.table) query_or_df: The SQL query to run or a dataframe. - columns_to_types: Only used if a dataframe is provided. A mapping between the column name and its data type. + target_columns_to_types: Only used if a dataframe is provided. A mapping between the column name and its data type. Expected to be ordered to match the order of values in the dataframe. kwargs: Optional create table properties. """ @@ -434,14 +437,14 @@ def replace_query( if self.drop_data_object_on_type_mismatch(target_data_object, DataObjectType.TABLE): table_exists = False - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( query_or_df, - columns_to_types, + target_columns_to_types, target_table=target_table, source_columns=source_columns, ) query = source_queries[0].query_factory() - columns_to_types = columns_to_types or self.columns(target_table) + target_columns_to_types = target_columns_to_types or self.columns(target_table) self_referencing = any( quote_identifiers(table) == quote_identifiers(target_table) for table in query.find_all(exp.Table) @@ -450,7 +453,7 @@ def replace_query( if self_referencing: self._create_table_from_columns( target_table, - columns_to_types, + target_columns_to_types, exists=True, table_description=table_description, column_descriptions=column_descriptions, @@ -462,7 +465,7 @@ def replace_query( return self._create_table_from_source_queries( target_table, source_queries, - columns_to_types, + target_columns_to_types, replace=self.SUPPORTS_REPLACE_TABLE, table_description=table_description, column_descriptions=column_descriptions, @@ -470,9 +473,9 @@ def replace_query( ) if self_referencing: with self.temp_table( - self._select_columns(columns_to_types).from_(target_table), + self._select_columns(target_columns_to_types).from_(target_table), name=target_table, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, **kwargs, ) as temp_table: for source_query in source_queries: @@ -487,12 +490,12 @@ def replace_query( return self._insert_overwrite_by_condition( target_table, source_queries, - columns_to_types, + target_columns_to_types, ) return self._insert_overwrite_by_condition( target_table, source_queries, - columns_to_types, + target_columns_to_types, ) def create_index( @@ -554,7 +557,7 @@ def _pop_creatable_type_from_properties( def create_table( self, table_name: TableName, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], primary_key: t.Optional[t.Tuple[str, ...]] = None, exists: bool = True, table_description: t.Optional[str] = None, @@ -565,7 +568,7 @@ def create_table( Args: table_name: The name of the table to create. Can be fully qualified or just table name. - columns_to_types: A mapping between the column name and its data type. + target_columns_to_types: A mapping between the column name and its data type. primary_key: Determines the table primary key. exists: Indicates whether to include the IF NOT EXISTS check. table_description: Optional table description from MODEL DDL. @@ -574,7 +577,7 @@ def create_table( """ self._create_table_from_columns( table_name, - columns_to_types, + target_columns_to_types, primary_key, exists, table_description, @@ -586,7 +589,7 @@ def create_managed_table( self, table_name: TableName, query: Query, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, partitioned_by: t.Optional[t.List[exp.Expression]] = None, clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, @@ -602,7 +605,7 @@ def create_managed_table( Args: table_name: The name of the table to create. Can be fully qualified or just table name. query: The SQL query for the engine to base the managed table on - columns_to_types: A mapping between the column name and its data type. + target_columns_to_types: A mapping between the column name and its data type. partitioned_by: The partition columns or engine specific expressions, only applicable in certain engines. (eg. (ds, hour)) clustered_by: The cluster columns or engine specific expressions, only applicable in certain engines. (eg. (ds, hour)) table_properties: Optional mapping of engine-specific properties to be set on the managed table @@ -616,7 +619,7 @@ def ctas( self, table_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, exists: bool = True, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, @@ -628,19 +631,22 @@ def ctas( Args: table_name: The name of the table to create. Can be fully qualified or just table name. query_or_df: The SQL query to run or a dataframe for the CTAS. - columns_to_types: A mapping between the column name and its data type. Required if using a DataFrame. + target_columns_to_types: A mapping between the column name and its data type. Required if using a DataFrame. exists: Indicates whether to include the IF NOT EXISTS check. table_description: Optional table description from MODEL DDL. column_descriptions: Optional column descriptions from model query. kwargs: Optional create table properties. """ - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name, source_columns=source_columns + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + target_table=table_name, + source_columns=source_columns, ) return self._create_table_from_source_queries( table_name, source_queries, - columns_to_types, + target_columns_to_types, exists, table_description=table_description, column_descriptions=column_descriptions, @@ -650,26 +656,26 @@ def ctas( def create_state_table( self, table_name: str, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], primary_key: t.Optional[t.Tuple[str, ...]] = None, ) -> None: """Create a table to store SQLMesh internal state. Args: table_name: The name of the table to create. Can be fully qualified or just table name. - columns_to_types: A mapping between the column name and its data type. + target_columns_to_types: A mapping between the column name and its data type. primary_key: Determines the table primary key. """ self.create_table( table_name, - columns_to_types, + target_columns_to_types, primary_key=primary_key, ) def _create_table_from_columns( self, table_name: TableName, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], primary_key: t.Optional[t.Tuple[str, ...]] = None, exists: bool = True, table_description: t.Optional[str] = None, @@ -681,7 +687,7 @@ def _create_table_from_columns( Args: table_name: The name of the table to create. Can be fully qualified or just table name. - columns_to_types: Mapping between the column name and its data type. + target_columns_to_types: Mapping between the column name and its data type. primary_key: Determines the table primary key. exists: Indicates whether to include the IF NOT EXISTS check. table_description: Optional table description from MODEL DDL. @@ -690,14 +696,14 @@ def _create_table_from_columns( """ table = exp.to_table(table_name) - if not columns_to_types_all_known(columns_to_types): + if not columns_to_types_all_known(target_columns_to_types): # It is ok if the columns types are not known if the table already exists and IF NOT EXISTS is set if exists and self.table_exists(table_name): return raise SQLMeshError( "Cannot create a table without knowing the column types. " "Try casting the columns to an expected type or defining the columns in the model metadata. " - f"Columns to types: {columns_to_types}" + f"Columns to types: {target_columns_to_types}" ) primary_key_expression = ( @@ -708,7 +714,7 @@ def _create_table_from_columns( schema = self._build_schema_exp( table, - columns_to_types, + target_columns_to_types, column_descriptions, primary_key_expression, ) @@ -717,7 +723,7 @@ def _create_table_from_columns( schema, None, exists=exists, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, **kwargs, ) @@ -739,7 +745,7 @@ def _create_table_from_columns( def _build_schema_exp( self, table: exp.Table, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], column_descriptions: t.Optional[t.Dict[str, str]] = None, expressions: t.Optional[t.List[exp.PrimaryKey]] = None, is_view: bool = False, @@ -752,7 +758,7 @@ def _build_schema_exp( return exp.Schema( this=table, expressions=self._build_column_defs( - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, column_descriptions=column_descriptions, is_view=is_view, ) @@ -761,7 +767,7 @@ def _build_schema_exp( def _build_column_defs( self, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], column_descriptions: t.Optional[t.Dict[str, str]] = None, is_view: bool = False, ) -> t.List[exp.ColumnDef]: @@ -777,7 +783,7 @@ def _build_column_defs( engine_supports_schema_comments=engine_supports_schema_comments, col_type=None if is_view else kind, # don't include column data type for views ) - for column, kind in columns_to_types.items() + for column, kind in target_columns_to_types.items() ] def _build_column_def( @@ -816,7 +822,7 @@ def _create_table_from_source_queries( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, exists: bool = True, replace: bool = False, table_description: t.Optional[str] = None, @@ -840,27 +846,29 @@ def _create_table_from_source_queries( # types, and for evaluation methods like `LogicalReplaceQueryMixin.replace_query()` # calls and SCD Type 2 model calls. schema = None - columns_to_types_known = columns_to_types and columns_to_types_all_known(columns_to_types) + target_columns_to_types_known = target_columns_to_types and columns_to_types_all_known( + target_columns_to_types + ) if ( column_descriptions - and columns_to_types_known + and target_columns_to_types_known and self.COMMENT_CREATION_TABLE.is_in_schema_def_ctas and self.comments_enabled ): - schema = self._build_schema_exp(table, columns_to_types, column_descriptions) # type: ignore + schema = self._build_schema_exp(table, target_columns_to_types, column_descriptions) # type: ignore with self.transaction(condition=len(source_queries) > 1): for i, source_query in enumerate(source_queries): with source_query as query: - if columns_to_types and columns_to_types_known: + if target_columns_to_types and target_columns_to_types_known: query = self._order_projections_and_filter( - query, columns_to_types, coerce_types=True + query, target_columns_to_types, coerce_types=True ) if i == 0: self._create_table( schema if schema else table, query, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, exists=exists, replace=replace, table_description=table_description, @@ -869,7 +877,7 @@ def _create_table_from_source_queries( ) else: self._insert_append_query( - table_name, query, columns_to_types or self.columns(table) + table_name, query, target_columns_to_types or self.columns(table) ) # Register comments with commands if the engine supports comments and we weren't able to @@ -889,7 +897,7 @@ def _create_table( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, @@ -901,7 +909,7 @@ def _create_table( expression=expression, exists=exists, replace=replace, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=( table_description if self.COMMENT_CREATION_TABLE.supports_schema_def and self.comments_enabled @@ -918,7 +926,7 @@ def _build_create_table_exp( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, @@ -936,7 +944,7 @@ def _build_create_table_exp( self._build_table_properties_exp( **kwargs, catalog_name=catalog_name, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, table_kind=table_kind, ) @@ -1091,7 +1099,7 @@ def create_view( self, view_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, replace: bool = True, materialized: bool = False, materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, @@ -1109,7 +1117,7 @@ def create_view( Args: view_name: The view name. query_or_df: A query or dataframe. - columns_to_types: Columns to use in the view statement. + target_columns_to_types: Columns to use in the view statement. replace: Whether or not to replace an existing view defaults to True. materialized: Whether to create a a materialized view. Only used for engines that support this feature. materialized_properties: Optional materialized view properties to add to the view. @@ -1129,12 +1137,14 @@ def create_view( values: t.List[t.Tuple[t.Any, ...]] = list( query_or_df.itertuples(index=False, name=None) ) - columns_to_types, source_columns = self._columns_to_types( - query_or_df, columns_to_types, source_columns + target_columns_to_types, source_columns = self._columns_to_types( + query_or_df, target_columns_to_types, source_columns ) - if not columns_to_types: + if not target_columns_to_types: raise SQLMeshError("columns_to_types must be provided for dataframes") - source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) query_or_df = self._values_to_sql( values, source_columns_to_types, @@ -1142,9 +1152,9 @@ def create_view( batch_end=len(values), ) - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( query_or_df, - columns_to_types, + target_columns_to_types, batch_size=0, target_table=view_name, source_columns=source_columns, @@ -1153,9 +1163,9 @@ def create_view( raise SQLMeshError("Only one source query is supported for creating views") schema: t.Union[exp.Table, exp.Schema] = exp.to_table(view_name) - if columns_to_types: + if target_columns_to_types: schema = self._build_schema_exp( - exp.to_table(view_name), columns_to_types, column_descriptions, is_view=True + exp.to_table(view_name), target_columns_to_types, column_descriptions, is_view=True ) properties = create_kwargs.pop("properties", None) @@ -1255,7 +1265,7 @@ def create_view( self.COMMENT_CREATION_VIEW.is_comment_command_only or ( self.COMMENT_CREATION_VIEW.is_in_schema_def_and_commands - and not columns_to_types + and not target_columns_to_types ) ) and self.comments_enabled @@ -1381,61 +1391,64 @@ def insert_append( self, table_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> None: - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name, source_columns=source_columns + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + target_table=table_name, + source_columns=source_columns, ) - self._insert_append_source_queries(table_name, source_queries, columns_to_types) + self._insert_append_source_queries(table_name, source_queries, target_columns_to_types) def _insert_append_source_queries( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, ) -> None: with self.transaction(condition=len(source_queries) > 0): - columns_to_types = columns_to_types or self.columns(table_name) + target_columns_to_types = target_columns_to_types or self.columns(table_name) for source_query in source_queries: with source_query as query: - self._insert_append_query(table_name, query, columns_to_types) + self._insert_append_query(table_name, query, target_columns_to_types) def _insert_append_query( self, table_name: TableName, query: Query, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], order_projections: bool = True, ) -> None: if order_projections: - query = self._order_projections_and_filter(query, columns_to_types) - self.execute(exp.insert(query, table_name, columns=list(columns_to_types))) + query = self._order_projections_and_filter(query, target_columns_to_types) + self.execute(exp.insert(query, table_name, columns=list(target_columns_to_types))) def insert_overwrite_by_partition( self, table_name: TableName, query_or_df: QueryOrDF, partitioned_by: t.List[exp.Expression], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> None: if self.INSERT_OVERWRITE_STRATEGY.is_insert_overwrite: target_table = exp.to_table(table_name) - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( query_or_df, - columns_to_types, + target_columns_to_types, target_table=target_table, source_columns=source_columns, ) self._insert_overwrite_by_condition( - table_name, source_queries, columns_to_types=columns_to_types + table_name, source_queries, target_columns_to_types=target_columns_to_types ) else: self._replace_by_key( table_name, query_or_df, - columns_to_types, + target_columns_to_types, partitioned_by, is_unique_key=False, source_columns=source_columns, @@ -1451,17 +1464,21 @@ def insert_overwrite_by_time_partition( [TimeLike, t.Optional[t.Dict[str, exp.DataType]]], exp.Expression ], time_column: TimeColumn | exp.Expression | str, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name, source_columns=source_columns + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + target_table=table_name, + source_columns=source_columns, ) - if not columns_to_types or not columns_to_types_all_known(columns_to_types): - columns_to_types = self.columns(table_name) + if not target_columns_to_types or not columns_to_types_all_known(target_columns_to_types): + target_columns_to_types = self.columns(table_name) low, high = [ - time_formatter(dt, columns_to_types) for dt in make_inclusive(start, end, self.dialect) + time_formatter(dt, target_columns_to_types) + for dt in make_inclusive(start, end, self.dialect) ] if isinstance(time_column, TimeColumn): time_column = time_column.column @@ -1471,25 +1488,25 @@ def insert_overwrite_by_time_partition( high=high, ) return self._insert_overwrite_by_time_partition( - table_name, source_queries, columns_to_types, where, **kwargs + table_name, source_queries, target_columns_to_types, where, **kwargs ) def _insert_overwrite_by_time_partition( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], where: exp.Condition, **kwargs: t.Any, ) -> None: return self._insert_overwrite_by_condition( - table_name, source_queries, columns_to_types, where + table_name, source_queries, target_columns_to_types, where ) def _values_to_sql( self, values: t.List[t.Tuple[t.Any, ...]], - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_start: int, batch_end: int, alias: str = "t", @@ -1497,7 +1514,7 @@ def _values_to_sql( ) -> Query: return select_from_values_for_batch_range( values=values, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, batch_start=batch_start, batch_end=batch_end, alias=alias, @@ -1508,7 +1525,7 @@ def _insert_overwrite_by_condition( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, where: t.Optional[exp.Condition] = None, insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, **kwargs: t.Any, @@ -1520,17 +1537,19 @@ def _insert_overwrite_by_condition( with self.transaction( condition=len(source_queries) > 0 or insert_overwrite_strategy.is_delete_insert ): - columns_to_types = columns_to_types or self.columns(table_name) + target_columns_to_types = target_columns_to_types or self.columns(table_name) for i, source_query in enumerate(source_queries): with source_query as query: - query = self._order_projections_and_filter(query, columns_to_types, where=where) + query = self._order_projections_and_filter( + query, target_columns_to_types, where=where + ) if i > 0 or insert_overwrite_strategy.is_delete_insert: if i == 0: self.delete_from(table_name, where=where or exp.true()) self._insert_append_query( table_name, query, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, order_projections=False, ) else: @@ -1538,7 +1557,7 @@ def _insert_overwrite_by_condition( query, table, columns=( - list(columns_to_types) + list(target_columns_to_types) if not insert_overwrite_strategy.is_replace_where else None ), @@ -1580,7 +1599,7 @@ def scd_type_2_by_time( updated_at_col: exp.Column, invalidate_hard_deletes: bool = True, updated_at_as_valid_from: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, @@ -1597,7 +1616,7 @@ def scd_type_2_by_time( updated_at_col=updated_at_col, invalidate_hard_deletes=invalidate_hard_deletes, updated_at_as_valid_from=updated_at_as_valid_from, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, truncate=truncate, @@ -1616,7 +1635,7 @@ def scd_type_2_by_column( check_columns: t.Union[exp.Star, t.Sequence[exp.Column]], invalidate_hard_deletes: bool = True, execution_time_as_valid_from: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, @@ -1631,7 +1650,7 @@ def scd_type_2_by_column( valid_to_col=valid_to_col, execution_time=execution_time, check_columns=check_columns, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, invalidate_hard_deletes=invalidate_hard_deletes, execution_time_as_valid_from=execution_time_as_valid_from, table_description=table_description, @@ -1654,7 +1673,7 @@ def _scd_type_2( check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None, updated_at_as_valid_from: bool = False, execution_time_as_valid_from: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, @@ -1670,15 +1689,15 @@ def remove_managed_columns( valid_from_name = valid_from_col.name valid_to_name = valid_to_col.name - columns_to_types = columns_to_types or self.columns(target_table) + target_columns_to_types = target_columns_to_types or self.columns(target_table) if ( - valid_from_name not in columns_to_types - or valid_to_name not in columns_to_types - or not columns_to_types_all_known(columns_to_types) + valid_from_name not in target_columns_to_types + or valid_to_name not in target_columns_to_types + or not columns_to_types_all_known(target_columns_to_types) ): - columns_to_types = self.columns(target_table) + target_columns_to_types = self.columns(target_table) unmanaged_columns_to_types = ( - remove_managed_columns(columns_to_types) if columns_to_types else None + remove_managed_columns(target_columns_to_types) if target_columns_to_types else None ) source_queries, unmanaged_columns_to_types = self._get_source_queries_and_columns_to_types( source_table, @@ -1688,10 +1707,10 @@ def remove_managed_columns( source_columns=source_columns, ) updated_at_name = updated_at_col.name if updated_at_col else None - if not columns_to_types: + if not target_columns_to_types: raise SQLMeshError(f"Could not get columns_to_types. Does {target_table} exist?") unmanaged_columns_to_types = unmanaged_columns_to_types or remove_managed_columns( - columns_to_types + target_columns_to_types ) if not unique_key: raise SQLMeshError("unique_key must be provided for SCD Type 2") @@ -1707,15 +1726,15 @@ def remove_managed_columns( raise SQLMeshError( "Cannot use `execution_time_as_valid_from` without `check_columns` for SCD Type 2" ) - if updated_at_name and updated_at_name not in columns_to_types: + if updated_at_name and updated_at_name not in target_columns_to_types: raise SQLMeshError( f"Column {updated_at_name} not found in {target_table}. Table must contain an `updated_at` timestamp for SCD Type 2" ) - time_data_type = columns_to_types[valid_from_name] + time_data_type = target_columns_to_types[valid_from_name] select_source_columns: t.List[t.Union[str, exp.Alias]] = [ col for col in unmanaged_columns_to_types if col != updated_at_name ] - table_columns = [exp.column(c, quoted=True) for c in columns_to_types] + table_columns = [exp.column(c, quoted=True) for c in target_columns_to_types] if updated_at_name: select_source_columns.append( exp.cast(updated_at_col, time_data_type).as_(updated_at_col.this) # type: ignore @@ -1852,7 +1871,7 @@ def remove_managed_columns( with source_queries[0] as source_query: prefixed_columns_to_types = [] - for column in columns_to_types: + for column in target_columns_to_types: prefixed_col = exp.column(column).copy() prefixed_col.this.set("this", f"t_{prefixed_col.name}") prefixed_columns_to_types.append(prefixed_col) @@ -1896,7 +1915,7 @@ def remove_managed_columns( # Deleted records which can be used to determine `valid_from` for undeleted source records .with_( "deleted", - exp.select(*[exp.column(col, "static") for col in columns_to_types]) + exp.select(*[exp.column(col, "static") for col in target_columns_to_types]) .from_("static") .join( "latest", @@ -1931,7 +1950,7 @@ def remove_managed_columns( exp.column("_exists", table="source").as_("_exists"), *( exp.column(col, table="latest").as_(prefixed_columns_to_types[i].this) - for i, col in enumerate(columns_to_types) + for i, col in enumerate(target_columns_to_types) ), *( exp.column(col, table="source").as_(col) @@ -1956,7 +1975,7 @@ def remove_managed_columns( exp.column(col, table="latest").as_( prefixed_columns_to_types[i].this ) - for i, col in enumerate(columns_to_types) + for i, col in enumerate(target_columns_to_types) ), *( exp.column(col, table="source").as_(col) @@ -2025,7 +2044,7 @@ def remove_managed_columns( self.replace_query( target_table, self.ensure_nulls_for_unmatched_after_join(query), - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, **kwargs, @@ -2035,17 +2054,20 @@ def merge( self, target_table: TableName, source_table: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - source_table, columns_to_types, target_table=target_table, source_columns=source_columns + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + source_table, + target_columns_to_types, + target_table=target_table, + source_columns=source_columns, ) - columns_to_types = columns_to_types or self.columns(target_table) + target_columns_to_types = target_columns_to_types or self.columns(target_table) on = exp.and_( *( add_table(part, MERGE_TARGET_ALIAS).eq(add_table(part, MERGE_SOURCE_ALIAS)) @@ -2065,7 +2087,7 @@ def merge( exp.column(col, MERGE_TARGET_ALIAS).eq( exp.column(col, MERGE_SOURCE_ALIAS) ) - for col in columns_to_types + for col in target_columns_to_types ], ), ) @@ -2078,10 +2100,12 @@ def merge( matched=False, source=False, then=exp.Insert( - this=exp.Tuple(expressions=[exp.column(col) for col in columns_to_types]), + this=exp.Tuple( + expressions=[exp.column(col) for col in target_columns_to_types] + ), expression=exp.Tuple( expressions=[ - exp.column(col, MERGE_SOURCE_ALIAS) for col in columns_to_types + exp.column(col, MERGE_SOURCE_ALIAS) for col in target_columns_to_types ] ), ), @@ -2365,7 +2389,7 @@ def temp_table( self, query_or_df: QueryOrDF, name: TableName = "diff", - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> t.Iterator[exp.Table]: @@ -2376,7 +2400,7 @@ def temp_table( Args: query_or_df: The query or df to create a temp table for. name: The base name of the temp table. - columns_to_types: A mapping between the column name and its data type. + target_columns_to_types: A mapping between the column name and its data type. Yields: The table expression @@ -2386,9 +2410,9 @@ def temp_table( if isinstance(name, exp.Table) and not name.catalog and name.db and self.default_catalog: name.set("catalog", exp.parse_identifier(self.default_catalog)) - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( query_or_df, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, target_table=name, source_columns=source_columns, ) @@ -2400,7 +2424,7 @@ def temp_table( self._create_table_from_source_queries( table, source_queries, - columns_to_types, + target_columns_to_types, exists=True, table_description=None, column_descriptions=None, @@ -2428,7 +2452,7 @@ def _build_partitioned_by_exp( partitioned_by: t.List[exp.Expression], *, partition_interval_unit: t.Optional[IntervalUnit] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, catalog_name: t.Optional[str] = None, **kwargs: t.Any, ) -> t.Optional[t.Union[exp.PartitionedByProperty, exp.Property]]: @@ -2450,7 +2474,7 @@ def _build_table_properties_exp( partition_interval_unit: t.Optional[IntervalUnit] = None, clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, @@ -2551,12 +2575,12 @@ def _get_temp_table( def _order_projections_and_filter( self, query: Query, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], where: t.Optional[exp.Expression] = None, coerce_types: bool = False, ) -> Query: if not isinstance(query, exp.Query) or ( - not where and not coerce_types and query.named_selects == list(columns_to_types) + not where and not coerce_types and query.named_selects == list(target_columns_to_types) ): return query @@ -2564,12 +2588,12 @@ def _order_projections_and_filter( with_ = query.args.pop("with", None) select_exprs: t.List[exp.Expression] = [ - exp.column(c, quoted=True) for c in columns_to_types + exp.column(c, quoted=True) for c in target_columns_to_types ] - if coerce_types and columns_to_types_all_known(columns_to_types): + if coerce_types and columns_to_types_all_known(target_columns_to_types): select_exprs = [ exp.cast(select_exprs[i], col_tpe).as_(col, quoted=True) - for i, (col, col_tpe) in enumerate(columns_to_types.items()) + for i, (col, col_tpe) in enumerate(target_columns_to_types.items()) ] query = exp.select(*select_exprs).from_(query.subquery("_subquery", copy=False), copy=False) @@ -2613,30 +2637,30 @@ def _replace_by_key( self, target_table: TableName, source_table: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], key: t.Sequence[exp.Expression], is_unique_key: bool, source_columns: t.Optional[t.List[str]] = None, ) -> None: - if columns_to_types is None: - columns_to_types = self.columns(target_table) + if target_columns_to_types is None: + target_columns_to_types = self.columns(target_table) temp_table = self._get_temp_table(target_table) key_exp = exp.func("CONCAT_WS", "'__SQLMESH_DELIM__'", *key) if len(key) > 1 else key[0] - column_names = list(columns_to_types or []) + column_names = list(target_columns_to_types or []) with self.transaction(): self.ctas( temp_table, source_table, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, exists=False, source_columns=source_columns, ) try: delete_query = exp.select(key_exp).from_(temp_table) - insert_query = self._select_columns(columns_to_types).from_(temp_table) + insert_query = self._select_columns(target_columns_to_types).from_(temp_table) if not is_unique_key: delete_query = delete_query.distinct() else: @@ -2758,17 +2782,6 @@ def _check_identifier_length(self, expression: exp.Expression) -> None: f"Identifier name '{name}' (length {name_length}) exceeds {self.dialect.capitalize()}'s max identifier limit of {self.MAX_IDENTIFIER_LENGTH} characters" ) - @classmethod - def get_source_columns_to_types( - cls, - columns_to_types: t.Dict[str, exp.DataType], - source_columns: t.Optional[t.List[str]], - ) -> t.Dict[str, exp.DataType]: - """Returns the source columns to types mapping.""" - return { - k: v for k, v in columns_to_types.items() if not source_columns or k in source_columns - } - class EngineAdapterWithIndexSupport(EngineAdapter): SUPPORTS_INDEXES = True diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py index 10083dcb91..cc394efd9e 100644 --- a/sqlmesh/core/engine_adapter/base_postgres.py +++ b/sqlmesh/core/engine_adapter/base_postgres.py @@ -91,7 +91,7 @@ def create_view( self, view_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, replace: bool = True, materialized: bool = False, materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, @@ -114,7 +114,7 @@ def create_view( super().create_view( view_name, query_or_df, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, replace=False, materialized=materialized, materialized_properties=materialized_properties, diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index fb88a707fd..4fe50fdeef 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -148,14 +148,16 @@ def catalog_support(self) -> CatalogSupport: def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: import pandas as pd - source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) temp_bq_table = self.__get_temp_bq_table( self._get_temp_table(target_table or "pandas"), source_columns_to_types @@ -182,7 +184,7 @@ def query_factory() -> Query: if result.errors: raise SQLMeshError(result.errors) return exp.select( - *self._casted_columns(columns_to_types, source_columns=source_columns) + *self._casted_columns(target_columns_to_types, source_columns=source_columns) ).from_(temp_table) return [ @@ -678,7 +680,7 @@ def insert_overwrite_by_partition( table_name: TableName, query_or_df: QueryOrDF, partitioned_by: t.List[exp.Expression], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> None: if len(partitioned_by) != 1: @@ -707,12 +709,14 @@ def insert_overwrite_by_partition( source_columns=source_columns, ) as temp_table_name, ): - if columns_to_types is None or columns_to_types[ + if target_columns_to_types is None or target_columns_to_types[ partition_column.name ] == exp.DataType.build("unknown"): - columns_to_types = self.columns(table_name) + target_columns_to_types = self.columns(table_name) - partition_type_sql = columns_to_types[partition_column.name].sql(dialect=self.dialect) + partition_type_sql = target_columns_to_types[partition_column.name].sql( + dialect=self.dialect + ) select_array_agg_partitions = select_partitions_expr( temp_table_name.db, @@ -732,7 +736,7 @@ def insert_overwrite_by_partition( self._insert_overwrite_by_condition( table_name, [SourceQuery(query_factory=lambda: exp.select("*").from_(temp_table_name))], - columns_to_types, + target_columns_to_types, where=where, ) @@ -824,7 +828,7 @@ def _build_partitioned_by_exp( partitioned_by: t.List[exp.Expression], *, partition_interval_unit: t.Optional[IntervalUnit] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, **kwargs: t.Any, ) -> t.Optional[exp.PartitionedByProperty]: if len(partitioned_by) > 1: @@ -836,7 +840,7 @@ def _build_partitioned_by_exp( and partition_interval_unit is not None and not partition_interval_unit.is_minute ): - column_type: t.Optional[exp.DataType] = (columns_to_types or {}).get(this.name) + column_type: t.Optional[exp.DataType] = (target_columns_to_types or {}).get(this.name) if column_type == exp.DataType.build( "date", dialect=self.dialect @@ -871,7 +875,7 @@ def _build_table_properties_exp( partition_interval_unit: t.Optional[IntervalUnit] = None, clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, @@ -882,7 +886,7 @@ def _build_table_properties_exp( partitioned_by_prop := self._build_partitioned_by_exp( partitioned_by, partition_interval_unit=partition_interval_unit, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, ) ): properties.append(partitioned_by_prop) @@ -1027,12 +1031,12 @@ def _build_create_comment_column_exp( def create_state_table( self, table_name: str, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], primary_key: t.Optional[t.Tuple[str, ...]] = None, ) -> None: self.create_table( table_name, - columns_to_types, + target_columns_to_types, ) def _db_call(self, func: t.Callable[..., t.Any], *args: t.Any, **kwargs: t.Any) -> t.Any: @@ -1215,7 +1219,7 @@ def _normalize_nested_value(self, col: exp.Expression) -> exp.Expression: def _columns_to_types( self, query_or_df: DF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... @@ -1223,28 +1227,28 @@ def _columns_to_types( def _columns_to_types( self, query_or_df: Query, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... def _columns_to_types( self, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: if ( - not columns_to_types + not target_columns_to_types and bigframes and isinstance(query_or_df, bigframes.dataframe.DataFrame) ): # using dry_run=True attempts to prevent the DataFrame from being materialized just to read the column types from it dtypes = query_or_df.to_pandas(dry_run=True).columnDtypes - columns_to_types = columns_to_types_from_dtypes(dtypes.items()) - return columns_to_types, list(source_columns or columns_to_types) + target_columns_to_types = columns_to_types_from_dtypes(dtypes.items()) + return target_columns_to_types, list(source_columns or target_columns_to_types) return super()._columns_to_types( - query_or_df, columns_to_types, source_columns=source_columns + query_or_df, target_columns_to_types, source_columns=source_columns ) def _native_df_to_pandas_df( diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py index 26be50aba0..5ac4e9b152 100644 --- a/sqlmesh/core/engine_adapter/clickhouse.py +++ b/sqlmesh/core/engine_adapter/clickhouse.py @@ -90,14 +90,16 @@ def _fetch_native_df( def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> t.List[SourceQuery]: temp_table = self._get_temp_table(target_table, **kwargs) - source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) def query_factory() -> Query: # It is possible for the factory to be called multiple times and if so then the temp table will already @@ -113,7 +115,7 @@ def query_factory() -> Query: self.cursor.client.insert_df(temp_table.sql(dialect=self.dialect), df=df) - return exp.select(*self._casted_columns(columns_to_types, source_columns)).from_( + return exp.select(*self._casted_columns(target_columns_to_types, source_columns)).from_( temp_table ) @@ -189,7 +191,7 @@ def _insert_overwrite_by_condition( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, where: t.Optional[exp.Condition] = None, insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, **kwargs: t.Any, @@ -204,7 +206,7 @@ def _insert_overwrite_by_condition( Args: table_name: Name of target table source_queries: Source queries returning records to insert - columns_to_types: Column names and data types of target table + target_columns_to_types: Column names and data types of target table where: SQLGlot expression determining which target table rows should be overwritten insert_overwrite_strategy_override: Not used by Clickhouse kwargs: @@ -218,7 +220,7 @@ def _insert_overwrite_by_condition( Side effects only: execution of insert-overwrite operation. """ target_table = exp.to_table(table_name) - columns_to_types = columns_to_types or self.columns(target_table) + target_columns_to_types = target_columns_to_types or self.columns(target_table) temp_table = self._get_temp_table(target_table) self._create_table_like(temp_table, target_table) @@ -237,11 +239,13 @@ def _insert_overwrite_by_condition( if dynamic_key and dynamic_key_unique: query = query.distinct(*dynamic_key) # type: ignore - query = self._order_projections_and_filter(query, columns_to_types, where=where) + query = self._order_projections_and_filter( + query, target_columns_to_types, where=where + ) self._insert_append_query( temp_table, query, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, order_projections=False, ) @@ -267,7 +271,7 @@ def _insert_overwrite_by_condition( if where: # identify existing records to keep by inverting the delete `where` clause existing_records_insert_exp = exp.insert( - self._select_columns(columns_to_types) + self._select_columns(target_columns_to_types) .from_(target_table) .where(exp.paren(expression=where).not_()), temp_table, @@ -408,13 +412,16 @@ def _replace_by_key( self, target_table: TableName, source_table: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], key: t.Sequence[exp.Expression], is_unique_key: bool, source_columns: t.Optional[t.List[str]] = None, ) -> None: - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - source_table, columns_to_types, target_table=target_table, source_columns=source_columns + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + source_table, + target_columns_to_types, + target_table=target_table, + source_columns=source_columns, ) key_exp = exp.func("CONCAT_WS", "'__SQLMESH_DELIM__'", *key) if len(key) > 1 else key[0] @@ -422,7 +429,7 @@ def _replace_by_key( self._insert_overwrite_by_condition( target_table, source_queries, - columns_to_types, + target_columns_to_types, dynamic_key=key, dynamic_key_exp=key_exp, dynamic_key_unique=is_unique_key, @@ -433,15 +440,18 @@ def insert_overwrite_by_partition( table_name: TableName, query_or_df: QueryOrDF, partitioned_by: t.List[exp.Expression], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> None: - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name, source_columns=source_columns + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + target_table=table_name, + source_columns=source_columns, ) self._insert_overwrite_by_condition( - table_name, source_queries, columns_to_types, keep_existing_partition_rows=False + table_name, source_queries, target_columns_to_types, keep_existing_partition_rows=False ) def _create_table_like( @@ -475,7 +485,7 @@ def _create_table( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, @@ -501,16 +511,16 @@ def _create_table( for coldef in table_name_or_schema.expressions: if coldef.name in partition_cols: coldef.kind.set("nullable", False) - if columns_to_types: + if target_columns_to_types: for col in partition_cols: - columns_to_types[col].set("nullable", False) + target_columns_to_types[col].set("nullable", False) super()._create_table( table_name_or_schema, expression, exists, replace, - columns_to_types, + target_columns_to_types, table_description, column_descriptions, table_kind, @@ -538,7 +548,7 @@ def _create_table( self._insert_append_query( table_name, expression, # type: ignore - columns_to_types or self.columns(table_name), + target_columns_to_types or self.columns(table_name), ) def _exchange_tables( @@ -718,7 +728,7 @@ def _build_table_properties_exp( partition_interval_unit: t.Optional[IntervalUnit] = None, clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, empty_ctas: bool = False, diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index dba7e58834..4e352b27ef 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -161,22 +161,24 @@ def _end_session(self) -> None: def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: if not self._use_spark_session: return super(SparkEngineAdapter, self)._df_to_source_queries( - df, columns_to_types, batch_size, target_table, source_columns=source_columns + df, target_columns_to_types, batch_size, target_table, source_columns=source_columns ) - pyspark_df = self._ensure_pyspark_df(df, columns_to_types, source_columns=source_columns) + pyspark_df = self._ensure_pyspark_df( + df, target_columns_to_types, source_columns=source_columns + ) def query_factory() -> Query: temp_table = self._get_temp_table(target_table or "spark", table_only=True) pyspark_df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) self._connection_pool.set_attribute("use_spark_engine_adapter", True) - return exp.select(*self._select_columns(columns_to_types)).from_(temp_table) + return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table) return [SourceQuery(query_factory=query_factory)] @@ -336,7 +338,7 @@ def _build_table_properties_exp( partition_interval_unit: t.Optional[IntervalUnit] = None, clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, @@ -349,7 +351,7 @@ def _build_table_properties_exp( partition_interval_unit=partition_interval_unit, clustered_by=clustered_by, table_properties=table_properties, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, table_kind=table_kind, ) diff --git a/sqlmesh/core/engine_adapter/duckdb.py b/sqlmesh/core/engine_adapter/duckdb.py index 49231fcf87..d90a4ed736 100644 --- a/sqlmesh/core/engine_adapter/duckdb.py +++ b/sqlmesh/core/engine_adapter/duckdb.py @@ -61,21 +61,23 @@ def _drop_catalog(self, catalog_name: exp.Identifier) -> None: def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: temp_table = self._get_temp_table(target_table) temp_table_sql = ( - exp.select(*self._casted_columns(columns_to_types, source_columns)) + exp.select(*self._casted_columns(target_columns_to_types, source_columns)) .from_("df") .sql(dialect=self.dialect) ) self.cursor.sql(f"CREATE TABLE {temp_table} AS {temp_table_sql}") return [ SourceQuery( - query_factory=lambda: self._select_columns(columns_to_types).from_(temp_table), # type: ignore + query_factory=lambda: self._select_columns(target_columns_to_types).from_( + temp_table + ), # type: ignore cleanup_func=lambda: self.drop_table(temp_table), ) ] @@ -150,7 +152,7 @@ def _create_table( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, @@ -173,7 +175,7 @@ def _create_table( expression, exists, replace, - columns_to_types, + target_columns_to_types, table_description, column_descriptions, table_kind, diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index 81c092c517..12c9bfc603 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -28,7 +28,7 @@ def merge( self, target_table: TableName, source_table: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, @@ -39,7 +39,7 @@ def merge( self, target_table, source_table, - columns_to_types, + target_columns_to_types, unique_key, when_matched=when_matched, merge_filter=merge_filter, @@ -77,7 +77,7 @@ def _insert_overwrite_by_condition( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, where: t.Optional[exp.Condition] = None, insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, **kwargs: t.Any, @@ -87,11 +87,13 @@ def _insert_overwrite_by_condition( doing an "INSERT OVERWRITE" using a Merge expression but with the predicate being `False`. """ - columns_to_types = columns_to_types or self.columns(table_name) + target_columns_to_types = target_columns_to_types or self.columns(table_name) for source_query in source_queries: with source_query as query: - query = self._order_projections_and_filter(query, columns_to_types, where=where) - columns = [exp.column(col) for col in columns_to_types] + query = self._order_projections_and_filter( + query, target_columns_to_types, where=where + ) + columns = [exp.column(col) for col in target_columns_to_types] when_not_matched_by_source = exp.When( matched=False, source=True, @@ -159,7 +161,7 @@ def _build_table_properties_exp( partition_interval_unit: t.Optional[IntervalUnit] = None, clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, @@ -278,7 +280,7 @@ def _build_create_table_exp( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, @@ -288,7 +290,7 @@ def _build_create_table_exp( expression=expression, exists=exists, replace=replace, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, table_kind=table_kind, **kwargs, @@ -328,7 +330,7 @@ def _build_create_table_exp( None, exists=exists, replace=replace, - columns_to_types=columns_to_types_from_view, + target_columns_to_types=columns_to_types_from_view, table_description=table_description, **kwargs, ) @@ -420,7 +422,7 @@ def logical_merge( engine_adapter: EngineAdapter, target_table: TableName, source_table: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, @@ -445,7 +447,7 @@ def logical_merge( engine_adapter._replace_by_key( target_table, source_table, - columns_to_types, + target_columns_to_types, unique_key, is_unique_key=True, source_columns=source_columns, diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index 0edfb4f1f2..3a43d539a9 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -195,7 +195,7 @@ def merge( self, target_table: TableName, source_table: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, @@ -204,10 +204,13 @@ def merge( ) -> None: mssql_merge_exists = kwargs.get("physical_properties", {}).get("mssql_merge_exists") - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - source_table, columns_to_types, target_table=target_table, source_columns=source_columns + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + source_table, + target_columns_to_types, + target_table=target_table, + source_columns=source_columns, ) - columns_to_types = columns_to_types or self.columns(target_table) + target_columns_to_types = target_columns_to_types or self.columns(target_table) on = exp.and_( *( add_table(part, MERGE_TARGET_ALIAS).eq(add_table(part, MERGE_SOURCE_ALIAS)) @@ -220,7 +223,9 @@ def merge( match_expressions = [] if not when_matched: unique_key_names = [y.name for y in unique_key] - columns_to_types_no_keys = [c for c in columns_to_types if c not in unique_key_names] + columns_to_types_no_keys = [ + c for c in target_columns_to_types if c not in unique_key_names + ] target_columns_no_keys = [ exp.column(c, MERGE_TARGET_ALIAS) for c in columns_to_types_no_keys @@ -263,10 +268,12 @@ def merge( matched=False, source=False, then=exp.Insert( - this=exp.Tuple(expressions=[exp.column(col) for col in columns_to_types]), + this=exp.Tuple( + expressions=[exp.column(col) for col in target_columns_to_types] + ), expression=exp.Tuple( expressions=[ - exp.column(col, MERGE_SOURCE_ALIAS) for col in columns_to_types + exp.column(col, MERGE_SOURCE_ALIAS) for col in target_columns_to_types ] ), ), @@ -305,7 +312,7 @@ def _convert_df_datetime(self, df: DF, columns_to_types: t.Dict[str, exp.DataTyp def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, source_columns: t.Optional[t.List[str]] = None, @@ -319,7 +326,7 @@ def _df_to_source_queries( # Return the superclass implementation if the connection pool doesn't support bulk_copy if not hasattr(self._connection_pool.get(), "bulk_copy"): return super()._df_to_source_queries( - df, columns_to_types, batch_size, target_table, source_columns=source_columns + df, target_columns_to_types, batch_size, target_table, source_columns=source_columns ) def query_factory() -> Query: @@ -328,7 +335,7 @@ def query_factory() -> Query: # as later calls. if not self.table_exists(temp_table): source_columns_to_types = get_source_columns_to_types( - columns_to_types, source_columns + target_columns_to_types, source_columns ) ordered_df = df[ list(source_columns_to_types) @@ -341,7 +348,7 @@ def query_factory() -> Query: conn = self._connection_pool.get() conn.bulk_copy(temp_table.sql(dialect=self.dialect), rows) return exp.select( - *self._casted_columns(columns_to_types, source_columns=source_columns) + *self._casted_columns(target_columns_to_types, source_columns=source_columns) ).from_(temp_table) # type: ignore return [ @@ -402,7 +409,7 @@ def _insert_overwrite_by_condition( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, where: t.Optional[exp.Condition] = None, insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, **kwargs: t.Any, @@ -414,7 +421,7 @@ def _insert_overwrite_by_condition( self, table_name=table_name, source_queries=source_queries, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, where=where, insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, **kwargs, @@ -424,7 +431,7 @@ def _insert_overwrite_by_condition( return super()._insert_overwrite_by_condition( table_name=table_name, source_queries=source_queries, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, where=where, insert_overwrite_strategy_override=insert_overwrite_strategy_override, **kwargs, diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index dd58b4949b..a1ff46e9ad 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -106,7 +106,7 @@ def merge( self, target_table: TableName, source_table: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, @@ -119,7 +119,7 @@ def merge( merge_impl( # type: ignore target_table, source_table, - columns_to_types, + target_columns_to_types, unique_key, when_matched=when_matched, merge_filter=merge_filter, diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 7b6b477d60..2589ef960e 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -168,7 +168,7 @@ def _create_table_from_source_queries( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, exists: bool = True, replace: bool = False, table_description: t.Optional[str] = None, @@ -186,7 +186,7 @@ def _create_table_from_source_queries( return super()._create_table_from_source_queries( table_name, source_queries, - columns_to_types, + target_columns_to_types, exists, table_description=table_description, column_descriptions=column_descriptions, @@ -207,7 +207,7 @@ def create_view( self, view_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, replace: bool = True, materialized: bool = False, materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, @@ -233,7 +233,7 @@ def create_view( return super().create_view( view_name, query_or_df, - columns_to_types, + target_columns_to_types, replace, materialized, materialized_properties, @@ -249,7 +249,7 @@ def replace_query( self, table_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, source_columns: t.Optional[t.List[str]] = None, @@ -274,32 +274,32 @@ def replace_query( return super().replace_query( table_name, query_or_df, - columns_to_types, + target_columns_to_types, table_description, column_descriptions, source_columns=source_columns, **kwargs, ) - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( query_or_df, - columns_to_types, + target_columns_to_types, target_table=table_name, source_columns=source_columns, ) - columns_to_types = columns_to_types or self.columns(table_name) + target_columns_to_types = target_columns_to_types or self.columns(table_name) target_table = exp.to_table(table_name) with self.transaction(): temp_table = self._get_temp_table(target_table) old_table = self._get_temp_table(target_table) self.create_table( temp_table, - columns_to_types, + target_columns_to_types, exists=False, table_description=table_description, column_descriptions=column_descriptions, **kwargs, ) - self._insert_append_source_queries(temp_table, source_queries, columns_to_types) + self._insert_append_source_queries(temp_table, source_queries, target_columns_to_types) self.rename_table(target_table, old_table) self.rename_table(temp_table, target_table) self.drop_table(old_table) @@ -361,7 +361,7 @@ def merge( self, target_table: TableName, source_table: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], when_matched: t.Optional[exp.Whens] = None, merge_filter: t.Optional[exp.Expression] = None, @@ -373,7 +373,7 @@ def merge( super().merge( target_table=target_table, source_table=source_table, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, unique_key=unique_key, when_matched=when_matched, merge_filter=merge_filter, @@ -384,7 +384,7 @@ def merge( self, target_table, source_table, - columns_to_types, + target_columns_to_types, unique_key, when_matched=when_matched, merge_filter=merge_filter, diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 7bf6f3d303..f6fc32cc0a 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -162,7 +162,7 @@ def _create_table( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, @@ -181,7 +181,7 @@ def _create_table( expression=expression, exists=exists, replace=replace, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, @@ -192,7 +192,7 @@ def create_managed_table( self, table_name: TableName, query: Query, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, partitioned_by: t.Optional[t.List[exp.Expression]] = None, clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, @@ -218,14 +218,14 @@ def create_managed_table( "`target_lag` must be specified in the model physical_properties for a Snowflake Dynamic Table" ) - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query, columns_to_types, target_table=target_table, source_columns=source_columns + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query, target_columns_to_types, target_table=target_table, source_columns=source_columns ) self._create_table_from_source_queries( target_table, source_queries, - columns_to_types, + target_columns_to_types, replace=self.SUPPORTS_REPLACE_TABLE, partitioned_by=partitioned_by, clustered_by=clustered_by, @@ -240,7 +240,7 @@ def create_view( self, view_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, replace: bool = True, materialized: bool = False, materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, @@ -259,7 +259,7 @@ def create_view( super().create_view( view_name=view_name, query_or_df=query_or_df, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, replace=replace, materialized=materialized, materialized_properties=materialized_properties, @@ -283,7 +283,7 @@ def _build_table_properties_exp( partition_interval_unit: t.Optional[IntervalUnit] = None, clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, @@ -324,7 +324,7 @@ def _build_table_properties_exp( def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, source_columns: t.Optional[t.List[str]] = None, @@ -332,7 +332,9 @@ def _df_to_source_queries( import pandas as pd from pandas.api.types import is_datetime64_any_dtype - source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) temp_table = self._get_temp_table( target_table or "pandas", quoted=False @@ -416,7 +418,7 @@ def query_factory() -> Query: ) return exp.select( - *self._casted_columns(columns_to_types, source_columns=source_columns) + *self._casted_columns(target_columns_to_types, source_columns=source_columns) ).from_(temp_table) def cleanup() -> None: @@ -626,7 +628,7 @@ def clone_table( def _columns_to_types( self, query_or_df: DF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... @@ -634,24 +636,24 @@ def _columns_to_types( def _columns_to_types( self, query_or_df: Query, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... def _columns_to_types( self, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: - if not columns_to_types and snowpark and isinstance(query_or_df, snowpark.DataFrame): - columns_to_types = columns_to_types_from_dtypes( + if not target_columns_to_types and snowpark and isinstance(query_or_df, snowpark.DataFrame): + target_columns_to_types = columns_to_types_from_dtypes( query_or_df.sample(n=1).to_pandas().dtypes.items() ) - return columns_to_types, list(source_columns or columns_to_types) + return target_columns_to_types, list(source_columns or target_columns_to_types) return super()._columns_to_types( - query_or_df, columns_to_types, source_columns=source_columns + query_or_df, target_columns_to_types, source_columns=source_columns ) def close(self) -> t.Any: diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index f015c0f158..5e37ba075e 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -23,7 +23,7 @@ set_catalog, ) from sqlmesh.core.schema_diff import SchemaDiffer -from sqlmesh.utils import classproperty +from sqlmesh.utils import classproperty, get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError if t.TYPE_CHECKING: @@ -242,7 +242,7 @@ def try_get_pandas_df(cls, value: t.Any) -> t.Optional[pd.DataFrame]: def _columns_to_types( self, query_or_df: DF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... @@ -250,64 +250,64 @@ def _columns_to_types( def _columns_to_types( self, query_or_df: Query, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... def _columns_to_types( self, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: - if columns_to_types: - return columns_to_types, list(source_columns or columns_to_types) + if target_columns_to_types: + return target_columns_to_types, list(source_columns or target_columns_to_types) if self.is_pyspark_df(query_or_df): from pyspark.sql import DataFrame - columns_to_types = self.spark_to_sqlglot_types(t.cast(DataFrame, query_or_df).schema) - return columns_to_types, list(source_columns or columns_to_types) + target_columns_to_types = self.spark_to_sqlglot_types( + t.cast(DataFrame, query_or_df).schema + ) + return target_columns_to_types, list(source_columns or target_columns_to_types) return super()._columns_to_types( - query_or_df, columns_to_types, source_columns=source_columns + query_or_df, target_columns_to_types, source_columns=source_columns ) def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: - df = self._ensure_pyspark_df(df, columns_to_types, source_columns=source_columns) + df = self._ensure_pyspark_df(df, target_columns_to_types, source_columns=source_columns) def query_factory() -> Query: temp_table = self._get_temp_table(target_table or "spark", table_only=True) df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore temp_table.set("db", "global_temp") - return exp.select(*self._select_columns(columns_to_types)).from_(temp_table) + return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table) return [SourceQuery(query_factory=query_factory)] def _ensure_pyspark_df( self, generic_df: DF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, source_columns: t.Optional[t.List[str]] = None, ) -> PySparkDataFrame: - def _get_pyspark_df() -> PySparkDataFrame: - pyspark_df = self.try_get_pyspark_df(generic_df) - if pyspark_df: - return pyspark_df + pyspark_df = self.try_get_pyspark_df(generic_df) + if not pyspark_df: df = self.try_get_pandas_df(generic_df) if df is None: raise SQLMeshError( "Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame" ) - if columns_to_types: - source_columns_to_types = self.get_source_columns_to_types( - columns_to_types, source_columns + if target_columns_to_types: + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns ) # ensure Pandas dataframe column order matches columns_to_types df = df[list(source_columns_to_types)] @@ -318,13 +318,13 @@ def _get_pyspark_df() -> PySparkDataFrame: if source_columns_to_types else {} ) - return self.spark.createDataFrame(df, **kwargs) # type: ignore - - df_result = _get_pyspark_df() - if columns_to_types: - select_columns = self._casted_columns(columns_to_types, source_columns=source_columns) - df_result = df_result.selectExpr(*[x.sql(self.dialect) for x in select_columns]) # type: ignore - return df_result + pyspark_df = self.spark.createDataFrame(df, **kwargs) # type: ignore + if target_columns_to_types: + select_columns = self._casted_columns( + target_columns_to_types, source_columns=source_columns + ) + pyspark_df = pyspark_df.selectExpr(*[x.sql(self.dialect) for x in select_columns]) # type: ignore + return pyspark_df def _get_temp_table( self, table: TableName, table_only: bool = False, quoted: bool = True @@ -405,12 +405,12 @@ def get_current_database(self) -> str: def create_state_table( self, table_name: str, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], primary_key: t.Optional[t.Tuple[str, ...]] = None, ) -> None: self.create_table( table_name, - columns_to_types, + target_columns_to_types, partitioned_by=[exp.column(x) for x in primary_key] if primary_key else None, ) @@ -429,7 +429,7 @@ def _create_table( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, @@ -458,7 +458,7 @@ def _create_table( expression, exists=exists, replace=replace, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, **kwargs, diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 592fe41109..e16cf2d76c 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -118,7 +118,7 @@ def _insert_overwrite_by_condition( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, where: t.Optional[exp.Condition] = None, insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, **kwargs: t.Any, @@ -131,14 +131,14 @@ def _insert_overwrite_by_condition( # "Session property 'catalog.insert_existing_partitions_behavior' does not exist" self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='OVERWRITE'") super()._insert_overwrite_by_condition( - table_name, source_queries, columns_to_types, where + table_name, source_queries, target_columns_to_types, where ) self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='APPEND'") else: super()._insert_overwrite_by_condition( table_name, source_queries, - columns_to_types, + target_columns_to_types, where, insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, ) @@ -216,7 +216,7 @@ def _get_data_objects( def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, source_columns: t.Optional[t.List[str]] = None, @@ -225,7 +225,9 @@ def _df_to_source_queries( from pandas.api.types import is_datetime64_any_dtype # type: ignore assert isinstance(df, pd.DataFrame) - source_columns_to_types = get_source_columns_to_types(columns_to_types, source_columns) + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) # Trino does not accept timestamps in ISOFORMAT that include the "T". `execution_time` is stored in # Pandas with that format, so we convert the column to a string with the proper format and CAST to @@ -236,22 +238,22 @@ def _df_to_source_queries( df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" ")) return super()._df_to_source_queries( - df, columns_to_types, batch_size, target_table, source_columns=source_columns + df, target_columns_to_types, batch_size, target_table, source_columns=source_columns ) def _build_schema_exp( self, table: exp.Table, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], column_descriptions: t.Optional[t.Dict[str, str]] = None, expressions: t.Optional[t.List[exp.PrimaryKey]] = None, is_view: bool = False, ) -> exp.Schema: if self.current_catalog_type == "delta_lake": - columns_to_types = self._to_delta_ts(columns_to_types) + target_columns_to_types = self._to_delta_ts(target_columns_to_types) return super()._build_schema_exp( - table, columns_to_types, column_descriptions, expressions, is_view + table, target_columns_to_types, column_descriptions, expressions, is_view ) def _scd_type_2( @@ -267,15 +269,15 @@ def _scd_type_2( check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None, updated_at_as_valid_from: bool = False, execution_time_as_valid_from: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: - if columns_to_types and self.current_catalog_type == "delta_lake": - columns_to_types = self._to_delta_ts(columns_to_types) + if target_columns_to_types and self.current_catalog_type == "delta_lake": + target_columns_to_types = self._to_delta_ts(target_columns_to_types) return super()._scd_type_2( target_table, @@ -289,7 +291,7 @@ def _scd_type_2( check_columns, updated_at_as_valid_from, execution_time_as_valid_from, - columns_to_types, + target_columns_to_types, table_description, column_descriptions, truncate, @@ -351,7 +353,7 @@ def _create_table( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, @@ -362,7 +364,7 @@ def _create_table( expression=expression, exists=exists, replace=replace, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, table_kind=table_kind, diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 0c15cdf26d..0cb45c1860 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -1528,7 +1528,7 @@ def append( self.adapter.insert_append( table_name, query_or_df, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=source_columns, ) @@ -1547,7 +1547,7 @@ def create( if model.annotated: self.adapter.create_table( table_name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_format=model.table_format, storage_format=model.storage_format, partitioned_by=model.partitioned_by, @@ -1626,7 +1626,7 @@ def _replace_query_for_model( else: # Source columns from the underlying table to prevent unintentional table schema changes during restatement of incremental models. columns_to_types, source_columns = self._get_target_and_source_columns( - model, name, render_kwargs, columns_to_types=self.adapter.columns(name) + model, name, render_kwargs, force_get_columns_from_target=True ) self.adapter.replace_query( @@ -1640,7 +1640,7 @@ def _replace_query_for_model( table_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description, column_descriptions=model.column_descriptions, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=source_columns, ) @@ -1649,13 +1649,20 @@ def _get_target_and_source_columns( model: Model, table_name: str, render_kwargs: t.Dict[str, t.Any], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_column_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + force_get_columns_from_target: bool = False, ) -> t.Tuple[t.Dict[str, exp.DataType], t.Optional[t.List[str]]]: - if not columns_to_types: - columns_to_types = ( - model.columns_to_types if model.annotated else self.adapter.columns(table_name) + if force_get_columns_from_target: + target_column_to_types = self.adapter.columns(table_name) + elif target_column_to_types: + target_column_to_types = target_column_to_types + else: + target_column_to_types = ( + model.columns_to_types + if model.annotated and not model.on_destructive_change.is_ignore + else self.adapter.columns(table_name) ) - assert columns_to_types is not None + assert target_column_to_types is not None if model.on_destructive_change.is_ignore: # We need to identify the columns that are only in the source so we create an empty table with # the user query to determine that @@ -1663,7 +1670,7 @@ def _get_target_and_source_columns( source_columns = list(self.adapter.columns(temp_table)) else: source_columns = None - return columns_to_types, source_columns + return target_column_to_types, source_columns class IncrementalByPartitionStrategy(MaterializableStrategy): @@ -1686,7 +1693,7 @@ def insert( table_name, query_or_df, partitioned_by=model.partitioned_by, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=source_columns, ) @@ -1710,7 +1717,7 @@ def insert( query_or_df, time_formatter=model.convert_to_time_column, time_column=model.time_column, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=source_columns, **kwargs, ) @@ -1737,7 +1744,7 @@ def insert( self.adapter.merge( table_name, query_or_df, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, unique_key=model.unique_key, when_matched=model.when_matched, merge_filter=model.render_merge_filter( @@ -1763,7 +1770,7 @@ def append( self.adapter.merge( table_name, query_or_df, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, unique_key=model.unique_key, when_matched=model.when_matched, merge_filter=model.render_merge_filter( @@ -1794,14 +1801,14 @@ def insert( model, table_name, render_kwargs=render_kwargs, - columns_to_types=kwargs.pop("columns_to_types", None), + target_column_to_types=kwargs.pop("columns_to_types", None), ) if isinstance(model.kind, IncrementalUnmanagedKind) and model.kind.insert_overwrite: return self.adapter.insert_overwrite_by_partition( table_name, query_or_df, model.partitioned_by, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=source_columns, ) return self.append( @@ -1859,7 +1866,7 @@ def create( ) else: self.adapter.insert_append( - table_name, df, columns_to_types=model.columns_to_types + table_name, df, target_columns_to_types=model.columns_to_types ) except Exception: self.adapter.drop_table(table_name) @@ -1895,7 +1902,7 @@ def create( columns_to_types[model.kind.updated_at_name.name] = model.kind.time_data_type self.adapter.create_table( table_name, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_format=model.table_format, storage_format=model.storage_format, partitioned_by=model.partitioned_by, @@ -1931,7 +1938,7 @@ def insert( model, table_name, render_kwargs=render_kwargs, - columns_to_types=self.adapter.columns(table_name), + force_get_columns_from_target=True, ) if isinstance(model.kind, SCDType2ByTimeKind): self.adapter.scd_type_2_by_time( @@ -1944,7 +1951,7 @@ def insert( updated_at_col=model.kind.updated_at_name, invalidate_hard_deletes=model.kind.invalidate_hard_deletes, updated_at_as_valid_from=model.kind.updated_at_as_valid_from, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_format=model.table_format, table_description=model.description, column_descriptions=model.column_descriptions, @@ -1962,7 +1969,7 @@ def insert( check_columns=model.kind.columns, invalidate_hard_deletes=model.kind.invalidate_hard_deletes, execution_time_as_valid_from=model.kind.execution_time_as_valid_from, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_format=model.table_format, table_description=model.description, column_descriptions=model.column_descriptions, @@ -1987,7 +1994,7 @@ def append( model, table_name, render_kwargs=render_kwargs, - columns_to_types=self.adapter.columns(table_name), + force_get_columns_from_target=True, ) if isinstance(model.kind, SCDType2ByTimeKind): self.adapter.scd_type_2_by_time( @@ -1999,7 +2006,7 @@ def append( updated_at_col=model.kind.updated_at_name, invalidate_hard_deletes=model.kind.invalidate_hard_deletes, updated_at_as_valid_from=model.kind.updated_at_as_valid_from, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_format=model.table_format, table_description=model.description, column_descriptions=model.column_descriptions, @@ -2014,7 +2021,7 @@ def append( valid_from_col=model.kind.valid_from_name, valid_to_col=model.kind.valid_to_name, check_columns=model.kind.columns, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_format=model.table_format, invalidate_hard_deletes=model.kind.invalidate_hard_deletes, execution_time_as_valid_from=model.kind.execution_time_as_valid_from, @@ -2296,7 +2303,7 @@ def create( self.adapter.create_managed_table( table_name=table_name, query=model.render_query_or_raise(**render_kwargs), - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, partitioned_by=model.partitioned_by, clustered_by=model.clustered_by, table_properties=kwargs.get("physical_properties", model.physical_properties), @@ -2334,7 +2341,7 @@ def insert( self.adapter.create_managed_table( table_name=table_name, query=query_or_df, # type: ignore - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, partitioned_by=model.partitioned_by, clustered_by=model.clustered_by, table_properties=kwargs.get("physical_properties", model.physical_properties), diff --git a/sqlmesh/core/state_sync/db/environment.py b/sqlmesh/core/state_sync/db/environment.py index b7e8128a93..3196d18078 100644 --- a/sqlmesh/core/state_sync/db/environment.py +++ b/sqlmesh/core/state_sync/db/environment.py @@ -77,7 +77,7 @@ def update_environment(self, environment: Environment) -> None: self.engine_adapter.insert_append( self.environments_table, _environment_to_df(environment), - columns_to_types=self._environment_columns_to_types, + target_columns_to_types=self._environment_columns_to_types, ) def update_environment_statements( @@ -107,7 +107,7 @@ def update_environment_statements( self.engine_adapter.insert_append( self.environment_statements_table, _environment_statements_to_df(environment_name, plan_id, environment_statements), - columns_to_types=self._environment_statements_columns_to_types, + target_columns_to_types=self._environment_statements_columns_to_types, ) def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: diff --git a/sqlmesh/core/state_sync/db/interval.py b/sqlmesh/core/state_sync/db/interval.py index bebf41d453..bdfedace1e 100644 --- a/sqlmesh/core/state_sync/db/interval.py +++ b/sqlmesh/core/state_sync/db/interval.py @@ -114,7 +114,7 @@ def remove_intervals( self.engine_adapter.insert_append( self.intervals_table, _intervals_to_df(intervals_to_remove, is_dev=False, is_removed=True), - columns_to_types=self._interval_columns_to_types, + target_columns_to_types=self._interval_columns_to_types, ) def get_snapshot_intervals( @@ -242,7 +242,7 @@ def _push_snapshot_intervals( self.engine_adapter.insert_append( self.intervals_table, pd.DataFrame(new_intervals), - columns_to_types=self._interval_columns_to_types, + target_columns_to_types=self._interval_columns_to_types, ) def _get_snapshot_intervals( diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 3745a27bb3..9cf4f2fbf5 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -102,7 +102,7 @@ def push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = Fals self.engine_adapter.insert_append( self.snapshots_table, _snapshots_to_df(snapshots_to_store), - columns_to_types=self._snapshot_columns_to_types, + target_columns_to_types=self._snapshot_columns_to_types, ) for snapshot in snapshots: @@ -363,7 +363,7 @@ def update_auto_restatements( self.engine_adapter.merge( self.auto_restatements_table, _auto_restatements_to_df(next_auto_restatement_ts_filtered), - columns_to_types=self._auto_restatement_columns_to_types, + target_columns_to_types=self._auto_restatement_columns_to_types, unique_key=(exp.column("snapshot_name"), exp.column("snapshot_version")), ) @@ -405,7 +405,7 @@ def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: self.engine_adapter.insert_append( self.snapshots_table, _snapshots_to_df(snapshots_to_store), - columns_to_types=self._snapshot_columns_to_types, + target_columns_to_types=self._snapshot_columns_to_types, ) def _get_snapshots( diff --git a/sqlmesh/core/state_sync/db/version.py b/sqlmesh/core/state_sync/db/version.py index 873e1633df..492d74cc09 100644 --- a/sqlmesh/core/state_sync/db/version.py +++ b/sqlmesh/core/state_sync/db/version.py @@ -54,7 +54,7 @@ def update_versions( } ] ), - columns_to_types=self._version_columns_to_types, + target_columns_to_types=self._version_columns_to_types, ) def get_versions(self) -> Versions: diff --git a/sqlmesh/core/table_diff.py b/sqlmesh/core/table_diff.py index b48b852f8a..b9dfadc075 100644 --- a/sqlmesh/core/table_diff.py +++ b/sqlmesh/core/table_diff.py @@ -512,7 +512,7 @@ def _column_expr(name: str, table: str) -> exp.Expression: ) with self.adapter.temp_table( - query, name=temp_table, columns_to_types=None, **temp_table_kwargs + query, name=temp_table, target_columns_to_types=None, **temp_table_kwargs ) as table: summary_sums = [ exp.func("SUM", "s_exists").as_("s_count"), diff --git a/sqlmesh/migrations/v0007_env_table_info_to_kind.py b/sqlmesh/migrations/v0007_env_table_info_to_kind.py index 61335a0c51..f09f0d2b72 100644 --- a/sqlmesh/migrations/v0007_env_table_info_to_kind.py +++ b/sqlmesh/migrations/v0007_env_table_info_to_kind.py @@ -86,7 +86,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( environments_table, pd.DataFrame(new_environments), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "snapshots": exp.DataType.build("text"), "start_at": exp.DataType.build("text"), diff --git a/sqlmesh/migrations/v0009_remove_pre_post_hooks.py b/sqlmesh/migrations/v0009_remove_pre_post_hooks.py index 05d50c0932..3671f547d3 100644 --- a/sqlmesh/migrations/v0009_remove_pre_post_hooks.py +++ b/sqlmesh/migrations/v0009_remove_pre_post_hooks.py @@ -53,7 +53,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0011_add_model_kind_name.py b/sqlmesh/migrations/v0011_add_model_kind_name.py index 298d4b61ee..77aa68506a 100644 --- a/sqlmesh/migrations/v0011_add_model_kind_name.py +++ b/sqlmesh/migrations/v0011_add_model_kind_name.py @@ -53,7 +53,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0012_update_jinja_expressions.py b/sqlmesh/migrations/v0012_update_jinja_expressions.py index 4f6f04fba5..28bc4acdca 100644 --- a/sqlmesh/migrations/v0012_update_jinja_expressions.py +++ b/sqlmesh/migrations/v0012_update_jinja_expressions.py @@ -57,7 +57,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0013_serde_using_model_dialects.py b/sqlmesh/migrations/v0013_serde_using_model_dialects.py index 6f03767061..7e5e2cc217 100644 --- a/sqlmesh/migrations/v0013_serde_using_model_dialects.py +++ b/sqlmesh/migrations/v0013_serde_using_model_dialects.py @@ -55,7 +55,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0016_fix_windows_path.py b/sqlmesh/migrations/v0016_fix_windows_path.py index fb40d30076..e37c45afca 100644 --- a/sqlmesh/migrations/v0016_fix_windows_path.py +++ b/sqlmesh/migrations/v0016_fix_windows_path.py @@ -49,7 +49,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0017_fix_windows_seed_path.py b/sqlmesh/migrations/v0017_fix_windows_seed_path.py index ca693bab72..5d91443009 100644 --- a/sqlmesh/migrations/v0017_fix_windows_seed_path.py +++ b/sqlmesh/migrations/v0017_fix_windows_seed_path.py @@ -45,7 +45,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0018_rename_snapshot_model_to_node.py b/sqlmesh/migrations/v0018_rename_snapshot_model_to_node.py index de8f157ebb..5229c54f81 100644 --- a/sqlmesh/migrations/v0018_rename_snapshot_model_to_node.py +++ b/sqlmesh/migrations/v0018_rename_snapshot_model_to_node.py @@ -43,7 +43,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0020_remove_redundant_attributes_from_dbt_models.py b/sqlmesh/migrations/v0020_remove_redundant_attributes_from_dbt_models.py index c6beeb7d0a..d4c449ff34 100644 --- a/sqlmesh/migrations/v0020_remove_redundant_attributes_from_dbt_models.py +++ b/sqlmesh/migrations/v0020_remove_redundant_attributes_from_dbt_models.py @@ -48,7 +48,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0021_fix_table_properties.py b/sqlmesh/migrations/v0021_fix_table_properties.py index 36bcbdcc82..41429b5650 100644 --- a/sqlmesh/migrations/v0021_fix_table_properties.py +++ b/sqlmesh/migrations/v0021_fix_table_properties.py @@ -52,7 +52,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0022_move_project_to_model.py b/sqlmesh/migrations/v0022_move_project_to_model.py index 8da19049af..a5a529ef31 100644 --- a/sqlmesh/migrations/v0022_move_project_to_model.py +++ b/sqlmesh/migrations/v0022_move_project_to_model.py @@ -44,7 +44,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0024_replace_model_kind_name_enum_with_value.py b/sqlmesh/migrations/v0024_replace_model_kind_name_enum_with_value.py index 2855ecebb2..abdbb716ea 100644 --- a/sqlmesh/migrations/v0024_replace_model_kind_name_enum_with_value.py +++ b/sqlmesh/migrations/v0024_replace_model_kind_name_enum_with_value.py @@ -45,7 +45,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0025_fix_intervals_and_missing_change_category.py b/sqlmesh/migrations/v0025_fix_intervals_and_missing_change_category.py index 7c794abdaa..b99e208806 100644 --- a/sqlmesh/migrations/v0025_fix_intervals_and_missing_change_category.py +++ b/sqlmesh/migrations/v0025_fix_intervals_and_missing_change_category.py @@ -85,7 +85,7 @@ def _add_interval(start_ts: int, end_ts: int, is_dev: bool) -> None: engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), @@ -98,7 +98,7 @@ def _add_interval(start_ts: int, end_ts: int, is_dev: bool) -> None: engine_adapter.insert_append( intervals_table, pd.DataFrame(new_intervals), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build(index_type), "created_ts": exp.DataType.build("bigint"), "name": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0026_remove_dialect_from_seed.py b/sqlmesh/migrations/v0026_remove_dialect_from_seed.py index c06eeb4bca..73ec09aa76 100644 --- a/sqlmesh/migrations/v0026_remove_dialect_from_seed.py +++ b/sqlmesh/migrations/v0026_remove_dialect_from_seed.py @@ -45,7 +45,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0027_minute_interval_to_five.py b/sqlmesh/migrations/v0027_minute_interval_to_five.py index f92ffcb929..ce8b272734 100644 --- a/sqlmesh/migrations/v0027_minute_interval_to_five.py +++ b/sqlmesh/migrations/v0027_minute_interval_to_five.py @@ -47,7 +47,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0029_generate_schema_types_using_dialect.py b/sqlmesh/migrations/v0029_generate_schema_types_using_dialect.py index b7f58dc67f..1f2dda5f5f 100644 --- a/sqlmesh/migrations/v0029_generate_schema_types_using_dialect.py +++ b/sqlmesh/migrations/v0029_generate_schema_types_using_dialect.py @@ -46,7 +46,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0030_update_unrestorable_snapshots.py b/sqlmesh/migrations/v0030_update_unrestorable_snapshots.py index c2b6f545bc..3cd27d2ee2 100644 --- a/sqlmesh/migrations/v0030_update_unrestorable_snapshots.py +++ b/sqlmesh/migrations/v0030_update_unrestorable_snapshots.py @@ -55,7 +55,7 @@ def migrate(state_sync: t.Any, **kwargs: t.Any) -> None: # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0031_remove_dbt_target_fields.py b/sqlmesh/migrations/v0031_remove_dbt_target_fields.py index 92137a4973..d13ec92e0b 100644 --- a/sqlmesh/migrations/v0031_remove_dbt_target_fields.py +++ b/sqlmesh/migrations/v0031_remove_dbt_target_fields.py @@ -55,7 +55,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0034_add_default_catalog.py b/sqlmesh/migrations/v0034_add_default_catalog.py index 85a97b1134..d6469fa4b1 100644 --- a/sqlmesh/migrations/v0034_add_default_catalog.py +++ b/sqlmesh/migrations/v0034_add_default_catalog.py @@ -161,7 +161,7 @@ def migrate(state_sync, default_catalog: t.Optional[str], **kwargs): # type: ig engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), @@ -241,7 +241,7 @@ def migrate(state_sync, default_catalog: t.Optional[str], **kwargs): # type: ig engine_adapter.insert_append( environments_table, pd.DataFrame(new_environments), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "snapshots": exp.DataType.build(blob_type), "start_at": exp.DataType.build("text"), @@ -316,7 +316,7 @@ def migrate(state_sync, default_catalog: t.Optional[str], **kwargs): # type: ig engine_adapter.insert_append( intervals_table, pd.DataFrame(new_intervals), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build(index_type), "created_ts": exp.DataType.build("bigint"), "name": exp.DataType.build(index_type), @@ -359,7 +359,7 @@ def migrate(state_sync, default_catalog: t.Optional[str], **kwargs): # type: ig engine_adapter.insert_append( seeds_table, pd.DataFrame(new_seeds), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "content": exp.DataType.build("text"), diff --git a/sqlmesh/migrations/v0037_remove_dbt_is_incremental_macro.py b/sqlmesh/migrations/v0037_remove_dbt_is_incremental_macro.py index 86fbc986ec..6ca7bef406 100644 --- a/sqlmesh/migrations/v0037_remove_dbt_is_incremental_macro.py +++ b/sqlmesh/migrations/v0037_remove_dbt_is_incremental_macro.py @@ -51,7 +51,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0038_add_expiration_ts_to_snapshot.py b/sqlmesh/migrations/v0038_add_expiration_ts_to_snapshot.py index 9f27239f41..54bb30a54b 100644 --- a/sqlmesh/migrations/v0038_add_expiration_ts_to_snapshot.py +++ b/sqlmesh/migrations/v0038_add_expiration_ts_to_snapshot.py @@ -62,7 +62,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0039_include_environment_in_plan_dag_spec.py b/sqlmesh/migrations/v0039_include_environment_in_plan_dag_spec.py index 10da4e18e5..39fc6b6a0f 100644 --- a/sqlmesh/migrations/v0039_include_environment_in_plan_dag_spec.py +++ b/sqlmesh/migrations/v0039_include_environment_in_plan_dag_spec.py @@ -60,7 +60,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( plan_dags_table, pd.DataFrame(new_specs), - columns_to_types={ + target_columns_to_types={ "request_id": exp.DataType.build(index_type), "dag_id": exp.DataType.build(index_type), "dag_spec": exp.DataType.build(blob_type), diff --git a/sqlmesh/migrations/v0041_remove_hash_raw_query_attribute.py b/sqlmesh/migrations/v0041_remove_hash_raw_query_attribute.py index ad47b63724..fee9ac2955 100644 --- a/sqlmesh/migrations/v0041_remove_hash_raw_query_attribute.py +++ b/sqlmesh/migrations/v0041_remove_hash_raw_query_attribute.py @@ -48,7 +48,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0042_trim_indirect_versions.py b/sqlmesh/migrations/v0042_trim_indirect_versions.py index 37b6bef570..6759e8140d 100644 --- a/sqlmesh/migrations/v0042_trim_indirect_versions.py +++ b/sqlmesh/migrations/v0042_trim_indirect_versions.py @@ -55,7 +55,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0043_fix_remove_obsolete_attributes_in_plan_dags.py b/sqlmesh/migrations/v0043_fix_remove_obsolete_attributes_in_plan_dags.py index 4054f34f40..8b27e90963 100644 --- a/sqlmesh/migrations/v0043_fix_remove_obsolete_attributes_in_plan_dags.py +++ b/sqlmesh/migrations/v0043_fix_remove_obsolete_attributes_in_plan_dags.py @@ -53,7 +53,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( plan_dags_table, pd.DataFrame(new_dag_specs), - columns_to_types={ + target_columns_to_types={ "request_id": exp.DataType.build(index_type), "dag_id": exp.DataType.build(index_type), "dag_spec": exp.DataType.build(blob_type), diff --git a/sqlmesh/migrations/v0045_move_gateway_variable.py b/sqlmesh/migrations/v0045_move_gateway_variable.py index bd00e40404..12115e03e0 100644 --- a/sqlmesh/migrations/v0045_move_gateway_variable.py +++ b/sqlmesh/migrations/v0045_move_gateway_variable.py @@ -59,7 +59,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0048_drop_indirect_versions.py b/sqlmesh/migrations/v0048_drop_indirect_versions.py index e5fe9a28ab..991fb43827 100644 --- a/sqlmesh/migrations/v0048_drop_indirect_versions.py +++ b/sqlmesh/migrations/v0048_drop_indirect_versions.py @@ -48,7 +48,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0051_rename_column_descriptions.py b/sqlmesh/migrations/v0051_rename_column_descriptions.py index 627e58b4b9..a6b4b72577 100644 --- a/sqlmesh/migrations/v0051_rename_column_descriptions.py +++ b/sqlmesh/migrations/v0051_rename_column_descriptions.py @@ -54,7 +54,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0055_add_updated_ts_unpaused_ts_ttl_ms_unrestorable_to_snapshot.py b/sqlmesh/migrations/v0055_add_updated_ts_unpaused_ts_ttl_ms_unrestorable_to_snapshot.py index 1c127b496b..b323afa04f 100644 --- a/sqlmesh/migrations/v0055_add_updated_ts_unpaused_ts_ttl_ms_unrestorable_to_snapshot.py +++ b/sqlmesh/migrations/v0055_add_updated_ts_unpaused_ts_ttl_ms_unrestorable_to_snapshot.py @@ -118,7 +118,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0056_restore_table_indexes.py b/sqlmesh/migrations/v0056_restore_table_indexes.py index d6fab1669b..4ffec4e9cb 100644 --- a/sqlmesh/migrations/v0056_restore_table_indexes.py +++ b/sqlmesh/migrations/v0056_restore_table_indexes.py @@ -79,7 +79,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( new_snapshots_table, exp.select("*").from_(snapshots_table), - columns_to_types=snapshots_columns_to_types, + target_columns_to_types=snapshots_columns_to_types, ) # Recreate the environments table and its indexes. @@ -89,7 +89,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( new_environments_table, exp.select("*").from_(environments_table), - columns_to_types=environments_columns_to_types, + target_columns_to_types=environments_columns_to_types, ) # Recreate the intervals table and its indexes. @@ -105,7 +105,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( new_intervals_table, exp.select("*").from_(intervals_table), - columns_to_types=intervals_columns_to_types, + target_columns_to_types=intervals_columns_to_types, ) # Drop old tables. diff --git a/sqlmesh/migrations/v0060_move_audits_to_model.py b/sqlmesh/migrations/v0060_move_audits_to_model.py index 31da86999e..ca61055579 100644 --- a/sqlmesh/migrations/v0060_move_audits_to_model.py +++ b/sqlmesh/migrations/v0060_move_audits_to_model.py @@ -72,7 +72,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0063_change_signals.py b/sqlmesh/migrations/v0063_change_signals.py index 48a5bd1998..cf01bd2420 100644 --- a/sqlmesh/migrations/v0063_change_signals.py +++ b/sqlmesh/migrations/v0063_change_signals.py @@ -84,7 +84,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0064_join_when_matched_strings.py b/sqlmesh/migrations/v0064_join_when_matched_strings.py index 6ca187be30..455bf9e2c0 100644 --- a/sqlmesh/migrations/v0064_join_when_matched_strings.py +++ b/sqlmesh/migrations/v0064_join_when_matched_strings.py @@ -71,7 +71,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0069_update_dev_table_suffix.py b/sqlmesh/migrations/v0069_update_dev_table_suffix.py index 57d0daaddd..1d714a5ba2 100644 --- a/sqlmesh/migrations/v0069_update_dev_table_suffix.py +++ b/sqlmesh/migrations/v0069_update_dev_table_suffix.py @@ -85,7 +85,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types=snapshots_columns_to_types, + target_columns_to_types=snapshots_columns_to_types, ) new_environments = [] @@ -144,7 +144,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( environments_table, pd.DataFrame(new_environments), - columns_to_types=environments_columns_to_types, + target_columns_to_types=environments_columns_to_types, ) diff --git a/sqlmesh/migrations/v0071_add_dev_version_to_intervals.py b/sqlmesh/migrations/v0071_add_dev_version_to_intervals.py index e1b7b32f37..7e14b2d4e1 100644 --- a/sqlmesh/migrations/v0071_add_dev_version_to_intervals.py +++ b/sqlmesh/migrations/v0071_add_dev_version_to_intervals.py @@ -137,7 +137,7 @@ def _migrate_intervals( engine_adapter.insert_append( intervals_table, pd.DataFrame(new_intervals), - columns_to_types=intervals_columns_to_types, + target_columns_to_types=intervals_columns_to_types, ) @@ -215,7 +215,7 @@ def _migrate_snapshots( engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types=snapshots_columns_to_types, + target_columns_to_types=snapshots_columns_to_types, ) diff --git a/sqlmesh/migrations/v0073_remove_symbolic_disable_restatement.py b/sqlmesh/migrations/v0073_remove_symbolic_disable_restatement.py index 98d9582bdc..a460399378 100644 --- a/sqlmesh/migrations/v0073_remove_symbolic_disable_restatement.py +++ b/sqlmesh/migrations/v0073_remove_symbolic_disable_restatement.py @@ -69,5 +69,5 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types=snapshots_columns_to_types, + target_columns_to_types=snapshots_columns_to_types, ) diff --git a/sqlmesh/migrations/v0075_remove_validate_query.py b/sqlmesh/migrations/v0075_remove_validate_query.py index aa9c3fccb3..137430bec4 100644 --- a/sqlmesh/migrations/v0075_remove_validate_query.py +++ b/sqlmesh/migrations/v0075_remove_validate_query.py @@ -69,7 +69,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0081_update_partitioned_by.py b/sqlmesh/migrations/v0081_update_partitioned_by.py index d6fd2dd669..e5c98bd8e3 100644 --- a/sqlmesh/migrations/v0081_update_partitioned_by.py +++ b/sqlmesh/migrations/v0081_update_partitioned_by.py @@ -78,7 +78,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0085_deterministic_repr.py b/sqlmesh/migrations/v0085_deterministic_repr.py index 4c86969843..b5f0203c6d 100644 --- a/sqlmesh/migrations/v0085_deterministic_repr.py +++ b/sqlmesh/migrations/v0085_deterministic_repr.py @@ -117,7 +117,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0087_normalize_blueprint_variables.py b/sqlmesh/migrations/v0087_normalize_blueprint_variables.py index 8878bc8019..12648b5a2e 100644 --- a/sqlmesh/migrations/v0087_normalize_blueprint_variables.py +++ b/sqlmesh/migrations/v0087_normalize_blueprint_variables.py @@ -124,7 +124,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/migrations/v0090_add_forward_only_column.py b/sqlmesh/migrations/v0090_add_forward_only_column.py index 32efc14eed..cdc3fc857a 100644 --- a/sqlmesh/migrations/v0090_add_forward_only_column.py +++ b/sqlmesh/migrations/v0090_add_forward_only_column.py @@ -85,7 +85,7 @@ def migrate(state_sync, **kwargs): # type: ignore engine_adapter.insert_append( snapshots_table, pd.DataFrame(new_snapshots), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), diff --git a/sqlmesh/utils/__init__.py b/sqlmesh/utils/__init__.py index b0a3b566d5..c220de4847 100644 --- a/sqlmesh/utils/__init__.py +++ b/sqlmesh/utils/__init__.py @@ -409,4 +409,9 @@ def get_source_columns_to_types( columns_to_types: t.Dict[str, exp.DataType], source_columns: t.Optional[t.List[str]], ) -> t.Dict[str, exp.DataType]: - return {k: v for k, v in columns_to_types.items() if not source_columns or k in source_columns} + source_column_lookup = set(source_columns) if source_columns else None + return { + k: v + for k, v in columns_to_types.items() + if not source_column_lookup or k in source_column_lookup + } diff --git a/tests/core/engine_adapter/integration/__init__.py b/tests/core/engine_adapter/integration/__init__.py index 15339eeaa6..50437338ae 100644 --- a/tests/core/engine_adapter/integration/__init__.py +++ b/tests/core/engine_adapter/integration/__init__.py @@ -340,7 +340,7 @@ def input_data( list(data.itertuples(index=False, name=None)), batch_start=0, batch_end=sys.maxsize, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, ) if self.test_type == "df": formatted_df = self._format_df(data, to_datetime=self.dialect != "trino") diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index f97298cf2d..1b7d54a2d9 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -283,7 +283,7 @@ def test_ctas_source_columns(ctx_query_and_df: TestContext): table_description="test table description", column_descriptions={"id": "test id column description"}, table_format=ctx.default_table_format, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=["id", "ds"], ) @@ -372,7 +372,7 @@ def test_create_view_source_columns(ctx_query_and_df: TestContext): table_description="test view description", column_descriptions={"id": "test id column description"}, source_columns=["id", "ds"], - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, ) expected_data = input_data.copy() @@ -470,7 +470,7 @@ def test_nan_roundtrip(ctx_df: TestContext): ctx.engine_adapter.replace_query( table, ctx.input_data(input_data), - columns_to_types=ctx.columns_to_types, + target_columns_to_types=ctx.columns_to_types, ) results = ctx.get_metadata_results() assert not results.views @@ -502,7 +502,9 @@ def test_replace_query(ctx_query_and_df: TestContext): # provided then it checks the table itself for types. This is fine within SQLMesh since we always know the tables # exist prior to evaluation but when running these tests that isn't the case. As a result we just pass in # columns_to_types for these two engines so we can still test inference on the other ones - columns_to_types=ctx.columns_to_types if ctx.dialect in ["spark", "databricks"] else None, + target_columns_to_types=ctx.columns_to_types + if ctx.dialect in ["spark", "databricks"] + else None, table_format=ctx.default_table_format, ) results = ctx.get_metadata_results() @@ -524,7 +526,7 @@ def test_replace_query(ctx_query_and_df: TestContext): ctx.engine_adapter.replace_query( table, ctx.input_data(replace_data), - columns_to_types=( + target_columns_to_types=( ctx.columns_to_types if ctx.dialect in ["spark", "databricks"] else None ), table_format=ctx.default_table_format, @@ -559,7 +561,7 @@ def test_replace_query_source_columns(ctx_query_and_df: TestContext): ctx.input_data(input_data), table_format=ctx.default_table_format, source_columns=["id", "ds"], - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, ) expected_data = input_data.copy() expected_data["ignored_column"] = pd.Series() @@ -585,7 +587,7 @@ def test_replace_query_source_columns(ctx_query_and_df: TestContext): ctx.input_data(replace_data), table_format=ctx.default_table_format, source_columns=["id", "ds"], - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, ) expected_data = replace_data.copy() expected_data["ignored_column"] = pd.Series() @@ -620,7 +622,9 @@ def test_replace_query_batched(ctx_query_and_df: TestContext): # provided then it checks the table itself for types. This is fine within SQLMesh since we always know the tables # exist prior to evaluation but when running these tests that isn't the case. As a result we just pass in # columns_to_types for these two engines so we can still test inference on the other ones - columns_to_types=ctx.columns_to_types if ctx.dialect in ["spark", "databricks"] else None, + target_columns_to_types=ctx.columns_to_types + if ctx.dialect in ["spark", "databricks"] + else None, table_format=ctx.default_table_format, ) results = ctx.get_metadata_results() @@ -642,7 +646,7 @@ def test_replace_query_batched(ctx_query_and_df: TestContext): ctx.engine_adapter.replace_query( table, ctx.input_data(replace_data), - columns_to_types=( + target_columns_to_types=( ctx.columns_to_types if ctx.dialect in ["spark", "databricks"] else None ), table_format=ctx.default_table_format, @@ -714,7 +718,7 @@ def test_insert_append_source_columns(ctx_query_and_df: TestContext): table, ctx.input_data(input_data), source_columns=["id", "ds"], - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, ) expected_data = input_data.copy() expected_data["ignored_column"] = pd.Series() @@ -738,7 +742,7 @@ def test_insert_append_source_columns(ctx_query_and_df: TestContext): table, ctx.input_data(append_data), source_columns=["id", "ds"], - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, ) append_expected_data = append_data.copy() append_expected_data["ignored_column"] = pd.Series() @@ -786,7 +790,7 @@ def test_insert_overwrite_by_time_partition(ctx_query_and_df: TestContext): end="2022-01-03", time_formatter=ctx.time_formatter, time_column=ctx.time_column, - columns_to_types=ctx.columns_to_types, + target_columns_to_types=ctx.columns_to_types, ) results = ctx.get_metadata_results() assert len(results.views) == 0 @@ -815,7 +819,7 @@ def test_insert_overwrite_by_time_partition(ctx_query_and_df: TestContext): end="2022-01-05", time_formatter=ctx.time_formatter, time_column=ctx.time_column, - columns_to_types=ctx.columns_to_types, + target_columns_to_types=ctx.columns_to_types, ) results = ctx.get_metadata_results() assert len(results.views) == 0 @@ -879,7 +883,7 @@ def test_insert_overwrite_by_time_partition_source_columns(ctx_query_and_df: Tes end="2022-01-03", time_formatter=ctx.time_formatter, time_column=ctx.time_column, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=["id", "ds"], ) @@ -913,7 +917,7 @@ def test_insert_overwrite_by_time_partition_source_columns(ctx_query_and_df: Tes end="2022-01-05", time_formatter=ctx.time_formatter, time_column=ctx.time_column, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=["id", "ds"], ) results = ctx.get_metadata_results() @@ -960,7 +964,7 @@ def test_merge(ctx_query_and_df: TestContext): ctx.engine_adapter.merge( table, ctx.input_data(input_data), - columns_to_types=None, + target_columns_to_types=None, unique_key=[exp.to_identifier("id")], ) results = ctx.get_metadata_results() @@ -982,7 +986,7 @@ def test_merge(ctx_query_and_df: TestContext): ctx.engine_adapter.merge( table, ctx.input_data(merge_data), - columns_to_types=None, + target_columns_to_types=None, unique_key=[exp.to_identifier("id")], ) results = ctx.get_metadata_results() @@ -1030,7 +1034,7 @@ def test_merge_source_columns(ctx_query_and_df: TestContext): table, ctx.input_data(input_data), unique_key=[exp.to_identifier("id")], - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=["id", "ds"], ) @@ -1057,7 +1061,7 @@ def test_merge_source_columns(ctx_query_and_df: TestContext): table, ctx.input_data(merge_data), unique_key=[exp.to_identifier("id")], - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=["id", "ds"], ) @@ -1119,7 +1123,7 @@ def test_scd_type_2_by_time(ctx_query_and_df: TestContext): updated_at_col=exp.column("updated_at", quoted=True), execution_time="2023-01-01 00:00:00", updated_at_as_valid_from=False, - columns_to_types=input_schema, + target_columns_to_types=input_schema, table_format=ctx.default_table_format, truncate=True, ) @@ -1182,7 +1186,7 @@ def test_scd_type_2_by_time(ctx_query_and_df: TestContext): updated_at_col=exp.column("updated_at", quoted=True), execution_time="2023-01-05 00:00:00", updated_at_as_valid_from=False, - columns_to_types=input_schema, + target_columns_to_types=input_schema, table_format=ctx.default_table_format, truncate=False, ) @@ -1278,7 +1282,7 @@ def test_scd_type_2_by_time_source_columns(ctx_query_and_df: TestContext): table_format=ctx.default_table_format, truncate=True, start="2022-01-01 00:00:00", - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=["id", "name", "updated_at"], ) results = ctx.get_metadata_results() @@ -1346,7 +1350,7 @@ def test_scd_type_2_by_time_source_columns(ctx_query_and_df: TestContext): table_format=ctx.default_table_format, truncate=False, start="2022-01-01 00:00:00", - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=["id", "name", "updated_at"], ) results = ctx.get_metadata_results() @@ -1443,7 +1447,7 @@ def test_scd_type_2_by_column(ctx_query_and_df: TestContext): valid_to_col=exp.column("valid_to", quoted=True), execution_time="2023-01-01", execution_time_as_valid_from=False, - columns_to_types=ctx.columns_to_types, + target_columns_to_types=ctx.columns_to_types, truncate=True, ) results = ctx.get_metadata_results() @@ -1514,7 +1518,7 @@ def test_scd_type_2_by_column(ctx_query_and_df: TestContext): valid_to_col=exp.column("valid_to", quoted=True), execution_time="2023-01-05 00:00:00", execution_time_as_valid_from=False, - columns_to_types=ctx.columns_to_types, + target_columns_to_types=ctx.columns_to_types, truncate=False, ) results = ctx.get_metadata_results() @@ -1623,7 +1627,7 @@ def test_scd_type_2_by_column_source_columns(ctx_query_and_df: TestContext): execution_time_as_valid_from=False, truncate=True, start="2023-01-01", - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=["id", "name", "status"], ) results = ctx.get_metadata_results() @@ -1700,7 +1704,7 @@ def test_scd_type_2_by_column_source_columns(ctx_query_and_df: TestContext): execution_time_as_valid_from=False, truncate=False, start="2023-01-01", - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, source_columns=["id", "name", "status"], ) results = ctx.get_metadata_results() @@ -3077,13 +3081,13 @@ def test_value_normalization( } ctx.engine_adapter.create_table( - table_name=test_table, columns_to_types=columns_to_types_normalized + table_name=test_table, target_columns_to_types=columns_to_types_normalized ) data_query = next(select_from_values(input_data_with_idx, columns_to_types_normalized)) ctx.engine_adapter.insert_append( table_name=test_table, query_or_df=data_query, - columns_to_types=columns_to_types_normalized, + target_columns_to_types=columns_to_types_normalized, ) query = ( diff --git a/tests/core/engine_adapter/integration/test_integration_athena.py b/tests/core/engine_adapter/integration/test_integration_athena.py index 33e76fc6e2..1c0ece6d78 100644 --- a/tests/core/engine_adapter/integration/test_integration_athena.py +++ b/tests/core/engine_adapter/integration/test_integration_athena.py @@ -284,10 +284,10 @@ def test_hive_drop_table_removes_data(ctx: TestContext, engine_adapter: AthenaEn columns_to_types = columns_to_types_from_df(data) engine_adapter.create_table( - table_name=seed_table, columns_to_types=columns_to_types, exists=False + table_name=seed_table, target_columns_to_types=columns_to_types, exists=False ) engine_adapter.insert_append( - table_name=seed_table, query_or_df=data, columns_to_types=columns_to_types + table_name=seed_table, query_or_df=data, target_columns_to_types=columns_to_types ) assert engine_adapter.fetchone(f"select count(*) from {seed_table}")[0] == 1 # type: ignore @@ -295,7 +295,7 @@ def test_hive_drop_table_removes_data(ctx: TestContext, engine_adapter: AthenaEn # This ensures that our drop table logic to delete the data from S3 is working engine_adapter.drop_table(seed_table, exists=False) engine_adapter.create_table( - table_name=seed_table, columns_to_types=columns_to_types, exists=False + table_name=seed_table, target_columns_to_types=columns_to_types, exists=False ) assert engine_adapter.fetchone(f"select count(*) from {seed_table}")[0] == 0 # type: ignore @@ -382,12 +382,14 @@ def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> return exp.cast(exp.Literal.string(to_ds(time)), "date") engine_adapter.create_table( - table_name=table, columns_to_types=columns_to_types, partitioned_by=[exp.to_column("date")] + table_name=table, + target_columns_to_types=columns_to_types, + partitioned_by=[exp.to_column("date")], ) engine_adapter.insert_overwrite_by_time_partition( table_name=table, query_or_df=data, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, time_column=exp.to_identifier("date"), start="2023-01-01", end="2023-01-03", @@ -406,7 +408,7 @@ def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> engine_adapter.insert_overwrite_by_time_partition( table_name=table, query_or_df=new_data, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, time_column=exp.to_identifier("date"), start="2023-01-03", end="2023-01-04", @@ -442,12 +444,14 @@ def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> return exp.cast(exp.Literal.string(to_ts(time)), "datetime") engine_adapter.create_table( - table_name=table, columns_to_types=columns_to_types, partitioned_by=[exp.to_column("ts")] + table_name=table, + target_columns_to_types=columns_to_types, + partitioned_by=[exp.to_column("ts")], ) engine_adapter.insert_overwrite_by_time_partition( table_name=table, query_or_df=data, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, time_column=exp.to_identifier("ts"), start="2023-01-01 00:00:00", end="2023-01-01 04:00:00", @@ -469,7 +473,7 @@ def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> engine_adapter.insert_overwrite_by_time_partition( table_name=table, query_or_df=new_data, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, time_column=exp.to_identifier("ts"), start="2023-01-01 03:00:00", end="2023-01-01 05:00:00", diff --git a/tests/core/engine_adapter/test_athena.py b/tests/core/engine_adapter/test_athena.py index 5ee07f52d5..4fe57baf34 100644 --- a/tests/core/engine_adapter/test_athena.py +++ b/tests/core/engine_adapter/test_athena.py @@ -133,7 +133,7 @@ def test_create_table_hive(adapter: AthenaEngineAdapter) -> None: adapter.create_table( model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_properties=model.physical_properties, partitioned_by=model.partitioned_by, storage_format=model.storage_format, @@ -165,7 +165,7 @@ def test_create_table_iceberg(adapter: AthenaEngineAdapter) -> None: adapter.create_table( model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_properties=model.physical_properties, partitioned_by=model.partitioned_by, table_format=model.table_format, @@ -193,14 +193,14 @@ def test_create_table_no_location(adapter: AthenaEngineAdapter) -> None: with pytest.raises(SQLMeshError, match=r"Cannot figure out location.*"): adapter.create_table( model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_properties=model.physical_properties, ) adapter.s3_warehouse_location = "s3://bucket/prefix" adapter.create_table( model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_properties=model.physical_properties, ) @@ -214,7 +214,7 @@ def test_ctas_hive(adapter: AthenaEngineAdapter): adapter.ctas( table_name="foo.bar", - columns_to_types={"a": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int")}, query_or_df=parse_one("select 1", into=exp.Select), ) @@ -228,7 +228,7 @@ def test_ctas_iceberg(adapter: AthenaEngineAdapter): adapter.ctas( table_name="foo.bar", - columns_to_types={"a": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int")}, query_or_df=parse_one("select 1", into=exp.Select), table_format="iceberg", ) @@ -242,7 +242,7 @@ def test_ctas_iceberg_no_specific_location(adapter: AthenaEngineAdapter): with pytest.raises(SQLMeshError, match=r"Cannot figure out location.*"): adapter.ctas( table_name="foo.bar", - columns_to_types={"a": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int")}, query_or_df=parse_one("select 1", into=exp.Select), table_properties={"table_type": exp.Literal.string("iceberg")}, ) @@ -270,7 +270,7 @@ def test_ctas_iceberg_partitioned(adapter: AthenaEngineAdapter): adapter.s3_warehouse_location = "s3://bucket/prefix/" adapter.ctas( table_name=model.name, - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, partitioned_by=model.partitioned_by, query_or_df=model.ctas_query(), table_format=model.table_format, @@ -298,7 +298,7 @@ def test_replace_query(adapter: AthenaEngineAdapter, mocker: MockerFixture): adapter.replace_query( table_name="test", query_or_df=parse_one("select 1 as a", into=exp.Select), - columns_to_types={"a": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int")}, table_properties={}, ) @@ -317,7 +317,7 @@ def test_replace_query(adapter: AthenaEngineAdapter, mocker: MockerFixture): adapter.replace_query( table_name="test", query_or_df=parse_one("select 1 as a", into=exp.Select), - columns_to_types={"a": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int")}, table_properties={}, ) @@ -482,14 +482,14 @@ def test_iceberg_partition_transforms(adapter: AthenaEngineAdapter): adapter.create_table( table_name=model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, partitioned_by=model.partitioned_by, table_format=model.table_format, ) adapter.ctas( table_name=model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, partitioned_by=model.partitioned_by, query_or_df=model.ctas_query(), table_format=model.table_format, diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index afe506143f..02029ca6f8 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -82,7 +82,7 @@ def test_create_view_pandas_source_columns(make_mocked_engine_adapter: t.Callabl adapter.create_view( "test_view", pd.DataFrame({"a": [1, 2, 3]}), - columns_to_types={"a": bigint_dtype, "b": bigint_dtype}, + target_columns_to_types={"a": bigint_dtype, "b": bigint_dtype}, replace=False, source_columns=["a"], ) @@ -97,7 +97,10 @@ def test_create_view_query_source_columns(make_mocked_engine_adapter: t.Callable adapter.create_view( "test_view", parse_one("SELECT a FROM tbl"), - columns_to_types={"a": exp.DataType.build("BIGINT"), "b": exp.DataType.build("BIGINT")}, + target_columns_to_types={ + "a": exp.DataType.build("BIGINT"), + "b": exp.DataType.build("BIGINT"), + }, replace=False, source_columns=["a"], ) @@ -113,14 +116,14 @@ def test_create_materialized_view(make_mocked_engine_adapter: t.Callable): "test_view", parse_one("SELECT a FROM tbl"), materialized=True, - columns_to_types={"a": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT")}, ) adapter.create_view( "test_view", parse_one("SELECT a FROM tbl"), replace=False, materialized=True, - columns_to_types={"a": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT")}, ) adapter.cursor.execute.assert_has_calls( @@ -136,7 +139,7 @@ def test_create_materialized_view(make_mocked_engine_adapter: t.Callable): parse_one("SELECT a, b FROM tbl"), replace=False, materialized=True, - columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, ) adapter.create_view( "test_view", parse_one("SELECT a, b FROM tbl"), replace=False, materialized=True @@ -220,7 +223,7 @@ def test_insert_overwrite_by_time_partition(make_mocked_engine_adapter: t.Callab end="2022-01-02", time_column="b", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, ) adapter.cursor.begin.assert_called_once() @@ -247,7 +250,10 @@ def test_insert_overwrite_by_time_partition_missing_time_column_type( end="2022-01-02", time_column="b", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("UNKNOWN")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("UNKNOWN"), + }, ) columns_mock.assert_called_once_with("test_table") @@ -274,7 +280,7 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite( end="2022-01-02", time_column="b", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, ) adapter.cursor.execute.assert_called_once_with( @@ -296,7 +302,10 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas( end="2022-01-02", time_column="ds", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, ) assert to_sql_calls(adapter) == [ @@ -317,7 +326,10 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_sou end="2022-01-02", time_column="ds", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, source_columns=["a"], ) assert to_sql_calls(adapter) == [ @@ -337,7 +349,10 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_query_sour end="2022-01-02", time_column="ds", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, source_columns=["a"], ) assert to_sql_calls(adapter) == [ @@ -356,7 +371,7 @@ def test_insert_overwrite_by_time_partition_replace_where(make_mocked_engine_ada end="2022-01-02", time_column="b", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, ) assert to_sql_calls(adapter) == [ @@ -379,7 +394,10 @@ def test_insert_overwrite_by_time_partition_replace_where_pandas( end="2022-01-02", time_column="ds", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, ) assert to_sql_calls(adapter) == [ @@ -400,7 +418,10 @@ def test_insert_overwrite_by_time_partition_replace_where_pandas_source_columns( end="2022-01-02", time_column="ds", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, source_columns=["a"], ) assert to_sql_calls(adapter) == [ @@ -420,7 +441,10 @@ def test_insert_overwrite_by_time_partition_replace_where_query_source_columns( end="2022-01-02", time_column="ds", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, source_columns=["a"], ) assert to_sql_calls(adapter) == [ @@ -440,7 +464,7 @@ def test_insert_overwrite_no_where(make_mocked_engine_adapter: t.Callable): adapter._insert_overwrite_by_condition( "test_table", source_queries, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, ) adapter.cursor.begin.assert_called_once() @@ -467,7 +491,7 @@ def test_insert_overwrite_by_condition_column_contains_unsafe_characters( adapter._insert_overwrite_by_condition( "test_table", source_queries, - columns_to_types=None, + target_columns_to_types=None, ) # The goal here is to assert that we don't parse `foo.bar.baz` into a qualified column @@ -482,7 +506,7 @@ def test_insert_append_query(make_mocked_engine_adapter: t.Callable): adapter.insert_append( "test_table", parse_one("SELECT a FROM tbl"), - columns_to_types={"a": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT")}, ) assert to_sql_calls(adapter) == [ @@ -496,7 +520,7 @@ def test_insert_append_query_select_star(make_mocked_engine_adapter: t.Callable) adapter.insert_append( "test_table", parse_one("SELECT 1 AS a, * FROM tbl"), - columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, ) assert to_sql_calls(adapter) == [ @@ -511,7 +535,7 @@ def test_insert_append_pandas(make_mocked_engine_adapter: t.Callable): adapter.insert_append( "test_table", df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -530,7 +554,7 @@ def test_insert_append_pandas_batches(make_mocked_engine_adapter: t.Callable): adapter.insert_append( "test_table", df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -552,7 +576,7 @@ def test_insert_append_pandas_source_columns(make_mocked_engine_adapter: t.Calla adapter.insert_append( "test_table", df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -568,7 +592,7 @@ def test_insert_append_query_source_columns(make_mocked_engine_adapter: t.Callab adapter.insert_append( "test_table", parse_one("SELECT a FROM tbl"), - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -1082,7 +1106,7 @@ def test_merge_upsert(make_mocked_engine_adapter: t.Callable, assert_exp_eq): adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -1113,7 +1137,7 @@ def test_merge_upsert(make_mocked_engine_adapter: t.Callable, assert_exp_eq): adapter.merge( target_table="target", source_table=parse_one("SELECT id, ts, val FROM source"), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -1134,7 +1158,7 @@ def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable): adapter.merge( target_table="target", source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -1151,7 +1175,7 @@ def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable): adapter.merge( target_table="target", source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -1171,7 +1195,7 @@ def test_merge_upsert_pandas_source_columns(make_mocked_engine_adapter: t.Callab adapter.merge( target_table="target", source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -1191,7 +1215,7 @@ def test_merge_upsert_query_source_columns(make_mocked_engine_adapter: t.Callabl adapter.merge( target_table="target", source_table=parse_one("SELECT id, ts FROM source"), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -1212,7 +1236,7 @@ def test_merge_when_matched(make_mocked_engine_adapter: t.Callable, assert_exp_e adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -1265,7 +1289,7 @@ def test_merge_when_matched_multiple(make_mocked_engine_adapter: t.Callable, ass adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -1336,7 +1360,7 @@ def test_merge_filter(make_mocked_engine_adapter: t.Callable, assert_exp_eq): adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -1418,7 +1442,7 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), updated_at_col=exp.column("test_UPDATED_at", quoted=True), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -1624,7 +1648,7 @@ def test_scd_type_2_by_time_source_columns(make_mocked_engine_adapter: t.Callabl valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), updated_at_col=exp.column("test_UPDATED_at", quoted=True), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -1830,7 +1854,7 @@ def test_scd_type_2_by_time_no_invalidate_hard_deletes(make_mocked_engine_adapte valid_to_col=exp.column("test_valid_to", quoted=True), updated_at_col=exp.column("test_updated_at", quoted=True), invalidate_hard_deletes=False, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -2017,7 +2041,7 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable): valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), updated_at_col=exp.column("test_updated_at", quoted=True), - columns_to_types={ + target_columns_to_types={ "id1": exp.DataType.build("INT"), "id2": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), @@ -2200,7 +2224,7 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable): valid_from_col=exp.column("test_VALID_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), check_columns=[exp.column("name"), exp.column("price")], - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -2381,7 +2405,7 @@ def test_scd_type_2_by_column_composite_key(make_mocked_engine_adapter: t.Callab valid_from_col=exp.column("test_VALID_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), check_columns=[exp.column("name"), exp.column("price")], - columns_to_types={ + target_columns_to_types={ "id_a": exp.DataType.build("VARCHAR"), "id_b": exp.DataType.build("VARCHAR"), "name": exp.DataType.build("VARCHAR"), @@ -2573,7 +2597,7 @@ def test_scd_type_2_truncate(make_mocked_engine_adapter: t.Callable): valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), check_columns=[exp.column("name"), exp.column("price")], - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -2756,7 +2780,7 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable) valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), check_columns=exp.Star(), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -2951,7 +2975,7 @@ def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adap valid_to_col=exp.column("test_valid_to", quoted=True), invalidate_hard_deletes=False, check_columns=[exp.column("name"), exp.column("price")], - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -3181,7 +3205,7 @@ def test_replace_query_pandas_source_columns(make_mocked_engine_adapter: t.Calla adapter.replace_query( "test_table", df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -3197,7 +3221,7 @@ def test_replace_query_query_source_columns(make_mocked_engine_adapter: t.Callab adapter.replace_query( "test_table", parse_one("SELECT a FROM tbl"), - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -3225,7 +3249,7 @@ def test_replace_query_self_referencing_not_exists_unknown( adapter.replace_query( "test", parse_one("SELECT a FROM test"), - columns_to_types={"a": exp.DataType.build("UNKNOWN")}, + target_columns_to_types={"a": exp.DataType.build("UNKNOWN")}, ) @@ -3242,7 +3266,7 @@ def test_replace_query_self_referencing_exists( adapter.replace_query( "test", parse_one("SELECT a FROM test"), - columns_to_types={"a": exp.DataType.build("UNKNOWN")}, + target_columns_to_types={"a": exp.DataType.build("UNKNOWN")}, ) assert to_sql_calls(adapter) == [ @@ -3263,7 +3287,7 @@ def test_replace_query_self_referencing_not_exists_known( adapter.replace_query( "test", parse_one("SELECT a FROM test"), - columns_to_types={"a": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT")}, ) assert to_sql_calls(adapter) == [ @@ -3362,7 +3386,7 @@ def test_ctas_pandas_source_columns(make_mocked_engine_adapter: t.Callable): adapter.ctas( "test_table", df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -3378,7 +3402,7 @@ def test_ctas_query_source_columns(make_mocked_engine_adapter: t.Callable): adapter.ctas( "test_table", parse_one("SELECT a FROM tbl"), - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -3528,7 +3552,7 @@ def test_insert_overwrite_by_partition_query( table_name, parse_one("SELECT a, ds, b FROM tbl"), partitioned_by=[d.parse_one(k) for k in partitioned_by], - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("DATETIME"), "b": exp.DataType.build("boolean"), @@ -3568,7 +3592,7 @@ def test_insert_overwrite_by_partition_query_insert_overwrite_strategy( d.parse_one("DATETIME_TRUNC(ds, MONTH)"), d.parse_one("b"), ], - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("DATETIME"), "b": exp.DataType.build("boolean"), diff --git a/tests/core/engine_adapter/test_bigquery.py b/tests/core/engine_adapter/test_bigquery.py index 79d6fbf9db..f5a287defb 100644 --- a/tests/core/engine_adapter/test_bigquery.py +++ b/tests/core/engine_adapter/test_bigquery.py @@ -38,7 +38,7 @@ def test_insert_overwrite_by_time_partition_query( end="2022-01-05", time_formatter=lambda x, _: exp.Literal.string(x.strftime("%Y-%m-%d")), time_column="ds", - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("string"), }, @@ -68,7 +68,7 @@ def test_insert_overwrite_by_partition_query( partitioned_by=[ d.parse_one("DATETIME_TRUNC(ds, MONTH)"), ], - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("DATETIME"), }, @@ -111,7 +111,7 @@ def test_insert_overwrite_by_partition_query_unknown_column_types( partitioned_by=[ d.parse_one("DATETIME_TRUNC(ds, MONTH)"), ], - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("unknown"), "ds": exp.DataType.build("UNKNOWN"), }, @@ -176,7 +176,7 @@ def temp_table_exists(table: exp.Table) -> bool: end="2022-01-05", time_formatter=lambda x, _: exp.Literal.string(x.strftime("%Y-%m-%d")), time_column="ds", - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("string"), }, @@ -431,7 +431,7 @@ def test_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): adapter.merge( target_table="target", source_table=parse_one("SELECT id, ts, val FROM source"), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.Type.INT, "ts": exp.DataType.Type.TIMESTAMP, "val": exp.DataType.Type.INT, @@ -488,7 +488,7 @@ def temp_table_exists(table: exp.Table) -> bool: adapter.merge( target_table="target", source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "ts": exp.DataType.build("TIMESTAMP"), "val": exp.DataType.build("INT"), @@ -733,14 +733,14 @@ def test_nested_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerF adapter.create_table( "test_table", - columns_to_types=nested_columns_to_types, + target_columns_to_types=nested_columns_to_types, column_descriptions=long_column_descriptions, ) adapter.ctas( "test_table", parse_one("SELECT * FROM source_table"), - columns_to_types=nested_columns_to_types, + target_columns_to_types=nested_columns_to_types, column_descriptions=long_column_descriptions, ) diff --git a/tests/core/engine_adapter/test_clickhouse.py b/tests/core/engine_adapter/test_clickhouse.py index 973e178820..3e92a8fe9b 100644 --- a/tests/core/engine_adapter/test_clickhouse.py +++ b/tests/core/engine_adapter/test_clickhouse.py @@ -603,7 +603,7 @@ def test_scd_type_2_by_time( valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), updated_at_col=exp.column("test_UPDATED_at", quoted=True), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -815,7 +815,7 @@ def test_scd_type_2_by_column( valid_from_col=exp.column("test_VALID_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), check_columns=[exp.column("name"), exp.column("price")], - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), diff --git a/tests/core/engine_adapter/test_databricks.py b/tests/core/engine_adapter/test_databricks.py index 5991f5b2b9..cd4c8c4074 100644 --- a/tests/core/engine_adapter/test_databricks.py +++ b/tests/core/engine_adapter/test_databricks.py @@ -159,7 +159,7 @@ def test_insert_overwrite_by_partition_query( d.parse_one("DATETIME_TRUNC(ds, MONTH)"), d.parse_one("b"), ], - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("DATETIME"), "b": exp.DataType.build("boolean"), diff --git a/tests/core/engine_adapter/test_mixins.py b/tests/core/engine_adapter/test_mixins.py index 57803427d4..50bef59d6e 100644 --- a/tests/core/engine_adapter/test_mixins.py +++ b/tests/core/engine_adapter/test_mixins.py @@ -23,7 +23,7 @@ def test_logical_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFix adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one("SELECT id, ts, val FROM source")), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType(this=exp.DataType.Type.INT), "ts": exp.DataType(this=exp.DataType.Type.TIMESTAMP), "val": exp.DataType(this=exp.DataType.Type.INT), @@ -48,7 +48,7 @@ def test_logical_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFix adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one("SELECT id, ts, val FROM source")), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType(this=exp.DataType.Type.INT), "ts": exp.DataType(this=exp.DataType.Type.TIMESTAMP), "val": exp.DataType(this=exp.DataType.Type.INT), diff --git a/tests/core/engine_adapter/test_mssql.py b/tests/core/engine_adapter/test_mssql.py index d8e5214be5..caa7843726 100644 --- a/tests/core/engine_adapter/test_mssql.py +++ b/tests/core/engine_adapter/test_mssql.py @@ -290,7 +290,10 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_not end="2022-01-02", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), time_column="ds", - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, ) adapter._connection_pool.get().bulk_copy.assert_called_with( f"__temp_test_table_{temp_table_id}", [(1, "2022-01-01"), (2, "2022-01-02")] @@ -327,7 +330,10 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_exi end="2022-01-02", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), time_column="ds", - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, ) assert to_sql_calls(adapter) == [ f"""MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT [a] AS [a], [ds] AS [ds] FROM (SELECT CAST([a] AS INTEGER) AS [a], CAST([ds] AS VARCHAR(MAX)) AS [ds] FROM [__temp_test_table_{temp_table_id}]) AS [_subquery] WHERE [ds] BETWEEN '2022-01-01' AND '2022-01-02') AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE AND [ds] BETWEEN '2022-01-01' AND '2022-01-02' THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [ds]) VALUES ([a], [ds]);""", @@ -359,7 +365,10 @@ def test_insert_overwrite_by_time_partition_replace_where_pandas( end="2022-01-02", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), time_column="ds", - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, ) adapter._connection_pool.get().bulk_copy.assert_called_with( f"__temp_test_table_{temp_table_id}", [(1, "2022-01-01"), (2, "2022-01-02")] @@ -391,7 +400,7 @@ def test_insert_append_pandas( adapter.insert_append( table_name, df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -461,7 +470,7 @@ def test_merge_pandas( adapter.merge( target_table=table_name, source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("TIMESTAMP"), "val": exp.DataType.build("int"), @@ -485,7 +494,7 @@ def test_merge_pandas( adapter.merge( target_table=table_name, source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("TIMESTAMP"), "val": exp.DataType.build("int"), @@ -524,7 +533,7 @@ def test_merge_exists( adapter.merge( target_table=table_name, source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("TIMESTAMP"), "val": exp.DataType.build("int"), @@ -545,7 +554,7 @@ def test_merge_exists( adapter.merge( target_table=table_name, source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("TIMESTAMP"), "val": exp.DataType.build("int"), @@ -567,7 +576,7 @@ def test_merge_exists( adapter.merge( target_table=table_name, source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("TIMESTAMP"), }, @@ -913,7 +922,7 @@ def test_replace_query_strategy(adapter: MSSQLEngineAdapter, mocker: MockerFixtu table_properties=model.physical_properties, table_description=model.description, column_descriptions=model.column_descriptions, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, ) # subsequent - table exists @@ -937,7 +946,7 @@ def test_replace_query_strategy(adapter: MSSQLEngineAdapter, mocker: MockerFixtu table_properties=model.physical_properties, table_description=model.description, column_descriptions=model.column_descriptions, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, ) assert to_sql_calls(adapter) == [ diff --git a/tests/core/engine_adapter/test_postgres.py b/tests/core/engine_adapter/test_postgres.py index fd6ce44994..5d05dd653c 100644 --- a/tests/core/engine_adapter/test_postgres.py +++ b/tests/core/engine_adapter/test_postgres.py @@ -99,7 +99,7 @@ def test_merge_version_gte_15(make_mocked_engine_adapter: t.Callable): adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -127,7 +127,7 @@ def test_merge_version_lt_15( adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), diff --git a/tests/core/engine_adapter/test_redshift.py b/tests/core/engine_adapter/test_redshift.py index 0db8e8d055..17c3dd1866 100644 --- a/tests/core/engine_adapter/test_redshift.py +++ b/tests/core/engine_adapter/test_redshift.py @@ -220,7 +220,7 @@ def test_values_to_sql(adapter: t.Callable, mocker: MockerFixture): df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) result = adapter._values_to_sql( values=list(df.itertuples(index=False, name=None)), - columns_to_types={"a": exp.DataType.build("int"), "b": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int"), "b": exp.DataType.build("int")}, batch_start=0, batch_end=2, ) @@ -272,7 +272,7 @@ def mock_table(*args, **kwargs): adapter.replace_query( table_name="test_table", query_or_df=df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "b": exp.DataType.build("int"), }, @@ -299,7 +299,7 @@ def test_replace_query_with_df_table_not_exists(adapter: t.Callable, mocker: Moc adapter.replace_query( table_name="test_table", query_or_df=df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "b": exp.DataType.build("int"), }, @@ -342,7 +342,7 @@ def test_create_view(adapter: t.Callable): adapter.create_view( view_name="test_view", query_or_df=parse_one("SELECT cola FROM table"), - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "b": exp.DataType.build("int"), }, @@ -428,7 +428,7 @@ def test_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): adapter.merge( target_table=exp.to_table("target_table_name"), source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -440,7 +440,7 @@ def test_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): adapter.merge( target_table=exp.to_table("target_table_name"), source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -473,7 +473,7 @@ def test_merge_when_matched_error(make_mocked_engine_adapter: t.Callable, mocker adapter.merge( target_table=exp.to_table("target_table_name"), source_table=t.cast(exp.Select, parse_one('SELECT "ID", val FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "val": exp.DataType.build("int"), }, @@ -521,7 +521,7 @@ def test_merge_logical_filter_error(make_mocked_engine_adapter: t.Callable, mock adapter.merge( target_table=exp.to_table("target_table_name_2"), source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), }, @@ -546,7 +546,7 @@ def test_merge_logical( adapter.merge( target_table=exp.to_table("target"), source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), }, diff --git a/tests/core/engine_adapter/test_snowflake.py b/tests/core/engine_adapter/test_snowflake.py index 4ca13ee8f9..9a1e068aa6 100644 --- a/tests/core/engine_adapter/test_snowflake.py +++ b/tests/core/engine_adapter/test_snowflake.py @@ -254,14 +254,14 @@ def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: Mo adapter.create_managed_table( table_name="test_table", query=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, ) # warehouse not specified, should default to current_warehouse() adapter.create_managed_table( table_name="test_table", query=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_properties={"target_lag": exp.Literal.string("20 minutes")}, ) @@ -269,7 +269,7 @@ def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: Mo adapter.create_managed_table( table_name="test_table", query=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_properties={ "target_lag": exp.Literal.string("20 minutes"), "warehouse": exp.to_identifier("foo"), @@ -280,7 +280,7 @@ def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: Mo adapter.create_managed_table( table_name="test_table", query=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_properties={ "target_lag": exp.Literal.string("20 minutes"), }, @@ -292,7 +292,7 @@ def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: Mo adapter.create_managed_table( table_name="test_table", query=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_properties={ "target_lag": exp.Literal.string("20 minutes"), "refresh_mode": exp.Literal.string("auto"), @@ -304,7 +304,7 @@ def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: Mo adapter.create_managed_table( table_name="test_table", query=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_properties={ "target_lag": exp.Literal.string("20 minutes"), "catalog": exp.Literal.string("snowflake"), @@ -343,7 +343,7 @@ def test_ctas_skips_dynamic_table_properties(make_mocked_engine_adapter: t.Calla adapter.ctas( table_name="test_table", query_or_df=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_properties={ "warehouse": exp.to_identifier("foo"), "target_lag": exp.Literal.string("20 minutes"), @@ -463,7 +463,10 @@ def test_replace_query_snowpark_dataframe( adapter.replace_query( table_name="foo", query_or_df=df, - columns_to_types={"ID": exp.DataType.build("INT"), "NAME": exp.DataType.build("VARCHAR")}, + target_columns_to_types={ + "ID": exp.DataType.build("INT"), + "NAME": exp.DataType.build("VARCHAR"), + }, ) # verify that DROP VIEW is called instead of DROP TABLE @@ -622,7 +625,7 @@ def test_creatable_type_transient_type_from_model_definition( ) adapter.create_table( model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_properties=model.physical_properties, ) @@ -657,7 +660,7 @@ def test_creatable_type_transient_type_from_model_definition_with_other_property ) adapter.create_table( model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_properties=model.physical_properties, ) @@ -733,7 +736,7 @@ def test_table_format_iceberg(snowflake_mocked_engine_adapter: SnowflakeEngineAd adapter.create_table( table_name=model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_format=model.table_format, table_properties=model.physical_properties, ) @@ -741,7 +744,7 @@ def test_table_format_iceberg(snowflake_mocked_engine_adapter: SnowflakeEngineAd adapter.ctas( table_name=model.name, query_or_df=model.render_query_or_raise(), - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_format=model.table_format, table_properties=model.physical_properties, ) diff --git a/tests/core/engine_adapter/test_spark.py b/tests/core/engine_adapter/test_spark.py index 468de9f75a..55a925b995 100644 --- a/tests/core/engine_adapter/test_spark.py +++ b/tests/core/engine_adapter/test_spark.py @@ -83,7 +83,7 @@ def test_replace_query_table_properties_not_exists( adapter.replace_query( "test_table", parse_one("SELECT 1 AS cola, '2' AS colb, '3' AS colc"), - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, partitioned_by=[exp.to_column("colb")], storage_format="ICEBERG", table_properties={"a": exp.convert(1)}, @@ -117,7 +117,7 @@ def test_replace_query_table_properties_exists( adapter.replace_query( "test_table", parse_one("SELECT 1 AS cola, '2' AS colb, '3' AS colc"), - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, partitioned_by=[exp.to_column("colb")], storage_format="ICEBERG", table_properties={"a": exp.convert(1)}, @@ -582,7 +582,7 @@ def check_table_exists(table_name: exp.Table) -> bool: valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), updated_at_col=exp.column("test_updated_at", quoted=True), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -1010,7 +1010,7 @@ def test_replace_query_with_wap_self_reference( adapter.replace_query( "catalog.schema.table.branch_wap_12345", parse_one("SELECT 1 as a FROM catalog.schema.table.branch_wap_12345"), - columns_to_types={"a": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT")}, storage_format="ICEBERG", ) @@ -1047,7 +1047,7 @@ def test_table_format(adapter: SparkEngineAdapter, mocker: MockerFixture): # both table_format and storage_format adapter.create_table( table_name=model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_format=model.table_format, storage_format=model.storage_format, ) @@ -1055,21 +1055,21 @@ def test_table_format(adapter: SparkEngineAdapter, mocker: MockerFixture): # just table_format adapter.create_table( table_name=model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_format=model.table_format, ) # just storage_format set to a table format (test for backwards compatibility) adapter.create_table( table_name=model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, storage_format=model.table_format, ) adapter.ctas( table_name=model.name, query_or_df=model.query, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_format=model.table_format, storage_format=model.storage_format, ) diff --git a/tests/core/engine_adapter/test_trino.py b/tests/core/engine_adapter/test_trino.py index 4895bc5a31..745c2bbdfb 100644 --- a/tests/core/engine_adapter/test_trino.py +++ b/tests/core/engine_adapter/test_trino.py @@ -183,7 +183,7 @@ def test_partitioned_by_iceberg_transforms( adapter.create_table( table_name=model.view_name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, partitioned_by=model.partitioned_by, ) @@ -426,7 +426,7 @@ def test_table_format(trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: M adapter.create_table( table_name=model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_format=model.table_format, storage_format=model.storage_format, ) @@ -434,7 +434,7 @@ def test_table_format(trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: M adapter.ctas( table_name=model.name, query_or_df=t.cast(exp.Query, model.query), - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_format=model.table_format, storage_format=model.storage_format, ) @@ -472,14 +472,14 @@ def test_table_location(trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: adapter.create_table( table_name=model.name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_properties=model.physical_properties, ) adapter.ctas( table_name=model.name, query_or_df=t.cast(exp.Query, model.query), - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, table_properties=model.physical_properties, ) diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index d61907a5aa..e7046be13d 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -2205,7 +2205,7 @@ def test_migrate_rows(state_sync: EngineAdapterStateSync, mocker: MockerFixture) state_sync.engine_adapter.replace_query( "sqlmesh._snapshots", pd.read_json("tests/fixtures/migrations/snapshots.json"), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build("text"), "identifier": exp.DataType.build("text"), "version": exp.DataType.build("text"), @@ -2216,7 +2216,7 @@ def test_migrate_rows(state_sync: EngineAdapterStateSync, mocker: MockerFixture) state_sync.engine_adapter.replace_query( "sqlmesh._environments", pd.read_json("tests/fixtures/migrations/environments.json"), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build("text"), "snapshots": exp.DataType.build("text"), "start_at": exp.DataType.build("text"), @@ -2285,7 +2285,7 @@ def test_backup_state(state_sync: EngineAdapterStateSync, mocker: MockerFixture) state_sync.engine_adapter.replace_query( "sqlmesh._snapshots", pd.read_json("tests/fixtures/migrations/snapshots.json"), - columns_to_types={ + target_columns_to_types={ "name": exp.DataType.build("text"), "identifier": exp.DataType.build("text"), "version": exp.DataType.build("text"), @@ -2310,7 +2310,7 @@ def test_restore_snapshots_table(state_sync: EngineAdapterStateSync) -> None: state_sync.engine_adapter.replace_query( "sqlmesh._snapshots", pd.read_json("tests/fixtures/migrations/snapshots.json"), - columns_to_types=snapshot_columns_to_types, + target_columns_to_types=snapshot_columns_to_types, ) old_snapshots = state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots") diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 852f00e760..a94ba74a20 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -16,7 +16,7 @@ import sqlmesh.core.constants from sqlmesh.cli.project_init import init_example_project -from sqlmesh.core.console import get_console, TerminalConsole +from sqlmesh.core.console import TerminalConsole from sqlmesh.core import dialect as d, constants as c from sqlmesh.core.config import ( load_configs, diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index b21f77da39..5cbe22ab46 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -2886,9 +2886,9 @@ def test_restatement_plan_hourly_with_downstream_daily_restates_correct_interval "ts": exp.DataType.build("timestamp"), } external_table = exp.table_(table="external_table", db="test", quoted=True) - engine_adapter.create_table(table_name=external_table, columns_to_types=columns_to_types) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) engine_adapter.insert_append( - table_name=external_table, query_or_df=df, columns_to_types=columns_to_types + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types ) # plan + apply @@ -2939,7 +2939,7 @@ def _dates_in_table(table_name: str) -> t.List[str]: } ) engine_adapter.replace_query( - table_name=external_table, query_or_df=df, columns_to_types=columns_to_types + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types ) # Restate A across a day boundary with the expectation that two day intervals in B are affected @@ -3018,9 +3018,9 @@ def test_restatement_plan_respects_disable_restatements(tmp_path: Path): "ts": exp.DataType.build("timestamp"), } external_table = exp.table_(table="external_table", db="test", quoted=True) - engine_adapter.create_table(table_name=external_table, columns_to_types=columns_to_types) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) engine_adapter.insert_append( - table_name=external_table, query_or_df=df, columns_to_types=columns_to_types + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types ) # plan + apply @@ -3124,9 +3124,9 @@ def test_restatement_plan_clears_correct_intervals_across_environments(tmp_path: "date": exp.DataType.build("date"), } external_table = exp.table_(table="external_table", db="test", quoted=True) - engine_adapter.create_table(table_name=external_table, columns_to_types=columns_to_types) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) engine_adapter.insert_append( - table_name=external_table, query_or_df=df, columns_to_types=columns_to_types + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types ) # first, create the prod models @@ -3319,9 +3319,9 @@ def _derived_incremental_model_def(name: str, upstream: str) -> str: "ts": exp.DataType.build("timestamp"), } external_table = exp.table_(table="external_table", db="test", quoted=True) - engine_adapter.create_table(table_name=external_table, columns_to_types=columns_to_types) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) engine_adapter.insert_append( - table_name=external_table, query_or_df=df, columns_to_types=columns_to_types + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types ) # plan + apply A, B, C in prod @@ -3469,9 +3469,9 @@ def test_prod_restatement_plan_clears_unaligned_intervals_in_derived_dev_tables( "ts": exp.DataType.build("timestamp"), } external_table = exp.table_(table="external_table", db="test", quoted=True) - engine_adapter.create_table(table_name=external_table, columns_to_types=columns_to_types) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) engine_adapter.insert_append( - table_name=external_table, query_or_df=df, columns_to_types=columns_to_types + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types ) # plan + apply A[hourly] in prod @@ -3611,9 +3611,9 @@ def test_prod_restatement_plan_causes_dev_intervals_to_be_processed_in_next_dev_ "ts": exp.DataType.build("timestamp"), } external_table = exp.table_(table="external_table", db="test", quoted=True) - engine_adapter.create_table(table_name=external_table, columns_to_types=columns_to_types) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) engine_adapter.insert_append( - table_name=external_table, query_or_df=df, columns_to_types=columns_to_types + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types ) # plan + apply A[hourly] in prod @@ -3746,9 +3746,9 @@ def test_prod_restatement_plan_causes_dev_intervals_to_be_widened_on_full_restat "ts": exp.DataType.build("timestamp"), } external_table = exp.table_(table="external_table", db="test", quoted=True) - engine_adapter.create_table(table_name=external_table, columns_to_types=columns_to_types) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) engine_adapter.insert_append( - table_name=external_table, query_or_df=df, columns_to_types=columns_to_types + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types ) # plan + apply A[daily] in prod @@ -3881,9 +3881,9 @@ def test_prod_restatement_plan_missing_model_in_dev( "ts": exp.DataType.build("timestamp"), } external_table = exp.table_(table="external_table", db="test", quoted=True) - engine_adapter.create_table(table_name=external_table, columns_to_types=columns_to_types) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) engine_adapter.insert_append( - table_name=external_table, query_or_df=df, columns_to_types=columns_to_types + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types ) # plan + apply A[hourly] in dev diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 6eae6376f1..b05d567cd2 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -204,7 +204,7 @@ def x(evaluator, y=None) -> None: ) common_kwargs = dict( - columns_to_types={"a": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int")}, table_format=None, storage_format="parquet", partitioned_by=[exp.to_column("a", quoted=True)], @@ -693,14 +693,14 @@ def test_evaluate_incremental_unmanaged_with_intervals( snapshot.table_name(), model.render_query(), [exp.to_column("ds", quoted=True)], - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, source_columns=None, ) else: adapter_mock.insert_append.assert_called_once_with( snapshot.table_name(), model.render_query(), - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, source_columns=None, ) @@ -735,7 +735,7 @@ def test_evaluate_incremental_unmanaged_no_intervals( model.render_query(), clustered_by=[], column_descriptions={}, - columns_to_types=table_columns, + target_columns_to_types=table_columns, partition_interval_unit=model.partition_interval_unit, partitioned_by=model.partitioned_by, table_format=None, @@ -851,7 +851,10 @@ def test_create_new_forward_only_model(mocker: MockerFixture, adapter_mock, make # Only non-deployable table should be created adapter_mock.create_table.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.dev_version}__dev", - columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("varchar")}, + target_columns_to_types={ + "a": exp.DataType.build("int"), + "ds": exp.DataType.build("varchar"), + }, table_format=None, storage_format=None, partitioned_by=model.partitioned_by, @@ -1010,7 +1013,7 @@ def test_create_prod_table_exists_forward_only(mocker: MockerFixture, adapter_mo adapter_mock.create_schema.assert_called_once_with(to_schema("sqlmesh__test_schema")) adapter_mock.create_table.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", - columns_to_types={"a": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int")}, table_format=None, storage_format=None, partitioned_by=[], @@ -1611,7 +1614,7 @@ def test_create_clone_in_dev(mocker: MockerFixture, adapter_mock, make_snapshot) adapter_mock.create_table.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev__schema_migration_source", - columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, + target_columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, table_format=None, storage_format=None, partitioned_by=[exp.to_column("ds", quoted=True)], @@ -1671,7 +1674,7 @@ def test_create_clone_in_dev_missing_table(mocker: MockerFixture, adapter_mock, adapter_mock.create_table.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.dev_version}__dev", - columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, + target_columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, table_format=None, storage_format=None, partitioned_by=[exp.to_column("ds", quoted=True)], @@ -1788,7 +1791,7 @@ def test_create_clone_in_dev_self_referencing( adapter_mock.create_table.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev__schema_migration_source", - columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, + target_columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, table_format=None, storage_format=None, partitioned_by=[exp.to_column("ds", quoted=True)], @@ -1919,7 +1922,7 @@ def test_forward_only_snapshot_for_added_model(mocker: MockerFixture, adapter_mo evaluator.create([snapshot], {}) common_create_args = dict( - columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, + target_columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, table_format=None, storage_format=None, partitioned_by=[exp.to_column("ds", quoted=True)], @@ -1963,7 +1966,7 @@ def test_create_scd_type_2_by_time(adapter_mock, make_snapshot): evaluator.create([snapshot], {}) common_kwargs = dict( - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("STRING"), "updated_at": exp.DataType.build("TIMESTAMPTZ"), @@ -2100,7 +2103,7 @@ def test_insert_into_scd_type_2_by_time( adapter_mock.scd_type_2_by_time.assert_called_once_with( target_table=snapshot.table_name(), source_table=model.render_query(), - columns_to_types=table_columns, + target_columns_to_types=table_columns, table_format=None, unique_key=[exp.to_column("id", quoted=True)], valid_from_col=exp.column("valid_from", quoted=True), @@ -2142,7 +2145,7 @@ def test_create_scd_type_2_by_column(adapter_mock, make_snapshot): evaluator.create([snapshot], {}) common_kwargs = dict( - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("STRING"), # Make sure that the call includes these extra columns @@ -2273,7 +2276,7 @@ def test_insert_into_scd_type_2_by_column( adapter_mock.scd_type_2_by_column.assert_called_once_with( target_table=snapshot.table_name(), source_table=model.render_query(), - columns_to_types=table_columns, + target_columns_to_types=table_columns, table_format=None, unique_key=[exp.to_column("id", quoted=True)], check_columns=exp.Star(), @@ -2323,7 +2326,7 @@ def test_create_incremental_by_unique_key_updated_at_exp(adapter_mock, make_snap adapter_mock.merge.assert_called_once_with( snapshot.table_name(), model.render_query(), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("STRING"), "updated_at": exp.DataType.build("TIMESTAMP"), @@ -2392,7 +2395,7 @@ def test_create_incremental_by_unique_key_multiple_updated_at_exp(adapter_mock, adapter_mock.merge.assert_called_once_with( snapshot.table_name(), model.render_query(), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("STRING"), "updated_at": exp.DataType.build("TIMESTAMP"), @@ -2488,7 +2491,7 @@ def test_create_incremental_by_unique_no_intervals(adapter_mock, make_snapshot): model.render_query(), clustered_by=[], column_descriptions={}, - columns_to_types=table_columns, + target_columns_to_types=table_columns, partition_interval_unit=model.partition_interval_unit, partitioned_by=model.partitioned_by, table_format=None, @@ -2553,7 +2556,7 @@ def test_create_incremental_by_unique_key_merge_filter(adapter_mock, make_snapsh adapter_mock.merge.assert_called_once_with( snapshot.table_name(), model.render_query(), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "updated_at": exp.DataType.build("TIMESTAMP"), }, @@ -2620,7 +2623,10 @@ def test_create_seed(mocker: MockerFixture, adapter_mock, make_snapshot): evaluator.create([snapshot], {}) common_create_kwargs: t.Dict[str, t.Any] = dict( - columns_to_types={"id": exp.DataType.build("bigint"), "name": exp.DataType.build("text")}, + target_columns_to_types={ + "id": exp.DataType.build("bigint"), + "name": exp.DataType.build("text"), + }, table_format=None, storage_format=None, partitioned_by=[], @@ -2698,7 +2704,10 @@ def test_create_seed_on_error(mocker: MockerFixture, adapter_mock, make_snapshot f"sqlmesh__db.db__seed__{snapshot.version}", mocker.ANY, column_descriptions={}, - columns_to_types={"id": exp.DataType.build("bigint"), "name": exp.DataType.build("text")}, + target_columns_to_types={ + "id": exp.DataType.build("bigint"), + "name": exp.DataType.build("text"), + }, table_format=None, storage_format=None, partitioned_by=[], @@ -2755,7 +2764,10 @@ def test_create_seed_no_intervals(mocker: MockerFixture, adapter_mock, make_snap f"sqlmesh__db.db__seed__{snapshot.version}", mocker.ANY, column_descriptions={}, - columns_to_types={"id": exp.DataType.build("bigint"), "name": exp.DataType.build("text")}, + target_columns_to_types={ + "id": exp.DataType.build("bigint"), + "name": exp.DataType.build("text"), + }, table_format=None, storage_format=None, partitioned_by=[], @@ -3261,7 +3273,7 @@ def test_evaluate_incremental_by_partition(mocker: MockerFixture, make_snapshot, exp.to_column("ds", quoted=True), exp.to_column("b", quoted=True), ], - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, clustered_by=[], table_properties={}, column_descriptions={}, @@ -3291,7 +3303,7 @@ def test_evaluate_incremental_by_partition(mocker: MockerFixture, make_snapshot, exp.to_column("ds", quoted=True), exp.to_column("b", quoted=True), ], - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, source_columns=None, ) @@ -3522,7 +3534,7 @@ def test_create_managed(adapter_mock, make_snapshot, mocker: MockerFixture): adapter_mock.create_managed_table.assert_called_with( table_name=snapshot.table_name(), query=mocker.ANY, - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, partitioned_by=model.partitioned_by, clustered_by=model.clustered_by, table_properties=model.physical_properties, @@ -3594,7 +3606,7 @@ def test_evaluate_managed(adapter_mock, make_snapshot, mocker: MockerFixture): adapter_mock.replace_query.assert_called_with( snapshot.table_name(is_deployable=False), mocker.ANY, - columns_to_types=table_colmns, + target_columns_to_types=table_colmns, table_format=model.table_format, storage_format=model.storage_format, partitioned_by=model.partitioned_by, @@ -3776,7 +3788,7 @@ def test_create_snapshot( ) common_kwargs: t.Dict[str, t.Any] = dict( - columns_to_types={"a": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int")}, table_format=None, storage_format=None, partitioned_by=[], @@ -3843,13 +3855,16 @@ def test_migrate_snapshot(snapshot: Snapshot, mocker: MockerFixture, adapter_moc [ call( new_snapshot.table_name(), - columns_to_types={"a": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int")}, column_descriptions={}, **common_kwargs, ), call( new_snapshot.table_name(is_deployable=False), - columns_to_types={"a": exp.DataType.build("int"), "b": exp.DataType.build("int")}, + target_columns_to_types={ + "a": exp.DataType.build("int"), + "b": exp.DataType.build("int"), + }, column_descriptions=None, **common_kwargs, ), diff --git a/tests/dbt/test_adapter.py b/tests/dbt/test_adapter.py index d4dbf62e74..944c4ce78d 100644 --- a/tests/dbt/test_adapter.py +++ b/tests/dbt/test_adapter.py @@ -31,16 +31,16 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla engine_adapter.create_schema("foo") engine_adapter.create_schema("ignored") engine_adapter.create_table( - table_name="foo.bar", columns_to_types={"baz": exp.DataType.build("int")} + table_name="foo.bar", target_columns_to_types={"baz": exp.DataType.build("int")} ) engine_adapter.create_table( - table_name="foo.another", columns_to_types={"col": exp.DataType.build("int")} + table_name="foo.another", target_columns_to_types={"col": exp.DataType.build("int")} ) engine_adapter.create_view( view_name="foo.bar_view", query_or_df=parse_one("select * from foo.bar") ) engine_adapter.create_table( - table_name="ignored.ignore", columns_to_types={"col": exp.DataType.build("int")} + table_name="ignored.ignore", target_columns_to_types={"col": exp.DataType.build("int")} ) assert ( @@ -262,10 +262,10 @@ def test_adapter_map_snapshot_tables( engine_adapter.create_schema("sqlmesh") engine_adapter.create_table( table_name='"memory"."sqlmesh"."test_db__test_model"', - columns_to_types={"baz": exp.DataType.build("int")}, + target_columns_to_types={"baz": exp.DataType.build("int")}, ) engine_adapter.create_table( - table_name="foo.bar", columns_to_types={"col": exp.DataType.build("int")} + table_name="foo.bar", target_columns_to_types={"col": exp.DataType.build("int")} ) expected_test_model_table_name = parse_one('"memory"."sqlmesh"."test_db__test_model"').sql( @@ -324,7 +324,7 @@ def test_adapter_get_relation_normalization( engine_adapter.create_schema('"FOO"') engine_adapter.create_table( - table_name='"FOO"."BAR"', columns_to_types={"baz": exp.DataType.build("int")} + table_name='"FOO"."BAR"', target_columns_to_types={"baz": exp.DataType.build("int")} ) assert ( diff --git a/tests/dbt/test_integration.py b/tests/dbt/test_integration.py index 9cee4796fb..45c1422395 100644 --- a/tests/dbt/test_integration.py +++ b/tests/dbt/test_integration.py @@ -194,9 +194,11 @@ def _replace_source_table( columns_to_types = columns_to_types_from_df(df) if values: - adapter.replace_query("sushi.raw_marketing", df, columns_to_types=columns_to_types) + adapter.replace_query( + "sushi.raw_marketing", df, target_columns_to_types=columns_to_types + ) else: - adapter.create_table("sushi.raw_marketing", columns_to_types=columns_to_types) + adapter.create_table("sushi.raw_marketing", target_columns_to_types=columns_to_types) def _normalize_dbt_dataframe( self, From 6468afb1dc52a203a79a8f07a1fab53ba952f304 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Thu, 14 Aug 2025 16:01:57 -0700 Subject: [PATCH 3/4] feedback2 --- sqlmesh/core/snapshot/evaluator.py | 41 ++++++++++++++++++------------ 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 0cb45c1860..edbabec4b4 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -1523,13 +1523,9 @@ def append( render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - columns_to_types = kwargs.pop("columns_to_types", model.columns_to_types) - source_columns = kwargs.pop("source_columns", None) self.adapter.insert_append( table_name, query_or_df, - target_columns_to_types=columns_to_types, - source_columns=source_columns, ) def create( @@ -1649,16 +1645,13 @@ def _get_target_and_source_columns( model: Model, table_name: str, render_kwargs: t.Dict[str, t.Any], - target_column_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, force_get_columns_from_target: bool = False, ) -> t.Tuple[t.Dict[str, exp.DataType], t.Optional[t.List[str]]]: if force_get_columns_from_target: target_column_to_types = self.adapter.columns(table_name) - elif target_column_to_types: - target_column_to_types = target_column_to_types else: target_column_to_types = ( - model.columns_to_types + model.columns_to_types # type: ignore if model.annotated and not model.on_destructive_change.is_ignore else self.adapter.columns(table_name) ) @@ -1784,6 +1777,24 @@ def append( class IncrementalUnmanagedStrategy(MaterializableStrategy): + def append( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, render_kwargs=render_kwargs + ) + self.adapter.insert_append( + table_name, + query_or_df, + target_columns_to_types=columns_to_types, + source_columns=source_columns, + ) + def insert( self, table_name: str, @@ -1797,13 +1808,13 @@ def insert( return self._replace_query_for_model( model, table_name, query_or_df, render_kwargs, **kwargs ) - columns_to_types, source_columns = self._get_target_and_source_columns( - model, - table_name, - render_kwargs=render_kwargs, - target_column_to_types=kwargs.pop("columns_to_types", None), - ) if isinstance(model.kind, IncrementalUnmanagedKind) and model.kind.insert_overwrite: + columns_to_types, source_columns = self._get_target_and_source_columns( + model, + table_name, + render_kwargs=render_kwargs, + ) + return self.adapter.insert_overwrite_by_partition( table_name, query_or_df, @@ -1816,8 +1827,6 @@ def insert( query_or_df, model, render_kwargs=render_kwargs, - columns_to_types=columns_to_types, - source_columns=source_columns, **kwargs, ) From 8b5c788554ab2a8b120bb86a574cf9342cd74f4c Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Fri, 15 Aug 2025 08:34:10 -0700 Subject: [PATCH 4/4] fix incremental append --- .circleci/continue_config.yml | 8 ++-- sqlmesh/core/snapshot/evaluator.py | 70 ++++++++++++++++++++++-------- tests/core/test_integration.py | 2 - 3 files changed, 55 insertions(+), 25 deletions(-) diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index e651a1e80b..04135574a9 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -305,10 +305,10 @@ workflows: - clickhouse-cloud - athena - gcp-postgres -# filters: -# branches: -# only: -# - main + filters: + branches: + only: + - main - ui_style - ui_test - vscode_test diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index edbabec4b4..f2f7044ba7 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -1515,19 +1515,6 @@ def demote(self, view_name: str, **kwargs: t.Any) -> None: class MaterializableStrategy(PromotableStrategy, abc.ABC): - def append( - self, - table_name: str, - query_or_df: QueryOrDF, - model: Model, - render_kwargs: t.Dict[str, t.Any], - **kwargs: t.Any, - ) -> None: - self.adapter.insert_append( - table_name, - query_or_df, - ) - def create( self, table_name: str, @@ -1666,7 +1653,27 @@ def _get_target_and_source_columns( return target_column_to_types, source_columns -class IncrementalByPartitionStrategy(MaterializableStrategy): +class IncrementalStrategy(MaterializableStrategy, abc.ABC): + def append( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, render_kwargs=render_kwargs + ) + self.adapter.insert_append( + table_name, + query_or_df, + target_columns_to_types=columns_to_types, + source_columns=source_columns, + ) + + +class IncrementalByPartitionStrategy(IncrementalStrategy): def insert( self, table_name: str, @@ -1691,7 +1698,7 @@ def insert( ) -class IncrementalByTimeRangeStrategy(MaterializableStrategy): +class IncrementalByTimeRangeStrategy(IncrementalStrategy): def insert( self, table_name: str, @@ -1716,7 +1723,7 @@ def insert( ) -class IncrementalByUniqueKeyStrategy(MaterializableStrategy): +class IncrementalByUniqueKeyStrategy(IncrementalStrategy): def insert( self, table_name: str, @@ -1776,7 +1783,7 @@ def append( ) -class IncrementalUnmanagedStrategy(MaterializableStrategy): +class IncrementalUnmanagedStrategy(IncrementalStrategy): def append( self, table_name: str, @@ -1832,6 +1839,20 @@ def insert( class FullRefreshStrategy(MaterializableStrategy): + def append( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + self.adapter.insert_append( + table_name, + query_or_df, + target_columns_to_types=model.columns_to_types, + ) + def insert( self, table_name: str, @@ -1893,8 +1914,19 @@ def insert( # Data has already been inserted at the time of table creation. pass + def append( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + # Data has already been inserted at the time of table creation. + pass + -class SCDType2Strategy(MaterializableStrategy): +class SCDType2Strategy(IncrementalStrategy): def create( self, table_name: str, @@ -2181,7 +2213,7 @@ def _is_materialized_view(self, model: Model) -> bool: C = t.TypeVar("C", bound=CustomKind) -class CustomMaterialization(MaterializableStrategy, t.Generic[C]): +class CustomMaterialization(IncrementalStrategy, t.Generic[C]): """Base class for custom materializations.""" def insert( diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 5cbe22ab46..72d8964a71 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -7948,7 +7948,6 @@ def test_incremental_unmanaged_model_ignore_destructive_change(tmp_path: Path): data_dir = tmp_path / "data" data_dir.mkdir() data_filepath = data_dir / "test.duckdb" - set_console(TerminalConsole()) config = Config( model_defaults=ModelDefaultsConfig(dialect="duckdb"), @@ -8058,7 +8057,6 @@ def test_scd_type_2_by_time_ignore_destructive_change(tmp_path: Path): data_dir = tmp_path / "data" data_dir.mkdir() data_filepath = data_dir / "test.duckdb" - set_console(TerminalConsole()) config = Config( model_defaults=ModelDefaultsConfig(dialect="duckdb"),