Skip to content

Commit 65b4c2e

Browse files
committed
Fix: Include unexpired downstream views when cleaning up expired tables
1 parent eb4c0b4 commit 65b4c2e

File tree

3 files changed

+317
-18
lines changed

3 files changed

+317
-18
lines changed

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
SnapshotId,
3131
SnapshotFingerprint,
3232
SnapshotChangeCategory,
33+
snapshots_to_dag,
3334
)
3435
from sqlmesh.utils.migration import index_text_type, blob_text_type
3536
from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp
@@ -224,18 +225,52 @@ def _get_expired_snapshots(
224225

225226
if not ignore_ttl:
226227
expired_query = expired_query.where(
227-
(exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts
228+
((exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts)
229+
# we need to include views even if they havent expired in case one depends on a table that /has/ expired
230+
.or_(exp.column("kind_name").eq(ModelKindName.VIEW))
228231
)
229232

230-
expired_candidates = {
233+
candidates = {
231234
SnapshotId(name=name, identifier=identifier): SnapshotNameVersion(
232235
name=name, version=version
233236
)
234237
for name, identifier, version in fetchall(self.engine_adapter, expired_query)
235238
}
236-
if not expired_candidates:
239+
240+
if not candidates:
237241
return set(), []
238242

243+
expired_candidates: t.Dict[SnapshotId, SnapshotNameVersion] = {}
244+
245+
if ignore_ttl:
246+
expired_candidates = candidates
247+
else:
248+
# Fetch full snapshots because we need to build a dependency tree
249+
full_candidates = self.get_snapshots(candidates.keys())
250+
251+
# Build DAG so we can check if any views that are not expired depend on a table snapshot that is expired
252+
dag = snapshots_to_dag(full_candidates.values())
253+
254+
# remove any non-expired views that dont depend on expired tables
255+
for snapshot_id in dag.reversed:
256+
snapshot = full_candidates.get(snapshot_id, None)
257+
if not snapshot:
258+
continue
259+
260+
if snapshot.expiration_ts <= current_ts:
261+
# All expired snapshots should be included
262+
expired_candidates[snapshot.snapshot_id] = snapshot.name_version
263+
elif snapshot.model_kind_name == ModelKindName.VIEW:
264+
# With non-expired views, check if they have any expired upstream tables
265+
if any(
266+
full_candidates[parent_id].expiration_ts <= current_ts
267+
for parent_id in dag.upstream(snapshot_id)
268+
if parent_id in full_candidates
269+
and full_candidates[parent_id].model_kind_name != ModelKindName.VIEW
270+
):
271+
# an upstream table has expired, therefore this view is no longer valid and needs to be cleaned up as well
272+
expired_candidates[snapshot.snapshot_id] = snapshot.name_version
273+
239274
promoted_snapshot_ids = {
240275
snapshot.snapshot_id
241276
for environment in environments
@@ -253,37 +288,39 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
253288
unique_expired_versions, batch_size=self.SNAPSHOT_BATCH_SIZE
254289
)
255290
cleanup_targets = []
256-
expired_snapshot_ids = set()
291+
expired_sv_snapshot_ids = set()
257292
for versions_batch in version_batches:
258-
snapshots = self._get_snapshots_with_same_version(versions_batch)
293+
sv_snapshots = self._get_snapshots_with_same_version(versions_batch)
259294

260295
snapshots_by_version = defaultdict(set)
261296
snapshots_by_dev_version = defaultdict(set)
262-
for s in snapshots:
297+
for s in sv_snapshots:
263298
snapshots_by_version[(s.name, s.version)].add(s.snapshot_id)
264299
snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id)
265300

266-
expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)]
267-
expired_snapshot_ids.update([s.snapshot_id for s in expired_snapshots])
301+
expired_sv_snapshots = [s for s in sv_snapshots if not _is_snapshot_used(s)]
302+
expired_sv_snapshot_ids.update([s.snapshot_id for s in expired_sv_snapshots])
268303

269-
for snapshot in expired_snapshots:
270-
shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)]
271-
shared_version_snapshots.discard(snapshot.snapshot_id)
304+
for sv_snapshot in expired_sv_snapshots:
305+
shared_version_snapshots = snapshots_by_version[
306+
(sv_snapshot.name, sv_snapshot.version)
307+
]
308+
shared_version_snapshots.discard(sv_snapshot.snapshot_id)
272309

273310
shared_dev_version_snapshots = snapshots_by_dev_version[
274-
(snapshot.name, snapshot.dev_version)
311+
(sv_snapshot.name, sv_snapshot.dev_version)
275312
]
276-
shared_dev_version_snapshots.discard(snapshot.snapshot_id)
313+
shared_dev_version_snapshots.discard(sv_snapshot.snapshot_id)
277314

