Skip to content

Commit 723e2f0

Browse files
authored
Fix: convert pyspark/snowpark dataframes into pandas when testing (#2737)
1 parent 6706a1a commit 723e2f0

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

sqlmesh/core/test/definition.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -618,10 +618,9 @@ def _execute_model(self) -> pd.DataFrame:
618618
time_ctx = freeze_time(self._execution_time) if self._execution_time else nullcontext()
619619
with patch.dict(self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms):
620620
with t.cast(AbstractContextManager, time_ctx):
621-
return t.cast(
622-
pd.DataFrame,
623-
next(self.model.render(context=self.context, **self.body.get("vars", {}))),
624-
)
621+
df = next(self.model.render(context=self.context, **self.body.get("vars", {})))
622+
assert not isinstance(df, exp.Expression)
623+
return df if isinstance(df, pd.DataFrame) else df.toPandas()
625624

626625

627626
def generate_test(

tests/core/test_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sqlmesh.core.config import (
1616
Config,
1717
DuckDBConnectionConfig,
18+
SparkConnectionConfig,
1819
GatewayConfig,
1920
ModelDefaultsConfig,
2021
)
@@ -1410,6 +1411,44 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) ->
14101411
)
14111412

14121413

1414+
def test_pyspark_python_model() -> None:
1415+
spark_connection_config = SparkConnectionConfig(
1416+
config={
1417+
"spark.master": "local",
1418+
"spark.sql.warehouse.dir": "/tmp/data_dir",
1419+
"spark.driver.extraJavaOptions": "-Dderby.system.home=/tmp/derby_dir",
1420+
},
1421+
)
1422+
config = Config(
1423+
gateways=GatewayConfig(test_connection=spark_connection_config),
1424+
model_defaults=ModelDefaultsConfig(dialect="spark"),
1425+
)
1426+
context = Context(config=config)
1427+
1428+
@model("pyspark_model", columns={"col": "int"})
1429+
def execute(context, start, end, execution_time, **kwargs):
1430+
return context.spark.sql("SELECT 1 AS col")
1431+
1432+
_check_successful_or_raise(
1433+
_create_test(
1434+
body=load_yaml(
1435+
"""
1436+
test_pyspark_model:
1437+
model: pyspark_model
1438+
outputs:
1439+
query:
1440+
- col: 1
1441+
"""
1442+
),
1443+
test_name="test_pyspark_model",
1444+
model=model.get_registry()["pyspark_model"].model(
1445+
module_path=Path("."), path=Path(".")
1446+
),
1447+
context=context,
1448+
).run()
1449+
)
1450+
1451+
14131452
def test_test_generation(tmp_path: Path) -> None:
14141453
init_example_project(tmp_path, dialect="duckdb")
14151454

tests/integrations/jupyter/test_magics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,9 @@ def test_evaluate(notebook, loaded_sushi_context):
201201
def test_format(notebook, sushi_context):
202202
with capture_output():
203203
test_model_path = sushi_context.path / "models" / "test_model.sql"
204-
test_model_path.write_text("MODEL(name db.test); SELECT 1 AS foo FROM table")
204+
test_model_path.write_text("MODEL(name db.test); SELECT 1 AS foo FROM t")
205205
sushi_context.load()
206-
assert test_model_path.read_text() == "MODEL(name db.test); SELECT 1 AS foo FROM table"
206+
assert test_model_path.read_text() == "MODEL(name db.test); SELECT 1 AS foo FROM t"
207207
with capture_output() as output:
208208
notebook.run_line_magic(magic_name="format", line="")
209209

@@ -218,7 +218,7 @@ def test_format(notebook, sushi_context):
218218
219219
SELECT
220220
1 AS foo
221-
FROM table"""
221+
FROM t"""
222222
)
223223

224224

0 commit comments

Comments
 (0)