Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sqlmesh/core/engine_adapter/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,15 @@ def get_current_database(self) -> str:
return self.spark.catalog.currentDatabase()
return self.fetchone(exp.select(exp.func("current_database")))[0] # type: ignore

def get_data_object(self, target_name: TableName) -> t.Optional[DataObject]:
target_table = exp.to_table(target_name)
if isinstance(target_table.this, exp.Dot) and target_table.this.expression.name.startswith(
f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"
):
# Exclude the branch name
target_table.set("this", target_table.this.this)
return super().get_data_object(target_table)

def create_state_table(
self,
table_name: str,
Expand Down
39 changes: 21 additions & 18 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def audit(
snapshot.snapshot_id,
wap_id,
)
self._wap_publish_snapshot(snapshot, wap_id, deployability_index)
self.wap_publish_snapshot(snapshot, wap_id, deployability_index)

return results

Expand Down Expand Up @@ -806,8 +806,10 @@ def _evaluate_snapshot(
}

wap_id: t.Optional[str] = None
if snapshot.is_materialized and (
model.wap_supported or adapter.wap_supported(target_table_name)
if (
snapshot.is_materialized
and target_table_exists
and (model.wap_supported or adapter.wap_supported(target_table_name))
):
wap_id = random_id()[0:8]
logger.info("Using WAP ID '%s' for snapshot %s", wap_id, snapshot.snapshot_id)
Expand All @@ -823,6 +825,7 @@ def _evaluate_snapshot(
create_render_kwargs=create_render_kwargs,
rendered_physical_properties=rendered_physical_properties,
deployability_index=deployability_index,
target_table_name=target_table_name,
is_first_insert=is_first_insert,
batch_index=batch_index,
)
Expand Down Expand Up @@ -896,6 +899,17 @@ def create_snapshot(
if on_complete is not None:
on_complete(snapshot)

def wap_publish_snapshot(
self,
snapshot: Snapshot,
wap_id: str,
deployability_index: t.Optional[DeployabilityIndex],
) -> None:
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
adapter = self.get_adapter(snapshot.model_gateway)
adapter.wap_publish(table_name, wap_id)

def _render_and_insert_snapshot(
self,
start: TimeLike,
Expand All @@ -907,6 +921,7 @@ def _render_and_insert_snapshot(
create_render_kwargs: t.Dict[str, t.Any],
rendered_physical_properties: t.Dict[str, exp.Expression],
deployability_index: DeployabilityIndex,
target_table_name: str,
is_first_insert: bool,
batch_index: int,
) -> None:
Expand All @@ -916,7 +931,6 @@ def _render_and_insert_snapshot(
logger.info("Inserting data for snapshot %s", snapshot.snapshot_id)

model = snapshot.model
table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
adapter = self.get_adapter(model.gateway)
evaluation_strategy = _evaluation_strategy(snapshot, adapter)

Expand All @@ -930,7 +944,7 @@ def _render_and_insert_snapshot(
def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
if index > 0:
evaluation_strategy.append(
table_name=table_name,
table_name=target_table_name,
query_or_df=query_or_df,
model=snapshot.model,
snapshot=snapshot,
Expand All @@ -948,10 +962,10 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
"Inserting batch (%s, %s) into %s'",
time_like_to_str(start),
time_like_to_str(end),
table_name,
target_table_name,
)
evaluation_strategy.insert(
table_name=table_name,
table_name=target_table_name,
query_or_df=query_or_df,
is_first_insert=is_first_insert,
model=snapshot.model,
Expand Down Expand Up @@ -1278,17 +1292,6 @@ def _cleanup_snapshot(
if on_complete is not None:
on_complete(table_name)

def _wap_publish_snapshot(
self,
snapshot: Snapshot,
wap_id: str,
deployability_index: t.Optional[DeployabilityIndex],
) -> None:
deployability_index = deployability_index or DeployabilityIndex.all_deployable()
table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot))
adapter = self.get_adapter(snapshot.model_gateway)
adapter.wap_publish(table_name, wap_id)

def _audit(
self,
audit: Audit,
Expand Down
16 changes: 16 additions & 0 deletions tests/core/engine_adapter/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,3 +1080,19 @@ def test_table_format(adapter: SparkEngineAdapter, mocker: MockerFixture):
"CREATE TABLE IF NOT EXISTS `test_table` (`cola` TIMESTAMP, `colb` STRING, `colc` STRING) USING ICEBERG",
"CREATE TABLE IF NOT EXISTS `test_table` USING ICEBERG TBLPROPERTIES ('write.format.default'='orc') AS SELECT CAST(`cola` AS TIMESTAMP) AS `cola`, CAST(`colb` AS STRING) AS `colb`, CAST(`colc` AS STRING) AS `colc` FROM (SELECT CAST(1 AS TIMESTAMP) AS `cola`, CAST(2 AS STRING) AS `colb`, 'foo' AS `colc`) AS `_subquery`",
]


def test_get_data_object_wap_branch(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
adapter = make_mocked_engine_adapter(SparkEngineAdapter, patch_get_data_objects=False)
mocker.patch.object(adapter, "_get_data_objects", return_value=[])

table = exp.to_table(
"`catalog`.`sqlmesh__test`.`test__test_view__630979748`.`branch_wap_472234d7`",
dialect="spark",
)
adapter.get_data_object(table)

adapter._get_data_objects.assert_called_once_with(
d.schema_("sqlmesh__test", "catalog"),
{"test__test_view__630979748"},
)
Loading