Skip to content

Commit ac71bae

Browse files
authored
Feat: add support for arrays in unit tests (#2215)
* Feat: add support for arrays in unit tests * Fix formatting * Use apply instead of map * Fix test * Test itemized YAML format for array as well
1 parent 094f228 commit ac71bae

File tree

2 files changed

+97
-39
lines changed

2 files changed

+97
-39
lines changed

sqlmesh/core/test/definition.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import numpy as np
88
import pandas as pd
9-
from sqlglot import exp, parse_one
9+
from sqlglot import exp
10+
from sqlglot.optimizer.annotate_types import annotate_types
1011
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1112

1213
from sqlmesh.core import constants as c
@@ -82,7 +83,8 @@ def setUp(self) -> None:
8283
for i, v in rows[0].items():
8384
# convert ruamel into python
8485
v = v.real if hasattr(v, "real") else v
85-
columns_to_types[i] = parse_one(type(v).__name__, into=exp.DataType)
86+
v_type = annotate_types(exp.convert(v)).type or type(v).__name__
87+
columns_to_types[i] = exp.maybe_parse(v_type, into=exp.DataType)
8688

8789
test_fixture_table = _fully_qualified_test_fixture_table(table_name, self.dialect)
8890
if test_fixture_table.db:
@@ -112,21 +114,28 @@ def assert_equal(self, expected: pd.DataFrame, actual: pd.DataFrame, sort: bool)
112114
actual_types, errors="ignore"
113115
)
114116

115-
expected = expected.replace({None: np.nan})
116117
actual = actual.replace({None: np.nan})
118+
expected = expected.replace({None: np.nan})
119+
120+
def _to_hashable(x: t.Any) -> t.Any:
121+
return tuple(x) if isinstance(x, list) else x
117122

118123
try:
124+
if sort:
125+
actual = (
126+
actual.apply(_to_hashable)
127+
.sort_values(by=actual.columns.to_list())
128+
.reset_index(drop=True)
129+
)
130+
expected = (
131+
expected.apply(_to_hashable)
132+
.sort_values(by=expected.columns.to_list())
133+
.reset_index(drop=True)
134+
)
135+
119136
pd.testing.assert_frame_equal(
120-
(
121-
expected.sort_values(by=expected.columns.to_list()).reset_index(drop=True)
122-
if sort
123-
else expected
124-
),
125-
(
126-
actual.sort_values(by=actual.columns.to_list()).reset_index(drop=True)
127-
if sort
128-
else actual
129-
),
137+
expected,
138+
actual,
130139
check_dtype=False,
131140
check_datetimelike_compat=True,
132141
check_like=True, # ignore column order

tests/core/test_test.py

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from sqlmesh.utils.errors import ConfigError
1717
from sqlmesh.utils.yaml import load as load_yaml
1818

19+
if t.TYPE_CHECKING:
20+
from unittest import TestResult
21+
1922
pytestmark = pytest.mark.slow
2023

2124
SUSHI_FOO_META = "MODEL (name sushi.foo, kind FULL)"
@@ -52,6 +55,18 @@ def _create_model(
5255
)
5356

5457

58+
def _check_successful_or_raise(
59+
result: t.Optional[TestResult], expected_failure_msg: t.Optional[str] = None
60+
) -> None:
61+
assert result is not None
62+
if not result.wasSuccessful():
63+
error_or_failure_traceback = (result.errors or result.failures)[0][1]
64+
if result.failures and expected_failure_msg:
65+
assert expected_failure_msg in error_or_failure_traceback
66+
else:
67+
raise AssertionError(error_or_failure_traceback)
68+
69+
5570
@pytest.fixture
5671
def full_model_without_ctes(request) -> SqlModel:
5772
return _create_model(
@@ -110,7 +125,7 @@ def test_ctes(sushi_context: Context, full_model_with_two_ctes: SqlModel) -> Non
110125
"""
111126
)
112127
result = _create_test(body, "test_foo", model, sushi_context).run()
113-
assert result and result.wasSuccessful()
128+
_check_successful_or_raise(result)
114129

115130

116131
def test_ctes_only(sushi_context: Context, full_model_with_two_ctes: SqlModel) -> None:
@@ -134,7 +149,7 @@ def test_ctes_only(sushi_context: Context, full_model_with_two_ctes: SqlModel) -
134149
"""
135150
)
136151
result = _create_test(body, "test_foo", model, sushi_context).run()
137-
assert result and result.wasSuccessful()
152+
_check_successful_or_raise(result)
138153

139154

140155
def test_query_only(sushi_context: Context, full_model_with_two_ctes: SqlModel) -> None:
@@ -155,7 +170,7 @@ def test_query_only(sushi_context: Context, full_model_with_two_ctes: SqlModel)
155170
"""
156171
)
157172
result = _create_test(body, "test_foo", model, sushi_context).run()
158-
assert result and result.wasSuccessful()
173+
_check_successful_or_raise(result)
159174

160175

161176
def test_with_rows(sushi_context: Context, full_model_with_single_cte: SqlModel) -> None:
@@ -182,7 +197,7 @@ def test_with_rows(sushi_context: Context, full_model_with_single_cte: SqlModel)
182197
"""
183198
)
184199
result = _create_test(body, "test_foo", model, sushi_context).run()
185-
assert result and result.wasSuccessful()
200+
_check_successful_or_raise(result)
186201

