Skip to content

Commit 51b7ef9

Browse files
authored
Fix: Concat Dataframes if insert overwrite (#890)
* concat dataframes if insert overwrite * add comment * restructure
1 parent 2fc0949 commit 51b7ef9

File tree

4 files changed

+129
-12
lines changed

4 files changed

+129
-12
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class EngineAdapter:
6464
DEFAULT_SQL_GEN_KWARGS: t.Dict[str, str | bool | int] = {}
6565
ESCAPE_JSON = False
6666
SUPPORTS_INDEXES = False
67+
SUPPORTS_INSERT_OVERWRITE = False
6768
SCHEMA_DIFFER = SchemaDiffer()
6869

6970
def __init__(

sqlmesh/core/engine_adapter/spark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
class SparkEngineAdapter(EngineAdapter):
3131
DIALECT = "spark"
3232
ESCAPE_JSON = True
33+
SUPPORTS_INSERT_OVERWRITE = True
3334

3435
@property
3536
def spark(self) -> PySparkSession:

sqlmesh/core/snapshot/evaluator.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
import logging
2525
import typing as t
2626
from contextlib import contextmanager
27+
from functools import reduce
2728

29+
import pandas as pd
2830
from sqlglot import exp, select
2931
from sqlglot.executor import execute
3032

@@ -165,18 +167,34 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
165167
if model.kind.is_view or model.kind.is_full
166168
else TransactionType.DML
167169
):
168-
for index, query_or_df in enumerate(queries_or_dfs):
169-
if limit and limit > 0:
170-
if isinstance(query_or_df, exp.Select):
171-
existing_limit = query_or_df.args.get("limit")
172-
if existing_limit:
173-
limit = min(
174-
limit,
175-
execute(exp.select(existing_limit.expression)).rows[0][0],
176-
)
177-
return query_or_df.head(limit) if hasattr(query_or_df, "head") else self.adapter._fetch_native_df(query_or_df.limit(limit)) # type: ignore
178-
179-
apply(query_or_df, index)
170+
if limit and limit > 0:
171+
query_or_df = next(queries_or_dfs)
172+
if isinstance(query_or_df, exp.Select):
173+
existing_limit = query_or_df.args.get("limit")
174+
if existing_limit:
175+
limit = min(
176+
limit,
177+
execute(exp.select(existing_limit.expression)).rows[0][0],
178+
)
179+
return query_or_df.head(limit) if hasattr(query_or_df, "head") else self.adapter._fetch_native_df(query_or_df.limit(limit)) # type: ignore
180+
# DataFrames, unlike SQL expressions, can provide partial results by yielding dataframes. As a result,
181+
# if the engine supports INSERT OVERWRITE and the snapshot is incremental by time range, we risk
182+
# having a partial result since each dataframe write can re-truncate partitions. To avoid this, we
183+
# union all the dataframes together before writing. For pandas this could result in OOM and a potential
184+
# workaround for that would be to serialize pandas to disk and then read it back with Spark.
185+
# Note: We assume that if multiple things are yielded from `queries_or_dfs` that they are dataframes
186+
# and not SQL expressions.
187+
elif self.adapter.SUPPORTS_INSERT_OVERWRITE and snapshot.is_incremental_by_time_range:
188+
query_or_df = reduce(
189+
lambda a, b: a.union_all(b) # type: ignore
190+
if self.adapter.is_pyspark_df(a)
191+
else pd.concat([a, b], ignore_index=True), # type: ignore
192+
queries_or_dfs,
193+
)
194+
apply(query_or_df, index=0)
195+
else:
196+
for index, query_or_df in enumerate(queries_or_dfs):
197+
apply(query_or_df, index)
180198

181199
model.run_post_hooks(
182200
context=context,

tests/core/test_snapshot_evaluator.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
IncrementalByTimeRangeKind,
1515
ModelKind,
1616
ModelKindName,
17+
PythonModel,
1718
SqlModel,
19+
TimeColumn,
1820
load_model,
1921
)
2022
from sqlmesh.core.model.meta import IntervalUnit
@@ -26,6 +28,7 @@
2628
SnapshotTableInfo,
2729
)
2830
from sqlmesh.utils.errors import ConfigError, SQLMeshError
31+
from sqlmesh.utils.metaprogramming import Executable
2932

