Skip to content

Commit 42a113c

Browse files
committed
PR feedback
1 parent 2636f4f commit 42a113c

File tree

6 files changed

+70
-49
lines changed

6 files changed

+70
-49
lines changed

sqlmesh/core/snapshot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Node as Node,
55
QualifiedViewName as QualifiedViewName,
66
Snapshot as Snapshot,
7+
MinimalSnapshot as MinimalSnapshot,
78
SnapshotChangeCategory as SnapshotChangeCategory,
89
SnapshotDataVersion as SnapshotDataVersion,
910
SnapshotFingerprint as SnapshotFingerprint,

sqlmesh/core/snapshot/definition.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,30 @@ def name_version(self) -> SnapshotNameVersion:
588588
return SnapshotNameVersion(name=self.name, version=self.version)
589589

590590

591+
class MinimalSnapshot(PydanticModel):
592+
"""A stripped down version of a snapshot that is used in situations where we want to fetch the main fields of the snapshots table
593+
without the overhead of parsing the full snapshot payload and fetching intervals.
594+
"""
595+
596+
name: str
597+
version: str
598+
dev_version_: t.Optional[str] = Field(alias="dev_version")
599+
identifier: str
600+
fingerprint: SnapshotFingerprint
601+
602+
@property
603+
def snapshot_id(self) -> SnapshotId:
604+
return SnapshotId(name=self.name, identifier=self.identifier)
605+
606+
@property
607+
def name_version(self) -> SnapshotNameVersion:
608+
return SnapshotNameVersion(name=self.name, version=self.version)
609+
610+
@property
611+
def dev_version(self) -> str:
612+
return self.dev_version_ or self.fingerprint.to_version()
613+
614+
591615
class Snapshot(PydanticModel, SnapshotInfoMixin):
592616
"""A snapshot represents a node at a certain point in time.
593617

sqlmesh/core/state_sync/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
SnapshotTableCleanupTask,
2424
SnapshotTableInfo,
2525
SnapshotNameVersion,
26+
MinimalSnapshot,
2627
)
2728
from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals
2829
from sqlmesh.utils import major_minor
@@ -98,21 +99,21 @@ def get_snapshots(
9899
"""
99100

100101
@abc.abstractmethod
101-
def get_snapshot_ids_by_names(
102+
def get_snapshots_by_names(
102103
self,
103104
snapshot_names: t.Iterable[str],
104105
current_ts: t.Optional[int] = None,
105106
exclude_expired: bool = True,
106-
) -> t.Set[SnapshotId]:
107+
) -> t.Set[MinimalSnapshot]:
107108
"""Return the snapshot id's for all versions of the specified snapshot names.
108109
109110
Args:
110-
snapshot_names: Iterable of snapshot names to fetch all snapshot id's for
111+
snapshot_names: Iterable of snapshot names to fetch all snapshot for
111112
current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True)
112113
exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result
113114
114115
Returns:
115-
A dictionary mapping snapshot names to a list of relevant snapshot id's
116+
A set containing all the matched snapshot records. To fetch full snapshots, pass it into StateSync.get_snapshots()
116117
"""
117118

118119
@abc.abstractmethod

