Skip to content

Commit 2636f4f

Browse files
committed
Feat(state_sync): Add the ability to fetch all versions of a snapshot by name
1 parent cef3859 commit 2636f4f

File tree

4 files changed

+146
-0
lines changed

4 files changed

+146
-0
lines changed

sqlmesh/core/state_sync/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,24 @@ def get_snapshots(
9797
A dictionary of snapshot ids to snapshots for ones that could be found.
9898
"""
9999

100+
@abc.abstractmethod
101+
def get_snapshot_ids_by_names(
102+
self,
103+
snapshot_names: t.Iterable[str],
104+
current_ts: t.Optional[int] = None,
105+
exclude_expired: bool = True,
106+
) -> t.Set[SnapshotId]:
107+
"""Return the snapshot id's for all versions of the specified snapshot names.
108+
109+
Args:
110+
snapshot_names: Iterable of snapshot names to fetch all snapshot id's for
111+
current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True)
112+
exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result
113+
114+
Returns:
115+
A dictionary mapping snapshot names to a list of relevant snapshot id's
116+
"""
117+
100118
@abc.abstractmethod
101119
def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
102120
"""Checks if multiple snapshots exist in the state sync.

sqlmesh/core/state_sync/db/facade.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,16 @@ def get_snapshots(
366366
Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals)
367367
return snapshots
368368

369+
def get_snapshot_ids_by_names(
370+
self,
371+
snapshot_names: t.Iterable[str],
372+
current_ts: t.Optional[int] = None,
373+
exclude_expired: bool = True,
374+
) -> t.Set[SnapshotId]:
375+
return self.snapshot_state.get_snapshot_ids_by_names(
376+
snapshot_names=snapshot_names, current_ts=current_ts, exclude_expired=exclude_expired
377+
)
378+
369379
@transactional()
370380
def add_interval(
371381
self,

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,46 @@ def get_snapshots(
308308
"""
309309
return self._get_snapshots(snapshot_ids)
310310

311+
def get_snapshot_ids_by_names(
312+
self,
313+
snapshot_names: t.Iterable[str],
314+
current_ts: t.Optional[int] = None,
315+
exclude_expired: bool = True,
316+
) -> t.Set[SnapshotId]:
317+
"""Return the snapshot id's for all versions of the specified snapshot names.
318+
319+
Args:
320+
snapshot_names: Iterable of snapshot names to fetch all snapshot id's for
321+
current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True)
322+
exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result
323+
324+
Returns:
325+
A dictionary mapping snapshot names to a list of relevant snapshot id's
326+
"""
327+
if not snapshot_names:
328+
return set()
329+
330+
if exclude_expired:
331+
current_ts = current_ts or now_timestamp()
332+
unexpired_expr = (exp.column("updated_ts") + exp.column("ttl_ms")) > current_ts
333+
else:
334+
unexpired_expr = None
335+
336+
return {
337+
SnapshotId(name=name, identifier=identifier)
338+
for where in snapshot_name_filter(
339+
snapshot_names=snapshot_names,
340+
batch_size=self.SNAPSHOT_BATCH_SIZE,
341+
)
342+
for name, identifier in fetchall(
343+
self.engine_adapter,
344+
exp.select("name", "identifier")
345+
.from_(self.snapshots_table)
346+
.where(where)
347+
.and_(unexpired_expr),
348+
)
349+
}
350+
311351
def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]:
312352
"""Checks if snapshots exist.
313353

tests/core/state_sync/test_state_sync.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3569,3 +3569,81 @@ def test_update_environment_statements(state_sync: EngineAdapterStateSync):
35693569
"@grant_schema_usage()",
35703570
"@grant_select_privileges()",
35713571
]
3572+
3573+
3574+
def test_get_snapshot_ids_by_names(
3575+
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot]
3576+
):
3577+
assert state_sync.get_snapshot_ids_by_names(snapshot_names=[]) == set()
3578+
3579+
snap_a_v1, snap_a_v2 = (
3580+
make_snapshot(
3581+
SqlModel(
3582+
name="a",
3583+
query=parse_one(f"select {i}, ds"),
3584+
),
3585+
version="a",
3586+
)
3587+
for i in range(2)
3588+
)
3589+
3590+
snap_b = make_snapshot(
3591+
SqlModel(
3592+
name="b",
3593+
query=parse_one(f"select 'b' as b, ds"),
3594+
),
3595+
version="b",
3596+
)
3597+
3598+
state_sync.push_snapshots([snap_a_v1, snap_a_v2, snap_b])
3599+
3600+
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"']) == {
3601+
snap_a_v1.snapshot_id,
3602+
snap_a_v2.snapshot_id,
3603+
}
3604+
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"', '"b"']) == {
3605+
snap_a_v1.snapshot_id,
3606+
snap_a_v2.snapshot_id,
3607+
snap_b.snapshot_id,
3608+
}
3609+
3610+
3611+
def test_get_snapshot_ids_by_names_include_expired(
3612+
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot]
3613+
):
3614+
now_ts = now_timestamp()
3615+
3616+
normal_a = make_snapshot(
3617+
SqlModel(
3618+
name="a",
3619+
query=parse_one(f"select 1, ds"),
3620+
),
3621+
version="a",
3622+
)
3623+
3624+
expired_a = make_snapshot(
3625+
SqlModel(
3626+
name="a",
3627+
query=parse_one(f"select 2, ds"),
3628+
),
3629+
version="a",
3630+
ttl="in 10 seconds",
3631+
)
3632+
expired_a.updated_ts = now_ts - (
3633+
1000 * 15
3634+
) # last updated 15 seconds ago, expired 10 seconds from last updated = expired 5 seconds ago
3635+
3636+
state_sync.push_snapshots([normal_a, expired_a])
3637+
3638+
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"'], current_ts=now_ts) == {
3639+
normal_a.snapshot_id
3640+
}
3641+
assert state_sync.get_snapshot_ids_by_names(snapshot_names=['"a"'], exclude_expired=False) == {
3642+
normal_a.snapshot_id,
3643+
expired_a.snapshot_id,
3644+
}
3645+
3646+
# wind back time to 10 seconds ago (before the expired snapshot is expired - it expired 5 seconds ago) to test it stil shows in a normal query
3647+
assert state_sync.get_snapshot_ids_by_names(
3648+
snapshot_names=['"a"'], current_ts=(now_ts - (10 * 1000))
3649+
) == {normal_a.snapshot_id, expired_a.snapshot_id}

0 commit comments

Comments
 (0)