From f0d3cc99a76cbcb38ed3dbaeb45bc1c6b239b091 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Thu, 24 Jul 2025 04:56:29 +0000 Subject: [PATCH 1/2] Fix(table_diff): Correctly handle joins with composite keys where one or more of the key fields are null --- sqlmesh/core/table_diff.py | 21 +++--- .../integration/test_integration.py | 16 +++-- tests/core/test_table_diff.py | 71 ++++++++++++++++--- 3 files changed, 79 insertions(+), 29 deletions(-) diff --git a/sqlmesh/core/table_diff.py b/sqlmesh/core/table_diff.py index 6a91b22dfb..126fa64b1e 100644 --- a/sqlmesh/core/table_diff.py +++ b/sqlmesh/core/table_diff.py @@ -421,21 +421,16 @@ def _column_expr(name: str, table: str) -> exp.Expression: exp.select( *s_selects.values(), *t_selects.values(), - exp.func("IF", exp.or_(*(c.is_(exp.Null()).not_() for c in s_index)), 1, 0).as_( - "s_exists" - ), - exp.func("IF", exp.or_(*(c.is_(exp.Null()).not_() for c in t_index)), 1, 0).as_( - "t_exists" - ), + exp.func( + "IF", exp.column(SQLMESH_JOIN_KEY_COL, "s").is_(exp.Null()).not_(), 1, 0 + ).as_("s_exists"), + exp.func( + "IF", exp.column(SQLMESH_JOIN_KEY_COL, "t").is_(exp.Null()).not_(), 1, 0 + ).as_("t_exists"), exp.func( "IF", - exp.and_( - exp.column(SQLMESH_JOIN_KEY_COL, "s").eq( - exp.column(SQLMESH_JOIN_KEY_COL, "t") - ), - exp.and_( - *(c.is_(exp.Null()).not_() for c in s_index + t_index), - ), + exp.column(SQLMESH_JOIN_KEY_COL, "s").eq( + exp.column(SQLMESH_JOIN_KEY_COL, "t") ), 1, 0, diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index ee839d7593..cb09d20537 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -2353,15 +2353,17 @@ def test_table_diff_grain_check_multiple_keys(ctx: TestContext): row_diff = table_diff.row_diff() assert row_diff.full_match_count == 7 - assert row_diff.full_match_pct == 93.33 - assert row_diff.s_only_count == 2 - assert row_diff.t_only_count == 5 - assert row_diff.stats["join_count"] == 4 - assert row_diff.stats["null_grain_count"] == 4 - assert row_diff.stats["s_count"] != row_diff.stats["distinct_count_s"] + assert row_diff.full_match_pct == 82.35 + assert row_diff.s_only_count == 0 + assert row_diff.t_only_count == 3 + assert row_diff.stats["join_count"] == 7 + assert ( + row_diff.stats["null_grain_count"] == 4 + ) # null grain currently (2025-07-24) means "any key column is null" as opposed to "all key columns are null" assert row_diff.stats["distinct_count_s"] == 7 - assert row_diff.stats["t_count"] != row_diff.stats["distinct_count_t"] + assert row_diff.stats["s_count"] == row_diff.stats["distinct_count_s"] assert row_diff.stats["distinct_count_t"] == 10 + assert row_diff.stats["t_count"] == row_diff.stats["distinct_count_t"] assert row_diff.s_sample.shape == (row_diff.s_only_count, 3) assert row_diff.t_sample.shape == (row_diff.t_only_count, 3) diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index 64096a6637..ffd20de6ca 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -291,15 +291,17 @@ def test_grain_check(sushi_context_fixed_date): row_diff = diff.row_diff() assert row_diff.full_match_count == 7 - assert row_diff.full_match_pct == 93.33 - assert row_diff.s_only_count == 2 - assert row_diff.t_only_count == 5 - assert row_diff.stats["join_count"] == 4 - assert row_diff.stats["null_grain_count"] == 4 - assert row_diff.stats["s_count"] != row_diff.stats["distinct_count_s"] + assert row_diff.full_match_pct == 82.35 + assert row_diff.s_only_count == 0 + assert row_diff.t_only_count == 3 + assert row_diff.stats["join_count"] == 7 + assert ( + row_diff.stats["null_grain_count"] == 4 + ) # null grain currently (2025-07-24) means "any key column is null" as opposed to "all key columns are null" assert row_diff.stats["distinct_count_s"] == 7 - assert row_diff.stats["t_count"] != row_diff.stats["distinct_count_t"] assert row_diff.stats["distinct_count_t"] == 10 + assert row_diff.stats["s_count"] == row_diff.stats["distinct_count_s"] + assert row_diff.stats["t_count"] == row_diff.stats["distinct_count_t"] assert row_diff.s_sample.shape == (row_diff.s_only_count, 3) assert row_diff.t_sample.shape == (row_diff.t_only_count, 3) @@ -329,7 +331,7 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture) ), ) - query_sql = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s"), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t"), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" AND (NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL) THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"' + query_sql = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s"), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t"), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"' summary_query_sql = 'SELECT SUM("s_exists") AS "s_count", SUM("t_exists") AS "t_count", SUM("row_joined") AS "join_count", SUM("null_grain") AS "null_grain_count", SUM("row_full_match") AS "full_match_count", SUM("key_matches") AS "key_matches", SUM("value_matches") AS "value_matches", COUNT(DISTINCT ("s____sqlmesh_join_key")) AS "distinct_count_s", COUNT(DISTINCT ("t____sqlmesh_join_key")) AS "distinct_count_t" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"' compare_sql = 'SELECT ROUND(100 * (CAST(SUM("key_matches") AS DECIMAL) / COUNT("key_matches")), 9) AS "key_matches", ROUND(100 * (CAST(SUM("value_matches") AS DECIMAL) / COUNT("value_matches")), 9) AS "value_matches" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1' sample_query_sql = 'WITH "source_only" AS (SELECT \'source_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "s_exists" = 1 AND "row_joined" = 0 ORDER BY "s__key" NULLS FIRST LIMIT 20), "target_only" AS (SELECT \'target_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "t_exists" = 1 AND "row_joined" = 0 ORDER BY "t__key" NULLS FIRST LIMIT 20), "common_rows" AS (SELECT \'common_rows\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1 AND "row_full_match" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20) SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "source_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "target_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "common_rows"' @@ -369,7 +371,7 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture) where="key = 2", ) - query_sql_where = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s" WHERE "s"."key" = 2), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t" WHERE "t"."key" = 2), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" AND (NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL) THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"' + query_sql_where = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s" WHERE "s"."key" = 2), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t" WHERE "t"."key" = 2), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"' spy_execute.assert_any_call(query_sql_where) @@ -1137,3 +1139,54 @@ def test_data_diff_sample_limit(): assert len(diff.s_sample) == 3 assert len(diff.t_sample) == 3 assert len(diff.joined_sample) == 3 + + +def test_data_diff_nulls_in_some_grain_columns(): + engine_adapter = DuckDBConnectionConfig().create_engine_adapter() + + columns_to_types = { + "key1": exp.DataType.build("int"), + "key2": exp.DataType.build("varchar"), + "key3": exp.DataType.build("int"), + "value": exp.DataType.build("varchar"), + } + + engine_adapter.create_table("src", columns_to_types) + engine_adapter.create_table("target", columns_to_types) + + src_records = [ + (1, None, 1, "value"), # full match + (None, None, None, "null value"), # join, partial match + (2, None, None, "source only"), # source only + ] + + target_records = [ + (1, None, 1, "value"), # full match + (None, None, None, "null value modified"), # join, partial match + (None, "three", 2, "target only"), # target only + ] + + src_df = pd.DataFrame(data=src_records, columns=columns_to_types.keys()) + target_df = pd.DataFrame(data=target_records, columns=columns_to_types.keys()) + + engine_adapter.insert_append("src", src_df) + engine_adapter.insert_append("target", target_df) + + table_diff = TableDiff( + adapter=engine_adapter, source="src", target="target", on=["key1", "key2", "key3"] + ) + + diff = table_diff.row_diff() + + assert diff.join_count == 2 + assert diff.s_only_count == 1 + assert diff.t_only_count == 1 + assert diff.full_match_count == 1 + assert diff.partial_match_count == 1 + + assert diff.s_sample["value"].tolist() == ["source only"] + assert diff.t_sample["value"].tolist() == ["target only"] + assert diff.joined_sample[["s__value", "t__value"]].values.flatten().tolist() == [ + "null value", + "null value modified", + ] From 9cc88127509f2063aa10bd8b3ec78863189f1574 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Thu, 24 Jul 2025 21:20:02 +0000 Subject: [PATCH 2/2] Add extra assertions --- tests/core/test_table_diff.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index ffd20de6ca..b2848676b4 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -290,10 +290,16 @@ def test_grain_check(sushi_context_fixed_date): )[0] row_diff = diff.row_diff() + assert row_diff.source_count == 7 + assert row_diff.target_count == 10 assert row_diff.full_match_count == 7 - assert row_diff.full_match_pct == 82.35 + assert row_diff.partial_match_count == 0 assert row_diff.s_only_count == 0 assert row_diff.t_only_count == 3 + assert row_diff.full_match_pct == 82.35 + assert row_diff.partial_match_pct == 0 + assert row_diff.s_only_pct == 0 + assert row_diff.t_only_pct == 17.65 assert row_diff.stats["join_count"] == 7 assert ( row_diff.stats["null_grain_count"] == 4