Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,17 +457,6 @@ def run_merged_intervals(
audit_only=audit_only,
)

snapshots_to_create = {
s.snapshot_id
for s in self.snapshot_evaluator.get_snapshots_to_create(
selected_snapshots, deployability_index
)
}

dag = self._dag(
batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create
)

if run_environment_statements:
environment_statements = self.state_sync.get_environment_statements(
environment_naming_info.name
Expand All @@ -484,6 +473,17 @@ def run_merged_intervals(
execution_time=execution_time,
)

snapshots_to_create = {
s.snapshot_id
for s in self.snapshot_evaluator.get_snapshots_to_create(
selected_snapshots, deployability_index
)
}

dag = self._dag(
batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create
)

def run_node(node: SchedulingUnit) -> None:
if circuit_breaker and circuit_breaker():
raise CircuitBreakerError()
Expand Down
87 changes: 87 additions & 0 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from sqlmesh.core.context import Context, ExecutionContext
from sqlmesh.core.environment import EnvironmentNamingInfo
from sqlmesh.core.macros import RuntimeStage
from sqlmesh.core.model import load_sql_based_model
from sqlmesh.core.model.definition import AuditResult, SqlModel
from sqlmesh.core.model.kind import (
Expand Down Expand Up @@ -932,3 +933,89 @@ def test_scd_type_2_batch_size(

# Verify batches match expectations
assert batches == expected_batches


def test_before_all_environment_statements_called_first(mocker: MockerFixture, make_snapshot):
model = SqlModel(
name="test.model_items",
query=parse_one("SELECT id, ds FROM raw.items"),
kind=IncrementalByTimeRangeKind(time_column=TimeColumn(column="ds")),
)
snapshot = make_snapshot(model)

# to track the order of calls
call_order = []

mock_state_sync = mocker.MagicMock()
mock_state_sync.get_environment_statements.return_value = [
("CREATE TABLE IF NOT EXISTS test_table (id INT)", RuntimeStage.BEFORE_ALL)
]

def record_get_environment_statements(*args, **kwargs):
call_order.append("get_environment_statements")
return mock_state_sync.get_environment_statements.return_value

mock_state_sync.get_environment_statements.side_effect = record_get_environment_statements

mock_snapshot_evaluator = mocker.MagicMock()
mock_adapter = mocker.MagicMock()
mock_snapshot_evaluator.adapter = mock_adapter

def record_get_snapshots_to_create(*args, **kwargs):
call_order.append("get_snapshots_to_create")
return []

mock_snapshot_evaluator.get_snapshots_to_create.side_effect = record_get_snapshots_to_create

mock_execute_env_statements = mocker.patch(
"sqlmesh.core.scheduler.execute_environment_statements"
)

def record_execute_environment_statements(*args, **kwargs):
call_order.append("execute_environment_statements")

mock_execute_env_statements.side_effect = record_execute_environment_statements

scheduler = Scheduler(
snapshots=[snapshot],
snapshot_evaluator=mock_snapshot_evaluator,
state_sync=mock_state_sync,
default_catalog=None,
)
merged_intervals = {
snapshot: [
(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
],
}

deployability_index = DeployabilityIndex.create([snapshot])
environment_naming_info = EnvironmentNamingInfo(name="test_env")

scheduler.run_merged_intervals(
merged_intervals=merged_intervals,
deployability_index=deployability_index,
environment_naming_info=environment_naming_info,
run_environment_statements=True,
)

mock_state_sync.get_environment_statements.assert_called_once_with("test_env")
mock_snapshot_evaluator.get_snapshots_to_create.assert_called_once()

# execute_environment_statements is called twice
assert mock_execute_env_statements.call_count == 2

# first for before all and second for after all
first_call = mock_execute_env_statements.call_args_list[0]
assert first_call.kwargs["runtime_stage"] == RuntimeStage.BEFORE_ALL
second_call = mock_execute_env_statements.call_args_list[1]
assert second_call.kwargs["runtime_stage"] == RuntimeStage.AFTER_ALL

assert "get_environment_statements" in call_order
assert "execute_environment_statements" in call_order
assert "get_snapshots_to_create" in call_order

# Verify the before all environment statements are called first before get_snapshots_to_create
env_statements_idx = call_order.index("get_environment_statements")
execute_env_idx = call_order.index("execute_environment_statements")
snapshots_to_create_idx = call_order.index("get_snapshots_to_create")
assert env_statements_idx < execute_env_idx < snapshots_to_create_idx