Skip to content

Commit 1438970

Browse files
committed
PR feedback
1 parent 06c11df commit 1438970

File tree

10 files changed

+81
-28
lines changed

10 files changed

+81
-28
lines changed

sqlmesh/core/plan/common.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass, field
55

66
from sqlmesh.core.state_sync import StateReader
7-
from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo, SnapshotNameVersion
7+
from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotIdAndVersion, SnapshotNameVersion
88
from sqlmesh.core.snapshot.definition import Interval
99
from sqlmesh.utils.dag import DAG
1010
from sqlmesh.utils.date import now_timestamp
@@ -41,7 +41,7 @@ def is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool:
4141
@dataclass
4242
class SnapshotIntervalClearRequest:
4343
# affected snapshot
44-
table_info: SnapshotTableInfo
44+
snapshot: SnapshotIdAndVersion
4545

4646
# which interval to clear
4747
interval: Interval
@@ -53,7 +53,7 @@ class SnapshotIntervalClearRequest:
5353

5454
@property
5555
def snapshot_id(self) -> SnapshotId:
56-
return self.table_info.snapshot_id
56+
return self.snapshot.snapshot_id
5757

5858
@property
5959
def sorted_environment_names(self) -> t.List[str]:
@@ -122,7 +122,7 @@ def identify_restatement_intervals_across_snapshot_versions(
122122
clear_request = snapshot_intervals_to_clear.get(affected_snapshot.snapshot_id)
123123
if not clear_request:
124124
clear_request = SnapshotIntervalClearRequest(
125-
table_info=affected_snapshot, interval=interval
125+
snapshot=affected_snapshot.id_and_version, interval=interval
126126
)
127127
snapshot_intervals_to_clear[affected_snapshot.snapshot_id] = clear_request
128128

@@ -164,19 +164,13 @@ def identify_restatement_intervals_across_snapshot_versions(
164164

165165
snapshot_name_to_widest_interval[s_id.name] = (next_start, next_end)
166166

167-
# we need to fetch full Snapshot's to get access to the SnapshotTableInfo objects
168-
# required by StateSync.remove_intervals()
169-
# but at this point we have minimized the list by excluding the ones that are already present in prod
170-
# and also excluding the ones we have already matched earlier while traversing the environment DAGs
171-
remaining_snapshots = state_reader.get_snapshots(snapshot_ids=remaining_snapshot_ids)
172-
for remaining_snapshot_id, remaining_snapshot in remaining_snapshots.items():
167+
for remaining_snapshot_id in remaining_snapshot_ids:
168+
remaining_snapshot = all_matching_non_prod_snapshots[remaining_snapshot_id]
173169
snapshot_intervals_to_clear[remaining_snapshot_id] = SnapshotIntervalClearRequest(
174-
table_info=remaining_snapshot.table_info,
170+
snapshot=remaining_snapshot,
175171
interval=snapshot_name_to_widest_interval[remaining_snapshot_id.name],
176172
)
177173

178-
loaded_snapshots.update(remaining_snapshots)
179-
180174
# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
181175
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
182176
# so we only do it if necessary
@@ -187,7 +181,7 @@ def identify_restatement_intervals_across_snapshot_versions(
187181
# So for now, these are not considered
188182
s_id
189183
for s_id, s in snapshot_intervals_to_clear.items()
190-
if s.table_info.full_history_restatement_only
184+
if s.snapshot.kind_name and s.snapshot.kind_name.full_history_restatement_only
191185
]
192186
if full_history_restatement_snapshot_ids:
193187
# only load full snapshot records that we havent already loaded

sqlmesh/core/plan/evaluator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,9 @@ def visit_audit_only_run_stage(
284284
def visit_restatement_stage(
285285
self, stage: stages.RestatementStage, plan: EvaluatablePlan
286286
) -> None:
287-
snapshot_intervals_to_restate = {(s, i) for s, i in stage.snapshot_intervals.items()}
287+
snapshot_intervals_to_restate = {
288+
(s.id_and_version, i) for s, i in stage.snapshot_intervals.items()
289+
}
288290

289291
# Restating intervals on prod plans should mean that the intervals are cleared across
290292
# all environments, not just the version currently in prod
@@ -294,7 +296,7 @@ def visit_restatement_stage(
294296
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
295297
snapshot_intervals_to_restate.update(
296298
{
297-
(s.table_info, s.interval)
299+
(s.snapshot, s.interval)
298300
for s in identify_restatement_intervals_across_snapshot_versions(
299301
state_reader=self.state_sync,
300302
prod_restatements=plan.restatements,

sqlmesh/core/snapshot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
SnapshotId as SnapshotId,
1212
SnapshotIdBatch as SnapshotIdBatch,
1313
SnapshotIdLike as SnapshotIdLike,
14+
SnapshotIdAndVersionLike as SnapshotIdAndVersionLike,
1415
SnapshotInfoLike as SnapshotInfoLike,
1516
SnapshotIntervals as SnapshotIntervals,
1617
SnapshotNameVersion as SnapshotNameVersion,

sqlmesh/core/snapshot/definition.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,17 @@ def name_version(self) -> SnapshotNameVersion:
587587
"""Returns the name and version of the snapshot."""
588588
return SnapshotNameVersion(name=self.name, version=self.version)
589589

590+
@property
591+
def id_and_version(self) -> SnapshotIdAndVersion:
592+
return SnapshotIdAndVersion(
593+
name=self.name,
594+
kind_name=self.kind_name,
595+
identifier=self.identifier,
596+
version=self.version,
597+
dev_version=self.dev_version,
598+
fingerprint=self.fingerprint,
599+
)
600+
590601

591602
class SnapshotIdAndVersion(PydanticModel):
592603
"""A stripped down version of a snapshot that is used in situations where we want to fetch the main fields of the snapshots table
@@ -595,6 +606,7 @@ class SnapshotIdAndVersion(PydanticModel):
595606

596607
name: str
597608
version: str
609+
kind_name: t.Optional[ModelKindName] = None
598610
dev_version_: t.Optional[str] = Field(alias="dev_version")
599611
identifier: str
600612
fingerprint_: t.Union[str, SnapshotFingerprint] = Field(alias="fingerprint")
@@ -603,6 +615,10 @@ class SnapshotIdAndVersion(PydanticModel):
603615
def snapshot_id(self) -> SnapshotId:
604616
return SnapshotId(name=self.name, identifier=self.identifier)
605617

618+
@property
619+
def id_and_version(self) -> SnapshotIdAndVersion:
620+
return self
621+
606622
@property
607623
def name_version(self) -> SnapshotNameVersion:
608624
return SnapshotNameVersion(name=self.name, version=self.version)
@@ -1424,6 +1440,10 @@ def name_version(self) -> SnapshotNameVersion:
14241440
"""Returns the name and version of the snapshot."""
14251441
return SnapshotNameVersion(name=self.name, version=self.version)
14261442

1443+
@property
1444+
def id_and_version(self) -> SnapshotIdAndVersion:
1445+
return self.table_info.id_and_version
1446+
14271447
@property
14281448
def disable_restatement(self) -> bool:
14291449
"""Is restatement disabled for the node"""
@@ -1494,7 +1514,8 @@ class SnapshotTableCleanupTask(PydanticModel):
14941514
dev_table_only: bool
14951515

14961516

1497-
SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, SnapshotIdAndVersion, Snapshot]
1517+
SnapshotIdLike = t.Union[SnapshotId, SnapshotIdAndVersion, SnapshotTableInfo, Snapshot]
1518+
SnapshotIdAndVersionLike = t.Union[SnapshotIdAndVersion, SnapshotTableInfo, Snapshot]
14981519
SnapshotInfoLike = t.Union[SnapshotTableInfo, Snapshot]
14991520
SnapshotNameVersionLike = t.Union[
15001521
SnapshotNameVersion, SnapshotTableInfo, SnapshotIdAndVersion, Snapshot

sqlmesh/core/state_sync/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Snapshot,
2020
SnapshotId,
2121
SnapshotIdLike,
22+
SnapshotIdAndVersionLike,
2223
SnapshotInfoLike,
2324
SnapshotTableCleanupTask,
2425
SnapshotTableInfo,
@@ -390,7 +391,7 @@ def remove_state(self, including_backup: bool = False) -> None:
390391
@abc.abstractmethod
391392
def remove_intervals(
392393
self,
393-
snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]],
394+
snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]],
394395
remove_shared_versions: bool = False,
395396
) -> None:
396397
"""Remove an interval from a list of snapshots and sync it to the store.

sqlmesh/core/state_sync/cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Snapshot,
88
SnapshotId,
99
SnapshotIdLike,
10+
SnapshotIdAndVersionLike,
1011
SnapshotInfoLike,
1112
)
1213
from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals
@@ -128,7 +129,7 @@ def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotInterv
128129

129130
def remove_intervals(
130131
self,
131-
snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]],
132+
snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]],
132133
remove_shared_versions: bool = False,
133134
) -> None:
134135
for s, _ in snapshot_intervals:

sqlmesh/core/state_sync/db/facade.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
SnapshotIdAndVersion,
3232
SnapshotId,
3333
SnapshotIdLike,
34+
SnapshotIdAndVersionLike,
3435
SnapshotInfoLike,
3536
SnapshotIntervals,
3637
SnapshotNameVersion,
@@ -407,7 +408,7 @@ def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotInterv
407408
@transactional()
408409
def remove_intervals(
409410
self,
410-
snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]],
411+
snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]],
411412
remove_shared_versions: bool = False,
412413
) -> None:
413414
self.interval_state.remove_intervals(snapshot_intervals, remove_shared_versions)

