diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index e3feb1e14b..e31a04fe81 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -274,6 +274,7 @@ def __init__( deployability_index: t.Optional[DeployabilityIndex] = None, default_dialect: t.Optional[str] = None, default_catalog: t.Optional[str] = None, + is_restatement: t.Optional[bool] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, ): @@ -284,6 +285,7 @@ def __init__( self._default_dialect = default_dialect self._variables = variables or {} self._blueprint_variables = blueprint_variables or {} + self._is_restatement = is_restatement @property def default_dialect(self) -> t.Optional[str]: @@ -308,6 +310,10 @@ def gateway(self) -> t.Optional[str]: """Returns the gateway name.""" return self.var(c.GATEWAY) + @property + def is_restatement(self) -> t.Optional[bool]: + return self._is_restatement + def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: """Returns a variable value.""" return self._variables.get(var_name.lower(), default) @@ -328,6 +334,7 @@ def with_variables( self.deployability_index, self._default_dialect, self._default_catalog, + self._is_restatement, variables=variables, blueprint_variables=blueprint_variables, ) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 47e6a4260c..68c6404081 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -119,6 +119,7 @@ class EngineAdapter: MAX_IDENTIFIER_LENGTH: t.Optional[int] = None ATTACH_CORRELATION_ID = True SUPPORTS_QUERY_EXECUTION_TRACKING = False + SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = False def __init__( self, @@ -2927,6 +2928,9 @@ def _check_identifier_length(self, expression: exp.Expression) -> None: f"Identifier name '{name}' (length {name_length}) exceeds {self.dialect.capitalize()}'s max identifier limit of {self.MAX_IDENTIFIER_LENGTH} characters" ) + def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]: + raise NotImplementedError() + class EngineAdapterWithIndexSupport(EngineAdapter): SUPPORTS_INDEXES = True diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 0dfa2325e8..26abad9ebc 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -755,6 +755,28 @@ def table_exists(self, table_name: TableName) -> bool: except NotFound: return False + def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]: + from sqlmesh.utils.date import to_timestamp + + datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list) + for table_name in table_names: + table = exp.to_table(table_name) + datasets_to_tables[table.db].append(table.name) + + results = [] + + for dataset, tables in datasets_to_tables.items(): + query = ( + f"SELECT TIMESTAMP_MILLIS(last_modified_time) FROM `{dataset}.__TABLES__` WHERE " + ) + for i, table_name in enumerate(tables): + query += f"TABLE_ID = '{table_name}'" + if i < len(tables) - 1: + query += " OR " + results.extend(self.fetchall(query)) + + return [to_timestamp(row[0]) for row in results] + def _get_table(self, table_name: TableName) -> BigQueryTable: """ Returns a BigQueryTable object for the given table name. diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 9c27b45115..1554589779 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -54,6 +54,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi SUPPORTS_MANAGED_MODELS = True CURRENT_CATALOG_EXPRESSION = exp.func("current_database") SUPPORTS_CREATE_DROP_CATALOG = True + SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = True SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA", "TABLE"] SCHEMA_DIFFER_KWARGS = { "parameterized_type_defaults": { @@ -669,3 +670,18 @@ def close(self) -> t.Any: self._connection_pool.set_attribute(self.SNOWPARK, None) return super().close() + + def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]: + from sqlmesh.utils.date import to_timestamp + + num_tables = len(table_names) + + query = "SELECT LAST_ALTERED FROM INFORMATION_SCHEMA.TABLES WHERE" + for i, table_name in enumerate(table_names): + table = exp.to_table(table_name) + query += f"""(TABLE_NAME = '{table.name}' AND TABLE_SCHEMA = '{table.db}' AND TABLE_CATALOG = '{table.catalog}')""" + if i < num_tables - 1: + query += " OR " + + result = self.fetchall(query) + return [to_timestamp(row[0]) for row in result] diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 03ecb770bf..f2f432a97e 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -258,6 +258,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla allow_additive_snapshots=plan.allow_additive_models, selected_snapshot_ids=stage.selected_snapshot_ids, selected_models=plan.selected_models, + is_restatement=bool(plan.restatements), ) if errors: raise PlanError("Plan application failed.") diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index af4d72b165..7e27205fc6 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -251,7 +251,9 @@ def evaluate( **kwargs, ) - self.state_sync.add_interval(snapshot, start, end, is_dev=not is_deployable) + self.state_sync.add_interval( + snapshot, start, end, is_dev=not is_deployable, last_altered_ts=now_timestamp() + ) return audit_results def run( @@ -335,6 +337,7 @@ def batch_intervals( deployability_index: t.Optional[DeployabilityIndex], environment_naming_info: EnvironmentNamingInfo, dag: t.Optional[DAG[SnapshotId]] = None, + is_restatement: bool = False, ) -> t.Dict[Snapshot, Intervals]: dag = dag or snapshots_to_dag(merged_intervals) @@ -367,6 +370,7 @@ def batch_intervals( deployability_index, default_dialect=adapter.dialect, default_catalog=self.default_catalog, + is_restatement=is_restatement, ) intervals = self._check_ready_intervals( @@ -422,6 +426,7 @@ def run_merged_intervals( run_environment_statements: bool = False, audit_only: bool = False, auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}, + is_restatement: bool = False, ) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]: """Runs precomputed batches of missing intervals. @@ -455,9 +460,12 @@ def run_merged_intervals( snapshot_dag = full_dag.subdag(*selected_snapshot_ids_set) batched_intervals = self.batch_intervals( - merged_intervals, deployability_index, environment_naming_info, dag=snapshot_dag + merged_intervals, + deployability_index, + environment_naming_info, + dag=snapshot_dag, + is_restatement=is_restatement, ) - self.console.start_evaluation_progress( batched_intervals, environment_naming_info, @@ -956,6 +964,7 @@ def _check_ready_intervals( python_env=signals.python_env, dialect=snapshot.model.dialect, path=snapshot.model._path, + snapshot=snapshot, kwargs=kwargs, ) except SQLMeshError as e: diff --git a/sqlmesh/core/signal.py b/sqlmesh/core/signal.py index d9ee670922..52e6c59c8d 100644 --- a/sqlmesh/core/signal.py +++ b/sqlmesh/core/signal.py @@ -1,8 +1,14 @@ from __future__ import annotations - +import typing as t from sqlmesh.utils import UniqueKeyDict, registry_decorator +if t.TYPE_CHECKING: + from sqlmesh.core.context import ExecutionContext + from sqlmesh.core.snapshot.definition import Snapshot + from sqlmesh.utils.date import DatetimeRanges + from sqlmesh.core.snapshot.definition import DeployabilityIndex + class signal(registry_decorator): """Specifies a function which intervals are ready from a list of scheduled intervals. @@ -33,3 +39,39 @@ class signal(registry_decorator): SignalRegistry = UniqueKeyDict[str, signal] + + +@signal() +def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionContext) -> bool: + adapter = context.engine_adapter + if context.is_restatement or not adapter.SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS: + return True + + deployability_index = context.deployability_index or DeployabilityIndex.all_deployable() + + last_altered_ts = ( + snapshot.last_altered_ts + if deployability_index.is_deployable(snapshot) + else snapshot.dev_last_altered_ts + ) + if not last_altered_ts: + return True + + parent_snapshots = {context.snapshots[p.name] for p in snapshot.parents} + if len(parent_snapshots) != len(snapshot.node.depends_on) or not all( + p.is_external for p in parent_snapshots + ): + # The mismatch can happen if e.g an external model is not registered in the project + return True + + # Finding new data means that the upstream depedencies have been altered + # since the last time the model was evaluated + upstream_dep_has_new_data = any( + upstream_last_altered_ts > last_altered_ts + for upstream_last_altered_ts in adapter.get_table_last_modified_ts( + [p.name for p in parent_snapshots] + ) + ) + + # Returning true is a no-op, returning False nullifies the batch so the model will not be evaluated. + return upstream_dep_has_new_data diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 600d84fe83..23ab0b21db 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -185,6 +185,8 @@ class SnapshotIntervals(PydanticModel): intervals: Intervals = [] dev_intervals: Intervals = [] pending_restatement_intervals: Intervals = [] + last_altered_ts: t.Optional[int] = None + dev_last_altered_ts: t.Optional[int] = None @property def snapshot_id(self) -> t.Optional[SnapshotId]: @@ -205,6 +207,12 @@ def add_dev_interval(self, start: int, end: int) -> None: def add_pending_restatement_interval(self, start: int, end: int) -> None: self._add_interval(start, end, "pending_restatement_intervals") + def update_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None: + self._update_last_altered_ts(last_altered_ts, "last_altered_ts") + + def update_dev_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None: + self._update_last_altered_ts(last_altered_ts, "dev_last_altered_ts") + def remove_interval(self, start: int, end: int) -> None: self._remove_interval(start, end, "intervals") @@ -224,6 +232,13 @@ def _add_interval(self, start: int, end: int, interval_attr: str) -> None: target_intervals = merge_intervals([*target_intervals, (start, end)]) setattr(self, interval_attr, target_intervals) + def _update_last_altered_ts( + self, last_altered_ts: t.Optional[int], last_altered_attr: str + ) -> None: + if last_altered_ts: + existing_last_altered_ts = getattr(self, last_altered_attr) + setattr(self, last_altered_attr, max(existing_last_altered_ts or 0, last_altered_ts)) + def _remove_interval(self, start: int, end: int, interval_attr: str) -> None: target_intervals = getattr(self, interval_attr) target_intervals = remove_interval(target_intervals, start, end) @@ -713,6 +728,10 @@ class Snapshot(PydanticModel, SnapshotInfoMixin): dev_table_suffix: str = "dev" table_naming_convention: TableNamingConvention = TableNamingConvention.default forward_only: bool = False + # Physical table last modified timestamp, not to be confused with the "updated_ts" field + # which is for the snapshot record itself + last_altered_ts: t.Optional[int] = None + dev_last_altered_ts: t.Optional[int] = None @field_validator("ttl") @classmethod @@ -751,6 +770,7 @@ def hydrate_with_intervals_by_version( ) for interval in snapshot_intervals: snapshot.merge_intervals(interval) + result.append(snapshot) return result @@ -957,12 +977,20 @@ def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None: if not apply_effective_from or end <= effective_from_ts: self.add_interval(start, end) + if other.last_altered_ts: + self.last_altered_ts = max(self.last_altered_ts or 0, other.last_altered_ts) + if self.dev_version == other.dev_version: # Merge dev intervals if the dev versions match which would mean # that this and the other snapshot are pointing to the same dev table. for start, end in other.dev_intervals: self.add_interval(start, end, is_dev=True) + if other.dev_last_altered_ts: + self.dev_last_altered_ts = max( + self.dev_last_altered_ts or 0, other.dev_last_altered_ts + ) + self.pending_restatement_intervals = merge_intervals( [*self.pending_restatement_intervals, *other.pending_restatement_intervals] ) @@ -1081,6 +1109,7 @@ def check_ready_intervals( python_env=signals.python_env, dialect=self.model.dialect, path=self.model._path, + snapshot=self, kwargs=kwargs, ) except SQLMeshError as e: @@ -2421,6 +2450,7 @@ def check_ready_intervals( python_env: t.Dict[str, Executable], dialect: DialectType = None, path: t.Optional[Path] = None, + snapshot: t.Optional[Snapshot] = None, kwargs: t.Optional[t.Dict] = None, ) -> Intervals: checked_intervals: Intervals = [] @@ -2436,6 +2466,7 @@ def check_ready_intervals( provided_args=(batch,), provided_kwargs=(kwargs or {}), context=context, + snapshot=snapshot, ) except Exception as ex: raise SignalEvalError(format_evaluated_code_exception(ex, python_env)) diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 450d6f7408..2f8a68dd4a 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -496,6 +496,7 @@ def add_interval( start: TimeLike, end: TimeLike, is_dev: bool = False, + last_altered_ts: t.Optional[int] = None, ) -> None: """Add an interval to a snapshot and sync it to the store. @@ -504,6 +505,7 @@ def add_interval( start: The start of the interval to add. end: The end of the interval to add. is_dev: Indicates whether the given interval is being added while in development mode + last_altered_ts: The timestamp of the last modification of the physical table """ start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False, expand=False) if not snapshot.version: @@ -516,6 +518,8 @@ def add_interval( dev_version=snapshot.dev_version, intervals=intervals if not is_dev else [], dev_intervals=intervals if is_dev else [], + last_altered_ts=last_altered_ts if not is_dev else None, + dev_last_altered_ts=last_altered_ts if is_dev else None, ) self.add_snapshots_intervals([snapshot_intervals]) diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 29fc9f1740..3c23ef339c 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -381,8 +381,9 @@ def add_interval( start: TimeLike, end: TimeLike, is_dev: bool = False, + last_altered_ts: t.Optional[int] = None, ) -> None: - super().add_interval(snapshot, start, end, is_dev) + super().add_interval(snapshot, start, end, is_dev, last_altered_ts) @transactional() def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: diff --git a/sqlmesh/core/state_sync/db/interval.py b/sqlmesh/core/state_sync/db/interval.py index b15ad2d57b..8ccdc58fa0 100644 --- a/sqlmesh/core/state_sync/db/interval.py +++ b/sqlmesh/core/state_sync/db/interval.py @@ -60,6 +60,7 @@ def __init__( "is_removed": exp.DataType.build("boolean"), "is_compacted": exp.DataType.build("boolean"), "is_pending_restatement": exp.DataType.build("boolean"), + "last_altered_ts": exp.DataType.build("bigint"), } def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: @@ -215,13 +216,23 @@ def _push_snapshot_intervals( for start_ts, end_ts in snapshot.intervals: new_intervals.append( _interval_to_df( - snapshot, start_ts, end_ts, is_dev=False, is_compacted=is_compacted + snapshot, + start_ts, + end_ts, + is_dev=False, + is_compacted=is_compacted, + last_altered_ts=snapshot.last_altered_ts, ) ) for start_ts, end_ts in snapshot.dev_intervals: new_intervals.append( _interval_to_df( - snapshot, start_ts, end_ts, is_dev=True, is_compacted=is_compacted + snapshot, + start_ts, + end_ts, + is_dev=True, + is_compacted=is_compacted, + last_altered_ts=snapshot.dev_last_altered_ts, ) ) @@ -236,6 +247,7 @@ def _push_snapshot_intervals( is_dev=False, is_compacted=is_compacted, is_pending_restatement=True, + last_altered_ts=snapshot.last_altered_ts, ) ) @@ -284,6 +296,7 @@ def _get_snapshot_intervals( is_dev, is_removed, is_pending_restatement, + last_altered_ts, ) in rows: interval_ids.add(interval_id) merge_key = (name, version, dev_version, identifier) @@ -318,8 +331,10 @@ def _get_snapshot_intervals( else: if is_dev: intervals[merge_key].add_dev_interval(start, end) + intervals[merge_key].update_dev_last_altered_ts(last_altered_ts) else: intervals[merge_key].add_interval(start, end) + intervals[merge_key].update_last_altered_ts(last_altered_ts) # Remove all pending restatement intervals recorded before the current interval has been added intervals[ pending_restatement_interval_merge_key @@ -340,6 +355,7 @@ def _get_snapshot_intervals_query(self, uncompacted_only: bool) -> exp.Select: "is_dev", "is_removed", "is_pending_restatement", + "last_altered_ts", ) .from_(exp.to_table(self.intervals_table).as_("intervals")) .order_by( @@ -460,6 +476,7 @@ def _interval_to_df( is_removed: bool = False, is_compacted: bool = False, is_pending_restatement: bool = False, + last_altered_ts: t.Optional[int] = None, ) -> t.Dict[str, t.Any]: return { "id": random_id(), @@ -474,4 +491,5 @@ def _interval_to_df( "is_removed": is_removed, "is_compacted": is_compacted, "is_pending_restatement": is_pending_restatement, + "last_altered_ts": last_altered_ts, } diff --git a/sqlmesh/migrations/v0099_add_last_altered_to_intervals.py b/sqlmesh/migrations/v0099_add_last_altered_to_intervals.py new file mode 100644 index 0000000000..1a119a338d --- /dev/null +++ b/sqlmesh/migrations/v0099_add_last_altered_to_intervals.py @@ -0,0 +1,27 @@ +"""Add dev version to the intervals table.""" + +from sqlglot import exp + + +def migrate_schemas(state_sync, **kwargs): # type: ignore + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + intervals_table = "_intervals" + if schema: + intervals_table = f"{schema}.{intervals_table}" + + alter_table_exp = exp.Alter( + this=exp.to_table(intervals_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("last_altered_ts"), + kind=exp.DataType.build("BIGINT", dialect=engine_adapter.dialect), + ) + ], + ) + engine_adapter.execute(alter_table_exp) + + +def migrate_rows(state_sync, **kwargs): # type: ignore + pass diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 5a708e1e4c..5190d26e98 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -10,6 +10,11 @@ from unittest import mock from unittest.mock import patch import logging +from IPython.utils.capture import capture_output + + +import time_machine +from pytest_mock.plugin import MockerFixture import numpy as np # noqa: TID253 import pandas as pd # noqa: TID253 @@ -45,6 +50,7 @@ TEST_SCHEMA, wait_until, ) +from tests.utils.test_helpers import use_terminal_console DATA_TYPE = exp.DataType.Type VARCHAR_100 = exp.DataType.build("varchar(100)") @@ -3774,7 +3780,7 @@ def _set_config(gateway: str, config: Config) -> None: ] -def test_materialized_view_evaluation(ctx: TestContext, mocker: MockerFixture): +def test_materialized_view_evaluation(ctx: TestContext): adapter = ctx.engine_adapter dialect = ctx.dialect @@ -3834,3 +3840,153 @@ def _assert_mview_value(value: int): assert any("Replacing view" in call[0][0] for call in mock_logger.call_args_list) _assert_mview_value(value=2) + + +@use_terminal_console +def test_external_model_freshness(ctx: TestContext, mocker: MockerFixture, tmp_path: pathlib.Path): + adapter = ctx.engine_adapter + if not adapter.SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS: + pytest.skip("This test only runs for engines that support metadata-based freshness") + + def _assert_snapshot_last_altered_ts( + context: Context, + snapshot_id: str, + last_altered_ts: datetime, + dev_last_altered_ts: t.Optional[datetime] = None, + ): + from sqlmesh.utils.date import to_datetime + + snapshot = context.state_sync.get_snapshots([snapshot_id])[snapshot_id] + + assert to_datetime(snapshot.last_altered_ts).replace( + microsecond=0 + ) == last_altered_ts.replace(microsecond=0) + + if dev_last_altered_ts: + assert to_datetime(snapshot.dev_last_altered_ts).replace( + microsecond=0 + ) == dev_last_altered_ts.replace(microsecond=0) + + import sqlmesh + + spy = mocker.spy(sqlmesh.core.snapshot.evaluator.SnapshotEvaluator, "evaluate") + + def _assert_model_evaluation(lambda_func, was_evaluated, day_delta=0): + spy.reset_mock() + timestamp = now(minute_floor=False) + timedelta(days=day_delta) + with time_machine.travel(timestamp, tick=False): + with capture_output() as output: + plan_or_run_result = lambda_func() + + evaluate_function_called = spy.call_count == 1 + signal_was_checked = "Checking signals for" in output.stdout + + assert signal_was_checked + if was_evaluated: + assert "All ready" in output.stdout + assert evaluate_function_called + else: + assert "None ready" in output.stdout + assert not evaluate_function_called + + return timestamp, plan_or_run_result + + # Create & initialize schema + schema = ctx.add_test_suffix(TEST_SCHEMA) + ctx._schemas.append(schema) + adapter.create_schema(schema) + + # Create & initialize external models + external_table1 = f"{schema}.external_table1" + external_table2 = f"{schema}.external_table2" + + external_models_yaml = tmp_path / "external_models.yaml" + external_models_yaml.write_text(f""" +- name: {external_table1} + columns: + col1: int + +- name: {external_table2} + columns: + col2: int +""") + + adapter.execute( + f"CREATE TABLE {external_table1} AS (SELECT 1 AS col1)", quote_identifiers=False + ) + adapter.execute( + f"CREATE TABLE {external_table2} AS (SELECT 2 AS col2)", quote_identifiers=False + ) + + # Create model that depends on external models + model_name = f"{schema}.new_model" + model_path = tmp_path / "models" / "new_model.sql" + (tmp_path / "models").mkdir(parents=True, exist_ok=True) + model_path.write_text(f""" + MODEL ( + name {model_name}, + start '2024-01-01', + kind FULL, + signals ( + freshness(), + ) + ); + + SELECT col1 * col2 AS col FROM {external_table1}, {external_table2}; + """) + + # Initialize context + def _set_config(gateway: str, config: Config) -> None: + config.model_defaults.dialect = ctx.dialect + + context = ctx.create_context(path=tmp_path, config_mutator=_set_config) + + # Case 1: Model is evaluated for the first plan + prod_plan_ts, prod_plan = _assert_model_evaluation( + lambda: context.plan(auto_apply=True, no_prompts=True), was_evaluated=True + ) + + prod_snapshot_id = next(iter(prod_plan.context_diff.new_snapshots)) + _assert_snapshot_last_altered_ts(context, prod_snapshot_id, last_altered_ts=prod_plan_ts) + + # Case 2: Model is NOT evaluated on run if external models are not fresh + _assert_model_evaluation(lambda: context.run(), was_evaluated=False, day_delta=1) + + # Case 3: Differentiate last_altered_ts between snapshots with shared version + # For instance, creating a FORWARD_ONLY change in dev (reusing the version but creating a dev preview) should not cause + # any side effects to the prod snapshot's last_altered_ts hydration + model_path.write_text(model_path.read_text().replace("col1 * col2", "col1 + col2")) + context.load() + dev_plan_ts = now(minute_floor=False) + timedelta(days=2) + with time_machine.travel(dev_plan_ts, tick=False): + dev_plan = context.plan( + environment="dev", forward_only=True, auto_apply=True, no_prompts=True + ) + + context.state_sync.clear_cache() + dev_snapshot_id = next(iter(dev_plan.context_diff.new_snapshots)) + _assert_snapshot_last_altered_ts( + context, + dev_snapshot_id, + last_altered_ts=prod_plan_ts, + dev_last_altered_ts=dev_plan_ts, + ) + _assert_snapshot_last_altered_ts(context, prod_snapshot_id, last_altered_ts=prod_plan_ts) + + # Case 4: Model is evaluated on run if any external model is fresh + adapter.execute(f"INSERT INTO {external_table2} (col2) VALUES (3)", quote_identifiers=False) + _assert_model_evaluation(lambda: context.run(), was_evaluated=True, day_delta=2) + + # Case 5: Model is evaluated if changed (case 3) even if the external model is not fresh + model_path.write_text(model_path.read_text().replace("col1 + col2", "col1 * col2 * 5")) + context.load() + _assert_model_evaluation( + lambda: context.plan(auto_apply=True, no_prompts=True), was_evaluated=True, day_delta=3 + ) + + # Case 6: Model is evaluated on a restatement plan even if the external model is not fresh + _assert_model_evaluation( + lambda: context.plan(restate_models=[model_name], auto_apply=True, no_prompts=True), + was_evaluated=True, + day_delta=4, + )