Skip to content

Commit c512e63

Browse files
authored
Refactor!: make when_matched syntax compatible with merge syntax (#3497)
1 parent a637248 commit c512e63

File tree

13 files changed

+323
-183
lines changed

13 files changed

+323
-183
lines changed

docs/concepts/models/model_kinds.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,9 @@ MODEL (
320320
name db.employees,
321321
kind INCREMENTAL_BY_UNIQUE_KEY (
322322
unique_key name,
323-
when_matched WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)
323+
when_matched (
324+
WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)
325+
)
324326
)
325327
);
326328
```
@@ -334,8 +336,10 @@ MODEL (
334336
name db.employees,
335337
kind INCREMENTAL_BY_UNIQUE_KEY (
336338
unique_key name,
337-
when_matched WHEN MATCHED AND source.value IS NULL THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary),
338-
WHEN MATCHED THEN UPDATE SET target.title = COALESCE(source.title, target.title)
339+
when_matched (
340+
WHEN MATCHED AND source.value IS NULL THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)
341+
WHEN MATCHED THEN UPDATE SET target.title = COALESCE(source.title, target.title)
342+
)
339343
)
340344
);
341345
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
"rich[jupyter]",
4949
"ruamel.yaml",
5050
"setuptools; python_version>='3.12'",
51-
"sqlglot[rs]~=25.34.1",
51+
"sqlglot[rs]~=26.0.0",
5252
"tenacity",
5353
],
5454
extras_require={

sqlmesh/core/dialect.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -409,13 +409,12 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
409409
return None
410410

411411
name = key.name.lower()
412-
if name == "when_matched":
413-
value: t.Optional[t.Union[exp.Expression, t.List[exp.Expression]]] = (
414-
self._parse_when_matched() # type: ignore
415-
)
416-
elif name == "time_data_type":
412+
if name == "time_data_type":
417413
# TODO: if we make *_data_type a convention to parse things into exp.DataType, we could make this more generic
418414
value = self._parse_types(schema=True)
415+
elif name == "when_matched":
416+
# Parentheses around the WHEN clauses can be used to disambiguate them from other properties
417+
value = self._parse_wrapped(self._parse_when_matched, optional=True)
419418
elif self._match(TokenType.L_PAREN):
420419
value = self.expression(exp.Tuple, expressions=self._parse_csv(self._parse_equality))
421420
self._match_r_paren()
@@ -605,15 +604,11 @@ def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str:
605604
size = len(expressions)
606605

607606
for i, prop in enumerate(expressions):
608-
value = prop.args.get("value")
609-
if prop.name == "when_matched" and isinstance(value, list):
610-
output_value = ", ".join(self.sql(v) for v in value)
611-
else:
612-
output_value = self.sql(prop, "value")
613-
sql = self.indent(f"{prop.name} {output_value}")
607+
sql = self.indent(f"{prop.name} {self.sql(prop, 'value')}")
614608

615609
if i < size - 1:
616610
sql += ","
611+
617612
props.append(self.maybe_comment(sql, expression=prop))
618613

619614
return "\n".join(props)
@@ -648,6 +643,15 @@ def _macro_func_sql(self: Generator, expression: MacroFunc) -> str:
648643
return self.maybe_comment(sql, expression)
649644

650645

646+
def _whens_sql(self: Generator, expression: exp.Whens) -> str:
647+
if isinstance(expression.parent, exp.Merge):
648+
return self.whens_sql(expression)
649+
650+
# If the `WHEN` clauses aren't part of a MERGE statement (e.g. they
651+
# appear in the `MODEL` DDL), then we will wrap them with parentheses.
652+
return self.wrap(self.expressions(expression, sep=" ", indent=False))
653+
654+
651655
def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None:
652656
name = func.__name__
653657
setattr(klass, f"_{name}", getattr(klass, name))
@@ -901,6 +905,7 @@ def extend_sqlglot() -> None:
901905
ModelKind: _model_kind_sql,
902906
PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False),
903907
StagedFilePath: lambda self, e: self.table_sql(e),
908+
exp.Whens: _whens_sql,
904909
}
905910
)
906911

sqlmesh/core/engine_adapter/base.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,20 +1337,13 @@ def _merge(
13371337
target_table: TableName,
13381338
query: Query,
13391339
on: exp.Expression,
1340-
match_expressions: t.List[exp.When],
1340+
whens: exp.Whens,
13411341
) -> None:
13421342
this = exp.alias_(exp.to_table(target_table), alias=MERGE_TARGET_ALIAS, table=True)
13431343
using = exp.alias_(
13441344
exp.Subquery(this=query), alias=MERGE_SOURCE_ALIAS, copy=False, table=True
13451345
)
1346-
self.execute(
1347-
exp.Merge(
1348-
this=this,
1349-
using=using,
1350-
on=on,
1351-
expressions=match_expressions,
1352-
)
1353-
)
1346+
self.execute(exp.Merge(this=this, using=using, on=on, whens=whens))
13541347

13551348
def scd_type_2_by_time(
13561349
self,
@@ -1807,7 +1800,7 @@ def merge(
18071800
source_table: QueryOrDF,
18081801
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
18091802
unique_key: t.Sequence[exp.Expression],
1810-
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
1803+
when_matched: t.Optional[exp.Whens] = None,
18111804
) -> None:
18121805
source_queries, columns_to_types = self._get_source_queries_and_columns_to_types(
18131806
source_table, columns_to_types, target_table=target_table
@@ -1820,17 +1813,23 @@ def merge(
18201813
)
18211814
)
18221815
if not when_matched:
1823-
when_matched = exp.When(
1824-
matched=True,
1825-
source=False,
1826-
then=exp.Update(
1827-
expressions=[
1828-
exp.column(col, MERGE_TARGET_ALIAS).eq(exp.column(col, MERGE_SOURCE_ALIAS))
1829-
for col in columns_to_types
1830-
],
1816+
when_matched = exp.Whens()
1817+
when_matched.append(
1818+
"expressions",
1819+
exp.When(
1820+
matched=True,
1821+
source=False,
1822+
then=exp.Update(
1823+
expressions=[
1824+
exp.column(col, MERGE_TARGET_ALIAS).eq(
1825+
exp.column(col, MERGE_SOURCE_ALIAS)
1826+
)
1827+
for col in columns_to_types
1828+
],
1829+
),
18311830
),
18321831
)
1833-
when_matched = ensure_list(when_matched)
1832+
18341833
when_not_matched = exp.When(
18351834
matched=False,
18361835
source=False,
@@ -1841,14 +1840,15 @@ def merge(
18411840
),
18421841
),
18431842
)
1844-
match_expressions = when_matched + [when_not_matched]
1843+
when_matched.append("expressions", when_not_matched)
1844+
18451845
for source_query in source_queries:
18461846
with source_query as query:
18471847
self._merge(
18481848
target_table=target_table,
18491849
query=query,
18501850
on=on,
1851-
match_expressions=match_expressions,
1851+
whens=when_matched,
18521852
)
18531853

18541854
def rename_table(

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def merge(
3030
source_table: QueryOrDF,
3131
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
3232
unique_key: t.Sequence[exp.Expression],
33-
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
33+
when_matched: t.Optional[exp.Whens] = None,
3434
) -> None:
3535
logical_merge(
3636
self,
@@ -105,7 +105,9 @@ def _insert_overwrite_by_condition(
105105
target_table=table_name,
106106
query=query,
107107
on=exp.false(),
108-
match_expressions=[when_not_matched_by_source, when_not_matched_by_target],
108+
whens=exp.Whens(
109+
expressions=[when_not_matched_by_source, when_not_matched_by_target]
110+
),
109111
)
110112

111113

@@ -406,7 +408,7 @@ def logical_merge(
406408
source_table: QueryOrDF,
407409
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
408410
unique_key: t.Sequence[exp.Expression],
409-
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
411+
when_matched: t.Optional[exp.Whens] = None,
410412
) -> None:
411413
"""
412414
Merge implementation for engine adapters that do not support merge natively.

sqlmesh/core/engine_adapter/postgres.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def merge(
106106
source_table: QueryOrDF,
107107
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
108108
unique_key: t.Sequence[exp.Expression],
109-
when_matched: t.Optional[t.Union[exp.When, t.List[exp.When]]] = None,
109+
when_matched: t.Optional[exp.Whens] = None,
110110
) -> None:
111111
# Merge isn't supported until Postgres 15
112112
merge_impl = (

sqlmesh/core/model/kind.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from pydantic import Field
77
from sqlglot import exp
8-
from sqlglot.helper import ensure_list
98
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
109
from sqlglot.optimizer.qualify_columns import quote_identifiers
1110
from sqlglot.optimizer.simplify import gen
@@ -423,48 +422,38 @@ class IncrementalByUniqueKeyKind(_IncrementalBy):
423422
ModelKindName.INCREMENTAL_BY_UNIQUE_KEY
424423
)
425424
unique_key: SQLGlotListOfFields
426-
when_matched: t.Optional[t.List[exp.When]] = None
425+
when_matched: t.Optional[exp.Whens] = None
427426
batch_concurrency: t.Literal[1] = 1
428427

429428
@field_validator("when_matched", mode="before")
430429
@field_validator_v1_args
431430
def _when_matched_validator(
432431
cls,
433-
v: t.Optional[t.Union[exp.When, str, t.List[exp.When], t.List[str]]],
432+
v: t.Optional[t.Union[str, exp.Whens]],
434433
values: t.Dict[str, t.Any],
435-
) -> t.Optional[t.List[exp.When]]:
434+
) -> t.Optional[exp.Whens]:
436435
def replace_table_references(expression: exp.Expression) -> exp.Expression:
437-
from sqlmesh.core.engine_adapter.base import (
438-
MERGE_SOURCE_ALIAS,
439-
MERGE_TARGET_ALIAS,
440-
)
436+
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
441437

442438
if isinstance(expression, exp.Column):
443439
if expression.table.lower() == "target":
444-
expression.set(
445-
"table",
446-
exp.to_identifier(MERGE_TARGET_ALIAS),
447-
)
440+
expression.set("table", exp.to_identifier(MERGE_TARGET_ALIAS))
448441
elif expression.table.lower() == "source":
449-
expression.set(
450-
"table",
451-
exp.to_identifier(MERGE_SOURCE_ALIAS),
452-
)
442+
expression.set("table", exp.to_identifier(MERGE_SOURCE_ALIAS))
443+
453444
return expression
454445

455-
if not v:
456-
return v # type: ignore
457-
458-
result = []
459-
list_v = ensure_list(v)
460-
for value in ensure_list(list_v):
461-
if isinstance(value, str):
462-
result.append(
463-
t.cast(exp.When, d.parse_one(value, into=exp.When, dialect=get_dialect(values)))
464-
)
465-
else:
466-
result.append(t.cast(exp.When, value.transform(replace_table_references))) # type: ignore
467-
return result
446+
if v is None:
447+
return v
448+
if isinstance(v, str):
449+
# Whens wrap the WHEN clauses, but the parentheses aren't parsed by sqlglot
450+
v = v.strip()
451+
if v.startswith("("):
452+
v = v[1:-1]
453+
454+
return t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=get_dialect(values)))
455+
456+
return t.cast(exp.Whens, v.transform(replace_table_references))
468457

469458
@property
470459
def data_hash_values(self) -> t.List[t.Optional[str]]:

sqlmesh/core/model/meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def managed_columns(self) -> t.Dict[str, exp.DataType]:
430430
return getattr(self.kind, "managed_columns", {})
431431

432432
@property
433-
def when_matched(self) -> t.Optional[t.List[exp.When]]:
433+
def when_matched(self) -> t.Optional[exp.Whens]:
434434
if isinstance(self.kind, IncrementalByUniqueKeyKind):
435435
return self.kind.when_matched
436436
return None
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""Join list of `WHEN [NOT] MATCHED` strings into a single string."""
2+
3+
import json
4+
5+
import pandas as pd
6+
from sqlglot import exp
7+
8+
from sqlmesh.utils.migration import index_text_type, blob_text_type
9+
10+
11+
def migrate(state_sync, **kwargs): # type: ignore
12+
engine_adapter = state_sync.engine_adapter
13+
schema = state_sync.schema
14+
snapshots_table = "_snapshots"
15+
index_type = index_text_type(engine_adapter.dialect)
16+
if schema:
17+
snapshots_table = f"{schema}.{snapshots_table}"
18+
19+
new_snapshots = []
20+
21+
for (
22+
name,
23+
identifier,
24+
version,
25+
snapshot,
26+
kind_name,
27+
updated_ts,
28+
unpaused_ts,
29+
ttl_ms,
30+
unrestorable,
31+
) in engine_adapter.fetchall(
32+
exp.select(
33+
"name",
34+
"identifier",
35+
"version",
36+
"snapshot",
37+
"kind_name",
38+
"updated_ts",
39+
"unpaused_ts",
40+
"ttl_ms",
41+
"unrestorable",
42+
).from_(snapshots_table),
43+
quote_identifiers=True,
44+
):
45+
parsed_snapshot = json.loads(snapshot)
46+
node = parsed_snapshot["node"]
47+
48+
if "kind" in node:
49+
kind = node["kind"]
50+
if isinstance(when_matched := kind.get("when_matched"), list):
51+
kind["when_matched"] = " ".join(when_matched)
52+
53+
new_snapshots.append(
54+
{
55+
"name": name,
56+
"identifier": identifier,
57+
"version": version,
58+
"snapshot": json.dumps(parsed_snapshot),
59+
"kind_name": kind_name,
60+
"updated_ts": updated_ts,
61+
"unpaused_ts": unpaused_ts,
62+
"ttl_ms": ttl_ms,
63+
"unrestorable": unrestorable,
64+
}
65+
)
66+
67+
if new_snapshots:
68+
engine_adapter.delete_from(snapshots_table, "TRUE")
69+
blob_type = blob_text_type(engine_adapter.dialect)
70+
71+
engine_adapter.insert_append(
72+
snapshots_table,
73+
pd.DataFrame(new_snapshots),
74+
columns_to_types={
75+
"name": exp.DataType.build(index_type),
76+
"identifier": exp.DataType.build(index_type),
77+
"version": exp.DataType.build(index_type),
78+
"snapshot": exp.DataType.build(blob_type),
79+
"kind_name": exp.DataType.build(index_type),
80+
"updated_ts": exp.DataType.build("bigint"),
81+
"unpaused_ts": exp.DataType.build("bigint"),
82+
"ttl_ms": exp.DataType.build("bigint"),
83+
"unrestorable": exp.DataType.build("boolean"),
84+
},
85+
)

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,6 +2334,8 @@ def test_value_normalization(
23342334
)
23352335
],
23362336
)
2337+
if ctx.dialect == "tsql" and column_type == exp.DataType.Type.DATETIME:
2338+
full_column_type = exp.DataType.build("DATETIME2", dialect="tsql")
23372339

23382340
columns_to_types = {
23392341
"_idx": exp.DataType.build(DATA_TYPE.INT),

0 commit comments

Comments
 (0)