Skip to content
Merged
7 changes: 7 additions & 0 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(
deployability_index: t.Optional[DeployabilityIndex] = None,
default_dialect: t.Optional[str] = None,
default_catalog: t.Optional[str] = None,
is_restatement: t.Optional[bool] = None,
variables: t.Optional[t.Dict[str, t.Any]] = None,
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
):
Expand All @@ -284,6 +285,7 @@ def __init__(
self._default_dialect = default_dialect
self._variables = variables or {}
self._blueprint_variables = blueprint_variables or {}
self._is_restatement = is_restatement

@property
def default_dialect(self) -> t.Optional[str]:
Expand All @@ -308,6 +310,10 @@ def gateway(self) -> t.Optional[str]:
"""Returns the gateway name."""
return self.var(c.GATEWAY)

@property
def is_restatement(self) -> t.Optional[bool]:
return self._is_restatement

def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
"""Returns a variable value."""
return self._variables.get(var_name.lower(), default)
Expand All @@ -328,6 +334,7 @@ def with_variables(
self.deployability_index,
self._default_dialect,
self._default_catalog,
self._is_restatement,
variables=variables,
blueprint_variables=blueprint_variables,
)
Expand Down
4 changes: 4 additions & 0 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class EngineAdapter:
MAX_IDENTIFIER_LENGTH: t.Optional[int] = None
ATTACH_CORRELATION_ID = True
SUPPORTS_QUERY_EXECUTION_TRACKING = False
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = False

def __init__(
self,
Expand Down Expand Up @@ -2927,6 +2928,9 @@ def _check_identifier_length(self, expression: exp.Expression) -> None:
f"Identifier name '{name}' (length {name_length}) exceeds {self.dialect.capitalize()}'s max identifier limit of {self.MAX_IDENTIFIER_LENGTH} characters"
)

def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
raise NotImplementedError()


class EngineAdapterWithIndexSupport(EngineAdapter):
SUPPORTS_INDEXES = True
Expand Down
22 changes: 22 additions & 0 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,28 @@ def table_exists(self, table_name: TableName) -> bool:
except NotFound:
return False

def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
from sqlmesh.utils.date import to_timestamp

datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list)
for table_name in table_names:
table = exp.to_table(table_name)
datasets_to_tables[table.db].append(table.name)

results = []

for dataset, tables in datasets_to_tables.items():
query = (
f"SELECT TIMESTAMP_MILLIS(last_modified_time) FROM `{dataset}.__TABLES__` WHERE "
)
for i, table_name in enumerate(tables):
query += f"TABLE_ID = '{table_name}'"
if i < len(tables) - 1:
query += " OR "
results.extend(self.fetchall(query))

return [to_timestamp(row[0]) for row in results]

def _get_table(self, table_name: TableName) -> BigQueryTable:
"""
Returns a BigQueryTable object for the given table name.
Expand Down
16 changes: 16 additions & 0 deletions sqlmesh/core/engine_adapter/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
SUPPORTS_MANAGED_MODELS = True
CURRENT_CATALOG_EXPRESSION = exp.func("current_database")
SUPPORTS_CREATE_DROP_CATALOG = True
SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = True
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA", "TABLE"]
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
Expand Down Expand Up @@ -669,3 +670,18 @@ def close(self) -> t.Any:
self._connection_pool.set_attribute(self.SNOWPARK, None)

return super().close()

def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
from sqlmesh.utils.date import to_timestamp

num_tables = len(table_names)

query = "SELECT LAST_ALTERED FROM INFORMATION_SCHEMA.TABLES WHERE"
for i, table_name in enumerate(table_names):
table = exp.to_table(table_name)
query += f"""(TABLE_NAME = '{table.name}' AND TABLE_SCHEMA = '{table.db}' AND TABLE_CATALOG = '{table.catalog}')"""
if i < num_tables - 1:
query += " OR "

result = self.fetchall(query)
return [to_timestamp(row[0]) for row in result]
1 change: 1 addition & 0 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
allow_additive_snapshots=plan.allow_additive_models,
selected_snapshot_ids=stage.selected_snapshot_ids,
selected_models=plan.selected_models,
is_restatement=bool(plan.restatements),
)
if errors:
raise PlanError("Plan application failed.")
Expand Down
15 changes: 12 additions & 3 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ def evaluate(
**kwargs,
)

self.state_sync.add_interval(snapshot, start, end, is_dev=not is_deployable)
self.state_sync.add_interval(
snapshot, start, end, is_dev=not is_deployable, last_altered_ts=now_timestamp()
)
return audit_results

def run(
Expand Down Expand Up @@ -335,6 +337,7 @@ def batch_intervals(
deployability_index: t.Optional[DeployabilityIndex],
environment_naming_info: EnvironmentNamingInfo,
dag: t.Optional[DAG[SnapshotId]] = None,
is_restatement: bool = False,
) -> t.Dict[Snapshot, Intervals]:
dag = dag or snapshots_to_dag(merged_intervals)

Expand Down Expand Up @@ -367,6 +370,7 @@ def batch_intervals(
deployability_index,
default_dialect=adapter.dialect,
default_catalog=self.default_catalog,
is_restatement=is_restatement,
)

intervals = self._check_ready_intervals(
Expand Down Expand Up @@ -422,6 +426,7 @@ def run_merged_intervals(
run_environment_statements: bool = False,
audit_only: bool = False,
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {},
is_restatement: bool = False,
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
"""Runs precomputed batches of missing intervals.