187202

188203
def test_without_rows(sushi_context: Context, full_model_with_single_cte: SqlModel) -> None:
@@ -206,7 +221,7 @@ def test_without_rows(sushi_context: Context, full_model_with_single_cte: SqlMod
206221
"""
207222
)
208223
result = _create_test(body, "test_foo", model, sushi_context).run()
209-
assert result and result.wasSuccessful()
224+
_check_successful_or_raise(result)
210225

211226

212227
def test_column_order(sushi_context: Context, full_model_without_ctes: SqlModel) -> None:
@@ -231,7 +246,7 @@ def test_column_order(sushi_context: Context, full_model_without_ctes: SqlModel)
231246
"""
232247
)
233248
result = _create_test(body, "test_foo", model, sushi_context).run()
234-
assert result and result.wasSuccessful()
249+
_check_successful_or_raise(result)
235250

236251

237252
def test_row_order(sushi_context: Context, full_model_without_ctes: SqlModel) -> None:
@@ -266,17 +281,24 @@ def test_row_order(sushi_context: Context, full_model_without_ctes: SqlModel) ->
266281

267282
# model query without ORDER BY should pass unit test
268283
result = _create_test(body, "test_foo", model, sushi_context).run()
269-
assert result and result.wasSuccessful()
284+
_check_successful_or_raise(result)
270285

271286
# model query with ORDER BY should fail unit test
272287
full_model_without_ctes_dict = full_model_without_ctes.dict()
273288
full_model_without_ctes_dict["query"] = full_model_without_ctes.query.order_by("id") # type: ignore
274289
full_model_without_ctes_orderby = SqlModel(**full_model_without_ctes_dict)
275290

276291
model = t.cast(SqlModel, sushi_context.upsert_model(full_model_without_ctes_orderby))
277-
278292
result = _create_test(body, "test_foo", model, sushi_context).run()
279-
assert result and not result.wasSuccessful()
293+
294+
expected_failure_msg = """AssertionError: Data differs (exp: expected, act: actual)
295+
296+
id value ds
297+
exp act exp act exp act
298+
0 2 1 3 2 4 3
299+
1 1 2 2 3 3 4"""
300+
301+
_check_successful_or_raise(result, expected_failure_msg=expected_failure_msg)
280302

281303

282304
def test_partial_data(sushi_context: Context) -> None:
@@ -316,7 +338,7 @@ def test_partial_data(sushi_context: Context) -> None:
316338
"""
317339
)
318340
result = _create_test(body, "test_foo", model, sushi_context).run()
319-
assert result and result.wasSuccessful()
341+
_check_successful_or_raise(result)
320342

321343

322344
def test_partial_data_column_order(sushi_context: Context) -> None:
@@ -347,7 +369,7 @@ def test_partial_data_column_order(sushi_context: Context) -> None:
347369
"""
348370
)
349371
result = _create_test(body, "test_foo", model, sushi_context).run()
350-
assert result and result.wasSuccessful()
372+
_check_successful_or_raise(result)
351373

352374

