Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 66 additions & 19 deletions sqlmesh/core/state_sync/db/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
SnapshotId,
SnapshotFingerprint,
SnapshotChangeCategory,
snapshots_to_dag,
)
from sqlmesh.utils.migration import index_text_type, blob_text_type
from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp
Expand Down Expand Up @@ -221,21 +222,65 @@ def _get_expired_snapshots(
ignore_ttl: bool = False,
) -> t.Tuple[t.Set[SnapshotId], t.List[SnapshotTableCleanupTask]]:
expired_query = exp.select("name", "identifier", "version").from_(self.snapshots_table)
expired_record_count = 0

if not ignore_ttl:
expired_query = expired_query.where(
(exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts
)

expired_candidates = {
SnapshotId(name=name, identifier=identifier): SnapshotNameVersion(
name=name, version=version
)
for name, identifier, version in fetchall(self.engine_adapter, expired_query)
}
if not expired_candidates:
if result := fetchone(
self.engine_adapter,
exp.select("count(*)").from_(expired_query.subquery().as_("expired_snapshots")),
):
expired_record_count = result[0]

# we need to include views even if they havent expired in case one depends on a table or view that /has/ expired
# but we only need to do this if there are expired objects to begin with
if expired_record_count > 0:
expired_query = t.cast(
exp.Select, expired_query.or_(exp.column("kind_name").eq(ModelKindName.VIEW))
)

candidates = {}
if ignore_ttl or expired_record_count > 0:
candidates = {
SnapshotId(name=name, identifier=identifier): SnapshotNameVersion(
name=name, version=version
)
for name, identifier, version in fetchall(self.engine_adapter, expired_query)
}

if not candidates:
return set(), []

expired_candidates: t.Dict[SnapshotId, SnapshotNameVersion] = {}

if ignore_ttl:
expired_candidates = candidates
else:
# Fetch full snapshots because we need the dependency relationship
full_candidates = self.get_snapshots(candidates.keys())

dag = snapshots_to_dag(full_candidates.values())

# Include any non-expired views that depend on expired tables
for snapshot_id in dag:
snapshot = full_candidates.get(snapshot_id, None)

if not snapshot:
continue

if snapshot.expiration_ts <= current_ts:
# All expired snapshots should be included regardless
expired_candidates[snapshot.snapshot_id] = snapshot.name_version
elif snapshot.model_kind_name == ModelKindName.VIEW:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be filtered by a database rather than the application layer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you mean? Line 242 adjusts the query to the database to include the views:

            if expired_record_count > 0:
                expired_query = t.cast(
                    exp.Select, expired_query.or_(exp.column("kind_name").eq(ModelKindName.VIEW))
                )

The DAG can only be built in the application layer, right? Since we have no way of knowing at the database level what views are affected, don't we have to select all views and filter them at the application layer?

We can't filter to expired views in the DB query because that's what caused this problem to begin with.

# Check if any of our parents are in the expired list
# This works because we traverse the dag in topological order, so if our parent either directly
# expired or indirectly expired because /its/ parent expired, it will still be in the expired_snapshots list
if any(parent.snapshot_id in expired_candidates for parent in snapshot.parents):
expired_candidates[snapshot.snapshot_id] = snapshot.name_version

promoted_snapshot_ids = {
snapshot.snapshot_id
for environment in environments
Expand All @@ -253,37 +298,39 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
unique_expired_versions, batch_size=self.SNAPSHOT_BATCH_SIZE
)
cleanup_targets = []
expired_snapshot_ids = set()
expired_sv_snapshot_ids = set()
for versions_batch in version_batches:
snapshots = self._get_snapshots_with_same_version(versions_batch)
sv_snapshots = self._get_snapshots_with_same_version(versions_batch)

snapshots_by_version = defaultdict(set)
snapshots_by_dev_version = defaultdict(set)
for s in snapshots:
for s in sv_snapshots:
Copy link
Collaborator Author

@erindru erindru Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These variable renames were just to placate mypy, because snapshot is used above to refer to objects of type Snapshot which does not match the objects of type SharedVersionSnapshot used here

snapshots_by_version[(s.name, s.version)].add(s.snapshot_id)
snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id)

expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)]
expired_snapshot_ids.update([s.snapshot_id for s in expired_snapshots])
expired_sv_snapshots = [s for s in sv_snapshots if not _is_snapshot_used(s)]
expired_sv_snapshot_ids.update([s.snapshot_id for s in expired_sv_snapshots])

for snapshot in expired_snapshots:
shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)]
shared_version_snapshots.discard(snapshot.snapshot_id)
for sv_snapshot in expired_sv_snapshots:
shared_version_snapshots = snapshots_by_version[
(sv_snapshot.name, sv_snapshot.version)
]
shared_version_snapshots.discard(sv_snapshot.snapshot_id)

shared_dev_version_snapshots = snapshots_by_dev_version[
(snapshot.name, snapshot.dev_version)
(sv_snapshot.name, sv_snapshot.dev_version)
]
shared_dev_version_snapshots.discard(snapshot.snapshot_id)
shared_dev_version_snapshots.discard(sv_snapshot.snapshot_id)

