Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 183 additions & 1 deletion sqlmesh/core/plan/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from __future__ import annotations
import typing as t
import logging
from dataclasses import dataclass, field

from sqlmesh.core.snapshot import Snapshot
from sqlmesh.core.state_sync import StateReader
from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotIdAndVersion, SnapshotNameVersion
from sqlmesh.core.snapshot.definition import Interval
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import now_timestamp

logger = logging.getLogger(__name__)


def should_force_rebuild(old: Snapshot, new: Snapshot) -> bool:
Expand All @@ -27,3 +36,176 @@ def is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool:
# If the partitioning hasn't changed, then we don't need to rebuild
return False
return True


@dataclass
class SnapshotIntervalClearRequest:
# affected snapshot
snapshot: SnapshotIdAndVersion

# which interval to clear
interval: Interval

# which environments this snapshot is currently promoted
# note that this can be empty if the snapshot exists because its ttl has not expired
# but it is not part of any particular environment
environment_names: t.Set[str] = field(default_factory=set)

@property
def snapshot_id(self) -> SnapshotId:
return self.snapshot.snapshot_id

@property
def sorted_environment_names(self) -> t.List[str]:
return list(sorted(self.environment_names))


def identify_restatement_intervals_across_snapshot_versions(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This got moved to common because I plan to call it outside the evaluator in an upcoming PR that improves the console output as well as the --explain output around restatements

state_reader: StateReader,
prod_restatements: t.Dict[str, Interval],
disable_restatement_models: t.Set[str],
loaded_snapshots: t.Dict[SnapshotId, Snapshot],
current_ts: t.Optional[int] = None,
) -> t.Dict[SnapshotId, SnapshotIntervalClearRequest]:
"""
Given a map of snapshot names + intervals to restate in prod:
- Look up matching snapshots (match based on name - regardless of version, to get all versions)
- For each match, also match downstream snapshots in each dev environment while filtering out models that have restatement disabled
- Return a list of all snapshots that are affected + the interval that needs to be cleared for each

The goal here is to produce a list of intervals to invalidate across all dev snapshots so that a subsequent plan or
cadence run in those environments causes the intervals to be repopulated.
"""
if not prod_restatements:
return {}

# Although :loaded_snapshots is sourced from RestatementStage.all_snapshots, since the only time we ever need
# to clear intervals across all environments is for prod, the :loaded_snapshots here are always from prod
prod_name_versions: t.Set[SnapshotNameVersion] = {
s.name_version for s in loaded_snapshots.values()
}

snapshot_intervals_to_clear: t.Dict[SnapshotId, SnapshotIntervalClearRequest] = {}

for env_summary in state_reader.get_environments_summary():
# Fetch the full environment object one at a time to avoid loading all environments into memory at once
env = state_reader.get_environment(env_summary.name)
if not env:
logger.warning("Environment %s not found", env_summary.name)
continue

snapshots_by_name = {s.name: s.table_info for s in env.snapshots}

# We dont just restate matching snapshots, we also have to restate anything downstream of them
# so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev
env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots})

for restate_snapshot_name, interval in prod_restatements.items():
if restate_snapshot_name not in snapshots_by_name:
# snapshot is not promoted in this environment
continue

affected_snapshot_names = [
x
for x in ([restate_snapshot_name] + env_dag.downstream(restate_snapshot_name))
if x not in disable_restatement_models
]

for affected_snapshot_name in affected_snapshot_names:
affected_snapshot = snapshots_by_name[affected_snapshot_name]

# Don't clear intervals for a dev snapshot if it shares the same physical version with prod.
# Otherwise, prod will be affected by what should be a dev operation
if affected_snapshot.name_version in prod_name_versions:
continue

clear_request = snapshot_intervals_to_clear.get(affected_snapshot.snapshot_id)
if not clear_request:
clear_request = SnapshotIntervalClearRequest(
snapshot=affected_snapshot.id_and_version, interval=interval
)
snapshot_intervals_to_clear[affected_snapshot.snapshot_id] = clear_request

clear_request.environment_names |= set([env.name])

# snapshot_intervals_to_clear now contains the entire hierarchy of affected snapshots based
# on building the DAG for each environment and including downstream snapshots
# but, what if there are affected snapshots that arent part of any environment?
unique_snapshot_names = set(snapshot_id.name for snapshot_id in snapshot_intervals_to_clear)

