Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,13 @@ def run_merged_intervals(
if not selected_snapshots:
selected_snapshots = list(merged_intervals)

snapshot_dag = snapshots_to_dag(selected_snapshots)
# Build the full DAG from all snapshots to preserve transitive dependencies
full_dag = snapshots_to_dag(self.snapshots.values())

# Create a subdag that includes the selected snapshots and all their upstream dependencies
# This ensures that transitive dependencies are preserved even when intermediate nodes are not selected
selected_snapshot_ids_set = {s.snapshot_id for s in selected_snapshots}
snapshot_dag = full_dag.subdag(*selected_snapshot_ids_set)

batched_intervals = self.batch_intervals(
merged_intervals, deployability_index, environment_naming_info, dag=snapshot_dag
Expand Down Expand Up @@ -642,20 +648,11 @@ def _dag(
upstream_dependencies: t.List[SchedulingUnit] = []

for p_sid in snapshot.parents:
if p_sid in self.snapshots:
p_intervals = intervals_per_snapshot.get(p_sid.name, [])

if not p_intervals and p_sid in original_snapshots_to_create:
upstream_dependencies.append(CreateNode(snapshot_name=p_sid.name))
elif len(p_intervals) > 1:
upstream_dependencies.append(DummyNode(snapshot_name=p_sid.name))
else:
for i, interval in enumerate(p_intervals):
upstream_dependencies.append(
EvaluateNode(
snapshot_name=p_sid.name, interval=interval, batch_index=i
)
)
upstream_dependencies.extend(
self._find_upstream_dependencies(
p_sid, intervals_per_snapshot, original_snapshots_to_create
)
)

batch_concurrency = snapshot.node.batch_concurrency
batch_size = snapshot.node.batch_size
Expand Down Expand Up @@ -699,6 +696,36 @@ def _dag(
)
return dag

def _find_upstream_dependencies(
self,
parent_sid: SnapshotId,
intervals_per_snapshot: t.Dict[str, Intervals],
snapshots_to_create: t.Set[SnapshotId],
) -> t.List[SchedulingUnit]:
if parent_sid not in self.snapshots:
return []

p_intervals = intervals_per_snapshot.get(parent_sid.name, [])

if p_intervals:
if len(p_intervals) > 1:
return [DummyNode(snapshot_name=parent_sid.name)]
interval = p_intervals[0]
return [EvaluateNode(snapshot_name=parent_sid.name, interval=interval, batch_index=0)]
if parent_sid in snapshots_to_create:
return [CreateNode(snapshot_name=parent_sid.name)]
# This snapshot has no intervals and doesn't need creation which means
# that it can be a transitive dependency
transitive_deps: t.List[SchedulingUnit] = []
parent_snapshot = self.snapshots[parent_sid]
for grandparent_sid in parent_snapshot.parents:
transitive_deps.extend(
self._find_upstream_dependencies(
grandparent_sid, intervals_per_snapshot, snapshots_to_create
)
)
return transitive_deps

def _run_or_audit(
self,
environment: str | EnvironmentNamingInfo,
Expand Down
77 changes: 77 additions & 0 deletions tests/core/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,83 @@ def test_plan_ignore_cron(
)


@time_machine.travel("2023-01-08 15:00:00 UTC")
def test_run_respects_excluded_transitive_dependencies(init_and_plan_context: t.Callable):
context, _ = init_and_plan_context("examples/sushi")

# Graph: C <- B <- A
# B is a transitive dependency linking A and C
# Note that the alphabetical ordering of the model names is intentional and helps
# surface the problem
expressions_a = d.parse(
f"""
MODEL (
name memory.sushi.test_model_c,
kind FULL,
allow_partials true,
cron '@hourly',
);

SELECT @execution_ts AS execution_ts
"""
)
model_c = load_sql_based_model(expressions_a)
context.upsert_model(model_c)

# A VIEW model with no partials allowed and a daily cron instead of hourly.
expressions_b = d.parse(
f"""
MODEL (
name memory.sushi.test_model_b,
kind VIEW,
allow_partials false,
cron '@daily',
);

SELECT * FROM memory.sushi.test_model_c
"""
)
model_b = load_sql_based_model(expressions_b)
context.upsert_model(model_b)

expressions_a = d.parse(
f"""
MODEL (
name memory.sushi.test_model_a,
kind FULL,
allow_partials true,
cron '@hourly',
);

SELECT * FROM memory.sushi.test_model_b
"""
)
model_a = load_sql_based_model(expressions_a)
context.upsert_model(model_a)

context.plan("prod", skip_tests=True, auto_apply=True, no_prompts=True)
assert (
context.fetchdf("SELECT execution_ts FROM memory.sushi.test_model_c")["execution_ts"].iloc[
0
]
== "2023-01-08 15:00:00"
)

with time_machine.travel("2023-01-08 17:00:00 UTC", tick=False):
context.run(
"prod",
select_models=["*test_model_c", "*test_model_a"],
no_auto_upstream=True,
ignore_cron=True,
)
assert (
context.fetchdf("SELECT execution_ts FROM memory.sushi.test_model_a")[
"execution_ts"
].iloc[0]
== "2023-01-08 17:00:00"
)


@time_machine.travel("2023-01-08 15:00:00 UTC")
def test_run_with_select_models_no_auto_upstream(
init_and_plan_context: t.Callable,
Expand Down
107 changes: 107 additions & 0 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SnapshotEvaluator,
SnapshotChangeCategory,
DeployabilityIndex,
snapshots_to_dag,
)
from sqlmesh.utils.date import to_datetime, to_timestamp, DatetimeRanges, TimeLike
from sqlmesh.utils.errors import CircuitBreakerError, NodeAuditsErrors
Expand Down Expand Up @@ -1019,3 +1020,109 @@ def record_execute_environment_statements(*args, **kwargs):
execute_env_idx = call_order.index("execute_environment_statements")
snapshots_to_create_idx = call_order.index("get_snapshots_to_create")
assert env_statements_idx < execute_env_idx < snapshots_to_create_idx


def test_dag_transitive_deps(mocker: MockerFixture, make_snapshot):
# Create a simple dependency chain: A <- B <- C
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id")))
snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT * FROM a")))
snapshot_c = make_snapshot(SqlModel(name="c", query=parse_one("SELECT * FROM b")))

snapshot_b = snapshot_b.model_copy(update={"parents": (snapshot_a.snapshot_id,)})
snapshot_c = snapshot_c.model_copy(update={"parents": (snapshot_b.snapshot_id,)})

snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)
snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING)
snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING)

scheduler = Scheduler(
snapshots=[snapshot_a, snapshot_b, snapshot_c],
snapshot_evaluator=mocker.Mock(),
state_sync=mocker.Mock(),
default_catalog=None,
)

# Test scenario: select only A and C (skip B)
merged_intervals = {
snapshot_a: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
snapshot_c: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
}

deployability_index = DeployabilityIndex.create([snapshot_a, snapshot_b, snapshot_c])

full_dag = snapshots_to_dag([snapshot_a, snapshot_b, snapshot_c])

dag = scheduler._dag(merged_intervals, snapshot_dag=full_dag)
assert dag.graph == {
EvaluateNode(
snapshot_name='"a"',
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
batch_index=0,
): set(),
EvaluateNode(
snapshot_name='"c"',
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
batch_index=0,
): {
EvaluateNode(
snapshot_name='"a"',
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
batch_index=0,
)
},
}


def test_dag_multiple_chain_transitive_deps(mocker: MockerFixture, make_snapshot):
# Create a more complex dependency graph:
# A <- B <- D <- E
# A <- C <- D <- E
# Select A and E only
snapshots = {}
for name in ["a", "b", "c", "d", "e"]:
snapshots[name] = make_snapshot(SqlModel(name=name, query=parse_one("SELECT 1 as id")))
snapshots[name].categorize_as(SnapshotChangeCategory.BREAKING)

# Set up dependencies
snapshots["b"] = snapshots["b"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)})
snapshots["c"] = snapshots["c"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)})
snapshots["d"] = snapshots["d"].model_copy(
update={"parents": (snapshots["b"].snapshot_id, snapshots["c"].snapshot_id)}
)
snapshots["e"] = snapshots["e"].model_copy(update={"parents": (snapshots["d"].snapshot_id,)})

scheduler = Scheduler(
snapshots=list(snapshots.values()),
snapshot_evaluator=mocker.Mock(),
state_sync=mocker.Mock(),
default_catalog=None,
)

# Only provide intervals for A and E
batched_intervals = {
snapshots["a"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
snapshots["e"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))],
}

# Create subdag including transitive dependencies
full_dag = snapshots_to_dag(snapshots.values())

dag = scheduler._dag(batched_intervals, snapshot_dag=full_dag)
assert dag.graph == {
EvaluateNode(
snapshot_name='"a"',
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
batch_index=0,
): set(),
EvaluateNode(
snapshot_name='"e"',
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
batch_index=0,
): {
EvaluateNode(
snapshot_name='"a"',
interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")),
batch_index=0,
)
},
}