Skip to content

Commit d184b2a

Browse files
committed
Fix: Regression in WAP support
1 parent d53a58e commit d184b2a

File tree

4 files changed

+298
-20
lines changed

4 files changed

+298
-20
lines changed

sqlmesh/core/engine_adapter/spark.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,15 @@ def get_current_database(self) -> str:
402402
return self.spark.catalog.currentDatabase()
403403
return self.fetchone(exp.select(exp.func("current_database")))[0] # type: ignore
404404

405+
def get_data_object(self, target_name: TableName) -> t.Optional[DataObject]:
406+
target_table = exp.to_table(target_name)
407+
if isinstance(target_table.this, exp.Dot) and target_table.this.expression.name.startswith(
408+
f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"
409+
):
410+
# Exclude the branch name
411+
target_table.set("this", target_table.this.this)
412+
return super().get_data_object(target_table)
413+
405414
def create_state_table(
406415
self,
407416
table_name: str,

sqlmesh/core/snapshot/evaluator.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def audit(
651651
snapshot.snapshot_id,
652652
wap_id,
653653
)
654-
self._wap_publish_snapshot(snapshot, wap_id, deployability_index)
654+
self.wap_publish_snapshot(snapshot, wap_id, deployability_index)
655655

656656
return results
657657

@@ -806,8 +806,10 @@ def _evaluate_snapshot(
806806
}
807807

808808
wap_id: t.Optional[str] = None
809-
if snapshot.is_materialized and (
810-
model.wap_supported or adapter.wap_supported(target_table_name)
809+
if (
810+
snapshot.is_materialized
811+
and target_table_exists
812+
and (model.wap_supported or adapter.wap_supported(target_table_name))
811813
):
812814
wap_id = random_id()[0:8]
813815
logger.info("Using WAP ID '%s' for snapshot %s", wap_id, snapshot.snapshot_id)
@@ -823,6 +825,7 @@ def _evaluate_snapshot(
823825
create_render_kwargs=create_render_kwargs,
824826
rendered_physical_properties=rendered_physical_properties,
825827
deployability_index=deployability_index,
828+
target_table_name=target_table_name,
826829
is_first_insert=is_first_insert,
827830
batch_index=batch_index,
828831
)
@@ -896,6 +899,17 @@ def create_snapshot(
896899
if on_complete is not None:
897900
on_complete(snapshot)
898901

902+
def wap_publish_snapshot(
903+
self,
904+
snapshot: Snapshot,
905+
wap_id: str,
906+
deployability_index: t.Optional[DeployabilityIndex],
907+
) -> None:
908+
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
909+
table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
910+
adapter = self.get_adapter(snapshot.model_gateway)
911+
adapter.wap_publish(table_name, wap_id)
912+
899913
def _render_and_insert_snapshot(
900914
self,
901915
start: TimeLike,
@@ -907,6 +921,7 @@ def _render_and_insert_snapshot(
907921
create_render_kwargs: t.Dict[str, t.Any],
908922
rendered_physical_properties: t.Dict[str, exp.Expression],
909923
deployability_index: DeployabilityIndex,
924+
target_table_name: str,
910925
is_first_insert: bool,
911926
batch_index: int,
912927
) -> None:
@@ -916,7 +931,6 @@ def _render_and_insert_snapshot(
916931
logger.info("Inserting data for snapshot %s", snapshot.snapshot_id)
917932

918933
model = snapshot.model
919-
table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
920934
adapter = self.get_adapter(model.gateway)
921935
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
922936

@@ -930,7 +944,7 @@ def _render_and_insert_snapshot(
930944
def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
931945
if index > 0:
932946
evaluation_strategy.append(
933-
table_name=table_name,
947+
table_name=target_table_name,
934948
query_or_df=query_or_df,
935949
model=snapshot.model,
936950
snapshot=snapshot,
@@ -948,10 +962,10 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
948962
"Inserting batch (%s, %s) into %s'",
949963
time_like_to_str(start),
950964
time_like_to_str(end),
951-
table_name,
965+
target_table_name,
952966
)
953967
evaluation_strategy.insert(
954-
table_name=table_name,
968+
table_name=target_table_name,
955969
query_or_df=query_or_df,
956970
is_first_insert=is_first_insert,
957971
model=snapshot.model,
@@ -1278,17 +1292,6 @@ def _cleanup_snapshot(
12781292
if on_complete is not None:
12791293
on_complete(table_name)
12801294

1281-
def _wap_publish_snapshot(
1282-
self,
1283-
snapshot: Snapshot,
1284-
wap_id: str,
1285-
deployability_index: t.Optional[DeployabilityIndex],
1286-
) -> None:
1287-
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
1288-
table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
1289-
adapter = self.get_adapter(snapshot.model_gateway)
1290-
adapter.wap_publish(table_name, wap_id)
1291-
12921295
def _audit(
12931296
self,
12941297
audit: Audit,

tests/core/engine_adapter/test_spark.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,3 +1080,19 @@ def test_table_format(adapter: SparkEngineAdapter, mocker: MockerFixture):
10801080
"CREATE TABLE IF NOT EXISTS `test_table` (`cola` TIMESTAMP, `colb` STRING, `colc` STRING) USING ICEBERG",
10811081
"CREATE TABLE IF NOT EXISTS `test_table` USING ICEBERG TBLPROPERTIES ('write.format.default'='orc') AS SELECT CAST(`cola` AS TIMESTAMP) AS `cola`, CAST(`colb` AS STRING) AS `colb`, CAST(`colc` AS STRING) AS `colc` FROM (SELECT CAST(1 AS TIMESTAMP) AS `cola`, CAST(2 AS STRING) AS `colb`, 'foo' AS `colc`) AS `_subquery`",
10821082
]
1083+
1084+
1085+
def test_get_data_object_wap_branch(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
1086+
adapter = make_mocked_engine_adapter(SparkEngineAdapter, patch_get_data_objects=False)
1087+
mocker.patch.object(adapter, "_get_data_objects", return_value=[])
1088+
1089+
table = exp.to_table(
1090+
"`catalog`.`sqlmesh__test`.`test__test_view__630979748`.`branch_wap_472234d7`",
1091+
dialect="spark",
1092+
)
1093+
adapter.get_data_object(table)
1094+
1095+
adapter._get_data_objects.assert_called_once_with(
1096+
d.schema_("sqlmesh__test", "catalog"),
1097+
{"test__test_view__630979748"},
1098+
)

0 commit comments

Comments
 (0)