Skip to content

Commit 431d319

Browse files
authored
feat: support multiple when matched inc unique key (#3124)
1 parent aa41dee commit 431d319

File tree

9 files changed

+272
-34
lines changed

9 files changed

+272
-34
lines changed

docs/concepts/models/model_kinds.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,19 @@ MODEL (
259259

260260
The `source` and `target` aliases are required when using the `when_matched` expression in order to distinguish between the source and target columns.
261261

262+
Multiple `WHEN MATCHED` expressions can also be provided. Ex:
263+
264+
```sql linenums="1" hl_lines="5-6"
265+
MODEL (
266+
name db.employees,
267+
kind INCREMENTAL_BY_UNIQUE_KEY (
268+
unique_key name,
269+
when_matched WHEN MATCHED AND source.value IS NULL THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary),
270+
WHEN MATCHED THEN UPDATE SET target.title = COALESCE(source.title, target.title)
271+
)
272+
);
273+
```
274+
262275
**Note**: `when_matched` is only available on engines that support the `MERGE` statement. Currently supported engines include:
263276

264277
* BigQuery

sqlmesh/core/dialect.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,9 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
398398

399399
name = key.name.lower()
400400
if name == "when_matched":
401-
value: t.Optional[exp.Expression] = self._parse_when_matched()[0]
401+
value: t.Optional[t.Union[exp.Expression, t.List[exp.Expression]]] = (
402+
self._parse_when_matched() # type: ignore
403+
)
402404
elif name == "time_data_type":
403405
# TODO: if we make *_data_type a convention to parse things into exp.DataType, we could make this more generic
404406
value = self._parse_types(schema=True)
@@ -410,7 +412,7 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
410412

411413
if name == "path" and value:
412414
# Make sure if we get a windows path that it is converted to posix
413-
value = exp.Literal.string(value.this.replace("\\", "/"))
415+
value = exp.Literal.string(value.this.replace("\\", "/")) # type: ignore
414416

415417
return self.expression(exp.Property, this=name, value=value)
416418

sqlmesh/core/engine_adapter/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1726,7 +1726,7 @@ def merge(
17261726
source_table: QueryOrDF,
17271727
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
17281728
unique_key: t.Sequence[exp.Expression],
1729-
when_matched: t.Optional[exp.When] = None,
1729+
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
17301730
) -> None:
17311731
source_queries, columns_to_types = self._get_source_queries_and_columns_to_types(
17321732
source_table, columns_to_types, target_table=target_table
@@ -1749,6 +1749,7 @@ def merge(
17491749
],
17501750
),
17511751
)
1752+
when_matched = ensure_list(when_matched)
17521753
when_not_matched = exp.When(
17531754
matched=False,
17541755
source=False,
@@ -1759,13 +1760,14 @@ def merge(
17591760
),
17601761
),
17611762
)
1763+
match_expressions = when_matched + [when_not_matched]
17621764
for source_query in source_queries:
17631765
with source_query as query:
17641766
self._merge(
17651767
target_table=target_table,
17661768
query=query,
17671769
on=on,
1768-
match_expressions=[when_matched, when_not_matched],
1770+
match_expressions=match_expressions,
17691771
)
17701772