353375
def test_partial_data_missing_schemas(sushi_context: Context) -> None:
@@ -371,7 +393,7 @@ def test_partial_data_missing_schemas(sushi_context: Context) -> None:
371393
"""
372394
)
373395
result = _create_test(body, "test_foo", model, sushi_context).run()
374-
assert result and result.wasSuccessful()
396+
_check_successful_or_raise(result)
375397

376398
model = _create_model(
377399
"SELECT *, DATE_TRUNC('month', date)::DATE AS month, NULL::DATE AS null_date, FROM unknown"
@@ -401,7 +423,7 @@ def test_partial_data_missing_schemas(sushi_context: Context) -> None:
401423
"""
402424
)
403425
result = _create_test(body, "test_foo", model, sushi_context).run()
404-
assert result and result.wasSuccessful()
426+
_check_successful_or_raise(result)
405427

406428

407429
def test_missing_column_failure(sushi_context: Context, full_model_without_ctes: SqlModel) -> None:
@@ -423,9 +445,8 @@ def test_missing_column_failure(sushi_context: Context, full_model_without_ctes:
423445
"""
424446
)
425447
result = _create_test(body, "test_foo", model, sushi_context).run()
426-
assert result and not result.wasSuccessful()
427448

428-
expected_msg = """AssertionError: Data differs (exp: expected, act: actual)
449+
expected_failure_msg = """AssertionError: Data differs (exp: expected, act: actual)
429450
430451
value ds
431452
exp act exp act
@@ -434,7 +455,7 @@ def test_missing_column_failure(sushi_context: Context, full_model_without_ctes:
434455
435456
Test description: sushi.foo's output has a missing column (fails intentionally)
436457
"""
437-
assert expected_msg in result.failures[0][1]
458+
_check_successful_or_raise(result, expected_failure_msg=expected_failure_msg)
438459

439460

440461
def test_empty_rows(sushi_context: Context) -> None:
@@ -454,7 +475,7 @@ def test_empty_rows(sushi_context: Context) -> None:
454475
"""
455476
)
456477
result = _create_test(body, "test_foo", model, sushi_context).run()
457-
assert result and result.wasSuccessful()
478+
_check_successful_or_raise(result)
458479

459480

460481
@pytest.mark.parametrize("full_model_without_ctes", ["snowflake"], indirect=True)
@@ -541,7 +562,7 @@ def test_test_generation(tmp_path: Path) -> None:
541562
assert test["test_full_model"]["vars"] == {"start": "2020-01-01", "end": "2024-01-01"}
542563

543564
result = context.test()
544-
assert result and result.wasSuccessful()
565+
_check_successful_or_raise(result)
545566

546567
context.create_test(
547568
"sqlmesh_example.full_model", input_queries=input_queries, name="new_name", path="foo/bar"
@@ -557,13 +578,13 @@ def test_source_func() -> None:
557578
body=load_yaml(
558579
"""
559580
test_foo:
560-
model: xyz
561-
outputs:
562-
query:
563-
- month: 2023-01-01
564-
- month: 2023-02-01
565-
- month: 2023-03-01
566-
"""
581+
model: xyz
582+
outputs:
583+
query:
584+
- month: 2023-01-01
585+
- month: 2023-02-01
586+
- month: 2023-03-01
587+
"""
567588
),
568589
test_name="test_foo",
569590
model=_create_model(
@@ -575,4 +596,32 @@ def test_source_func() -> None:
575596
context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))),
576597
).run()
577598

578-
assert result and result.wasSuccessful()
599+
_check_successful_or_raise(result)
600+
601+
602+
def test_nested_data_types() -> None:
603+
result = _create_test(
604+
body=load_yaml(
605+
"""
606+
test_foo:
607+
model: sushi.foo
608+
inputs:
609+
raw:
610+
- value: [1, 2, 3]
611+
- value:
612+
- 2
613+
- 3
614+
- value: [0, 4, 1]
615+
outputs:
616+
query:
617+
- value: [0, 4, 1]
618+
- value: [1, 2, 3]
619+
- value: [2, 3]
620+
"""
621+
),
622+
test_name="test_foo",
623+
model=_create_model("SELECT value FROM raw"),
624+
context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))),
625+
).run()
626+
627+
_check_successful_or_raise(result)

0 commit comments

Comments
 (0)