-
Notifications
You must be signed in to change notification settings - Fork 358
Fix: Include unexpired downstream views when cleaning up expired tables #5098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ebb5933
3751c4b
532d1c5
47ea790
a8742bd
3234a78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |
| SnapshotId, | ||
| SnapshotFingerprint, | ||
| SnapshotChangeCategory, | ||
| snapshots_to_dag, | ||
| ) | ||
| from sqlmesh.utils.migration import index_text_type, blob_text_type | ||
| from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp | ||
|
|
@@ -221,21 +222,65 @@ def _get_expired_snapshots( | |
| ignore_ttl: bool = False, | ||
| ) -> t.Tuple[t.Set[SnapshotId], t.List[SnapshotTableCleanupTask]]: | ||
| expired_query = exp.select("name", "identifier", "version").from_(self.snapshots_table) | ||
| expired_record_count = 0 | ||
|
|
||
| if not ignore_ttl: | ||
| expired_query = expired_query.where( | ||
| (exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts | ||
| ) | ||
|
|
||
| expired_candidates = { | ||
| SnapshotId(name=name, identifier=identifier): SnapshotNameVersion( | ||
| name=name, version=version | ||
| ) | ||
| for name, identifier, version in fetchall(self.engine_adapter, expired_query) | ||
| } | ||
| if not expired_candidates: | ||
| if result := fetchone( | ||
| self.engine_adapter, | ||
| exp.select("count(*)").from_(expired_query.subquery().as_("expired_snapshots")), | ||
| ): | ||
| expired_record_count = result[0] | ||
|
|
||
| # we need to include views even if they havent expired in case one depends on a table or view that /has/ expired | ||
| # but we only need to do this if there are expired objects to begin with | ||
| if expired_record_count > 0: | ||
| expired_query = t.cast( | ||
| exp.Select, expired_query.or_(exp.column("kind_name").eq(ModelKindName.VIEW)) | ||
| ) | ||
|
|
||
| candidates = {} | ||
| if ignore_ttl or expired_record_count > 0: | ||
| candidates = { | ||
| SnapshotId(name=name, identifier=identifier): SnapshotNameVersion( | ||
| name=name, version=version | ||
| ) | ||
| for name, identifier, version in fetchall(self.engine_adapter, expired_query) | ||
| } | ||
|
|
||
| if not candidates: | ||
| return set(), [] | ||
|
|
||
| expired_candidates: t.Dict[SnapshotId, SnapshotNameVersion] = {} | ||
|
|
||
| if ignore_ttl: | ||
| expired_candidates = candidates | ||
| else: | ||
| # Fetch full snapshots because we need the dependency relationship | ||
| full_candidates = self.get_snapshots(candidates.keys()) | ||
|
|
||
| dag = snapshots_to_dag(full_candidates.values()) | ||
|
|
||
| # Include any non-expired views that depend on expired tables | ||
| for snapshot_id in dag: | ||
| snapshot = full_candidates.get(snapshot_id, None) | ||
|
|
||
| if not snapshot: | ||
| continue | ||
|
|
||
| if snapshot.expiration_ts <= current_ts: | ||
| # All expired snapshots should be included regardless | ||
| expired_candidates[snapshot.snapshot_id] = snapshot.name_version | ||
| elif snapshot.model_kind_name == ModelKindName.VIEW: | ||
| # Check if any of our parents are in the expired list | ||
| # This works because we traverse the dag in topological order, so if our parent either directly | ||
| # expired or indirectly expired because /its/ parent expired, it will still be in the expired_snapshots list | ||
| if any(parent.snapshot_id in expired_candidates for parent in snapshot.parents): | ||
| expired_candidates[snapshot.snapshot_id] = snapshot.name_version | ||
|
|
||
| promoted_snapshot_ids = { | ||
| snapshot.snapshot_id | ||
| for environment in environments | ||
|
|
@@ -253,37 +298,39 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: | |
| unique_expired_versions, batch_size=self.SNAPSHOT_BATCH_SIZE | ||
| ) | ||
| cleanup_targets = [] | ||
| expired_snapshot_ids = set() | ||
| expired_sv_snapshot_ids = set() | ||
| for versions_batch in version_batches: | ||
| snapshots = self._get_snapshots_with_same_version(versions_batch) | ||
| sv_snapshots = self._get_snapshots_with_same_version(versions_batch) | ||
|
|
||
| snapshots_by_version = defaultdict(set) | ||
| snapshots_by_dev_version = defaultdict(set) | ||
| for s in snapshots: | ||
| for s in sv_snapshots: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These variable renames were just to placate mypy, because |
||
| snapshots_by_version[(s.name, s.version)].add(s.snapshot_id) | ||
| snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id) | ||
|
|
||
| expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)] | ||
| expired_snapshot_ids.update([s.snapshot_id for s in expired_snapshots]) | ||
| expired_sv_snapshots = [s for s in sv_snapshots if not _is_snapshot_used(s)] | ||
| expired_sv_snapshot_ids.update([s.snapshot_id for s in expired_sv_snapshots]) | ||
|
|
||
| for snapshot in expired_snapshots: | ||
| shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)] | ||
| shared_version_snapshots.discard(snapshot.snapshot_id) | ||
| for sv_snapshot in expired_sv_snapshots: | ||
| shared_version_snapshots = snapshots_by_version[ | ||
| (sv_snapshot.name, sv_snapshot.version) | ||
| ] | ||
| shared_version_snapshots.discard(sv_snapshot.snapshot_id) | ||
|
|
||
| shared_dev_version_snapshots = snapshots_by_dev_version[ | ||
| (snapshot.name, snapshot.dev_version) | ||
| (sv_snapshot.name, sv_snapshot.dev_version) | ||
| ] | ||
| shared_dev_version_snapshots.discard(snapshot.snapshot_id) | ||
| shared_dev_version_snapshots.discard(sv_snapshot.snapshot_id) | ||
|
|
||
| if not shared_dev_version_snapshots: | ||
| cleanup_targets.append( | ||
| SnapshotTableCleanupTask( | ||
| snapshot=snapshot.full_snapshot.table_info, | ||
| snapshot=sv_snapshot.full_snapshot.table_info, | ||
| dev_table_only=bool(shared_version_snapshots), | ||
| ) | ||
| ) | ||
|
|
||
| return expired_snapshot_ids, cleanup_targets | ||
| return expired_sv_snapshot_ids, cleanup_targets | ||
|
|
||
| def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: | ||
| """Deletes snapshots. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,23 +13,25 @@ | |
| import pandas as pd # noqa: TID253 | ||
| import pytest | ||
| import pytz | ||
| import time_machine | ||
| from sqlglot import exp, parse_one | ||
| from sqlglot.optimizer.normalize_identifiers import normalize_identifiers | ||
|
|
||
| from sqlmesh import Config, Context | ||
| from sqlmesh.cli.project_init import init_example_project | ||
| from sqlmesh.core.config import load_config_from_paths | ||
| from sqlmesh.core.config.connection import ConnectionConfig | ||
| from sqlmesh.core.config.connection import ConnectionConfig, DuckDBConnectionConfig | ||
| import sqlmesh.core.dialect as d | ||
| from sqlmesh.core.dialect import select_from_values | ||
| from sqlmesh.core.model import Model, load_sql_based_model | ||
| from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType | ||
| from sqlmesh.core.engine_adapter.mixins import RowDiffMixin, LogicalMergeMixin | ||
| from sqlmesh.core.model.definition import create_sql_model | ||
| from sqlmesh.core.plan import Plan | ||
| from sqlmesh.core.state_sync.cache import CachingStateSync | ||
| from sqlmesh.core.state_sync.db import EngineAdapterStateSync | ||
| from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory | ||
| from sqlmesh.utils.date import now, to_date, to_time_column | ||
| from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory, SnapshotId | ||
| from sqlmesh.utils.date import now, to_date, to_time_column, to_ds | ||
| from sqlmesh.core.table_diff import TableDiff | ||
| from sqlmesh.utils.errors import SQLMeshError | ||
| from sqlmesh.utils.pydantic import PydanticModel | ||
|
|
@@ -2799,3 +2801,224 @@ def test_identifier_length_limit(ctx: TestContext): | |
| match=re.escape(match), | ||
| ): | ||
| adapter.create_table(long_table_name, {"col": exp.DataType.build("int")}) | ||
|
|
||
|
|
||
| def test_janitor_drops_downstream_unexpired_hard_dependencies( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you also add a test for a transitive dependency? For example:
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah you're right, the drops weren't cascading through transitive dependencies. I've re implemented this to use a DAG again which simplified the implementation because it can be traversed in topological order and for any given node we can:
I also updated the tests to test transitive dependencies too |
||
| ctx: TestContext, tmp_path: pathlib.Path | ||
| ): | ||
| """ | ||
| Scenario: | ||
|
|
||
| Ensure that cleaning up expired table snapshots also cleans up any unexpired view snapshots that depend on them | ||
|
|
||
| - We create a A (table) <- B (view) | ||
| - In dev, we modify A - triggers new version of A and a dev preview of B that both expire in 7 days | ||
| - We advance time by 3 days | ||
| - In dev, we modify B - triggers a new version of B that depends on A but expires 3 days after A | ||
| - In dev, we create B(view) <- C(view) and B(view) <- D(table) | ||
| - We advance time by 5 days so that A has reached its expiry but B, C and D have not | ||
| - We expire dev so that none of these snapshots are promoted and are thus targets for cleanup | ||
| - We run the janitor | ||
|
|
||
| Expected outcome: | ||
| - All the dev versions of A and B should be dropped | ||
| - C should be dropped as well because it's a view that depends on B which was dropped | ||
| - D should not be dropped because while it depends on B which was dropped, it's a table so is still valid after B is dropped | ||
| - We should not get a 'ERROR: cannot drop table x because other objects depend on it' on engines that do schema binding | ||
| """ | ||
|
|
||
| def _all_snapshot_ids(context: Context) -> t.List[SnapshotId]: | ||
| assert isinstance(context.state_sync, CachingStateSync) | ||
| assert isinstance(context.state_sync.state_sync, EngineAdapterStateSync) | ||
|
|
||
| return [ | ||
| SnapshotId(name=name, identifier=identifier) | ||
| for name, identifier in context.state_sync.state_sync.engine_adapter.fetchall( | ||
| "select name, identifier from sqlmesh._snapshots" | ||
| ) | ||
| ] | ||
|
|
||
| models_dir = tmp_path / "models" | ||
| models_dir.mkdir() | ||
| schema = exp.to_table(ctx.schema(TEST_SCHEMA)).this | ||
|
|
||
| (models_dir / "model_a.sql").write_text(f""" | ||
| MODEL ( | ||
| name {schema}.model_a, | ||
| kind FULL | ||
| ); | ||
|
|
||
| SELECT 1 as a, 2 as b; | ||
| """) | ||
|
|
||
| (models_dir / "model_b.sql").write_text(f""" | ||
| MODEL ( | ||
| name {schema}.model_b, | ||
| kind VIEW | ||
| ); | ||
|
|
||
| SELECT a from {schema}.model_a; | ||
| """) | ||
|
|
||
| def _mutate_config(gateway: str, config: Config): | ||
| config.gateways[gateway].state_connection = DuckDBConnectionConfig( | ||
| database=str(tmp_path / "state.db") | ||
| ) | ||
|
|
||
| with time_machine.travel("2020-01-01 00:00:00"): | ||
| sqlmesh = ctx.create_context( | ||
| path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False | ||
| ) | ||
| sqlmesh.plan(auto_apply=True) | ||
|
|
||
| model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n) | ||
| # expiry is last updated + ttl | ||
| assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1) | ||
| assert to_ds(model_a_snapshot.updated_ts) == "2020-01-01" | ||
| assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-08" | ||
|
|
||
| model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n) | ||
| assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1) | ||
| assert to_ds(model_b_snapshot.updated_ts) == "2020-01-01" | ||
| assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-08" | ||
|
|
||
| model_a_prod_snapshot = model_a_snapshot | ||
| model_b_prod_snapshot = model_b_snapshot | ||
|
|
||
| # move forward 1 days | ||
| # new dev environment - touch models to create new snapshots | ||
| # model a / b expiry in prod should remain unmodified | ||
| # model a / b expiry in dev should be as at today | ||
| with time_machine.travel("2020-01-02 00:00:00"): | ||
| (models_dir / "model_a.sql").write_text(f""" | ||
| MODEL ( | ||
| name {schema}.model_a, | ||
| kind FULL | ||
| ); | ||
|
|
||
| SELECT 1 as a, 2 as b, 3 as c; | ||
| """) | ||
|
|
||
| sqlmesh = ctx.create_context( | ||
| path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False | ||
| ) | ||
| sqlmesh.plan(environment="dev", auto_apply=True) | ||
|
|
||
| # should now have 4 snapshots in state - 2x model a and 2x model b | ||
| # the new model b is a dev preview because its upstream model changed | ||
| all_snapshot_ids = _all_snapshot_ids(sqlmesh) | ||
| assert len(all_snapshot_ids) == 4 | ||
| assert len([s for s in all_snapshot_ids if "model_a" in s.name]) == 2 | ||
| assert len([s for s in all_snapshot_ids if "model_b" in s.name]) == 2 | ||
|
|
||
| # context just has the two latest | ||
| assert len(sqlmesh.snapshots) == 2 | ||
|
|
||
| # these expire 1 day later than what's in prod | ||
| model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n) | ||
| assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1) | ||
| assert to_ds(model_a_snapshot.updated_ts) == "2020-01-02" | ||
| assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-09" | ||
|
|
||
| model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n) | ||
| assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1) | ||
| assert to_ds(model_b_snapshot.updated_ts) == "2020-01-02" | ||
| assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-09" | ||
|
|
||
| # move forward 3 days | ||
| # touch model b in dev but leave model a | ||
| # this bumps the model b expiry but model a remains unchanged, so will expire before model b even though model b depends on it | ||
| with time_machine.travel("2020-01-05 00:00:00"): | ||
| (models_dir / "model_b.sql").write_text(f""" | ||
| MODEL ( | ||
| name {schema}.model_b, | ||
| kind VIEW | ||
| ); | ||
|
|
||
| SELECT a, 'b' as b from {schema}.model_a; | ||
| """) | ||
|
|
||
| (models_dir / "model_c.sql").write_text(f""" | ||
| MODEL ( | ||
| name {schema}.model_c, | ||
| kind VIEW | ||
| ); | ||
|
|
||
| SELECT a, 'c' as c from {schema}.model_b; | ||
| """) | ||
|
|
||
| (models_dir / "model_d.sql").write_text(f""" | ||
| MODEL ( | ||
| name {schema}.model_d, | ||
| kind FULL | ||
| ); | ||
|
|
||
| SELECT a, 'd' as d from {schema}.model_b; | ||
| """) | ||
|
|
||
| sqlmesh = ctx.create_context( | ||
| path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False | ||
| ) | ||
| sqlmesh.plan(environment="dev", auto_apply=True) | ||
|
|
||
| # should now have 7 snapshots in state - 2x model a, 3x model b, 1x model c and 1x model d | ||
| all_snapshot_ids = _all_snapshot_ids(sqlmesh) | ||
| assert len(all_snapshot_ids) == 7 | ||
| assert len([s for s in all_snapshot_ids if "model_a" in s.name]) == 2 | ||
| assert len([s for s in all_snapshot_ids if "model_b" in s.name]) == 3 | ||
| assert len([s for s in all_snapshot_ids if "model_c" in s.name]) == 1 | ||
| assert len([s for s in all_snapshot_ids if "model_d" in s.name]) == 1 | ||
|
|
||
| # context just has the 4 latest | ||
| assert len(sqlmesh.snapshots) == 4 | ||
|
|
||
| # model a expiry should not have changed | ||
| model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n) | ||
| assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1) | ||
| assert to_ds(model_a_snapshot.updated_ts) == "2020-01-02" | ||
| assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-09" | ||
|
|
||
| # model b should now expire well after model a | ||
| model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n) | ||
| assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1) | ||
| assert to_ds(model_b_snapshot.updated_ts) == "2020-01-05" | ||
| assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-12" | ||
|
|
||
| # model c should expire at the same time as model b | ||
| model_c_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_c" in n) | ||
| assert to_ds(model_c_snapshot.updated_ts) == to_ds(model_b_snapshot.updated_ts) | ||
| assert to_ds(model_c_snapshot.expiration_ts) == to_ds(model_b_snapshot.expiration_ts) | ||
|
|
||
| # model d should expire at the same time as model b | ||
| model_d_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_d" in n) | ||
| assert to_ds(model_d_snapshot.updated_ts) == to_ds(model_b_snapshot.updated_ts) | ||
| assert to_ds(model_d_snapshot.expiration_ts) == to_ds(model_b_snapshot.expiration_ts) | ||
|
|
||
| # move forward to date where after model a has expired but before model b has expired | ||
| # invalidate dev to trigger cleanups | ||
| # run janitor | ||
| # - table model a is expired so will be cleaned up and this will cascade to view model b | ||
| # - view model b is not expired, but because it got cascaded to, this will cascade again to view model c | ||
| # - table model d is a not a view, so even though its parent view model b got dropped, it doesnt need to be dropped | ||
| with time_machine.travel("2020-01-10 00:00:00"): | ||
| sqlmesh = ctx.create_context( | ||
| path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False | ||
| ) | ||
|
|
||
| before_snapshot_ids = _all_snapshot_ids(sqlmesh) | ||
|
|
||
| sqlmesh.invalidate_environment("dev") | ||
| sqlmesh.run_janitor(ignore_ttl=False) | ||
|
|
||
| after_snapshot_ids = _all_snapshot_ids(sqlmesh) | ||
|
|
||
| assert len(before_snapshot_ids) != len(after_snapshot_ids) | ||
|
|
||
| # all that's left should be the two original snapshots that were in prod and model d | ||
| assert set(after_snapshot_ids) == set( | ||
| [ | ||
| model_a_prod_snapshot.snapshot_id, | ||
| model_b_prod_snapshot.snapshot_id, | ||
| model_d_snapshot.snapshot_id, | ||
| ] | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be filtered by a database rather than the application layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you mean? Line 242 adjusts the query to the database to include the views:
The DAG can only be built in the application layer, right? Since we have no way of knowing at the database level what views are affected, don't we have to select all views and filter them at the application layer?
We can't filter to expired views in the DB query because that's what caused this problem to begin with.