Skip to content

Commit 25f23ce

Browse files
authored
Fix: always make dataframes hashable when comparing them in unit tests (#3354)
1 parent 855688d commit 25f23ce

File tree

3 files changed

+70
-6
lines changed

3 files changed

+70
-6
lines changed

sqlmesh/core/test/definition.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,11 @@ def _to_hashable(x: t.Any) -> t.Any:
237237
return tuple((k, _to_hashable(v)) for k, v in x.items())
238238
return str(x) if not isinstance(x, t.Hashable) else x
239239

240+
actual = actual.apply(lambda col: col.map(_to_hashable))
241+
expected = expected.apply(lambda col: col.map(_to_hashable))
242+
240243
if sort:
241-
actual = actual.apply(lambda col: col.map(_to_hashable))
242244
actual = actual.sort_values(by=actual.columns.to_list()).reset_index(drop=True)
243-
expected = expected.apply(lambda col: col.map(_to_hashable))
244245
expected = expected.sort_values(by=expected.columns.to_list()).reset_index(drop=True)
245246

246247
try:

sqlmesh/core/test/result.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
import typing as t
55
import unittest
66

7+
if t.TYPE_CHECKING:
8+
ErrorType = t.Union[
9+
t.Tuple[type[BaseException], BaseException, types.TracebackType],
10+
t.Tuple[None, None, None],
11+
]
12+
713

814
class ModelTextTestResult(unittest.TextTestResult):
915
successes: t.List[unittest.TestCase]
@@ -12,13 +18,28 @@ def __init__(self, *args: t.Any, **kwargs: t.Any):
1218
super().__init__(*args, **kwargs)
1319
self.successes = []
1420

15-
def addFailure(
21+
def addSubTest(
1622
self,
1723
test: unittest.TestCase,
18-
err: (
19-
tuple[type[BaseException], BaseException, types.TracebackType] | tuple[None, None, None]
20-
),
24+
subtest: unittest.TestCase,
25+
err: t.Optional[ErrorType],
2126
) -> None:
27+
"""Called at the end of a subtest.
28+
29+
The traceback is suppressed because it is redundant and not useful.
30+
31+
Args:
32+
test: The test case.
33+
subtest: The subtest instance.
34+
err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback).
35+
"""
36+
if err:
37+
exctype, value, tb = err
38+
err = (exctype, value, None) # type: ignore
39+
40+
super().addSubTest(test, subtest, err)
41+
42+
def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None:
2243
"""Called when the test case test signals a failure.
2344
2445
The traceback is suppressed because it is redundant and not useful.

tests/core/test_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,48 @@ def test_row_order(sushi_context: Context, full_model_without_ctes: SqlModel) ->
345345
),
346346
)
347347

348+
model_sql = """
349+
SELECT
350+
ARRAY_AGG(DISTINCT id_contact_b ORDER BY id_contact_b) AS aggregated_duplicates
351+
FROM
352+
source
353+
GROUP BY
354+
id_contact_a
355+
ORDER BY
356+
id_contact_a
357+
"""
358+
359+
_check_successful_or_raise(
360+
_create_test(
361+
body=load_yaml(
362+
"""
363+
test_array_order:
364+
model: test
365+
inputs:
366+
source:
367+
- id_contact_a: a
368+
id_contact_b: b
369+
- id_contact_a: a
370+
id_contact_b: c
371+
outputs:
372+
query:
373+
- aggregated_duplicates:
374+
- c
375+
- b
376+
"""
377+
),
378+
test_name="test_array_order",
379+
model=_create_model(model_sql),
380+
context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))),
381+
).run(),
382+
expected_msg=(
383+
"""AssertionError: Data mismatch (exp: expected, act: actual)\n\n"""
384+
" aggregated_duplicates \n"
385+
" exp act\n"
386+
"0 (c, b) (b, c)\n"
387+
),
388+
)
389+
348390

349391
@pytest.mark.parametrize(
350392
"waiter_names_input",

0 commit comments

Comments
 (0)