Skip to content

Commit f0d3cc9

Browse files
committed
Fix(table_diff): Correctly handle joins with composite keys where one or more of the key fields are null
1 parent 969c9cc commit f0d3cc9

File tree

3 files changed

+79
-29
lines changed

3 files changed

+79
-29
lines changed

sqlmesh/core/table_diff.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -421,21 +421,16 @@ def _column_expr(name: str, table: str) -> exp.Expression:
421421
exp.select(
422422
*s_selects.values(),
423423
*t_selects.values(),
424-
exp.func("IF", exp.or_(*(c.is_(exp.Null()).not_() for c in s_index)), 1, 0).as_(
425-
"s_exists"
426-
),
427-
exp.func("IF", exp.or_(*(c.is_(exp.Null()).not_() for c in t_index)), 1, 0).as_(
428-
"t_exists"
429-
),
424+
exp.func(
425+
"IF", exp.column(SQLMESH_JOIN_KEY_COL, "s").is_(exp.Null()).not_(), 1, 0
426+
).as_("s_exists"),
427+
exp.func(
428+
"IF", exp.column(SQLMESH_JOIN_KEY_COL, "t").is_(exp.Null()).not_(), 1, 0
429+
).as_("t_exists"),
430430
exp.func(
431431
"IF",
432-
exp.and_(
433-
exp.column(SQLMESH_JOIN_KEY_COL, "s").eq(
434-
exp.column(SQLMESH_JOIN_KEY_COL, "t")
435-
),
436-
exp.and_(
437-
*(c.is_(exp.Null()).not_() for c in s_index + t_index),
438-
),
432+
exp.column(SQLMESH_JOIN_KEY_COL, "s").eq(
433+
exp.column(SQLMESH_JOIN_KEY_COL, "t")
439434
),
440435
1,
441436
0,

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2353,15 +2353,17 @@ def test_table_diff_grain_check_multiple_keys(ctx: TestContext):
23532353
row_diff = table_diff.row_diff()
23542354

23552355
assert row_diff.full_match_count == 7
2356-
assert row_diff.full_match_pct == 93.33
2357-
assert row_diff.s_only_count == 2
2358-
assert row_diff.t_only_count == 5
2359-
assert row_diff.stats["join_count"] == 4
2360-
assert row_diff.stats["null_grain_count"] == 4
2361-
assert row_diff.stats["s_count"] != row_diff.stats["distinct_count_s"]
2356+
assert row_diff.full_match_pct == 82.35
2357+
assert row_diff.s_only_count == 0
2358+
assert row_diff.t_only_count == 3
2359+
assert row_diff.stats["join_count"] == 7
2360+
assert (
2361+
row_diff.stats["null_grain_count"] == 4
2362+
) # null grain currently (2025-07-24) means "any key column is null" as opposed to "all key columns are null"
23622363
assert row_diff.stats["distinct_count_s"] == 7
2363-
assert row_diff.stats["t_count"] != row_diff.stats["distinct_count_t"]
2364+
assert row_diff.stats["s_count"] == row_diff.stats["distinct_count_s"]
23642365
assert row_diff.stats["distinct_count_t"] == 10
2366+
assert row_diff.stats["t_count"] == row_diff.stats["distinct_count_t"]
23652367
assert row_diff.s_sample.shape == (row_diff.s_only_count, 3)
23662368
assert row_diff.t_sample.shape == (row_diff.t_only_count, 3)
23672369

tests/core/test_table_diff.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,15 +291,17 @@ def test_grain_check(sushi_context_fixed_date):
291291

292292
row_diff = diff.row_diff()
293293
assert row_diff.full_match_count == 7
294-
assert row_diff.full_match_pct == 93.33
295-
assert row_diff.s_only_count == 2
296-
assert row_diff.t_only_count == 5
297-
assert row_diff.stats["join_count"] == 4
298-
assert row_diff.stats["null_grain_count"] == 4
299-
assert row_diff.stats["s_count"] != row_diff.stats["distinct_count_s"]
294+
assert row_diff.full_match_pct == 82.35
295+
assert row_diff.s_only_count == 0
296+
assert row_diff.t_only_count == 3
297+
assert row_diff.stats["join_count"] == 7
298+
assert (
299+
row_diff.stats["null_grain_count"] == 4
300+
) # null grain currently (2025-07-24) means "any key column is null" as opposed to "all key columns are null"
300301
assert row_diff.stats["distinct_count_s"] == 7
301-
assert row_diff.stats["t_count"] != row_diff.stats["distinct_count_t"]
302302
assert row_diff.stats["distinct_count_t"] == 10
303+
assert row_diff.stats["s_count"] == row_diff.stats["distinct_count_s"]
304+
assert row_diff.stats["t_count"] == row_diff.stats["distinct_count_t"]
303305
assert row_diff.s_sample.shape == (row_diff.s_only_count, 3)
304306
assert row_diff.t_sample.shape == (row_diff.t_only_count, 3)
305307

@@ -329,7 +331,7 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture)
329331
),
330332
)
331333

