Skip to content

Commit ec840bf

Browse files
authored
Feat: Introduce INCREMENTAL_BY_PARTITION model kind (#2687)
1 parent 999a202 commit ec840bf

File tree

15 files changed

+355
-34
lines changed

15 files changed

+355
-34
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
"cryptography~=42.0.4",
6666
"dbt-core",
6767
"dbt-duckdb>=1.7.1",
68+
"dbt-snowflake",
69+
"dbt-bigquery",
6870
"Faker",
6971
"google-auth",
7072
"google-cloud-bigquery",
@@ -93,10 +95,8 @@
9395
"typing-extensions",
9496
],
9597
"cicdtest": [
96-
"dbt-bigquery",
9798
"dbt-databricks",
9899
"dbt-redshift",
99-
"dbt-snowflake",
100100
"dbt-sqlserver>=1.7.0",
101101
"dbt-trino",
102102
],

sqlmesh/core/dialect.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]:
491491
if kind in (
492492
ModelKindName.INCREMENTAL_BY_TIME_RANGE,
493493
ModelKindName.INCREMENTAL_BY_UNIQUE_KEY,
494+
ModelKindName.INCREMENTAL_BY_PARTITION,
494495
ModelKindName.SEED,
495496
ModelKindName.VIEW,
496497
ModelKindName.SCD_TYPE_2,

sqlmesh/core/engine_adapter/base.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,9 +1041,18 @@ def insert_overwrite_by_partition(
10411041
partitioned_by: t.List[exp.Expression],
10421042
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
10431043
) -> None:
1044-
raise NotImplementedError(
1045-
"Insert Overwrite by Partition (not time) is not supported by this engine"
1046-
)
1044+
if self.INSERT_OVERWRITE_STRATEGY.is_insert_overwrite:
1045+
target_table = exp.to_table(table_name)
1046+
source_queries, columns_to_types = self._get_source_queries_and_columns_to_types(
1047+
query_or_df, columns_to_types, target_table=target_table
1048+
)
1049+
self._insert_overwrite_by_condition(
1050+
table_name, source_queries, columns_to_types=columns_to_types
1051+
)
1052+
else:
1053+
self._replace_by_key(
1054+
table_name, query_or_df, columns_to_types, partitioned_by, is_unique_key=False
1055+
)
10471056

10481057
def insert_overwrite_by_time_partition(
10491058
self,
@@ -2006,6 +2015,49 @@ def _truncate_table(self, table_name: TableName) -> None:
20062015
table = exp.to_table(table_name)
20072016
self.execute(f"TRUNCATE TABLE {table.sql(dialect=self.dialect, identify=True)}")
20082017

2018+
def _replace_by_key(
2019+
self,
2020+
target_table: TableName,
2021+
source_table: QueryOrDF,
2022+
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
2023+
key: t.Sequence[exp.Expression],
2024+
is_unique_key: bool,
2025+
) -> None:
2026+
if columns_to_types is None:
2027+
columns_to_types = self.columns(target_table)
2028+
2029+
temp_table = self._get_temp_table(target_table)
2030+
key_exp = exp.func("CONCAT_WS", "'__SQLMESH_DELIM__'", *key)
2031+
column_names = list(columns_to_types or [])
2032+
2033+
with self.transaction():
2034+
self.ctas(temp_table, source_table, columns_to_types=columns_to_types, exists=False)
2035+
2036+
try:
2037+
delete_query = exp.select(key_exp).from_(temp_table)
2038+
insert_query = self._select_columns(columns_to_types).from_(temp_table)
2039+
if not is_unique_key:
2040+
delete_query = delete_query.distinct()
2041+
else:
2042+
insert_query = insert_query.distinct(*key)
2043+
2044+
insert_statement = exp.insert(
2045+
insert_query,
2046+
target_table,
2047+
columns=column_names,
2048+
)
2049+
delete_filter = key_exp.isin(query=delete_query)
2050+
2051+
if not self.INSERT_OVERWRITE_STRATEGY.is_replace_where:
2052+
self.execute(exp.delete(target_table).where(delete_filter))
2053+
else:
2054+
insert_statement.set("where", delete_filter)
2055+
insert_statement.set("this", exp.to_table(target_table))
2056+
2057+
self.execute(insert_statement)
2058+
finally:
2059+
self.drop_table(temp_table)
2060+
20092061
def _build_create_comment_table_exp(
20102062
self, table: exp.Table, table_comment: str, table_kind: str
20112063
) -> exp.Comment | str:

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,9 @@ def merge(
4141
raise SQLMeshError(
4242
"This engine does not support MERGE expressions and therefore `when_matched` is not supported."
4343
)
44-
if columns_to_types is None:
45-
columns_to_types = self.columns(target_table)
46-
47-
temp_table = self._get_temp_table(target_table)
48-
unique_exp = exp.func("CONCAT_WS", "'__SQLMESH_DELIM__'", *unique_key)
49-
column_names = list(columns_to_types or [])
50-
51-
with self.transaction():
52-
self.ctas(temp_table, source_table, columns_to_types=columns_to_types, exists=False)
53-
self.execute(
54-
exp.delete(target_table).where(
55-
unique_exp.isin(query=exp.select(unique_exp).from_(temp_table))
56-
)
57-
)
58-
self.execute(
59-
exp.insert(
60-
self._select_columns(columns_to_types)
61-
.distinct(*unique_key)
62-
.from_(temp_table)
63-
.subquery(),
64-
target_table,
65-
columns=column_names,
66-
)
67-
)
68-
self.drop_table(temp_table)
44+
self._replace_by_key(
45+
target_table, source_table, columns_to_types, unique_key, is_unique_key=True
46+
)
6947

7048

7149
class PandasNativeFetchDFSupportMixin(EngineAdapter):

sqlmesh/core/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
IncrementalByTimeRangeKind as IncrementalByTimeRangeKind,
2222
IncrementalByUniqueKeyKind as IncrementalByUniqueKeyKind,
2323
IncrementalUnmanagedKind as IncrementalUnmanagedKind,
24+
IncrementalByPartitionKind as IncrementalByPartitionKind,
2425
ModelKind as ModelKind,
2526
ModelKindMixin as ModelKindMixin,
2627
ModelKindName as ModelKindName,

sqlmesh/core/model/kind.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ def is_incremental_by_time_range(self) -> bool:
5353
def is_incremental_by_unique_key(self) -> bool:
5454
return self.model_kind_name == ModelKindName.INCREMENTAL_BY_UNIQUE_KEY
5555

56+
@property
57+
def is_incremental_by_partition(self) -> bool:
58+
return self.model_kind_name == ModelKindName.INCREMENTAL_BY_PARTITION
59+
5660
@property
5761
def is_incremental_unmanaged(self) -> bool:
5862
return self.model_kind_name == ModelKindName.INCREMENTAL_UNMANAGED
@@ -62,6 +66,7 @@ def is_incremental(self) -> bool:
6266
return (
6367
self.is_incremental_by_time_range
6468
or self.is_incremental_by_unique_key
69+
or self.is_incremental_by_partition
6570
or self.is_incremental_unmanaged
6671
or self.is_scd_type_2
6772
)
@@ -129,6 +134,7 @@ class ModelKindName(str, ModelKindMixin, Enum):
129134

130135
INCREMENTAL_BY_TIME_RANGE = "INCREMENTAL_BY_TIME_RANGE"
131136
INCREMENTAL_BY_UNIQUE_KEY = "INCREMENTAL_BY_UNIQUE_KEY"
137+
INCREMENTAL_BY_PARTITION = "INCREMENTAL_BY_PARTITION"
132138
INCREMENTAL_UNMANAGED = "INCREMENTAL_UNMANAGED"
133139
FULL = "FULL"
134140
# Legacy alias to SCD Type 2 By Time
@@ -399,6 +405,20 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
399405
]
400406

401407

408+
class IncrementalByPartitionKind(_Incremental):
409+
name: Literal[ModelKindName.INCREMENTAL_BY_PARTITION] = ModelKindName.INCREMENTAL_BY_PARTITION
410+
forward_only: Literal[True] = True
411+
disable_restatement: SQLGlotBool = True
412+
413+
@property
414+
def metadata_hash_values(self) -> t.List[t.Optional[str]]:
415+
return [
416+
*super().metadata_hash_values,
417+
str(self.forward_only),
418+
str(self.disable_restatement),
419+
]
420+
421+
402422
class IncrementalUnmanagedKind(_Incremental):
403423
name: Literal[ModelKindName.INCREMENTAL_UNMANAGED] = ModelKindName.INCREMENTAL_UNMANAGED
404424
insert_overwrite: SQLGlotBool = False
@@ -411,7 +431,11 @@ def data_hash_values(self) -> t.List[t.Optional[str]]:
411431

412432
@property
413433
def metadata_hash_values(self) -> t.List[t.Optional[str]]:
414-
return [*super().metadata_hash_values, str(self.forward_only)]
434+
return [
435+
*super().metadata_hash_values,
436+
str(self.forward_only),
437+
str(self.disable_restatement),
438+
]
415439

416440

417441
class ViewKind(_ModelKind):
@@ -581,6 +605,7 @@ class ExternalKind(_ModelKind):
581605
FullKind,
582606
IncrementalByTimeRangeKind,
583607
IncrementalByUniqueKeyKind,
608+
IncrementalByPartitionKind,
584609
IncrementalUnmanagedKind,
585610
SeedKind,
586611
ViewKind,
@@ -596,6 +621,7 @@ class ExternalKind(_ModelKind):
596621
ModelKindName.FULL: FullKind,
597622
ModelKindName.INCREMENTAL_BY_TIME_RANGE: IncrementalByTimeRangeKind,
598623
ModelKindName.INCREMENTAL_BY_UNIQUE_KEY: IncrementalByUniqueKeyKind,
624+
ModelKindName.INCREMENTAL_BY_PARTITION: IncrementalByPartitionKind,
599625
ModelKindName.INCREMENTAL_UNMANAGED: IncrementalUnmanagedKind,
600626
ModelKindName.SEED: SeedKind,
601627
ModelKindName.VIEW: ViewKind,

sqlmesh/core/model/meta.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def _kind_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
300300
for field in ("partitioned_by_", "clustered_by"):
301301
if values.get(field) and not kind.is_materialized:
302302
raise ValueError(f"{field} field cannot be set for {kind} models")
303+
if kind.is_incremental_by_partition and not values.get("partitioned_by_"):
304+
raise ValueError(f"partitioned_by field is required for {kind.name} models")
303305
return values
304306

305307
@property

sqlmesh/core/snapshot/evaluator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,8 @@ def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) ->
865865
klass = IncrementalByTimeRangeStrategy
866866
elif snapshot.is_incremental_by_unique_key:
867867
klass = IncrementalByUniqueKeyStrategy
868+
elif snapshot.is_incremental_by_partition:
869+
klass = IncrementalByPartitionStrategy
868870
elif snapshot.is_incremental_unmanaged:
869871
klass = IncrementalUnmanagedStrategy
870872
elif snapshot.is_view:
@@ -1209,6 +1211,26 @@ def delete(self, table_name: str) -> None:
12091211
logger.info("Dropped table '%s'", table_name)
12101212

12111213

1214+
class IncrementalByPartitionStrategy(MaterializableStrategy):
1215+
def insert(
1216+
self,
1217+
snapshot: Snapshot,
1218+
name: str,
1219+
query_or_df: QueryOrDF,
1220+
snapshots: t.Dict[str, Snapshot],
1221+
deployability_index: DeployabilityIndex,
1222+
batch_index: int,
1223+
**kwargs: t.Any,
1224+
) -> None:
1225+
model = snapshot.model
1226+
self.adapter.insert_overwrite_by_partition(
1227+
name,
1228+
query_or_df,
1229+
partitioned_by=model.partitioned_by,
1230+
columns_to_types=model.columns_to_types,
1231+
)
1232+
1233+
12121234
class IncrementalByTimeRangeStrategy(MaterializableStrategy):
12131235
def insert(
12141236
self,

tests/core/engine_adapter/test_base.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sqlglot import parse_one
1111
from sqlglot.helper import ensure_list
1212

13+
from sqlmesh.core import dialect as d
1314
from sqlmesh.core.dialect import normalize_model_name
1415
from sqlmesh.core.engine_adapter import EngineAdapter, EngineAdapterWithIndexSupport
1516
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy
@@ -2674,3 +2675,67 @@ def test_pre_ping(mocker: MockerFixture, make_mocked_engine_adapter: t.Callable)
26742675
]
26752676

26762677
adapter._connection_pool.get().close.assert_called_once()
2678+
2679+
2680+
def test_insert_overwrite_by_partition_query(
2681+
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
2682+
):
2683+
adapter = make_mocked_engine_adapter(EngineAdapter)
2684+
2685+
temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table")
2686+
table_name = "test_schema.test_table"
2687+
temp_table_id = "abcdefgh"
2688+
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
2689+
2690+
adapter.insert_overwrite_by_partition(
2691+
table_name,
2692+
parse_one("SELECT a, ds, b FROM tbl"),
2693+
partitioned_by=[
2694+
d.parse_one("DATETIME_TRUNC(ds, MONTH)"),
2695+
d.parse_one("b"),
2696+
],
2697+
columns_to_types={
2698+
"a": exp.DataType.build("int"),
2699+
"ds": exp.DataType.build("DATETIME"),
2700+
"b": exp.DataType.build("boolean"),
2701+
},
2702+
)
2703+
2704+
sql_calls = to_sql_calls(adapter)
2705+
assert sql_calls == [
2706+
'CREATE TABLE "test_schema"."__temp_test_table_abcdefgh" AS SELECT "a", "ds", "b" FROM "tbl"',
2707+
'DELETE FROM "test_schema"."test_table" WHERE CONCAT_WS(\'__SQLMESH_DELIM__\', DATETIME_TRUNC("ds", MONTH), "b") IN (SELECT DISTINCT CONCAT_WS(\'__SQLMESH_DELIM__\', DATETIME_TRUNC("ds", MONTH), "b") FROM "test_schema"."__temp_test_table_abcdefgh")',
2708+
'INSERT INTO "test_schema"."test_table" ("a", "ds", "b") SELECT "a", "ds", "b" FROM "test_schema"."__temp_test_table_abcdefgh"',
2709+
'DROP TABLE IF EXISTS "test_schema"."__temp_test_table_abcdefgh"',
2710+
]
2711+
2712+
2713+
def test_insert_overwrite_by_partition_query_insert_overwrite_strategy(
2714+
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
2715+
):
2716+
adapter = make_mocked_engine_adapter(EngineAdapter)
2717+
adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INSERT_OVERWRITE
2718+
2719+
temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table")
2720+
table_name = "test_schema.test_table"
2721+
temp_table_id = "abcdefgh"
2722+
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
2723+
2724+
adapter.insert_overwrite_by_partition(
2725+
table_name,
2726+
parse_one("SELECT a, ds, b FROM tbl"),
2727+
partitioned_by=[
2728+
d.parse_one("DATETIME_TRUNC(ds, MONTH)"),
2729+
d.parse_one("b"),
2730+
],
2731+
columns_to_types={
2732+
"a": exp.DataType.build("int"),
2733+
"ds": exp.DataType.build("DATETIME"),
2734+
"b": exp.DataType.build("boolean"),
2735+
},
2736+
)
2737+
2738+
sql_calls = to_sql_calls(adapter)
2739+
assert sql_calls == [
2740+
'INSERT OVERWRITE TABLE "test_schema"."test_table" ("a", "ds", "b") SELECT "a", "ds", "b" FROM "tbl"'
2741+
]

tests/core/engine_adapter/test_databricks.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytest_mock import MockFixture
77
from sqlglot import exp, parse_one
88

9+
from sqlmesh.core import dialect as d
910
from sqlmesh.core.engine_adapter import DatabricksEngineAdapter
1011
from tests.core.engine_adapter import to_sql_calls
1112

@@ -103,3 +104,35 @@ def test_get_current_database(make_mocked_engine_adapter: t.Callable):
103104

104105
assert adapter.get_current_database() == "test_database"
105106
assert to_sql_calls(adapter) == ["SELECT CURRENT_DATABASE()"]
107+
108+
109+
def test_insert_overwrite_by_partition_query(
110+
make_mocked_engine_adapter: t.Callable, mocker: MockFixture, make_temp_table_name: t.Callable
111+
):
112+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter)
113+
114+
temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table")
115+
table_name = "test_schema.test_table"
116+
temp_table_id = "abcdefgh"
117+
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
118+
119+
adapter.insert_overwrite_by_partition(
120+
table_name,
121+
parse_one("SELECT a, ds, b FROM tbl"),
122+
partitioned_by=[
123+
d.parse_one("DATETIME_TRUNC(ds, MONTH)"),
124+
d.parse_one("b"),
125+
],
126+
columns_to_types={
127+
"a": exp.DataType.build("int"),
128+
"ds": exp.DataType.build("DATETIME"),
129+
"b": exp.DataType.build("boolean"),
130+
},
131+
)
132+
133+
sql_calls = to_sql_calls(adapter)
134+
assert sql_calls == [
135+
"CREATE TABLE `test_schema`.`temp_test_table_abcdefgh` AS SELECT `a`, `ds`, `b` FROM `tbl`",
136+
"INSERT INTO `test_schema`.`test_table` REPLACE WHERE CONCAT_WS('__SQLMESH_DELIM__', DATE_TRUNC('MONTH', `ds`), `b`) IN (SELECT DISTINCT CONCAT_WS('__SQLMESH_DELIM__', DATE_TRUNC('MONTH', `ds`), `b`) FROM `test_schema`.`temp_test_table_abcdefgh`) SELECT `a`, `ds`, `b` FROM `test_schema`.`temp_test_table_abcdefgh`",
137+
"DROP TABLE IF EXISTS `test_schema`.`temp_test_table_abcdefgh`",
138+
]

0 commit comments

Comments
 (0)