current_ts = current_ts or now_timestamp()
all_matching_non_prod_snapshots = {
s.snapshot_id: s
for s in state_reader.get_snapshots_by_names(
snapshot_names=unique_snapshot_names, current_ts=current_ts, exclude_expired=True
)
# Don't clear intervals for a snapshot if it shares the same physical version with prod.
# Otherwise, prod will be affected by what should be a dev operation
if s.name_version not in prod_name_versions
}

# identify the ones that we havent picked up yet, which are the ones that dont exist in any environment
if remaining_snapshot_ids := set(all_matching_non_prod_snapshots).difference(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's worth noting that there's still a slight risk of missing a relevant dependency. For example, if there's a snapshot A' that is not promoted anywhere which has a downstream dependency D which is a model that has been removed from all existing environments, we won't drop intervals for D like we're suppose to because its name won't show up here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's true. We only determine the dependency tree based on promoted snapshots.

I couldn't think of a sane way to push dependency resolution to the db layer, and reading all snapshots to figure this out is a non-starter, so you're right in that this PR improves rather than fully eliminates the current situation.

snapshot_intervals_to_clear
):
# these snapshot id's exist in isolation and may be related to a downstream dependency of the :prod_restatements,
# rather than directly related, so we can't simply look up the interval to clear based on :prod_restatements.
# To figure out the interval that should be cleared, we can match to the existing list based on name
# and conservatively take the widest interval that shows up
snapshot_name_to_widest_interval: t.Dict[str, Interval] = {}
for s_id, clear_request in snapshot_intervals_to_clear.items():
current_start, current_end = snapshot_name_to_widest_interval.get(
s_id.name, clear_request.interval
)
next_start, next_end = clear_request.interval

next_start = min(current_start, next_start)
next_end = max(current_end, next_end)

snapshot_name_to_widest_interval[s_id.name] = (next_start, next_end)

for remaining_snapshot_id in remaining_snapshot_ids:
remaining_snapshot = all_matching_non_prod_snapshots[remaining_snapshot_id]
snapshot_intervals_to_clear[remaining_snapshot_id] = SnapshotIntervalClearRequest(
snapshot=remaining_snapshot,
interval=snapshot_name_to_widest_interval[remaining_snapshot_id.name],
)

# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
# so we only do it if necessary
full_history_restatement_snapshot_ids = [
# FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self
# however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state,
# is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present)
# So for now, these are not considered
s_id
for s_id, s in snapshot_intervals_to_clear.items()
if s.snapshot.full_history_restatement_only
]
if full_history_restatement_snapshot_ids:
# only load full snapshot records that we havent already loaded
additional_snapshots = state_reader.get_snapshots(
[
s.snapshot_id
for s in full_history_restatement_snapshot_ids
if s.snapshot_id not in loaded_snapshots
]
)

all_snapshots = loaded_snapshots | additional_snapshots

for full_snapshot_id in full_history_restatement_snapshot_ids:
full_snapshot = all_snapshots[full_snapshot_id]
intervals_to_clear = snapshot_intervals_to_clear[full_snapshot_id]

original_start, original_end = intervals_to_clear.interval

# get_removal_interval() widens intervals if necessary
new_interval = full_snapshot.get_removal_interval(
start=original_start, end=original_end
)

intervals_to_clear.interval = new_interval

