|
7 | 7 |
|
8 | 8 | from sqlmesh.core.context import Context, ExecutionContext |
9 | 9 | from sqlmesh.core.environment import EnvironmentNamingInfo |
| 10 | +from sqlmesh.core.macros import RuntimeStage |
10 | 11 | from sqlmesh.core.model import load_sql_based_model |
11 | 12 | from sqlmesh.core.model.definition import AuditResult, SqlModel |
12 | 13 | from sqlmesh.core.model.kind import ( |
@@ -932,3 +933,89 @@ def test_scd_type_2_batch_size( |
932 | 933 |
|
933 | 934 | # Verify batches match expectations |
934 | 935 | assert batches == expected_batches |
| 936 | + |
| 937 | + |
| 938 | +def test_before_all_environment_statements_called_first(mocker: MockerFixture, make_snapshot): |
| 939 | + model = SqlModel( |
| 940 | + name="test.model_items", |
| 941 | + query=parse_one("SELECT id, ds FROM raw.items"), |
| 942 | + kind=IncrementalByTimeRangeKind(time_column=TimeColumn(column="ds")), |
| 943 | + ) |
| 944 | + snapshot = make_snapshot(model) |
| 945 | + |
| 946 | + # to track the order of calls |
| 947 | + call_order = [] |
| 948 | + |
| 949 | + mock_state_sync = mocker.MagicMock() |
| 950 | + mock_state_sync.get_environment_statements.return_value = [ |
| 951 | + ("CREATE TABLE IF NOT EXISTS test_table (id INT)", RuntimeStage.BEFORE_ALL) |
| 952 | + ] |
| 953 | + |
| 954 | + def record_get_environment_statements(*args, **kwargs): |
| 955 | + call_order.append("get_environment_statements") |
| 956 | + return mock_state_sync.get_environment_statements.return_value |
| 957 | + |
| 958 | + mock_state_sync.get_environment_statements.side_effect = record_get_environment_statements |
| 959 | + |
| 960 | + mock_snapshot_evaluator = mocker.MagicMock() |
| 961 | + mock_adapter = mocker.MagicMock() |
| 962 | + mock_snapshot_evaluator.adapter = mock_adapter |
| 963 | + |
| 964 | + def record_get_snapshots_to_create(*args, **kwargs): |
| 965 | + call_order.append("get_snapshots_to_create") |
| 966 | + return [] |
| 967 | + |
| 968 | + mock_snapshot_evaluator.get_snapshots_to_create.side_effect = record_get_snapshots_to_create |
| 969 | + |
| 970 | + mock_execute_env_statements = mocker.patch( |
| 971 | + "sqlmesh.core.scheduler.execute_environment_statements" |
| 972 | + ) |
| 973 | + |
| 974 | + def record_execute_environment_statements(*args, **kwargs): |
| 975 | + call_order.append("execute_environment_statements") |
| 976 | + |
| 977 | + mock_execute_env_statements.side_effect = record_execute_environment_statements |
| 978 | + |
| 979 | + scheduler = Scheduler( |
| 980 | + snapshots=[snapshot], |
| 981 | + snapshot_evaluator=mock_snapshot_evaluator, |
| 982 | + state_sync=mock_state_sync, |
| 983 | + default_catalog=None, |
| 984 | + ) |
| 985 | + merged_intervals = { |
| 986 | + snapshot: [ |
| 987 | + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), |
| 988 | + ], |
| 989 | + } |
| 990 | + |
| 991 | + deployability_index = DeployabilityIndex.create([snapshot]) |
| 992 | + environment_naming_info = EnvironmentNamingInfo(name="test_env") |
| 993 | + |
| 994 | + scheduler.run_merged_intervals( |
| 995 | + merged_intervals=merged_intervals, |
| 996 | + deployability_index=deployability_index, |
| 997 | + environment_naming_info=environment_naming_info, |
| 998 | + run_environment_statements=True, |
| 999 | + ) |
| 1000 | + |
| 1001 | + mock_state_sync.get_environment_statements.assert_called_once_with("test_env") |
| 1002 | + mock_snapshot_evaluator.get_snapshots_to_create.assert_called_once() |
| 1003 | + |
| 1004 | + # execute_environment_statements is called twice |
| 1005 | + assert mock_execute_env_statements.call_count == 2 |
| 1006 | + |
| 1007 | + # first for before all and second for after all |
| 1008 | + first_call = mock_execute_env_statements.call_args_list[0] |
| 1009 | + assert first_call.kwargs["runtime_stage"] == RuntimeStage.BEFORE_ALL |
| 1010 | + second_call = mock_execute_env_statements.call_args_list[1] |
| 1011 | + assert second_call.kwargs["runtime_stage"] == RuntimeStage.AFTER_ALL |
| 1012 | + |
| 1013 | + assert "get_environment_statements" in call_order |
| 1014 | + assert "execute_environment_statements" in call_order |
| 1015 | + assert "get_snapshots_to_create" in call_order |
| 1016 | + |
| 1017 | + # Verify the before all environment statements are called first before get_snapshots_to_create |
| 1018 | + env_statements_idx = call_order.index("get_environment_statements") |
| 1019 | + execute_env_idx = call_order.index("execute_environment_statements") |
| 1020 | + snapshots_to_create_idx = call_order.index("get_snapshots_to_create") |
| 1021 | + assert env_statements_idx < execute_env_idx < snapshots_to_create_idx |
0 commit comments