if not shared_dev_version_snapshots:
cleanup_targets.append(
SnapshotTableCleanupTask(
snapshot=snapshot.full_snapshot.table_info,
snapshot=sv_snapshot.full_snapshot.table_info,
dev_table_only=bool(shared_version_snapshots),
)
)

return expired_snapshot_ids, cleanup_targets
return expired_sv_snapshot_ids, cleanup_targets

def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None:
"""Deletes snapshots.
Expand Down
229 changes: 226 additions & 3 deletions tests/core/engine_adapter/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,25 @@
import pandas as pd # noqa: TID253
import pytest
import pytz
import time_machine
from sqlglot import exp, parse_one
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers

from sqlmesh import Config, Context
from sqlmesh.cli.project_init import init_example_project
from sqlmesh.core.config import load_config_from_paths
from sqlmesh.core.config.connection import ConnectionConfig
from sqlmesh.core.config.connection import ConnectionConfig, DuckDBConnectionConfig
import sqlmesh.core.dialect as d
from sqlmesh.core.dialect import select_from_values
from sqlmesh.core.model import Model, load_sql_based_model
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
from sqlmesh.core.engine_adapter.mixins import RowDiffMixin, LogicalMergeMixin
from sqlmesh.core.model.definition import create_sql_model
from sqlmesh.core.plan import Plan
from sqlmesh.core.state_sync.cache import CachingStateSync
from sqlmesh.core.state_sync.db import EngineAdapterStateSync
from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory
from sqlmesh.utils.date import now, to_date, to_time_column
from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory, SnapshotId
from sqlmesh.utils.date import now, to_date, to_time_column, to_ds
from sqlmesh.core.table_diff import TableDiff
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.pydantic import PydanticModel
Expand Down Expand Up @@ -2799,3 +2801,224 @@ def test_identifier_length_limit(ctx: TestContext):
match=re.escape(match),
):
adapter.create_table(long_table_name, {"col": exp.DataType.build("int")})


def test_janitor_drops_downstream_unexpired_hard_dependencies(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also add a test for a transitive dependency? For example:
Table A is expired ← View B (not expired) ← View C (not expired)
I believe this case was handled with the reversed dag before and that View C got picked up, but I’m not entirely certain without the dag construction now, so it’d be good to have a test to confirm

Copy link
Collaborator Author

@erindru erindru Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you're right, the drops weren't cascading through transitive dependencies. I've re implemented this to use a DAG again which simplified the implementation because it can be traversed in topological order and for any given node we can:

  • check if its expired (gets put directly in the expired list)
  • if it's a view, check if any of its parents are in the expired list. if they are, it should be expired too

I also updated the tests to test transitive dependencies too

ctx: TestContext, tmp_path: pathlib.Path
):
"""
Scenario:

Ensure that cleaning up expired table snapshots also cleans up any unexpired view snapshots that depend on them

- We create a A (table) <- B (view)
- In dev, we modify A - triggers new version of A and a dev preview of B that both expire in 7 days
- We advance time by 3 days
- In dev, we modify B - triggers a new version of B that depends on A but expires 3 days after A
- In dev, we create B(view) <- C(view) and B(view) <- D(table)
- We advance time by 5 days so that A has reached its expiry but B, C and D have not
- We expire dev so that none of these snapshots are promoted and are thus targets for cleanup
- We run the janitor

Expected outcome:
- All the dev versions of A and B should be dropped
- C should be dropped as well because it's a view that depends on B which was dropped
- D should not be dropped because while it depends on B which was dropped, it's a table so is still valid after B is dropped
- We should not get a 'ERROR: cannot drop table x because other objects depend on it' on engines that do schema binding
"""

def _all_snapshot_ids(context: Context) -> t.List[SnapshotId]:
assert isinstance(context.state_sync, CachingStateSync)
assert isinstance(context.state_sync.state_sync, EngineAdapterStateSync)

return [
SnapshotId(name=name, identifier=identifier)
for name, identifier in context.state_sync.state_sync.engine_adapter.fetchall(
"select name, identifier from sqlmesh._snapshots"
)
]

models_dir = tmp_path / "models"
models_dir.mkdir()
schema = exp.to_table(ctx.schema(TEST_SCHEMA)).this

(models_dir / "model_a.sql").write_text(f"""
MODEL (
name {schema}.model_a,
kind FULL
);

SELECT 1 as a, 2 as b;
""")

(models_dir / "model_b.sql").write_text(f"""
MODEL (
name {schema}.model_b,
kind VIEW
);

