Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
from sqlmesh import CustomMaterialization
from sqlmesh.core.model import Model
from sqlmesh.core.model.kind import TimeColumn
import sqlmesh.core.dialect as d
from sqlglot import exp
from sqlmesh.utils.date import make_inclusive
from sqlmesh.utils.errors import ConfigError, SQLMeshError
from pydantic import model_validator
from sqlmesh.utils.pydantic import list_of_fields_validator, bool_validator
from sqlmesh.utils.pydantic import (
bool_validator,
list_of_fields_validator,
validate_expression,
)
from sqlmesh.utils.date import TimeLike
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
from sqlmesh import CustomKind
Expand All @@ -22,6 +27,26 @@ class NonIdempotentIncrementalByTimeRangeKind(CustomKind):
_primary_key: t.List[exp.Expression]

_partition_by_time_column: bool
_when_matched: t.Optional[exp.Whens]

def _parse_when_matched(self, value: t.Any) -> t.Optional[exp.Whens]:
if value is None:
return None

if isinstance(value, list):
value = " ".join(value)

if isinstance(value, str):
value = value.strip()
if value.startswith("("):
value = value[1:-1]
value = t.cast(exp.Whens, d.parse_one(value, into=exp.Whens, dialect=self.dialect))

value = validate_expression(value, dialect=self.dialect)
return t.cast(
exp.Whens,
value.transform(d.replace_merge_table_aliases, dialect=self.dialect),
)

@model_validator(mode="after")
def _validate_model(self):
Expand Down Expand Up @@ -49,6 +74,10 @@ def _validate_model(self):
self.materialization_properties.get("partition_by_time_column", True)
)

self._when_matched = self._parse_when_matched(
self.materialization_properties.get("when_matched")
)

return self

@property
Expand All @@ -63,6 +92,10 @@ def primary_key(self) -> t.List[exp.Expression]:
def partition_by_time_column(self) -> bool:
return self._partition_by_time_column

@property
def when_matched(self) -> t.Optional[exp.Whens]:
return self._when_matched


class NonIdempotentIncrementalByTimeRangeMaterialization(
CustomMaterialization[NonIdempotentIncrementalByTimeRangeKind]
Expand Down Expand Up @@ -130,6 +163,7 @@ def _inject_alias(node: exp.Expression, alias: str) -> exp.Expression:
source_table=query_or_df,
target_columns_to_types=columns_to_types,
unique_key=model.kind.primary_key,
when_matched=model.kind.when_matched,
merge_filter=exp.and_(*betweens),
source_columns=source_columns,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def test_kind(make_model: ModelMaker):
exp.to_column("id", quoted=True),
exp.to_column("ds", quoted=True),
]
assert model.kind.when_matched is None

model = make_model(
[
"time_column = ds",
"primary_key = (id, ds)",
"when_matched = 'when matched then update set target.name = source.name'",
]
)
assert model.kind.when_matched is not None

# required fields
with pytest.raises(ConfigError, match=r"Invalid time_column"):
Expand Down Expand Up @@ -165,6 +175,83 @@ def test_append(make_model: ModelMaker, make_mocked_engine_adapter: MockedEngine
]


def test_insert_with_when_matched(
make_model: ModelMaker, make_mocked_engine_adapter: MockedEngineAdapterMaker
):
model: Model = make_model(
[
"time_column = ds",
"primary_key = name",
"when_matched = 'when matched then update set target.name = source.name'",
],
dialect="trino",
)
adapter = make_mocked_engine_adapter(TrinoEngineAdapter)
strategy = NonIdempotentIncrementalByTimeRangeMaterialization(adapter)

start = to_timestamp("2020-01-01")
end = to_timestamp("2020-01-03")

strategy.insert(
"test.snapshot_table",
query_or_df=model.render_query(
start=start, end=end, execution_time=now(), runtime_stage=RuntimeStage.EVALUATING
),
model=model,
is_first_insert=False,
start=start,
end=end,
render_kwargs={},
)

assert to_sql_calls(adapter) == [
parse_one(
"""
MERGE INTO "test"."snapshot_table" AS "__merge_target__"
USING (
SELECT
CAST("name" AS VARCHAR) AS "name",
CAST("ds" AS TIMESTAMP) AS "ds"
FROM "upstream"."table" AS "table"
WHERE
"ds" BETWEEN '2020-01-01 00:00:00' AND '2020-01-02 23:59:59.999999'
) AS "__MERGE_SOURCE__"
ON (
"__MERGE_SOURCE__"."ds" BETWEEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) AND CAST('2020-01-02 23:59:59.999999' AS TIMESTAMP)
AND "__MERGE_TARGET__"."ds" BETWEEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) AND CAST('2020-01-02 23:59:59.999999' AS TIMESTAMP)
)
AND "__MERGE_TARGET__"."name" = "__MERGE_SOURCE__"."name"
WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."name" = "__MERGE_SOURCE__"."name"
WHEN NOT MATCHED THEN INSERT ("name", "ds") VALUES ("__MERGE_SOURCE__"."name", "__MERGE_SOURCE__"."ds")
""",
dialect=adapter.dialect,
).sql(dialect=adapter.dialect),
]


def test_when_matched_multiple_clauses(make_model: ModelMaker):
model = make_model(
[
"time_column = ds",
"primary_key = (id, ds)",
"when_matched = 'when matched and source.name is null then delete when matched then update set target.name = source.name'",
]
)
assert model.kind.when_matched is not None
assert len(model.kind.when_matched.expressions) == 2


def test_when_matched_invalid_syntax(make_model: ModelMaker):
with pytest.raises(Exception):
make_model(
[
"time_column = ds",
"primary_key = (id, ds)",
"when_matched = 'this is not valid sql'",
]
)


def test_partition_by_time_column_opt_out(make_model: ModelMaker):
model = make_model(
["time_column = ds", "primary_key = name", "partition_by_time_column = false"]
Expand Down