diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 5b6d96d970..7aaf902216 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -373,25 +373,31 @@ def update_auto_restatements( Args: next_auto_restatement_ts: A dictionary of snapshot name version to the next auto restatement timestamp. """ + next_auto_restatement_ts_deleted = [] + next_auto_restatement_ts_filtered = {} + for k, v in next_auto_restatement_ts.items(): + if v is None: + next_auto_restatement_ts_deleted.append(k) + else: + next_auto_restatement_ts_filtered[k] = v + for where in snapshot_name_version_filter( self.engine_adapter, - next_auto_restatement_ts, + next_auto_restatement_ts_deleted, column_prefix="snapshot", alias=None, batch_size=self.SNAPSHOT_BATCH_SIZE, ): self.engine_adapter.delete_from(self.auto_restatements_table, where=where) - next_auto_restatement_ts_filtered = { - k: v for k, v in next_auto_restatement_ts.items() if v is not None - } if not next_auto_restatement_ts_filtered: return - self.engine_adapter.insert_append( + self.engine_adapter.merge( self.auto_restatements_table, _auto_restatements_to_df(next_auto_restatement_ts_filtered), columns_to_types=self._auto_restatement_columns_to_types, + unique_key=(exp.column("snapshot_name"), exp.column("snapshot_version")), ) def count(self) -> int: