Skip to content

Commit 1cdda84

Browse files
authored
Fix!: Improve handling of the --where option in table_diff (#2975)
1 parent a652a10 commit 1cdda84

File tree

2 files changed

+54
-24
lines changed

2 files changed

+54
-24
lines changed

sqlmesh/core/table_diff.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,22 @@ def _column_expr(name: str, table: str) -> exp.Expression:
272272
def name(e: exp.Expression) -> str:
273273
return e.args["alias"].sql(identify=True)
274274

275-
query = (
275+
source_query = (
276+
exp.select(*(exp.column(c) for c in source_schema))
277+
.from_(self.source)
278+
.where(self.where)
279+
)
280+
target_query = (
281+
exp.select(*(exp.column(c) for c in target_schema))
282+
.from_(self.target)
283+
.where(self.where)
284+
)
285+
286+
source_table = exp.table_("__source")
287+
target_table = exp.table_("__target")
288+
stats_table = exp.table_("__stats")
289+
290+
stats_query = (
276291
exp.select(
277292
*s_selects.values(),
278293
*t_selects.values(),
@@ -313,31 +328,32 @@ def name(e: exp.Expression) -> str:
313328
).as_("null_grain"),
314329
*comparisons,
315330
)
316-
.from_(exp.alias_(self.source, "s"))
317-
.join(
318-
self.target,
319-
on=self.on,
320-
join_type="FULL",
321-
join_alias="t",
322-
)
323-
.where(self.where)
331+
.from_(source_table.as_("s"))
332+
.join(target_table.as_("t"), on=self.on, join_type="FULL")
324333
)
325334

326-
query = exp.select(
327-
"*",
328-
exp.Case()
329-
.when(
330-
exp.and_(
331-
*[
332-
exp.column(f"{c}_matches").eq(exp.Literal.number(1))
333-
for c in matched_columns
334-
]
335-
),
336-
exp.Literal.number(1),
335+
query = (
336+
exp.Select()
337+
.with_(source_table, source_query)
338+
.with_(target_table, target_query)
339+
.with_(stats_table, stats_query)
340+
.select(
341+
"*",
342+
exp.Case()
343+
.when(
344+
exp.and_(
345+
*[
346+
exp.column(f"{c}_matches").eq(exp.Literal.number(1))
347+
for c in matched_columns
348+
]
349+
),
350+
exp.Literal.number(1),
351+
)
352+
.else_(exp.Literal.number(0))
353+
.as_("row_full_match"),
337354
)
338-
.else_(exp.Literal.number(0))
339-
.as_("row_full_match"),
340-
).from_(query.subquery("stats"))
355+
.from_(stats_table)
356+
)
341357

342358
query = quote_identifiers(query, dialect=self.model_dialect or self.dialect)
343359
temp_table = exp.table_("diff", db="sqlmesh_temp", quoted=True)

tests/core/test_table_diff.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture)
246246
),
247247
)
248248

249-
query_sql = 'CREATE TABLE IF NOT EXISTS "sqlmesh_temp"."__temp_diff_abcdefgh" AS SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "t"."key" AS "t__key", "t"."value" AS "t__value", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."key" = "t"."key" AND NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "table_diff_source" AS "s" FULL JOIN "table_diff_target" AS "t" ON ("s"."key" = "t"."key") OR (("s"."key" IS NULL) AND ("t"."key" IS NULL))) AS "stats"'
249+
query_sql = 'CREATE TABLE IF NOT EXISTS "sqlmesh_temp"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "key", "value" FROM "table_diff_source"), "__target" AS (SELECT "key", "value" FROM "table_diff_target"), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "t"."key" AS "t__key", "t"."value" AS "t__value", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."key" = "t"."key" AND NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON ("s"."key" = "t"."key") OR (("s"."key" IS NULL) AND ("t"."key" IS NULL))) SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"'
250250
summary_query_sql = 'SELECT SUM("s_exists") AS "s_count", SUM("t_exists") AS "t_count", SUM("row_joined") AS "join_count", SUM("null_grain") AS "null_grain_count", SUM("row_full_match") AS "full_match_count", SUM("key_matches") AS "key_matches", SUM("value_matches") AS "value_matches", COUNT(DISTINCT ("s__key")) AS "distinct_count_s", COUNT(DISTINCT ("t__key")) AS "distinct_count_t" FROM "sqlmesh_temp"."__temp_diff_abcdefgh"'
251251
sample_query_sql = 'SELECT "s_exists", "t_exists", "row_joined", "row_full_match", "s__key", "s__value", "t__key", "t__value" FROM "sqlmesh_temp"."__temp_diff_abcdefgh" WHERE "key_matches" = 0 OR "value_matches" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20'
252252

@@ -263,3 +263,17 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture)
263263
spy_execute.assert_any_call(query_sql)
264264
spy_execute.assert_any_call(summary_query_sql)
265265
spy_execute.assert_any_call(sample_query_sql)
266+
267+
spy_execute.reset_mock()
268+
269+
# Also check WHERE clause is propagated correctly
270+
sushi_context_fixed_date.table_diff(
271+
source="table_diff_source",
272+
target="table_diff_target",
273+
on=["key"],
274+
skip_columns=["ignored"],
275+
where="key = 2",
276+
)
277+
278+
query_sql_where = 'CREATE TABLE IF NOT EXISTS "sqlmesh_temp"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "key", "value" FROM "table_diff_source" WHERE "key" = 2), "__target" AS (SELECT "key", "value" FROM "table_diff_target" WHERE "key" = 2), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "t"."key" AS "t__key", "t"."value" AS "t__value", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."key" = "t"."key" AND NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON ("s"."key" = "t"."key") OR (("s"."key" IS NULL) AND ("t"."key" IS NULL))) SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"'
279+
spy_execute.assert_any_call(query_sql_where)

0 commit comments

Comments
 (0)