3033

3134
@pytest.fixture
@@ -337,3 +340,97 @@ def test_audit_unversioned(mocker: MockerFixture, adapter_mock, make_snapshot):
337340
match="Cannot audit 'db.model' because it has not been versioned yet. Apply a plan first.",
338341
):
339342
evaluator.audit(snapshot=snapshot, snapshots={})
343+
344+
345+
@pytest.mark.parametrize(
346+
"input_dfs, output_dict",
347+
[
348+
(
349+
"""pd.DataFrame({"a": [1, 2, 3], "ds": ["2023-01-01", "2023-01-02", "2023-01-03"]}),
350+
pd.DataFrame({"a": [4, 5, 6], "ds": ["2023-01-04", "2023-01-05", "2023-01-06"]}),
351+
pd.DataFrame({"a": [7, 8, 9], "ds": ["2023-01-07", "2023-01-08", "2023-01-09"]})""",
352+
{
353+
"a": {
354+
0: 1,
355+
1: 2,
356+
2: 3,
357+
3: 4,
358+
4: 5,
359+
5: 6,
360+
6: 7,
361+
7: 8,
362+
8: 9,
363+
},
364+
"ds": {
365+
0: "2023-01-01",
366+
1: "2023-01-02",
367+
2: "2023-01-03",
368+
3: "2023-01-04",
369+
4: "2023-01-05",
370+
5: "2023-01-06",
371+
6: "2023-01-07",
372+
7: "2023-01-08",
373+
8: "2023-01-09",
374+
},
375+
},
376+
),
377+
(
378+
"""pd.DataFrame({"a": [1, 2, 3], "ds": ["2023-01-01", "2023-01-02", "2023-01-03"]})""",
379+
{
380+
"a": {
381+
0: 1,
382+
1: 2,
383+
2: 3,
384+
},
385+
"ds": {
386+
0: "2023-01-01",
387+
1: "2023-01-02",
388+
2: "2023-01-03",
389+
},
390+
},
391+
),
392+
],
393+
)
394+
def test_snapshot_evaluator_yield_pd(adapter_mock, make_snapshot, input_dfs, output_dict):
395+
adapter_mock.is_pyspark_df.return_value = False
396+
adapter_mock.SUPPORTS_INSERT_OVERWRITE = True
397+
adapter_mock.try_get_df = lambda x: x
398+
evaluator = SnapshotEvaluator(adapter_mock)
399+
400+
snapshot = make_snapshot(
401+
PythonModel(
402+
name="db.model",
403+
entrypoint="python_func",
404+
kind=IncrementalByTimeRangeKind(time_column=TimeColumn(column="ds", format="%Y-%m-%d")),
405+
columns={
406+
"a": "INT",
407+
"ds": "STRING",
408+
},
409+
python_env={
410+
"python_func": Executable(
411+
name="python_func",
412+
alias="python_func",
413+
path="test_snapshot_evaluator.py",
414+
payload=f"""import pandas as pd
415+
def python_func(**kwargs):
416+
for df in [
417+
{input_dfs}
418+
]:
419+
yield df""",
420+
)
421+
},
422+
)
423+
)
424+
425+
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
426+
evaluator.create([snapshot], {})
427+
428+
evaluator.evaluate(
429+
snapshot,
430+
"2023-01-01",
431+
"2023-01-09",
432+
"2023-01-09",
433+
snapshots={},
434+
)
435+
436+
assert adapter_mock.insert_overwrite_by_time_partition.call_args[0][1].to_dict() == output_dict

0 commit comments

Comments
 (0)