From 00a03c2900a144b25bd8d22100d67fa3b18a52d2 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Tue, 26 Aug 2025 18:02:13 +0300 Subject: [PATCH] Fix: Move before all statements execution before snapshot creation logic --- sqlmesh/core/scheduler.py | 22 ++++----- tests/core/test_scheduler.py | 87 ++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 11 deletions(-) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 2cbf769ea2..7a653877ae 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -457,17 +457,6 @@ def run_merged_intervals( audit_only=audit_only, ) - snapshots_to_create = { - s.snapshot_id - for s in self.snapshot_evaluator.get_snapshots_to_create( - selected_snapshots, deployability_index - ) - } - - dag = self._dag( - batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create - ) - if run_environment_statements: environment_statements = self.state_sync.get_environment_statements( environment_naming_info.name @@ -484,6 +473,17 @@ def run_merged_intervals( execution_time=execution_time, ) + snapshots_to_create = { + s.snapshot_id + for s in self.snapshot_evaluator.get_snapshots_to_create( + selected_snapshots, deployability_index + ) + } + + dag = self._dag( + batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create + ) + def run_node(node: SchedulingUnit) -> None: if circuit_breaker and circuit_breaker(): raise CircuitBreakerError() diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index b74aa3480e..b894f60f58 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -7,6 +7,7 @@ from sqlmesh.core.context import Context, ExecutionContext from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model import load_sql_based_model from sqlmesh.core.model.definition import AuditResult, SqlModel from sqlmesh.core.model.kind import ( @@ -932,3 +933,89 @@ def test_scd_type_2_batch_size( # Verify batches match expectations assert batches == expected_batches + + +def test_before_all_environment_statements_called_first(mocker: MockerFixture, make_snapshot): + model = SqlModel( + name="test.model_items", + query=parse_one("SELECT id, ds FROM raw.items"), + kind=IncrementalByTimeRangeKind(time_column=TimeColumn(column="ds")), + ) + snapshot = make_snapshot(model) + + # to track the order of calls + call_order = [] + + mock_state_sync = mocker.MagicMock() + mock_state_sync.get_environment_statements.return_value = [ + ("CREATE TABLE IF NOT EXISTS test_table (id INT)", RuntimeStage.BEFORE_ALL) + ] + + def record_get_environment_statements(*args, **kwargs): + call_order.append("get_environment_statements") + return mock_state_sync.get_environment_statements.return_value + + mock_state_sync.get_environment_statements.side_effect = record_get_environment_statements + + mock_snapshot_evaluator = mocker.MagicMock() + mock_adapter = mocker.MagicMock() + mock_snapshot_evaluator.adapter = mock_adapter + + def record_get_snapshots_to_create(*args, **kwargs): + call_order.append("get_snapshots_to_create") + return [] + + mock_snapshot_evaluator.get_snapshots_to_create.side_effect = record_get_snapshots_to_create + + mock_execute_env_statements = mocker.patch( + "sqlmesh.core.scheduler.execute_environment_statements" + ) + + def record_execute_environment_statements(*args, **kwargs): + call_order.append("execute_environment_statements") + + mock_execute_env_statements.side_effect = record_execute_environment_statements + + scheduler = Scheduler( + snapshots=[snapshot], + snapshot_evaluator=mock_snapshot_evaluator, + state_sync=mock_state_sync, + default_catalog=None, + ) + merged_intervals = { + snapshot: [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + ], + } + + deployability_index = DeployabilityIndex.create([snapshot]) + environment_naming_info = EnvironmentNamingInfo(name="test_env") + + scheduler.run_merged_intervals( + merged_intervals=merged_intervals, + deployability_index=deployability_index, + environment_naming_info=environment_naming_info, + run_environment_statements=True, + ) + + mock_state_sync.get_environment_statements.assert_called_once_with("test_env") + mock_snapshot_evaluator.get_snapshots_to_create.assert_called_once() + + # execute_environment_statements is called twice + assert mock_execute_env_statements.call_count == 2 + + # first for before all and second for after all + first_call = mock_execute_env_statements.call_args_list[0] + assert first_call.kwargs["runtime_stage"] == RuntimeStage.BEFORE_ALL + second_call = mock_execute_env_statements.call_args_list[1] + assert second_call.kwargs["runtime_stage"] == RuntimeStage.AFTER_ALL + + assert "get_environment_statements" in call_order + assert "execute_environment_statements" in call_order + assert "get_snapshots_to_create" in call_order + + # Verify the before all environment statements are called first before get_snapshots_to_create + env_statements_idx = call_order.index("get_environment_statements") + execute_env_idx = call_order.index("execute_environment_statements") + snapshots_to_create_idx = call_order.index("get_snapshots_to_create") + assert env_statements_idx < execute_env_idx < snapshots_to_create_idx