332-
query_sql = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s"), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t"), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" AND (NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL) THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"'
334+
query_sql = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s"), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t"), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"'
333335
summary_query_sql = 'SELECT SUM("s_exists") AS "s_count", SUM("t_exists") AS "t_count", SUM("row_joined") AS "join_count", SUM("null_grain") AS "null_grain_count", SUM("row_full_match") AS "full_match_count", SUM("key_matches") AS "key_matches", SUM("value_matches") AS "value_matches", COUNT(DISTINCT ("s____sqlmesh_join_key")) AS "distinct_count_s", COUNT(DISTINCT ("t____sqlmesh_join_key")) AS "distinct_count_t" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"'
334336
compare_sql = 'SELECT ROUND(100 * (CAST(SUM("key_matches") AS DECIMAL) / COUNT("key_matches")), 9) AS "key_matches", ROUND(100 * (CAST(SUM("value_matches") AS DECIMAL) / COUNT("value_matches")), 9) AS "value_matches" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1'
335337
sample_query_sql = 'WITH "source_only" AS (SELECT \'source_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "s_exists" = 1 AND "row_joined" = 0 ORDER BY "s__key" NULLS FIRST LIMIT 20), "target_only" AS (SELECT \'target_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "t_exists" = 1 AND "row_joined" = 0 ORDER BY "t__key" NULLS FIRST LIMIT 20), "common_rows" AS (SELECT \'common_rows\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1 AND "row_full_match" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20) SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "source_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "target_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "common_rows"'
@@ -369,7 +371,7 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture)
369371
where="key = 2",
370372
)
371373

372-
query_sql_where = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s" WHERE "s"."key" = 2), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t" WHERE "t"."key" = 2), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" AND (NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL) THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"'
374+
query_sql_where = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s" WHERE "s"."key" = 2), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t" WHERE "t"."key" = 2), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"'
373375
spy_execute.assert_any_call(query_sql_where)
374376

375377

@@ -1137,3 +1139,54 @@ def test_data_diff_sample_limit():
11371139
assert len(diff.s_sample) == 3
11381140
assert len(diff.t_sample) == 3
11391141
assert len(diff.joined_sample) == 3
1142+
1143+
1144+
def test_data_diff_nulls_in_some_grain_columns():
1145+
engine_adapter = DuckDBConnectionConfig().create_engine_adapter()
1146+
1147+
columns_to_types = {
1148+
"key1": exp.DataType.build("int"),
1149+
"key2": exp.DataType.build("varchar"),
1150+
"key3": exp.DataType.build("int"),
1151+
"value": exp.DataType.build("varchar"),
1152+
}
1153+
1154+
engine_adapter.create_table("src", columns_to_types)
1155+
engine_adapter.create_table("target", columns_to_types)
1156+
1157+
src_records = [
1158+
(1, None, 1, "value"), # full match
1159+
(None, None, None, "null value"), # join, partial match
1160+
(2, None, None, "source only"), # source only
1161+
]
1162+
1163+
target_records = [
1164+
(1, None, 1, "value"), # full match
1165+
(None, None, None, "null value modified"), # join, partial match
1166+
(None, "three", 2, "target only"), # target only
1167+
]
1168+
1169+
src_df = pd.DataFrame(data=src_records, columns=columns_to_types.keys())
1170+
target_df = pd.DataFrame(data=target_records, columns=columns_to_types.keys())
1171+
1172+
engine_adapter.insert_append("src", src_df)
1173+
engine_adapter.insert_append("target", target_df)
1174+
1175+
table_diff = TableDiff(
1176+
adapter=engine_adapter, source="src", target="target", on=["key1", "key2", "key3"]
1177+
)
1178+
1179+
diff = table_diff.row_diff()
1180+
1181+
assert diff.join_count == 2
1182+
assert diff.s_only_count == 1
1183+
assert diff.t_only_count == 1
1184+
assert diff.full_match_count == 1
1185+
assert diff.partial_match_count == 1
1186+
1187+
assert diff.s_sample["value"].tolist() == ["source only"]
1188+
assert diff.t_sample["value"].tolist() == ["target only"]
1189+
assert diff.joined_sample[["s__value", "t__value"]].values.flatten().tolist() == [
1190+
"null value",
1191+
"null value modified",
1192+
]

0 commit comments

Comments
 (0)