Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlmesh/core/state_sync/db/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def finalize(self, environment: Environment) -> None:
def unpause_snapshots(
self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
) -> None:
self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt, self.interval_state)
self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt)

def invalidate_environment(self, name: str, protect_prod: bool = True) -> None:
self.environment_state.invalidate_environment(name, protect_prod)
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/core/state_sync/db/migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _migrate_environment_rows(
if updated_prod_environment:
try:
self.snapshot_state.unpause_snapshots(
updated_prod_environment.snapshots, now_timestamp(), self.interval_state
updated_prod_environment.snapshots, now_timestamp()
)
except Exception:
logger.warning("Failed to unpause migrated snapshots", exc_info=True)
Expand Down
147 changes: 60 additions & 87 deletions sqlmesh/core/state_sync/db/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from sqlmesh.core.engine_adapter import EngineAdapter
from sqlmesh.core.state_sync.db.utils import (
snapshot_name_filter,
snapshot_name_version_filter,
snapshot_id_filter,
fetchone,
Expand All @@ -32,15 +33,13 @@
SnapshotChangeCategory,
)
from sqlmesh.utils.migration import index_text_type, blob_text_type
from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp
from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp
from sqlmesh.utils.pydantic import PydanticModel
from sqlmesh.utils import unique

if t.TYPE_CHECKING:
import pandas as pd

from sqlmesh.core.state_sync.db.interval import IntervalState


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,6 +69,7 @@ def __init__(
"unpaused_ts": exp.DataType.build("bigint"),
"ttl_ms": exp.DataType.build("bigint"),
"unrestorable": exp.DataType.build("boolean"),
"forward_only": exp.DataType.build("boolean"),
}

self._auto_restatement_columns_to_types = {
Expand Down Expand Up @@ -112,84 +112,52 @@ def unpause_snapshots(
self,
snapshots: t.Collection[SnapshotInfoLike],
unpaused_dt: TimeLike,
interval_state: IntervalState,
) -> None:
"""Unpauses given snapshots while pausing all other snapshots that share the same version.

Args:
snapshots: The snapshots to unpause.
unpaused_dt: The timestamp to unpause the snapshots at.
interval_state: The interval state to use to remove intervals when needed.
"""
current_ts = now()

target_snapshot_ids = {s.snapshot_id for s in snapshots}
same_version_snapshots = self._get_snapshots_with_same_version(
snapshots, lock_for_update=True
unrestorable_snapshots_by_forward_only: t.Dict[bool, t.List[SnapshotNameVersion]] = (
defaultdict(list)
)
target_snapshots_by_version = {
(s.name, s.version): s
for s in same_version_snapshots
if s.snapshot_id in target_snapshot_ids
}

unpaused_snapshots: t.Dict[int, t.List[SnapshotId]] = defaultdict(list)
paused_snapshots: t.List[SnapshotId] = []
unrestorable_snapshots: t.List[SnapshotId] = []

for snapshot in same_version_snapshots:
is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids
if is_target_snapshot and not snapshot.unpaused_ts:
logger.info("Unpausing snapshot %s", snapshot.snapshot_id)
snapshot.set_unpaused_ts(unpaused_dt)
assert snapshot.unpaused_ts is not None
unpaused_snapshots[snapshot.unpaused_ts].append(snapshot.snapshot_id)
elif not is_target_snapshot:
target_snapshot = target_snapshots_by_version[(snapshot.name, snapshot.version)]
if (
target_snapshot.normalized_effective_from_ts
and not target_snapshot.disable_restatement
):
# Making sure that there are no overlapping intervals.
effective_from_ts = target_snapshot.normalized_effective_from_ts
logger.info(
"Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s",
target_snapshot.effective_from,
snapshot.snapshot_id,
target_snapshot.snapshot_id,
)
full_snapshot = snapshot.full_snapshot
interval_state.remove_intervals(
[
(
full_snapshot,
full_snapshot.get_removal_interval(effective_from_ts, current_ts),
)
]
)

if snapshot.unpaused_ts:
logger.info("Pausing snapshot %s", snapshot.snapshot_id)
snapshot.set_unpaused_ts(None)
paused_snapshots.append(snapshot.snapshot_id)
for snapshot in snapshots:
# We need to mark all other snapshots that have forward-only opposite to the target snapshot as unrestorable
unrestorable_snapshots_by_forward_only[not snapshot.is_forward_only].append(
snapshot.name_version
)

if not snapshot.unrestorable and (
(target_snapshot.is_forward_only and not snapshot.is_forward_only)
or (snapshot.is_forward_only and not target_snapshot.is_forward_only)
):
logger.info("Marking snapshot %s as unrestorable", snapshot.snapshot_id)
snapshot.unrestorable = True
unrestorable_snapshots.append(snapshot.snapshot_id)
updated_ts = now_timestamp()
unpaused_ts = to_timestamp(unpaused_dt)

if unpaused_snapshots:
for unpaused_ts, snapshot_ids in unpaused_snapshots.items():
self._update_snapshots(snapshot_ids, unpaused_ts=unpaused_ts)
# Pause all snapshots with target names first
for where in snapshot_name_filter(
[s.name for s in snapshots],
batch_size=self.SNAPSHOT_BATCH_SIZE,
):
self.engine_adapter.update_table(
self.snapshots_table,
{"unpaused_ts": None, "updated_ts": updated_ts},
where=where,
)

if paused_snapshots:
self._update_snapshots(paused_snapshots, unpaused_ts=None)
# Now unpause the target snapshots
self._update_snapshots(
[s.snapshot_id for s in snapshots],
unpaused_ts=unpaused_ts,
updated_ts=updated_ts,
)

if unrestorable_snapshots:
self._update_snapshots(unrestorable_snapshots, unrestorable=True)
# Mark unrestorable snapshots
for forward_only, snapshot_name_versions in unrestorable_snapshots_by_forward_only.items():
forward_only_exp = exp.column("forward_only").is_(exp.convert(forward_only))
for where in snapshot_name_version_filter(
self.engine_adapter,
snapshot_name_versions,
batch_size=self.SNAPSHOT_BATCH_SIZE,
alias=None,
):
self.engine_adapter.update_table(
self.snapshots_table,
{"unrestorable": True, "updated_ts": updated_ts},
where=forward_only_exp.and_(where),
)

def get_expired_snapshots(
self,
Expand Down Expand Up @@ -414,7 +382,8 @@ def _update_snapshots(
**kwargs: t.Any,
) -> None:
properties = kwargs
properties["updated_ts"] = now_timestamp()
if "updated_ts" not in properties:
properties["updated_ts"] = now_timestamp()

for where in snapshot_id_filter(
self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE
Expand Down Expand Up @@ -466,13 +435,15 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
updated_ts,
unpaused_ts,
unrestorable,
forward_only,
next_auto_restatement_ts,
) in fetchall(self.engine_adapter, query):
snapshot = parse_snapshot(
serialized_snapshot=serialized_snapshot,
updated_ts=updated_ts,
unpaused_ts=unpaused_ts,
unrestorable=unrestorable,
forward_only=forward_only,
next_auto_restatement_ts=next_auto_restatement_ts,
)
snapshot_id = snapshot.snapshot_id
Expand Down Expand Up @@ -502,6 +473,7 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
"updated_ts",
"unpaused_ts",
"unrestorable",
"forward_only",
"next_auto_restatement_ts",
)
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
Expand All @@ -528,13 +500,15 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
updated_ts,
unpaused_ts,
unrestorable,
forward_only,
next_auto_restatement_ts,
) in fetchall(self.engine_adapter, query):
snapshot_id = SnapshotId(name=name, identifier=identifier)
snapshot = snapshots[snapshot_id]
snapshot.updated_ts = updated_ts
snapshot.unpaused_ts = unpaused_ts
snapshot.unrestorable = unrestorable
snapshot.forward_only = forward_only
snapshot.next_auto_restatement_ts = next_auto_restatement_ts
cached_snapshots_in_state.add(snapshot_id)

Expand Down Expand Up @@ -568,6 +542,7 @@ def _get_snapshots_expressions(
"snapshots.updated_ts",
"snapshots.unpaused_ts",
"snapshots.unrestorable",
"snapshots.forward_only",
"auto_restatements.next_auto_restatement_ts",
)
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
Expand Down Expand Up @@ -623,6 +598,7 @@ def _get_snapshots_with_same_version(
"updated_ts",
"unpaused_ts",
"unrestorable",
"forward_only",
)
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
.where(where)
Expand All @@ -640,9 +616,10 @@ def _get_snapshots_with_same_version(
updated_ts=updated_ts,
unpaused_ts=unpaused_ts,
unrestorable=unrestorable,
forward_only=forward_only,
snapshot=snapshot,
)
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable in snapshot_rows
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable, forward_only in snapshot_rows
]


Expand All @@ -651,6 +628,7 @@ def parse_snapshot(
updated_ts: int,
unpaused_ts: t.Optional[int],
unrestorable: bool,
forward_only: bool,
next_auto_restatement_ts: t.Optional[int],
) -> Snapshot:
return Snapshot(
Expand All @@ -659,6 +637,7 @@ def parse_snapshot(
"updated_ts": updated_ts,
"unpaused_ts": unpaused_ts,
"unrestorable": unrestorable,
"forward_only": forward_only,
"next_auto_restatement_ts": next_auto_restatement_ts,
}
)
Expand All @@ -673,6 +652,7 @@ def _snapshot_to_json(snapshot: Snapshot) -> str:
"updated_ts",
"unpaused_ts",
"unrestorable",
"forward_only",
"next_auto_restatement_ts",
}
)
Expand All @@ -693,6 +673,7 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
"unpaused_ts": snapshot.unpaused_ts,
"ttl_ms": snapshot.ttl_ms,
"unrestorable": snapshot.unrestorable,
"forward_only": snapshot.forward_only,
}
for snapshot in snapshots
]
Expand Down Expand Up @@ -762,19 +743,10 @@ def full_snapshot(self) -> Snapshot:
"updated_ts": self.updated_ts,
"unpaused_ts": self.unpaused_ts,
"unrestorable": self.unrestorable,
"forward_only": self.forward_only,
}
)

def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None:
"""Sets the timestamp for when this snapshot was unpaused.

Args:
unpaused_dt: The datetime object of when this snapshot was unpaused.
"""
self.unpaused_ts = (
to_timestamp(self.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None
)

@classmethod
def from_snapshot_record(
cls,
Expand All @@ -785,6 +757,7 @@ def from_snapshot_record(
updated_ts: int,
unpaused_ts: t.Optional[int],
unrestorable: bool,
forward_only: bool,
snapshot: str,
) -> SharedVersionSnapshot:
raw_snapshot = json.loads(snapshot)
Expand All @@ -803,5 +776,5 @@ def from_snapshot_record(
disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False),
effective_from=raw_snapshot.get("effective_from"),
raw_snapshot=raw_snapshot,
forward_only=raw_snapshot.get("forward_only", False),
forward_only=forward_only,
)
15 changes: 15 additions & 0 deletions sqlmesh/core/state_sync/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@
T = t.TypeVar("T")


def snapshot_name_filter(
snapshot_names: t.Iterable[str],
batch_size: int,
alias: t.Optional[str] = None,
) -> t.Iterator[exp.Condition]:
names = sorted(snapshot_names)

if not names:
yield exp.false()
else:
batches = create_batches(names, batch_size=batch_size)
for names in batches:
yield exp.column("name", table=alias).isin(*names)


def snapshot_id_filter(
engine_adapter: EngineAdapter,
snapshot_ids: t.Iterable[SnapshotIdLike],
Expand Down
Loading
Loading