diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 553ffd58a5..dbda66614e 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -1755,6 +1755,7 @@ class SparkConnectionConfig(ConnectionConfig): config_dir: t.Optional[str] = None catalog: t.Optional[str] = None config: t.Dict[str, t.Any] = {} + wap_enabled: bool = False concurrent_tasks: int = 4 register_comments: bool = True @@ -1801,6 +1802,10 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: .getOrCreate(), } + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return {"wap_enabled": self.wap_enabled} + class TrinoAuthenticationMethod(str, Enum): NO_AUTH = "no-auth" diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index d8747c979d..47e6a4260c 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -2357,6 +2357,11 @@ def fetch_pyspark_df( """Fetches a PySpark DataFrame from the cursor""" raise NotImplementedError(f"Engine does not support PySpark DataFrames: {type(self)}") + @property + def wap_enabled(self) -> bool: + """Returns whether WAP is enabled for this engine.""" + return self._extra_config.get("wap_enabled", False) + def wap_supported(self, table_name: TableName) -> bool: """Returns whether WAP for the target table is supported.""" return False diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 7d6a4d969b..18ba6ea106 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -457,12 +457,14 @@ def _create_table( if wap_id.startswith(f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"): table_name.set("this", table_name.this.this) - wap_supported = ( - kwargs.get("storage_format") or "" - ).lower() == "iceberg" or self.wap_supported(table_name) - do_dummy_insert = ( - False if not wap_supported or not exists else not self.table_exists(table_name) - ) + do_dummy_insert = False + if self.wap_enabled: + wap_supported = ( + kwargs.get("storage_format") or "" + ).lower() == "iceberg" or self.wap_supported(table_name) + do_dummy_insert = ( + False if not wap_supported or not exists else not self.table_exists(table_name) + ) super()._create_table( table_name_or_schema, expression, diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 658bb1c400..0e4a440a29 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -793,6 +793,7 @@ def _evaluate_snapshot( if ( snapshot.is_materialized and target_table_exists + and adapter.wap_enabled and (model.wap_supported or adapter.wap_supported(target_table_name)) ): wap_id = random_id()[0:8] diff --git a/tests/core/engine_adapter/test_spark.py b/tests/core/engine_adapter/test_spark.py index f1929639a2..bc4e352bd7 100644 --- a/tests/core/engine_adapter/test_spark.py +++ b/tests/core/engine_adapter/test_spark.py @@ -66,14 +66,15 @@ def test_create_table_properties(make_mocked_engine_adapter: t.Callable): ) +@pytest.mark.parametrize("wap_enabled", [True, False]) def test_replace_query_table_properties_not_exists( - mocker: MockerFixture, make_mocked_engine_adapter: t.Callable + mocker: MockerFixture, make_mocked_engine_adapter: t.Callable, wap_enabled: bool ): mocker.patch( "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists", return_value=False, ) - adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter = make_mocked_engine_adapter(SparkEngineAdapter, wap_enabled=wap_enabled) columns_to_types = { "cola": exp.DataType.build("INT"), @@ -89,10 +90,13 @@ def test_replace_query_table_properties_not_exists( table_properties={"a": exp.convert(1)}, ) - assert to_sql_calls(adapter) == [ + expected_sql_calls = [ "CREATE TABLE IF NOT EXISTS `test_table` USING ICEBERG PARTITIONED BY (`colb`) TBLPROPERTIES ('a'=1) AS SELECT CAST(`cola` AS INT) AS `cola`, CAST(`colb` AS STRING) AS `colb`, CAST(`colc` AS STRING) AS `colc` FROM (SELECT 1 AS `cola`, '2' AS `colb`, '3' AS `colc`) AS `_subquery`", - "INSERT INTO `test_table` SELECT * FROM `test_table`", ] + if wap_enabled: + expected_sql_calls.append("INSERT INTO `test_table` SELECT * FROM `test_table`") + + assert to_sql_calls(adapter) == expected_sql_calls def test_replace_query_table_properties_exists( @@ -825,13 +829,16 @@ def test_wap_publish(make_mocked_engine_adapter: t.Callable, mocker: MockerFixtu ) -def test_create_table_iceberg(mocker: MockerFixture, make_mocked_engine_adapter: t.Callable): +@pytest.mark.parametrize("wap_enabled", [True, False]) +def test_create_table_iceberg( + mocker: MockerFixture, make_mocked_engine_adapter: t.Callable, wap_enabled: bool +): mocker.patch( "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists", return_value=False, ) - adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter = make_mocked_engine_adapter(SparkEngineAdapter, wap_enabled=wap_enabled) columns_to_types = { "cola": exp.DataType.build("INT"), @@ -846,10 +853,13 @@ def test_create_table_iceberg(mocker: MockerFixture, make_mocked_engine_adapter: storage_format="ICEBERG", ) - assert to_sql_calls(adapter) == [ + expected_sql_calls = [ "CREATE TABLE IF NOT EXISTS `test_table` (`cola` INT, `colb` STRING, `colc` STRING) USING ICEBERG PARTITIONED BY (`colb`)", - "INSERT INTO `test_table` SELECT * FROM `test_table`", ] + if wap_enabled: + expected_sql_calls.append("INSERT INTO `test_table` SELECT * FROM `test_table`") + + assert to_sql_calls(adapter) == expected_sql_calls def test_comments_hive(mocker: MockerFixture, make_mocked_engine_adapter: t.Callable): @@ -973,7 +983,7 @@ def test_create_table_with_wap(make_mocked_engine_adapter: t.Callable, mocker: M "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists", return_value=False, ) - adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter = make_mocked_engine_adapter(SparkEngineAdapter, wap_enabled=True) adapter.create_table( "catalog.schema.table.branch_wap_12345",