Skip to content

Commit 282cfe3

Browse files
committed
fix model kind change edge case
1 parent 2670119 commit 282cfe3

File tree

5 files changed

+179
-46
lines changed

5 files changed

+179
-46
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,10 @@ def replace_query(
371371
"""
372372
target_table = exp.to_table(table_name)
373373

374-
target_data_object = self._get_data_object(target_table)
374+
target_data_object = self.get_data_object(target_table)
375375
table_exists = target_data_object is not None
376-
self._drop_data_object_on_type_mismatch(target_data_object, DataObjectType.TABLE)
376+
if self.drop_data_object_on_type_mismatch(target_data_object, DataObjectType.TABLE):
377+
table_exists = False
377378

378379
source_queries, columns_to_types = self._get_source_queries_and_columns_to_types(
379380
query_or_df, columns_to_types, target_table=target_table
@@ -1147,8 +1148,8 @@ def create_view(
11471148
create_kwargs["properties"] = properties
11481149

11491150
if replace:
1150-
self._drop_data_object_on_type_mismatch(
1151-
self._get_data_object(view_name),
1151+
self.drop_data_object_on_type_mismatch(
1152+
self.get_data_object(view_name),
11521153
DataObjectType.VIEW if not materialized else DataObjectType.MATERIALIZED_VIEW,
11531154
)
11541155

@@ -2056,6 +2057,15 @@ def rename_table(
20562057
)
20572058
self._rename_table(old_table_name, new_table_name)
20582059

2060+
def get_data_object(self, target_name: TableName) -> t.Optional[DataObject]:
2061+
target_table = exp.to_table(target_name)
2062+
existing_data_objects = self.get_data_objects(
2063+
schema_(target_table.db, target_table.catalog), {target_table.name}
2064+
)
2065+
if existing_data_objects:
2066+
return existing_data_objects[0]
2067+
return None
2068+
20592069
def get_data_objects(
20602070
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
20612071
) -> t.List[DataObject]:
@@ -2517,26 +2527,20 @@ def _truncate_table(self, table_name: TableName) -> None:
25172527
table = exp.to_table(table_name)
25182528
self.execute(f"TRUNCATE TABLE {table.sql(dialect=self.dialect, identify=True)}")
25192529

2520-
def _get_data_object(self, target_name: TableName) -> t.Optional[DataObject]:
2521-
target_table = exp.to_table(target_name)
2522-
existing_data_objects = self.get_data_objects(
2523-
schema_(target_table.db, target_table.catalog), {target_table.name}
2524-
)
2525-
if existing_data_objects:
2526-
return existing_data_objects[0]
2527-
return None
2528-
2529-
def _drop_data_object_on_type_mismatch(
2530+
def drop_data_object_on_type_mismatch(
25302531
self, data_object: t.Optional[DataObject], expected_type: DataObjectType
2531-
) -> None:
2532+
) -> bool:
25322533
"""Drops a data object if it exists and is not of the expected type.
25332534
25342535
Args:
25352536
data_object: The data object to check.
25362537
expected_type: The expected type of the data object.
2538+
2539+
Returns:
2540+
True if the data object was dropped, False otherwise.
25372541
"""
25382542
if data_object is None or data_object.type == expected_type:
2539-
return
2543+
return False
25402544

25412545
logger.warning(
25422546
"Target data object '%s' is a %s and not a %s, dropping it",
@@ -2545,6 +2549,7 @@ def _drop_data_object_on_type_mismatch(
25452549
expected_type.value,
25462550
)
25472551
self.drop_data_object(data_object)
2552+
return True
25482553

25492554
def _replace_by_key(
25502555
self,

sqlmesh/core/engine_adapter/redshift.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,10 @@ def replace_query(
262262
"""
263263
import pandas as pd
264264

265-
target_data_object = self._get_data_object(table_name)
265+
target_data_object = self.get_data_object(table_name)
266266
table_exists = target_data_object is not None
267-
self._drop_data_object_on_type_mismatch(target_data_object, DataObjectType.TABLE)
267+
if self.drop_data_object_on_type_mismatch(target_data_object, DataObjectType.TABLE):
268+
table_exists = False
268269

269270
if not isinstance(query_or_df, pd.DataFrame) or not table_exists:
270271
return super().replace_query(

sqlmesh/core/snapshot/evaluator.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from sqlmesh.core.audit import Audit, StandaloneAudit
3939
from sqlmesh.core.dialect import schema_
4040
from sqlmesh.core.engine_adapter import EngineAdapter
41-
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy
41+
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObjectType
4242
from sqlmesh.core.macros import RuntimeStage
4343
from sqlmesh.core.model import (
4444
AuditResult,
@@ -934,7 +934,14 @@ def _migrate_snapshot(
934934
adapter.transaction(),
935935
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
936936
):
937-
if adapter.table_exists(target_table_name):
937+
target_data_object = adapter.get_data_object(target_table_name)
938+
table_exists = target_data_object is not None
939+
if adapter.drop_data_object_on_type_mismatch(
940+
target_data_object, _snapshot_to_data_object_type(snapshot)
941+
):
942+
table_exists = False
943+
944+
if table_exists:
938945
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
939946
tmp_table_name = snapshot.table_name(is_deployable=False)
940947
logger.info(
@@ -2307,3 +2314,15 @@ def _check_table_db_is_physical_schema(table_name: str, physical_schema: str) ->
23072314
raise SQLMeshError(
23082315
f"Table '{table_name}' is not a part of the physical schema '{physical_schema}' and so can't be dropped."
23092316
)
2317+
2318+
2319+
def _snapshot_to_data_object_type(snapshot: Snapshot) -> DataObjectType:
2320+
if snapshot.is_managed:
2321+
return DataObjectType.MANAGED_TABLE
2322+
if snapshot.is_materialized_view:
2323+
return DataObjectType.MATERIALIZED_VIEW
2324+
if snapshot.is_view:
2325+
return DataObjectType.VIEW
2326+
if snapshot.is_materialized:
2327+
return DataObjectType.TABLE
2328+
return DataObjectType.UNKNOWN

tests/core/test_integration.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2603,6 +2603,48 @@ def test_virtual_environment_mode_dev_only_model_kind_change(init_and_plan_conte
26032603
assert data_objects[0].type == "table"
26042604

26052605

2606+
@time_machine.travel("2023-01-08 15:00:00 UTC")
2607+
def test_virtual_environment_mode_dev_only_model_kind_change_with_follow_up_changes_in_dev(
2608+
init_and_plan_context: t.Callable,
2609+
):
2610+
context, plan = init_and_plan_context(
2611+
"examples/sushi", config="test_config_virtual_environment_mode_dev_only"
2612+
)
2613+
context.apply(plan)
2614+
2615+
# Make sure the initial state is a view
2616+
data_objects = context.engine_adapter.get_data_objects("sushi", {"top_waiters"})
2617+
assert len(data_objects) == 1
2618+
assert data_objects[0].type == "view"
2619+
2620+
# Change to incremental unmanaged kind
2621+
model = context.get_model("sushi.top_waiters")
2622+
model = model.copy(update={"kind": IncrementalUnmanagedKind()})
2623+
context.upsert_model(model)
2624+
dev_plan = context.plan_builder("dev", skip_tests=True).build()
2625+
assert dev_plan.missing_intervals
2626+
assert dev_plan.requires_backfill
2627+
context.apply(dev_plan)
2628+
2629+
# Make a follow-up forward-only change
2630+
model = add_projection_to_model(t.cast(SqlModel, model))
2631+
context.upsert_model(model)
2632+
dev_plan = context.plan_builder("dev", skip_tests=True, forward_only=True).build()
2633+
context.apply(dev_plan)
2634+
2635+
# Deploy to prod
2636+
prod_plan = context.plan_builder("prod", skip_tests=True).build()
2637+
assert prod_plan.requires_backfill
2638+
assert prod_plan.missing_intervals
2639+
assert not prod_plan.context_diff.snapshots[
2640+
context.get_snapshot(model.name).snapshot_id
2641+
].intervals
2642+
context.apply(prod_plan)
2643+
data_objects = context.engine_adapter.get_data_objects("sushi", {"top_waiters"})
2644+
assert len(data_objects) == 1
2645+
assert data_objects[0].type == "table"
2646+
2647+
26062648
@time_machine.travel("2023-01-08 15:00:00 UTC")
26072649
def test_virtual_environment_mode_dev_only_model_kind_change_manual_categorization(
26082650
init_and_plan_context: t.Callable,

0 commit comments

Comments
 (0)