Skip to content

Commit f24d075

Browse files
authored
Fix: Use SUPPORTS_INSERT_OVERWRITE and cleanup (#893)
* use SUPPORTS_INSERT_OVERWRITE and cleanup * feedback * remove where clause
1 parent 51b7ef9 commit f24d075

File tree

3 files changed

+79
-44
lines changed

3 files changed

+79
-44
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -596,13 +596,29 @@ def _insert_overwrite_by_condition(
596596
where: t.Optional[exp.Condition] = None,
597597
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
598598
) -> None:
599-
if where is None:
600-
raise SQLMeshError(
601-
"Where condition is required when doing a delete/insert for insert/overwrite"
599+
table = exp.to_table(table_name)
600+
if self.SUPPORTS_INSERT_OVERWRITE:
601+
df = self.try_get_pandas_df(query_or_df)
602+
if df is not None:
603+
query_or_df = next(
604+
pandas_to_sql(
605+
df,
606+
alias=table.alias_or_name,
607+
columns_to_types=columns_to_types,
608+
)
609+
)
610+
query = t.cast("Query", query_or_df)
611+
self.execute(
612+
exp.insert(query, table, columns=list(columns_to_types or []), overwrite=True)
602613
)
603-
with self.transaction():
604-
self.delete_from(table_name, where=where)
605-
self.insert_append(table_name, query_or_df, columns_to_types=columns_to_types)
614+
else:
615+
if where is None:
616+
raise SQLMeshError(
617+
"Where condition is required when doing a delete/insert for insert/overwrite"
618+
)
619+
with self.transaction():
620+
self.delete_from(table_name, where=where)
621+
self.insert_append(table_name, query_or_df, columns_to_types=columns_to_types)
606622

607623
def update_table(
608624
self,

sqlmesh/core/engine_adapter/spark.py

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pandas as pd
66
from sqlglot import exp
77

8-
from sqlmesh.core.dialect import pandas_to_sql
98
from sqlmesh.core.engine_adapter.base import EngineAdapter
109
from sqlmesh.core.engine_adapter.shared import (
1110
DataObject,
@@ -21,7 +20,6 @@
2120
DF,
2221
PySparkDataFrame,
2322
PySparkSession,
24-
Query,
2523
QueryOrDF,
2624
)
2725
from sqlmesh.core.model.meta import IntervalUnit
@@ -62,29 +60,11 @@ def _insert_overwrite_by_condition(
6260
where: t.Optional[exp.Condition] = None,
6361
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
6462
) -> None:
65-
table = exp.to_table(table_name)
66-
df = self.try_get_pandas_df(query_or_df)
67-
pyspark_df = self.try_get_pyspark_df(query_or_df)
68-
if self._use_spark_session and (df is not None or pyspark_df):
69-
if df is not None:
70-
pyspark_df = self._ensure_pyspark_df(df)
71-
assert pyspark_df
72-
self._insert_pyspark_df(table_name, pyspark_df, overwrite=True)
63+
df = self.try_get_df(query_or_df)
64+
if self._use_spark_session and df is not None:
65+
self._insert_pyspark_df(table_name, self._ensure_pyspark_df(df), overwrite=True)
7366
else:
74-
if df is not None:
75-
query_or_df = next(
76-
pandas_to_sql(
77-
df,
78-
alias=table.alias_or_name,
79-
columns_to_types=columns_to_types,
80-
)
81-
)
82-
column_names = list(columns_to_types or [])
83-
self.execute(
84-
exp.insert(
85-
t.cast("Query", query_or_df), table, columns=column_names, overwrite=True
86-
)
87-
)
67+
super()._insert_overwrite_by_condition(table_name, query_or_df, where, columns_to_types)
8868

8969
def insert_append(
9070
self,
@@ -94,10 +74,10 @@ def insert_append(
9474
contains_json: bool = False,
9575
) -> None:
9676
df = self.try_get_df(query_or_df)
97-
if df is None or not self._use_spark_session:
98-
super().insert_append(table_name, query_or_df, columns_to_types, contains_json)
99-
else:
77+
if self._use_spark_session and df is not None:
10078
self._insert_append_pyspark_df(table_name, self._ensure_pyspark_df(df))
79+
else:
80+
super().insert_append(table_name, query_or_df, columns_to_types, contains_json)
10181

10282
def merge(
10383
self,
@@ -108,16 +88,16 @@ def merge(
10888
) -> None:
10989
column_names = columns_to_types.keys()
11090
df = self.try_get_df(source_table)
111-
if df is None or not self._use_spark_session:
112-
super().merge(target_table, source_table, columns_to_types, unique_key)
113-
else:
114-
df = self._ensure_pyspark_df(df)
91+
if self._use_spark_session and df is not None:
92+
pyspark_df = self._ensure_pyspark_df(df)
11593
temp_view_name = self._get_temp_table(target_table, table_only=True).sql(
11694
dialect=self.dialect
11795
)
118-
df.createOrReplaceTempView(temp_view_name)
96+
pyspark_df.createOrReplaceTempView(temp_view_name)
11997
query = exp.select(*column_names).from_(temp_view_name)
12098
super().merge(target_table, query, columns_to_types, unique_key)
99+
else:
100+
super().merge(target_table, source_table, columns_to_types, unique_key)
121101

122102
def _insert_append_pandas_df(
123103
self,
@@ -126,10 +106,10 @@ def _insert_append_pandas_df(
126106
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
127107
contains_json: bool = False,
128108
) -> None:
129-
if not self._use_spark_session:
130-
super()._insert_append_pandas_df(table_name, df, columns_to_types, contains_json)
131-
else:
109+
if self._use_spark_session:
132110
self._insert_pyspark_df(table_name, self._ensure_pyspark_df(df), overwrite=False)
111+
else:
112+
super()._insert_append_pandas_df(table_name, df, columns_to_types, contains_json)
133113

134114
def _insert_append_pyspark_df(
135115
self,
@@ -160,13 +140,13 @@ def _create_table_from_df(
160140
replace: bool = True,
161141
**kwargs: t.Any,
162142
) -> None:
163-
if not self._use_spark_session:
164-
super()._create_table_from_df(table_name, df, columns_to_types, exists, replace)
165-
else:
143+
if self._use_spark_session:
166144
df = self._ensure_pyspark_df(df)
167145
if isinstance(table_name, exp.Table):
168146
table_name = table_name.sql(dialect=self.dialect)
169147
df.write.saveAsTable(table_name, mode="overwrite")
148+
else:
149+
super()._create_table_from_df(table_name, df, columns_to_types, exists, replace)
170150

171151
def _get_data_objects(
172152
self, schema_name: str, catalog_name: t.Optional[str] = None

tests/core/engine_adapter/test_base.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,45 @@ def test_insert_overwrite_by_time_partition(mocker: MockerFixture):
118118
)
119119

120120

121+
def test_insert_overwrite_by_time_partition_supports_insert_overwrite(mocker: MockerFixture):
122+
connection_mock = mocker.NonCallableMock()
123+
cursor_mock = mocker.Mock()
124+
connection_mock.cursor.return_value = cursor_mock
125+
126+
adapter = EngineAdapter(lambda: connection_mock, "") # type: ignore
127+
adapter.SUPPORTS_INSERT_OVERWRITE = True
128+
adapter._insert_overwrite_by_condition(
129+
"test_table",
130+
parse_one("SELECT a, b FROM tbl"),
131+
where=parse_one("b BETWEEN '2022-01-01' and '2022-01-02'"),
132+
columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")},
133+
)
134+
135+
cursor_mock.execute.assert_called_once_with(
136+
"INSERT OVERWRITE TABLE test_table (a, b) SELECT a, b FROM tbl"
137+
)
138+
139+
140+
def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas(mocker: MockerFixture):
141+
connection_mock = mocker.NonCallableMock()
142+
cursor_mock = mocker.Mock()
143+
connection_mock.cursor.return_value = cursor_mock
144+
145+
adapter = EngineAdapter(lambda: connection_mock, "") # type: ignore
146+
adapter.SUPPORTS_INSERT_OVERWRITE = True
147+
df = pd.DataFrame({"a": [1, 2], "ds": ["2022-01-01", "2022-01-02"]})
148+
adapter._insert_overwrite_by_condition(
149+
"test_table",
150+
df,
151+
where=parse_one("ds BETWEEN '2022-01-01' and '2022-01-02'"),
152+
columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")},
153+
)
154+
155+
cursor_mock.execute.assert_called_once_with(
156+
"INSERT OVERWRITE TABLE test_table (a, ds) SELECT CAST(a AS INT) AS a, CAST(ds AS TEXT) AS ds FROM (VALUES (1, '2022-01-01'), (2, '2022-01-02')) AS test_table(a, ds)"
157+
)
158+
159+
121160
def test_insert_append_query(mocker: MockerFixture):
122161
connection_mock = mocker.NonCallableMock()
123162
cursor_mock = mocker.Mock()

0 commit comments

Comments
 (0)