From bd005ed866e77d54fd8f16edee1da100ed0f21ef Mon Sep 17 00:00:00 2001 From: 0xRob <83790096+0xRobin@users.noreply.github.com> Date: Mon, 16 Mar 2026 10:52:20 +0100 Subject: [PATCH] add when_matched support to custom time-range materialization Parse and validate optional materialization_properties.when_matched and pass it through to adapter.merge so update clauses can preserve immutable columns like _inserted_at. --- ...on_idempotent_incremental_by_time_range.py | 36 +++++++- ...on_idempotent_incremental_by_time_range.py | 87 +++++++++++++++++++ 2 files changed, 122 insertions(+), 1 deletion(-) diff --git a/sqlmesh_utils/materializations/non_idempotent_incremental_by_time_range.py b/sqlmesh_utils/materializations/non_idempotent_incremental_by_time_range.py index 515ea68..04d5df9 100644 --- a/sqlmesh_utils/materializations/non_idempotent_incremental_by_time_range.py +++ b/sqlmesh_utils/materializations/non_idempotent_incremental_by_time_range.py @@ -3,11 +3,16 @@ from sqlmesh import CustomMaterialization from sqlmesh.core.model import Model from sqlmesh.core.model.kind import TimeColumn +import sqlmesh.core.dialect as d from sqlglot import exp from sqlmesh.utils.date import make_inclusive from sqlmesh.utils.errors import ConfigError, SQLMeshError from pydantic import model_validator -from sqlmesh.utils.pydantic import list_of_fields_validator, bool_validator +from sqlmesh.utils.pydantic import ( + bool_validator, + list_of_fields_validator, + validate_expression, +) from sqlmesh.utils.date import TimeLike from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS from sqlmesh import CustomKind @@ -22,6 +27,26 @@ class NonIdempotentIncrementalByTimeRangeKind(CustomKind): _primary_key: t.List[exp.Expression] _partition_by_time_column: bool + _when_matched: t.Optional[exp.Whens] + + def _parse_when_matched(self, value: t.Any) -> t.Optional[exp.Whens]: + if value is None: + return None + + if isinstance(value, list): + value = " ".join(value) + + if isinstance(value, str): + value = value.strip() + if value.startswith("("): + value = value[1:-1] + value = t.cast(exp.Whens, d.parse_one(value, into=exp.Whens, dialect=self.dialect)) + + value = validate_expression(value, dialect=self.dialect) + return t.cast( + exp.Whens, + value.transform(d.replace_merge_table_aliases, dialect=self.dialect), + ) @model_validator(mode="after") def _validate_model(self): @@ -49,6 +74,10 @@ def _validate_model(self): self.materialization_properties.get("partition_by_time_column", True) ) + self._when_matched = self._parse_when_matched( + self.materialization_properties.get("when_matched") + ) + return self @property @@ -63,6 +92,10 @@ def primary_key(self) -> t.List[exp.Expression]: def partition_by_time_column(self) -> bool: return self._partition_by_time_column + @property + def when_matched(self) -> t.Optional[exp.Whens]: + return self._when_matched + class NonIdempotentIncrementalByTimeRangeMaterialization( CustomMaterialization[NonIdempotentIncrementalByTimeRangeKind] @@ -130,6 +163,7 @@ def _inject_alias(node: exp.Expression, alias: str) -> exp.Expression: source_table=query_or_df, target_columns_to_types=columns_to_types, unique_key=model.kind.primary_key, + when_matched=model.kind.when_matched, merge_filter=exp.and_(*betweens), source_columns=source_columns, ) diff --git a/tests/materializations/test_non_idempotent_incremental_by_time_range.py b/tests/materializations/test_non_idempotent_incremental_by_time_range.py index 3586524..8d7e3ff 100644 --- a/tests/materializations/test_non_idempotent_incremental_by_time_range.py +++ b/tests/materializations/test_non_idempotent_incremental_by_time_range.py @@ -59,6 +59,16 @@ def test_kind(make_model: ModelMaker): exp.to_column("id", quoted=True), exp.to_column("ds", quoted=True), ] + assert model.kind.when_matched is None + + model = make_model( + [ + "time_column = ds", + "primary_key = (id, ds)", + "when_matched = 'when matched then update set target.name = source.name'", + ] + ) + assert model.kind.when_matched is not None # required fields with pytest.raises(ConfigError, match=r"Invalid time_column"): @@ -165,6 +175,83 @@ def test_append(make_model: ModelMaker, make_mocked_engine_adapter: MockedEngine ] +def test_insert_with_when_matched( + make_model: ModelMaker, make_mocked_engine_adapter: MockedEngineAdapterMaker +): + model: Model = make_model( + [ + "time_column = ds", + "primary_key = name", + "when_matched = 'when matched then update set target.name = source.name'", + ], + dialect="trino", + ) + adapter = make_mocked_engine_adapter(TrinoEngineAdapter) + strategy = NonIdempotentIncrementalByTimeRangeMaterialization(adapter) + + start = to_timestamp("2020-01-01") + end = to_timestamp("2020-01-03") + + strategy.insert( + "test.snapshot_table", + query_or_df=model.render_query( + start=start, end=end, execution_time=now(), runtime_stage=RuntimeStage.EVALUATING + ), + model=model, + is_first_insert=False, + start=start, + end=end, + render_kwargs={}, + ) + + assert to_sql_calls(adapter) == [ + parse_one( + """ + MERGE INTO "test"."snapshot_table" AS "__merge_target__" + USING ( + SELECT + CAST("name" AS VARCHAR) AS "name", + CAST("ds" AS TIMESTAMP) AS "ds" + FROM "upstream"."table" AS "table" + WHERE + "ds" BETWEEN '2020-01-01 00:00:00' AND '2020-01-02 23:59:59.999999' + ) AS "__MERGE_SOURCE__" + ON ( + "__MERGE_SOURCE__"."ds" BETWEEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) AND CAST('2020-01-02 23:59:59.999999' AS TIMESTAMP) + AND "__MERGE_TARGET__"."ds" BETWEEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) AND CAST('2020-01-02 23:59:59.999999' AS TIMESTAMP) + ) + AND "__MERGE_TARGET__"."name" = "__MERGE_SOURCE__"."name" + WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."name" = "__MERGE_SOURCE__"."name" + WHEN NOT MATCHED THEN INSERT ("name", "ds") VALUES ("__MERGE_SOURCE__"."name", "__MERGE_SOURCE__"."ds") + """, + dialect=adapter.dialect, + ).sql(dialect=adapter.dialect), + ] + + +def test_when_matched_multiple_clauses(make_model: ModelMaker): + model = make_model( + [ + "time_column = ds", + "primary_key = (id, ds)", + "when_matched = 'when matched and source.name is null then delete when matched then update set target.name = source.name'", + ] + ) + assert model.kind.when_matched is not None + assert len(model.kind.when_matched.expressions) == 2 + + +def test_when_matched_invalid_syntax(make_model: ModelMaker): + with pytest.raises(Exception): + make_model( + [ + "time_column = ds", + "primary_key = (id, ds)", + "when_matched = 'this is not valid sql'", + ] + ) + + def test_partition_by_time_column_opt_out(make_model: ModelMaker): model = make_model( ["time_column = ds", "primary_key = name", "partition_by_time_column = false"]