sqlmesh/core/state_sync/db/facade.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentSummary
3030
from sqlmesh.core.snapshot import (
3131
Snapshot,
32+
MinimalSnapshot,
3233
SnapshotId,
3334
SnapshotIdLike,
3435
SnapshotInfoLike,
@@ -366,13 +367,13 @@ def get_snapshots(
366367
Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals)
367368
return snapshots
368369

369-
def get_snapshot_ids_by_names(
370+
def get_snapshots_by_names(
370371
self,
371372
snapshot_names: t.Iterable[str],
372373
current_ts: t.Optional[int] = None,
373374
exclude_expired: bool = True,
374-
) -> t.Set[SnapshotId]:
375-
return self.snapshot_state.get_snapshot_ids_by_names(
375+
) -> t.Set[MinimalSnapshot]:
376+
return self.snapshot_state.get_snapshots_by_names(
376377
snapshot_names=snapshot_names, current_ts=current_ts, exclude_expired=exclude_expired
377378
)
378379

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pathlib import Path
77
from collections import defaultdict
88
from sqlglot import exp
9-
from pydantic import Field
109

1110
from sqlmesh.core.engine_adapter import EngineAdapter
1211
from sqlmesh.core.state_sync.db.utils import (
@@ -27,12 +26,12 @@
2726
SnapshotNameVersion,
2827
SnapshotInfoLike,
2928
Snapshot,
29+
MinimalSnapshot,
3030
SnapshotId,
3131
SnapshotFingerprint,
3232
)
3333
from sqlmesh.utils.migration import index_text_type, blob_text_type
3434
from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp
35-
from sqlmesh.utils.pydantic import PydanticModel
3635
from sqlmesh.utils import unique
3736

3837
if t.TYPE_CHECKING:
@@ -215,7 +214,7 @@ def _get_expired_snapshots(
215214
for snapshot in environment.snapshots
216215
}
217216

218-
def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
217+
def _is_snapshot_used(snapshot: MinimalSnapshot) -> bool:
219218
return (
220219
snapshot.snapshot_id in promoted_snapshot_ids
221220
or snapshot.snapshot_id not in expired_candidates
@@ -308,12 +307,12 @@ def get_snapshots(
308307
"""
309308
return self._get_snapshots(snapshot_ids)
310309

311-
def get_snapshot_ids_by_names(
310+
def get_snapshots_by_names(
312311
self,
313312
snapshot_names: t.Iterable[str],
314313
current_ts: t.Optional[int] = None,
315314
exclude_expired: bool = True,
316-
) -> t.Set[SnapshotId]:
315+
) -> t.Set[MinimalSnapshot]:
317316
"""Return the snapshot id's for all versions of the specified snapshot names.
318317
319318
Args:
@@ -334,14 +333,20 @@ def get_snapshot_ids_by_names(
334333
unexpired_expr = None
335334

336335
return {
337-
SnapshotId(name=name, identifier=identifier)
336+
MinimalSnapshot(
337+
name=name,
338+
identifier=identifier,
339+
version=version,
340+
dev_version=dev_version,
341+
fingerprint=SnapshotFingerprint.parse_raw(fingerprint),
342+
)
338343
for where in snapshot_name_filter(
339344
snapshot_names=snapshot_names,
340345
batch_size=self.SNAPSHOT_BATCH_SIZE,
341346
)
342-
for name, identifier in fetchall(
347+
for name, identifier, version, dev_version, fingerprint in fetchall(
343348
self.engine_adapter,
344-
exp.select("name", "identifier")
349+
exp.select("name", "identifier", "version", "dev_version", "fingerprint")
345350
.from_(self.snapshots_table)
346351
.where(where)
347352
.and_(unexpired_expr),
@@ -631,7 +636,7 @@ def _get_snapshots_with_same_version(
631636
self,
632637
snapshots: t.Collection[SnapshotNameVersionLike],
633638
lock_for_update: bool = False,
634-
) -> t.List[SharedVersionSnapshot]:
639+
) -> t.List[MinimalSnapshot]:
635640
"""Fetches all snapshots that share the same version as the snapshots.
636641
637642
The output includes the snapshots with the specified identifiers.
@@ -668,7 +673,7 @@ def _get_snapshots_with_same_version(
668673
snapshot_rows.extend(fetchall(self.engine_adapter, query))
669674

670675
return [
671-
SharedVersionSnapshot(
676+
MinimalSnapshot(
672677
name=name,
673678
identifier=identifier,
674679
version=version,
@@ -751,23 +756,3 @@ def _auto_restatements_to_df(auto_restatements: t.Dict[SnapshotNameVersion, int]
751756
for name_version, ts in auto_restatements.items()
752757
]
753758
)
754-
755-
756-
class SharedVersionSnapshot(PydanticModel):
757-
"""A stripped down version of a snapshot that is used for fetching snapshots that share the same version
758-
with a significantly reduced parsing overhead.
759-
"""
760-
761-
name: str
762-
version: str
763-
dev_version_: t.Optional[str] = Field(alias="dev_version")
764-
identifier: str
765-
fingerprint: SnapshotFingerprint
766-
767-
@property
768-
def snapshot_id(self) -> SnapshotId:
769-
return SnapshotId(name=self.name, identifier=self.identifier)
770-
771-
@property
772-
def dev_version(self) -> str:
773-
return self.dev_version_ or self.fingerprint.to_version()

tests/core/state_sync/test_state_sync.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3571,10 +3571,10 @@ def test_update_environment_statements(state_sync: EngineAdapterStateSync):
35713571
]
35723572

35733573

3574-
def test_get_snapshot_ids_by_names(
3574+
def test_get_snapshots_by_names(
35753575
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot]
35763576
):
3577-
assert state_sync.get_snapshot_ids_by_names(snapshot_names=[]) == set()
3577+
assert state_sync.get_snapshots_by_names(snapshot_names=[]) == set()
35783578

35793579
snap_a_v1, snap_a_v2 = (
35803580
make_snapshot(
@@ -3597,18 +3597,20 @@ def test_get_snapshot_ids_by_names(
35973597

35983598
state_sync.push_snapshots([snap_a_v1, snap_a_v2, snap_b])
35993599

3600-
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"']) == {
3600+
assert {s.snapshot_id for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'])} == {
36013601
snap_a_v1.snapshot_id,
36023602
snap_a_v2.snapshot_id,
36033603
}
3604-
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"', '"b"']) == {
3604+
assert {
3605+
s.snapshot_id for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"', '"b"'])
3606+
} == {
36053607
snap_a_v1.snapshot_id,
36063608
snap_a_v2.snapshot_id,
36073609
snap_b.snapshot_id,
36083610
}
36093611

36103612

3611-
def test_get_snapshot_ids_by_names_include_expired(
3613+
def test_get_snapshots_by_names_include_expired(
36123614
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot]
36133615
):
36143616
now_ts = now_timestamp()
@@ -3635,15 +3637,22 @@ def test_get_snapshot_ids_by_names_include_expired(
36353637

36363638
state_sync.push_snapshots([normal_a, expired_a])
36373639

3638-
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"'], current_ts=now_ts) == {
3639-
normal_a.snapshot_id
3640-
}
3641-
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"'], exclude_expired=False) == {
3640+
assert {
3641+
s.snapshot_id
3642+
for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'], current_ts=now_ts)
3643+
} == {normal_a.snapshot_id}
3644+
assert {
3645+
s.snapshot_id
3646+
for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'], exclude_expired=False)
3647+
} == {
36423648
normal_a.snapshot_id,
36433649
expired_a.snapshot_id,
36443650
}
36453651

36463652
# wind back time to 10 seconds ago (before the expired snapshot is expired - it expired 5 seconds ago) to test it stil shows in a normal query
3647-
assert state_sync.get_snapshot_ids_by_names(
3648-
snapshot_names=['"a"'], current_ts=(now_ts - (10 * 1000))
3649-
) == {normal_a.snapshot_id, expired_a.snapshot_id}
3653+
assert {
3654+
s.snapshot_id
3655+
for s in state_sync.get_snapshots_by_names(
3656+
snapshot_names=['"a"'], current_ts=(now_ts - (10 * 1000))
3657+
)
3658+
} == {normal_a.snapshot_id, expired_a.snapshot_id}

0 commit comments

Comments
 (0)