Expand Down Expand Up @@ -455,9 +460,12 @@ def run_merged_intervals(
snapshot_dag = full_dag.subdag(*selected_snapshot_ids_set)

batched_intervals = self.batch_intervals(
merged_intervals, deployability_index, environment_naming_info, dag=snapshot_dag
merged_intervals,
deployability_index,
environment_naming_info,
dag=snapshot_dag,
is_restatement=is_restatement,
)

self.console.start_evaluation_progress(
batched_intervals,
environment_naming_info,
Expand Down Expand Up @@ -956,6 +964,7 @@ def _check_ready_intervals(
python_env=signals.python_env,
dialect=snapshot.model.dialect,
path=snapshot.model._path,
snapshot=snapshot,
kwargs=kwargs,
)
except SQLMeshError as e:
Expand Down
44 changes: 43 additions & 1 deletion sqlmesh/core/signal.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from __future__ import annotations


import typing as t
from sqlmesh.utils import UniqueKeyDict, registry_decorator

if t.TYPE_CHECKING:
from sqlmesh.core.context import ExecutionContext
from sqlmesh.core.snapshot.definition import Snapshot
from sqlmesh.utils.date import DatetimeRanges
from sqlmesh.core.snapshot.definition import DeployabilityIndex


class signal(registry_decorator):
"""Specifies a function which intervals are ready from a list of scheduled intervals.
Expand Down Expand Up @@ -33,3 +39,39 @@ class signal(registry_decorator):


SignalRegistry = UniqueKeyDict[str, signal]


@signal()
def freshness(batch: DatetimeRanges, snapshot: Snapshot, context: ExecutionContext) -> bool:
adapter = context.engine_adapter
if context.is_restatement or not adapter.SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS:
return True

deployability_index = context.deployability_index or DeployabilityIndex.all_deployable()

last_altered_ts = (
snapshot.last_altered_ts
if deployability_index.is_deployable(snapshot)
else snapshot.dev_last_altered_ts
)
if not last_altered_ts:
return True

parent_snapshots = {context.snapshots[p.name] for p in snapshot.parents}
if len(parent_snapshots) != len(snapshot.node.depends_on) or not all(
p.is_external for p in parent_snapshots
):
# The mismatch can happen if e.g an external model is not registered in the project
return True

# Finding new data means that the upstream depedencies have been altered
# since the last time the model was evaluated
upstream_dep_has_new_data = any(
upstream_last_altered_ts > last_altered_ts
for upstream_last_altered_ts in adapter.get_table_last_modified_ts(
[p.name for p in parent_snapshots]
)
)

# Returning true is a no-op, returning False nullifies the batch so the model will not be evaluated.
return upstream_dep_has_new_data
31 changes: 31 additions & 0 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ class SnapshotIntervals(PydanticModel):
intervals: Intervals = []
dev_intervals: Intervals = []
pending_restatement_intervals: Intervals = []
last_altered_ts: t.Optional[int] = None
dev_last_altered_ts: t.Optional[int] = None

@property
def snapshot_id(self) -> t.Optional[SnapshotId]:
Expand All @@ -205,6 +207,12 @@ def add_dev_interval(self, start: int, end: int) -> None:
def add_pending_restatement_interval(self, start: int, end: int) -> None:
self._add_interval(start, end, "pending_restatement_intervals")

def update_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None:
self._update_last_altered_ts(last_altered_ts, "last_altered_ts")

def update_dev_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None:
self._update_last_altered_ts(last_altered_ts, "dev_last_altered_ts")

def remove_interval(self, start: int, end: int) -> None:
self._remove_interval(start, end, "intervals")

Expand All @@ -224,6 +232,13 @@ def _add_interval(self, start: int, end: int, interval_attr: str) -> None:
target_intervals = merge_intervals([*target_intervals, (start, end)])
setattr(self, interval_attr, target_intervals)

def _update_last_altered_ts(
self, last_altered_ts: t.Optional[int], last_altered_attr: str
) -> None:
if last_altered_ts:
existing_last_altered_ts = getattr(self, last_altered_attr)
setattr(self, last_altered_attr, max(existing_last_altered_ts or 0, last_altered_ts))

def _remove_interval(self, start: int, end: int, interval_attr: str) -> None:
target_intervals = getattr(self, interval_attr)
target_intervals = remove_interval(target_intervals, start, end)
Expand Down Expand Up @@ -713,6 +728,10 @@ class Snapshot(PydanticModel, SnapshotInfoMixin):
dev_table_suffix: str = "dev"
table_naming_convention: TableNamingConvention = TableNamingConvention.default
forward_only: bool = False
# Physical table last modified timestamp, not to be confused with the "updated_ts" field
# which is for the snapshot record itself
last_altered_ts: t.Optional[int] = None
dev_last_altered_ts: t.Optional[int] = None

@field_validator("ttl")
@classmethod
Expand Down Expand Up @@ -751,6 +770,7 @@ def hydrate_with_intervals_by_version(
)
for interval in snapshot_intervals:
snapshot.merge_intervals(interval)

result.append(snapshot)

return result
Expand Down Expand Up @@ -957,12 +977,20 @@ def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None:
if not apply_effective_from or end <= effective_from_ts:
self.add_interval(start, end)

if other.last_altered_ts:
self.last_altered_ts = max(self.last_altered_ts or 0, other.last_altered_ts)

if self.dev_version == other.dev_version:
# Merge dev intervals if the dev versions match which would mean
# that this and the other snapshot are pointing to the same dev table.
for start, end in other.dev_intervals:
self.add_interval(start, end, is_dev=True)

if other.dev_last_altered_ts:
self.dev_last_altered_ts = max(
self.dev_last_altered_ts or 0, other.dev_last_altered_ts
)

self.pending_restatement_intervals = merge_intervals(
[*self.pending_restatement_intervals, *other.pending_restatement_intervals]
)
Expand Down Expand Up @@ -1081,6 +1109,7 @@ def check_ready_intervals(
python_env=signals.python_env,
dialect=self.model.dialect,
path=self.model._path,
snapshot=self,
kwargs=kwargs,
)
except SQLMeshError as e:
Expand Down Expand Up @@ -2421,6 +2450,7 @@ def check_ready_intervals(
python_env: t.Dict[str, Executable],
dialect: DialectType = None,
path: t.Optional[Path] = None,
snapshot: t.Optional[Snapshot] = None,
kwargs: t.Optional[t.Dict] = None,
) -> Intervals:
checked_intervals: Intervals = []
Expand All @@ -2436,6 +2466,7 @@ def check_ready_intervals(
provided_args=(batch,),
provided_kwargs=(kwargs or {}),
context=context,
snapshot=snapshot,
)
except Exception as ex:
raise SignalEvalError(format_evaluated_code_exception(ex, python_env))
Expand Down
4 changes: 4 additions & 0 deletions sqlmesh/core/state_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def add_interval(
start: TimeLike,
end: TimeLike,
is_dev: bool = False,
last_altered_ts: t.Optional[int] = None,
) -> None:
"""Add an interval to a snapshot and sync it to the store.

Expand All @@ -504,6 +505,7 @@ def add_interval(
start: The start of the interval to add.
end: The end of the interval to add.
is_dev: Indicates whether the given interval is being added while in development mode
last_altered_ts: The timestamp of the last modification of the physical table
"""
start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False, expand=False)
if not snapshot.version:
Expand All @@ -516,6 +518,8 @@ def add_interval(
dev_version=snapshot.dev_version,
intervals=intervals if not is_dev else [],
dev_intervals=intervals if is_dev else [],
last_altered_ts=last_altered_ts if not is_dev else None,
dev_last_altered_ts=last_altered_ts if is_dev else None,
)
self.add_snapshots_intervals([snapshot_intervals])

Expand Down
3 changes: 2 additions & 1 deletion sqlmesh/core/state_sync/db/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,9 @@ def add_interval(
start: TimeLike,
end: TimeLike,
is_dev: bool = False,
last_altered_ts: t.Optional[int] = None,
) -> None:
super().add_interval(snapshot, start, end, is_dev)
super().add_interval(snapshot, start, end, is_dev, last_altered_ts)

@transactional()
def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None:
Expand Down
Loading