sqlmesh/core/state_sync/db/interval.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from sqlmesh.core.snapshot import (
1616
SnapshotIntervals,
1717
SnapshotIdLike,
18+
SnapshotIdAndVersionLike,
1819
SnapshotNameVersionLike,
1920
SnapshotTableCleanupTask,
2021
SnapshotNameVersion,
21-
SnapshotInfoLike,
2222
Snapshot,
2323
)
2424
from sqlmesh.core.snapshot.definition import Interval
@@ -68,11 +68,11 @@ def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotInterv
6868

6969
def remove_intervals(
7070
self,
71-
snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]],
71+
snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]],
7272
remove_shared_versions: bool = False,
7373
) -> None:
7474
intervals_to_remove: t.Sequence[
75-
t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval]
75+
t.Tuple[t.Union[SnapshotIdAndVersionLike, SnapshotIntervals], Interval]
7676
] = snapshot_intervals
7777
if remove_shared_versions:
7878
name_version_mapping = {s.name_version: interval for s, interval in snapshot_intervals}
@@ -431,7 +431,9 @@ def _delete_intervals_by_version(self, targets: t.List[SnapshotTableCleanupTask]
431431

432432

433433
def _intervals_to_df(
434-
snapshot_intervals: t.Sequence[t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval]],
434+
snapshot_intervals: t.Sequence[
435+
t.Tuple[t.Union[SnapshotIdAndVersionLike, SnapshotIntervals], Interval]
436+
],
435437
is_dev: bool,
436438
is_removed: bool,
437439
) -> pd.DataFrame:
@@ -451,7 +453,7 @@ def _intervals_to_df(
451453

452454

453455
def _interval_to_df(
454-
snapshot: t.Union[SnapshotInfoLike, SnapshotIntervals],
456+
snapshot: t.Union[SnapshotIdAndVersionLike, SnapshotIntervals],
455457
start_ts: int,
456458
end_ts: int,
457459
is_dev: bool = False,

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,19 @@ def get_snapshots_by_names(
337337
name=name,
338338
identifier=identifier,
339339
version=version,
340+
kind_name=kind_name,
340341
dev_version=dev_version,
341342
fingerprint=fingerprint,
342343
)
343344
for where in snapshot_name_filter(
344345
snapshot_names=snapshot_names,
345346
batch_size=self.SNAPSHOT_BATCH_SIZE,
346347
)
347-
for name, identifier, version, dev_version, fingerprint in fetchall(
348+
for name, identifier, version, kind_name, dev_version, fingerprint in fetchall(
348349
self.engine_adapter,
349-
exp.select("name", "identifier", "version", "dev_version", "fingerprint")
350+
exp.select(
351+
"name", "identifier", "version", "kind_name", "dev_version", "fingerprint"
352+
)
350353
.from_(self.snapshots_table)
351354
.where(where)
352355
.and_(unexpired_expr),
@@ -661,6 +664,7 @@ def _get_snapshots_with_same_version(
661664
"name",
662665
"identifier",
663666
"version",
667+
"kind_name",
664668
"dev_version",
665669
"fingerprint",
666670
)
@@ -677,10 +681,11 @@ def _get_snapshots_with_same_version(
677681
name=name,
678682
identifier=identifier,
679683
version=version,
684+
kind_name=kind_name,
680685
dev_version=dev_version,
681686
fingerprint=SnapshotFingerprint.parse_raw(fingerprint),
682687
)
683-
for name, identifier, version, dev_version, fingerprint in snapshot_rows
688+
for name, identifier, version, kind_name, dev_version, fingerprint in snapshot_rows
684689
]
685690

686691

tests/core/test_snapshot.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3567,3 +3567,28 @@ def test_snapshot_id_and_version_fingerprint_lazy_init():
35673567

35683568
assert isinstance(snapshot.fingerprint_, SnapshotFingerprint)
35693569
assert snapshot.fingerprint == fingerprint
3570+
3571+
3572+
def test_snapshot_id_and_version_optional_kind_name():
3573+
snapshot = SnapshotIdAndVersion(
3574+
name="a",
3575+
identifier="1234",
3576+
version="2345",
3577+
dev_version=None,
3578+
fingerprint="",
3579+
)
3580+
3581+
assert snapshot.kind_name is None
3582+
3583+
snapshot = SnapshotIdAndVersion(
3584+
name="a",
3585+
identifier="1234",
3586+
version="2345",
3587+
kind_name="INCREMENTAL_UNMANAGED",
3588+
dev_version=None,
3589+
fingerprint="",
3590+
)
3591+
3592+
assert snapshot.kind_name
3593+
assert snapshot.kind_name.is_incremental_unmanaged
3594+
assert snapshot.kind_name.full_history_restatement_only

0 commit comments

Comments
 (0)