17711773
def rename_table(

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def merge(
2525
source_table: QueryOrDF,
2626
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
2727
unique_key: t.Sequence[exp.Expression],
28-
when_matched: t.Optional[exp.When] = None,
28+
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
2929
) -> None:
3030
"""
3131
Merge implementation for engine adapters that do not support merge natively.

sqlmesh/core/model/kind.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pydantic import Field
88
from sqlglot import exp
9+
from sqlglot.helper import ensure_list
910
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1011
from sqlglot.optimizer.qualify_columns import quote_identifiers
1112
from sqlglot.optimizer.simplify import gen
@@ -424,14 +425,16 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
424425
class IncrementalByUniqueKeyKind(_IncrementalBy):
425426
name: Literal[ModelKindName.INCREMENTAL_BY_UNIQUE_KEY] = ModelKindName.INCREMENTAL_BY_UNIQUE_KEY
426427
unique_key: SQLGlotListOfFields
427-
when_matched: t.Optional[exp.When] = None
428+
when_matched: t.Optional[t.List[exp.When]] = None
428429
batch_concurrency: Literal[1] = 1
429430

430431
@field_validator("when_matched", mode="before")
431432
@field_validator_v1_args
432433
def _when_matched_validator(
433-
cls, v: t.Optional[t.Union[exp.When, str]], values: t.Dict[str, t.Any]
434-
) -> t.Optional[exp.When]:
434+
cls,
435+
v: t.Optional[t.Union[exp.When, str, t.List[exp.When], t.List[str]]],
436+
values: t.Dict[str, t.Any],
437+
) -> t.Optional[t.List[exp.When]]:
435438
def replace_table_references(expression: exp.Expression) -> exp.Expression:
436439
from sqlmesh.core.engine_adapter.base import (
437440
MERGE_SOURCE_ALIAS,
@@ -451,13 +454,19 @@ def replace_table_references(expression: exp.Expression) -> exp.Expression:
451454
)
452455
return expression
453456

454-
if isinstance(v, str):
455-
return t.cast(exp.When, d.parse_one(v, into=exp.When, dialect=get_dialect(values)))
456-
457457
if not v:
458-
return v
459-
460-
return t.cast(exp.When, v.transform(replace_table_references))
458+
return v # type: ignore
459+
460+
result = []
461+
list_v = ensure_list(v)
462+
for value in ensure_list(list_v):
463+
if isinstance(value, str):
464+
result.append(
465+
t.cast(exp.When, d.parse_one(value, into=exp.When, dialect=get_dialect(values)))
466+
)
467+
else:
468+
result.append(t.cast(exp.When, value.transform(replace_table_references))) # type: ignore
469+
return result
461470

462471
@property
463472
def data_hash_values(self) -> t.List[t.Optional[str]]:

sqlmesh/core/model/meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def managed_columns(self) -> t.Dict[str, exp.DataType]:
439439
return getattr(self.kind, "managed_columns", {})
440440

441441
@property
442-
def when_matched(self) -> t.Optional[exp.When]:
442+
def when_matched(self) -> t.Optional[t.List[exp.When]]:
443443
if isinstance(self.kind, IncrementalByUniqueKeyKind):
444444
return self.kind.when_matched
445445
return None

tests/core/engine_adapter/test_base.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,75 @@ def test_merge_when_matched(make_mocked_engine_adapter: t.Callable, assert_exp_e
974974
)
975975

976976

977+
def test_merge_when_matched_multiple(make_mocked_engine_adapter: t.Callable, assert_exp_eq):
978+
adapter = make_mocked_engine_adapter(EngineAdapter)
979+
980+
adapter.merge(
981+
target_table="target",
982+
source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')),
983+
columns_to_types={
984+
"ID": exp.DataType.build("int"),
985+
"ts": exp.DataType.build("timestamp"),
986+
"val": exp.DataType.build("int"),
987+
},
988+
unique_key=[exp.to_identifier("ID", quoted=True)],
989+
when_matched=[
990+
exp.When(
991+
matched=True,
992+
condition=exp.column("ID", "__MERGE_SOURCE__").eq(exp.Literal.number(1)),
993+
then=exp.Update(
994+
expressions=[
995+
exp.column("val", "__MERGE_TARGET__").eq(
996+
exp.column("val", "__MERGE_SOURCE__")
997+
),
998+
exp.column("ts", "__MERGE_TARGET__").eq(
999+
exp.Coalesce(
1000+
this=exp.column("ts", "__MERGE_SOURCE__"),
1001+
expressions=[exp.column("ts", "__MERGE_TARGET__")],
1002+
)
1003+
),
1004+
],
1005+
),
1006+
),
1007+
exp.When(
1008+
matched=True,
1009+
source=False,
1010+
then=exp.Update(
1011+
expressions=[
1012+
exp.column("val", "__MERGE_TARGET__").eq(
1013+
exp.column("val", "__MERGE_SOURCE__")
1014+
),
1015+
exp.column("ts", "__MERGE_TARGET__").eq(
1016+
exp.Coalesce(
1017+
this=exp.column("ts", "__MERGE_SOURCE__"),
1018+
expressions=[exp.column("ts", "__MERGE_TARGET__")],
1019+
)
1020+
),
1021+
],
1022+
),
1023+
),
1024+
],
1025+
)
1026+
1027+
assert_exp_eq(
1028+
adapter.cursor.execute.call_args[0][0],
1029+
"""
1030+
MERGE INTO "target" AS "__MERGE_TARGET__" USING (
1031+
SELECT
1032+
"ID",
1033+
"ts",
1034+
"val"
1035+
FROM "source"
1036+
) AS "__MERGE_SOURCE__"
1037+
ON "__MERGE_TARGET__"."ID" = "__MERGE_SOURCE__"."ID"
1038+
WHEN MATCHED AND "__MERGE_SOURCE__"."ID" = 1 THEN UPDATE SET "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val", "__MERGE_TARGET__"."ts" = COALESCE("__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."ts"),
1039+
WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val", "__MERGE_TARGET__"."ts" = COALESCE("__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."ts")
1040+
WHEN NOT MATCHED THEN INSERT ("ID", "ts", "val")
1041+
VALUES ("__MERGE_SOURCE__"."ID", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")
1042+
""",
1043+
)
1044+
1045+
9771046
def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
9781047
adapter = make_mocked_engine_adapter(EngineAdapter)
9791048

tests/core/test_model.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3904,10 +3904,44 @@ def test_when_matched():
39043904
expected_when_matched = "WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)"
39053905

39063906
model = load_sql_based_model(expressions, dialect="hive")
3907-
assert model.kind.when_matched.sql() == expected_when_matched
3907+
assert len(model.kind.when_matched) == 1
3908+
assert model.kind.when_matched[0].sql() == expected_when_matched
39083909

39093910
model = SqlModel.parse_raw(model.json())
3910-
assert model.kind.when_matched.sql() == expected_when_matched
3911+
assert len(model.kind.when_matched) == 1
3912+
assert model.kind.when_matched[0].sql() == expected_when_matched
3913+
3914+
3915+
def test_when_matched_multiple():
3916+
expressions = d.parse(
3917+
"""
3918+
MODEL (
3919+
name db.employees,
3920+
kind INCREMENTAL_BY_UNIQUE_KEY (
3921+
unique_key name,
3922+
when_matched WHEN MATCHED AND source.x = 1 THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary),
3923+
WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)
3924+
3925+
)
3926+
);
3927+
SELECT 'name' AS name, 1 AS salary;
3928+
"""
3929+
)
3930+
3931+
expected_when_matched = [
3932+
"WHEN MATCHED AND __MERGE_SOURCE__.x = 1 THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)",
3933+
"WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)",
3934+
]
3935+
3936+
model = load_sql_based_model(expressions, dialect="hive")
3937+
assert len(model.kind.when_matched) == 2
3938+
assert model.kind.when_matched[0].sql() == expected_when_matched[0]
3939+
assert model.kind.when_matched[1].sql() == expected_when_matched[1]
3940+
3941+
model = SqlModel.parse_raw(model.json())
3942+
assert len(model.kind.when_matched) == 2
3943+
assert model.kind.when_matched[0].sql() == expected_when_matched[0]
3944+
assert model.kind.when_matched[1].sql() == expected_when_matched[1]
39113945

39123946

39133947
def test_default_catalog_sql(assert_exp_eq):
@@ -5438,7 +5472,35 @@ def test_model_kind_to_expression():
54385472
.sql()
54395473
== """INCREMENTAL_BY_UNIQUE_KEY (
54405474
unique_key ("a"),
5441-
when_matched WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b),
5475+
when_matched ARRAY(WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)),
5476+
batch_concurrency 1,
5477+
forward_only FALSE,
5478+
disable_restatement FALSE,
5479+
on_destructive_change 'ERROR'
5480+
)"""
5481+
)
5482+
5483+
assert (
5484+
load_sql_based_model(
5485+
d.parse(
5486+
"""
5487+
MODEL (
5488+
name db.table,
5489+
kind INCREMENTAL_BY_UNIQUE_KEY(
5490+
unique_key a,
5491+
when_matched WHEN MATCHED AND source.x = 1 THEN UPDATE SET target.b = COALESCE(source.b, target.b),
5492+
WHEN MATCHED THEN UPDATE SET target.b = COALESCE(source.b, target.b)
5493+
),
5494+
);
5495+
SELECT a, b
5496+
"""
5497+
)
5498+
)
5499+
.kind.to_expression()
5500+
.sql()
5501+
== """INCREMENTAL_BY_UNIQUE_KEY (
5502+
unique_key ("a"),
5503+
when_matched ARRAY(WHEN MATCHED AND __MERGE_SOURCE__.x = 1 THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b), WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b)),
54425504
batch_concurrency 1,
54435505
forward_only FALSE,
54445506
disable_restatement FALSE,

0 commit comments

Comments
 (0)