return snapshot_intervals_to_clear
118 changes: 16 additions & 102 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sqlmesh.core.console import Console, get_console
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
from sqlmesh.core.macros import RuntimeStage
from sqlmesh.core.snapshot.definition import Interval, to_view_mapping
from sqlmesh.core.snapshot.definition import to_view_mapping
from sqlmesh.core.plan import stages
from sqlmesh.core.plan.definition import EvaluatablePlan
from sqlmesh.core.scheduler import Scheduler
Expand All @@ -33,17 +33,15 @@
SnapshotIntervals,
SnapshotId,
SnapshotInfoLike,
SnapshotTableInfo,
SnapshotCreationFailedError,
SnapshotNameVersion,
)
from sqlmesh.utils import to_snake_case
from sqlmesh.core.state_sync import StateSync
from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions
from sqlmesh.utils import CorrelationId
from sqlmesh.utils.concurrency import NodeExecutionFailedError
from sqlmesh.utils.errors import PlanError, SQLMeshError
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import now
from sqlmesh.utils.date import now, to_timestamp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -289,7 +287,9 @@ def visit_audit_only_run_stage(
def visit_restatement_stage(
self, stage: stages.RestatementStage, plan: EvaluatablePlan
) -> None:
snapshot_intervals_to_restate = {(s, i) for s, i in stage.snapshot_intervals.items()}
snapshot_intervals_to_restate = {
(s.id_and_version, i) for s, i in stage.snapshot_intervals.items()
}

# Restating intervals on prod plans should mean that the intervals are cleared across
# all environments, not just the version currently in prod
Expand All @@ -298,11 +298,16 @@ def visit_restatement_stage(
#
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
snapshot_intervals_to_restate.update(
self._restatement_intervals_across_all_environments(
prod_restatements=plan.restatements,
disable_restatement_models=plan.disabled_restatement_models,
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
)
{
(s.snapshot, s.interval)
for s in identify_restatement_intervals_across_snapshot_versions(
state_reader=self.state_sync,
prod_restatements=plan.restatements,
disable_restatement_models=plan.disabled_restatement_models,
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
current_ts=to_timestamp(plan.execution_time or now()),
).values()
}
)

self.state_sync.remove_intervals(
Expand Down Expand Up @@ -422,97 +427,6 @@ def _demote_snapshots(
on_complete=on_complete,
)

def _restatement_intervals_across_all_environments(
self,
prod_restatements: t.Dict[str, Interval],
disable_restatement_models: t.Set[str],
loaded_snapshots: t.Dict[SnapshotId, Snapshot],
) -> t.Set[t.Tuple[SnapshotTableInfo, Interval]]:
"""
Given a map of snapshot names + intervals to restate in prod:
- Look up matching snapshots across all environments (match based on name - regardless of version)
- For each match, also match downstream snapshots while filtering out models that have restatement disabled
- Return all matches mapped to the intervals of the prod snapshot being restated

The goal here is to produce a list of intervals to invalidate across all environments so that a cadence
run in those environments causes the intervals to be repopulated
"""
if not prod_restatements:
return set()

prod_name_versions: t.Set[SnapshotNameVersion] = {
s.name_version for s in loaded_snapshots.values()
}

snapshots_to_restate: t.Dict[SnapshotId, t.Tuple[SnapshotTableInfo, Interval]] = {}

for env_summary in self.state_sync.get_environments_summary():
# Fetch the full environment object one at a time to avoid loading all environments into memory at once
env = self.state_sync.get_environment(env_summary.name)
if not env:
logger.warning("Environment %s not found", env_summary.name)
continue

keyed_snapshots = {s.name: s.table_info for s in env.snapshots}

# We dont just restate matching snapshots, we also have to restate anything downstream of them
# so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev
env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots})

for restatement, intervals in prod_restatements.items():
if restatement not in keyed_snapshots:
continue
affected_snapshot_names = [
x
for x in ([restatement] + env_dag.downstream(restatement))
if x not in disable_restatement_models
]
snapshots_to_restate.update(
{
keyed_snapshots[a].snapshot_id: (keyed_snapshots[a], intervals)
for a in affected_snapshot_names
# Don't restate a snapshot if it shares the version with a snapshot in prod
if keyed_snapshots[a].name_version not in prod_name_versions
}
)

# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
# so we only do it if necessary
full_history_restatement_snapshot_ids = [
# FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self
# however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state,
# is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present)
# So for now, these are not considered
s_id
for s_id, s in snapshots_to_restate.items()
if s[0].full_history_restatement_only
]
if full_history_restatement_snapshot_ids:
# only load full snapshot records that we havent already loaded
additional_snapshots = self.state_sync.get_snapshots(
[
s.snapshot_id
for s in full_history_restatement_snapshot_ids
if s.snapshot_id not in loaded_snapshots
]
)

all_snapshots = loaded_snapshots | additional_snapshots

for full_snapshot_id in full_history_restatement_snapshot_ids:
full_snapshot = all_snapshots[full_snapshot_id]
_, original_intervals = snapshots_to_restate[full_snapshot_id]
original_start, original_end = original_intervals

# get_removal_interval() widens intervals if necessary
new_intervals = full_snapshot.get_removal_interval(
start=original_start, end=original_end
)
snapshots_to_restate[full_snapshot_id] = (full_snapshot.table_info, new_intervals)

return set(snapshots_to_restate.values())

def _update_intervals_for_new_snapshots(self, snapshots: t.Collection[Snapshot]) -> None:
snapshots_intervals: t.List[SnapshotIntervals] = []
for snapshot in snapshots:
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/snapshot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SnapshotId as SnapshotId,
SnapshotIdBatch as SnapshotIdBatch,
SnapshotIdLike as SnapshotIdLike,
SnapshotIdAndVersionLike as SnapshotIdAndVersionLike,
SnapshotInfoLike as SnapshotInfoLike,
SnapshotIntervals as SnapshotIntervals,
SnapshotNameVersion as SnapshotNameVersion,
Expand Down
Loading