SELECT a from {schema}.model_a;
""")

def _mutate_config(gateway: str, config: Config):
config.gateways[gateway].state_connection = DuckDBConnectionConfig(
database=str(tmp_path / "state.db")
)

with time_machine.travel("2020-01-01 00:00:00"):
sqlmesh = ctx.create_context(
path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False
)
sqlmesh.plan(auto_apply=True)

model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n)
# expiry is last updated + ttl
assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1)
assert to_ds(model_a_snapshot.updated_ts) == "2020-01-01"
assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-08"

model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n)
assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1)
assert to_ds(model_b_snapshot.updated_ts) == "2020-01-01"
assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-08"

model_a_prod_snapshot = model_a_snapshot
model_b_prod_snapshot = model_b_snapshot

# move forward 1 days
# new dev environment - touch models to create new snapshots
# model a / b expiry in prod should remain unmodified
# model a / b expiry in dev should be as at today
with time_machine.travel("2020-01-02 00:00:00"):
(models_dir / "model_a.sql").write_text(f"""
MODEL (
name {schema}.model_a,
kind FULL
);

SELECT 1 as a, 2 as b, 3 as c;
""")

sqlmesh = ctx.create_context(
path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False
)
sqlmesh.plan(environment="dev", auto_apply=True)

# should now have 4 snapshots in state - 2x model a and 2x model b
# the new model b is a dev preview because its upstream model changed
all_snapshot_ids = _all_snapshot_ids(sqlmesh)
assert len(all_snapshot_ids) == 4
assert len([s for s in all_snapshot_ids if "model_a" in s.name]) == 2
assert len([s for s in all_snapshot_ids if "model_b" in s.name]) == 2

# context just has the two latest
assert len(sqlmesh.snapshots) == 2

# these expire 1 day later than what's in prod
model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n)
assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1)
assert to_ds(model_a_snapshot.updated_ts) == "2020-01-02"
assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-09"

model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n)
assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1)
assert to_ds(model_b_snapshot.updated_ts) == "2020-01-02"
assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-09"

# move forward 3 days
# touch model b in dev but leave model a
# this bumps the model b expiry but model a remains unchanged, so will expire before model b even though model b depends on it
with time_machine.travel("2020-01-05 00:00:00"):
(models_dir / "model_b.sql").write_text(f"""
MODEL (
name {schema}.model_b,
kind VIEW
);

SELECT a, 'b' as b from {schema}.model_a;
""")

(models_dir / "model_c.sql").write_text(f"""
MODEL (
name {schema}.model_c,
kind VIEW
);

SELECT a, 'c' as c from {schema}.model_b;
""")

(models_dir / "model_d.sql").write_text(f"""
MODEL (
name {schema}.model_d,
kind FULL
);

SELECT a, 'd' as d from {schema}.model_b;
""")

sqlmesh = ctx.create_context(
path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False
)
sqlmesh.plan(environment="dev", auto_apply=True)

# should now have 7 snapshots in state - 2x model a, 3x model b, 1x model c and 1x model d
all_snapshot_ids = _all_snapshot_ids(sqlmesh)
assert len(all_snapshot_ids) == 7
assert len([s for s in all_snapshot_ids if "model_a" in s.name]) == 2
assert len([s for s in all_snapshot_ids if "model_b" in s.name]) == 3
assert len([s for s in all_snapshot_ids if "model_c" in s.name]) == 1
assert len([s for s in all_snapshot_ids if "model_d" in s.name]) == 1

# context just has the 4 latest
assert len(sqlmesh.snapshots) == 4

# model a expiry should not have changed
model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n)
assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1)
assert to_ds(model_a_snapshot.updated_ts) == "2020-01-02"
assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-09"

# model b should now expire well after model a
model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n)
assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1)
assert to_ds(model_b_snapshot.updated_ts) == "2020-01-05"
assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-12"

# model c should expire at the same time as model b
model_c_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_c" in n)
assert to_ds(model_c_snapshot.updated_ts) == to_ds(model_b_snapshot.updated_ts)
assert to_ds(model_c_snapshot.expiration_ts) == to_ds(model_b_snapshot.expiration_ts)

# model d should expire at the same time as model b
model_d_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_d" in n)
assert to_ds(model_d_snapshot.updated_ts) == to_ds(model_b_snapshot.updated_ts)
assert to_ds(model_d_snapshot.expiration_ts) == to_ds(model_b_snapshot.expiration_ts)

# move forward to date where after model a has expired but before model b has expired
# invalidate dev to trigger cleanups
# run janitor
# - table model a is expired so will be cleaned up and this will cascade to view model b
# - view model b is not expired, but because it got cascaded to, this will cascade again to view model c
# - table model d is a not a view, so even though its parent view model b got dropped, it doesnt need to be dropped
with time_machine.travel("2020-01-10 00:00:00"):
sqlmesh = ctx.create_context(
path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False
)

before_snapshot_ids = _all_snapshot_ids(sqlmesh)

sqlmesh.invalidate_environment("dev")
sqlmesh.run_janitor(ignore_ttl=False)

after_snapshot_ids = _all_snapshot_ids(sqlmesh)

assert len(before_snapshot_ids) != len(after_snapshot_ids)

# all that's left should be the two original snapshots that were in prod and model d
assert set(after_snapshot_ids) == set(
[
model_a_prod_snapshot.snapshot_id,
model_b_prod_snapshot.snapshot_id,
model_d_snapshot.snapshot_id,
]
)
Loading