278315
if not shared_dev_version_snapshots:
279316
cleanup_targets.append(
280317
SnapshotTableCleanupTask(
281-
snapshot=snapshot.full_snapshot.table_info,
318+
snapshot=sv_snapshot.full_snapshot.table_info,
282319
dev_table_only=bool(shared_version_snapshots),
283320
)
284321
)
285322

286-
return expired_snapshot_ids, cleanup_targets
323+
return expired_sv_snapshot_ids, cleanup_targets
287324

288325
def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None:
289326
"""Deletes snapshots.

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 181 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,26 @@
1313
import pandas as pd # noqa: TID253
1414
import pytest
1515
import pytz
16+
import time_machine
1617
from sqlglot import exp, parse_one
1718
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1819

1920
from sqlmesh import Config, Context
2021
from sqlmesh.cli.project_init import init_example_project
2122
from sqlmesh.core.config import load_config_from_paths
22-
from sqlmesh.core.config.connection import ConnectionConfig
23+
from sqlmesh.core.config.connection import ConnectionConfig, DuckDBConnectionConfig
2324
import sqlmesh.core.dialect as d
2425
from sqlmesh.core.dialect import select_from_values
2526
from sqlmesh.core.model import Model, load_sql_based_model
27+
from sqlmesh.core.engine_adapter import EngineAdapter
2628
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
2729
from sqlmesh.core.engine_adapter.mixins import RowDiffMixin
2830
from sqlmesh.core.model.definition import create_sql_model
2931
from sqlmesh.core.plan import Plan
32+
from sqlmesh.core.state_sync.cache import CachingStateSync
3033
from sqlmesh.core.state_sync.db import EngineAdapterStateSync
31-
from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory
32-
from sqlmesh.utils.date import now, to_date, to_time_column
34+
from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory, SnapshotId
35+
from sqlmesh.utils.date import now, to_date, to_time_column, to_ds
3336
from sqlmesh.core.table_diff import TableDiff
3437
from sqlmesh.utils.errors import SQLMeshError
3538
from sqlmesh.utils.pydantic import PydanticModel
@@ -2672,3 +2675,178 @@ def test_identifier_length_limit(ctx: TestContext):
26722675
match=re.escape(match),
26732676
):
26742677
adapter.create_table(long_table_name, {"col": exp.DataType.build("int")})
2678+
2679+
2680+
def test_janitor_out_of_order_drop(ctx: TestContext, tmp_path: pathlib.Path):
2681+
"""
2682+
Scenario:
2683+
2684+
Ensure that cleaning up expired table snapshots also cleans up any unexpired view snapshots that depend on them
2685+
2686+
- We create a A (table) <- B (view)
2687+
- In dev, we modify A - triggers new version of A and a dev preview of B that both expire in 7 days
2688+
- We advance time by 3 days
2689+
- In dev, we modify B - triggers a new version of B that depends on A but expires 3 days after A
2690+
- We advance time by 5 days so that A has reached its expiry but B has not
2691+
- We expire dev so that none of these snapshots are promoted and are thus targets for cleanup
2692+
- We run the janitor
2693+
2694+
Expected outcome:
2695+
- All the dev versions of A and B should be dropped
2696+
- We should not get a 'ERROR: cannot drop table x because other objects depend on it' on engines that do schema binding
2697+
"""
2698+
2699+
def _state_sync_engine_adapter(context: Context) -> EngineAdapter:
2700+
assert isinstance(context.state_sync, CachingStateSync)
2701+
assert isinstance(context.state_sync.state_sync, EngineAdapterStateSync)
2702+
return context.state_sync.state_sync.engine_adapter
2703+
2704+
models_dir = tmp_path / "models"
2705+
models_dir.mkdir()
2706+
schema = exp.to_table(ctx.schema(TEST_SCHEMA)).this
2707+
2708+
(models_dir / "model_a.sql").write_text(f"""
2709+
MODEL (
2710+
name {schema}.model_a,
2711+
kind FULL
2712+
);
2713+
2714+
SELECT 1 as a, 2 as b;
2715+
""")
2716+
2717+
(models_dir / "model_b.sql").write_text(f"""
2718+
MODEL (
2719+
name {schema}.model_b,
2720+
kind VIEW
2721+
);
2722+
2723+
SELECT a from {schema}.model_a;
2724+
""")
2725+
2726+
def _mutate_config(gateway: str, config: Config):
2727+
config.gateways[gateway].state_connection = DuckDBConnectionConfig(
2728+
database=str(tmp_path / "state.db")
2729+
)
2730+
2731+
with time_machine.travel("2020-01-01 00:00:00"):
2732+
sqlmesh = ctx.create_context(
2733+
path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False
2734+
)
2735+
sqlmesh.plan(auto_apply=True)
2736+
2737+
model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n)
2738+
# expiry is last updated + ttl
2739+
assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1)
2740+
assert to_ds(model_a_snapshot.updated_ts) == "2020-01-01"
2741+
assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-08"
2742+
2743+
model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n)
2744+
assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1)
2745+
assert to_ds(model_b_snapshot.updated_ts) == "2020-01-01"
2746+
assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-08"
2747+
2748+
model_a_prod_snapshot = model_a_snapshot
2749+
model_b_prod_snapshot = model_b_snapshot
2750+
2751+
# move forward 1 days
2752+
# new dev environment - touch models to create new snapshots
2753+
# model a / b expiry in prod should remain unmodified
2754+
# model a / b expiry in dev should be as at today
2755+
with time_machine.travel("2020-01-02 00:00:00"):
2756+
(models_dir / "model_a.sql").write_text(f"""
2757+
MODEL (
2758+
name {schema}.model_a,
2759+
kind FULL
2760+
);
2761+
2762+
SELECT 1 as a, 2 as b, 3 as c;
2763+
""")
2764+
2765+
sqlmesh = ctx.create_context(
2766+
path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False
2767+
)
2768+
sqlmesh.plan(environment="dev", auto_apply=True)
2769+
2770+
# should now have 4 snapshots in state - 2x model a and 2x model b
2771+
# the new model b is a dev preview because its upstream model changed
2772+
assert (
2773+
len(_state_sync_engine_adapter(sqlmesh).fetchall(f"select * from sqlmesh._snapshots"))
2774+
== 4
2775+
)
2776+
2777+
# context just has the two latest
2778+
assert len(sqlmesh.snapshots) == 2
2779+
2780+
# these expire 1 day later than what's in prod
2781+
model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n)
2782+
assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1)
2783+
assert to_ds(model_a_snapshot.updated_ts) == "2020-01-02"
2784+
assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-09"
2785+
2786+
model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n)
2787+
assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1)
2788+
assert to_ds(model_b_snapshot.updated_ts) == "2020-01-02"
2789+
assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-09"
2790+
2791+
# move forward 3 days
2792+
# touch model b in dev but leave model a
2793+
# this bumps the model b expiry but model a remains unchanged, so will expire before model b even though model b depends on it
2794+
with time_machine.travel("2020-01-05 00:00:00"):
2795+
(models_dir / "model_b.sql").write_text(f"""
2796+
MODEL (
2797+
name {schema}.model_b,
2798+
kind VIEW
2799+
);
2800+
2801+
SELECT a, 'b' as b from {schema}.model_a;
2802+
""")
2803+
2804+
sqlmesh = ctx.create_context(
2805+
path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False
2806+
)
2807+
sqlmesh.plan(environment="dev", auto_apply=True)
2808+
2809+
# should now have 5 snapshots in state - 2x model a and 3x model b
2810+
assert (
2811+
len(_state_sync_engine_adapter(sqlmesh).fetchall(f"select * from sqlmesh._snapshots"))
2812+
== 5
2813+
)
2814+
2815+
# context just has the two latest
2816+
assert len(sqlmesh.snapshots) == 2
2817+
2818+
# model a expiry should not have changed
2819+
model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n)
2820+
assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1)
2821+
assert to_ds(model_a_snapshot.updated_ts) == "2020-01-02"
2822+
assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-09"
2823+
2824+
# model b should now expire well after model a
2825+
model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n)
2826+
assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1)
2827+
assert to_ds(model_b_snapshot.updated_ts) == "2020-01-05"
2828+
assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-12"
2829+
2830+
# move forward to date where after model a has expired but before model b has expired
2831+
# invalidate dev to trigger cleanups
2832+
# run janitor. model a is expired so will be cleaned up and this will cascade to model b.
2833+
with time_machine.travel("2020-01-10 00:00:00"):
2834+
sqlmesh = ctx.create_context(
2835+
path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False
2836+
)
2837+
2838+
before_snapshots = _state_sync_engine_adapter(sqlmesh).fetchall(
2839+
f"select name, identifier from sqlmesh._snapshots"
2840+
)
2841+
sqlmesh.invalidate_environment("dev")
2842+
sqlmesh.run_janitor(ignore_ttl=False)
2843+
after_snapshots = _state_sync_engine_adapter(sqlmesh).fetchall(
2844+
f"select name, identifier from sqlmesh._snapshots"
2845+
)
2846+
2847+
assert len(before_snapshots) != len(after_snapshots)
2848+
2849+
# all that's left should be the two snapshots that were in prod
2850+
assert set(
2851+
[SnapshotId(name=name, identifier=identifier) for name, identifier in after_snapshots]
2852+
) == set([model_a_prod_snapshot.snapshot_id, model_b_prod_snapshot.snapshot_id])

0 commit comments

Comments
 (0)