Skip to content

Commit 6e71e5f

Browse files
deduplicate SQLMesh Native Macro (#2960)
Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
1 parent 1cdda84 commit 6e71e5f

File tree

3 files changed

+182
-0
lines changed

3 files changed

+182
-0
lines changed

docs/concepts/macros/sqlmesh_macros.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,40 @@ FROM rides
901901
GROUP BY 1
902902
```
903903

904+
### @DEDUPLICATE
905+
906+
`@DEDUPLICATE` is used to deduplicate rows in a table based on the specified partition and order columns with a window function.
907+
908+
It supports the following arguments, in this order:
909+
910+
- `relation`: The table or CTE name to deduplicate
911+
- `partition_by`: column names, or expressions to use to identify a window of rows out of which to select one as the deduplicated row
912+
- `order_by`: A list of strings representing the ORDER BY clause, optional - nulls ordering: ['<column> <asc|desc> nulls <first|last>']
913+
914+
For example, the following query:
915+
```sql linenums="1"
916+
with raw_data as (
917+
@deduplicate(my_table, [id, cast(event_date as date)], ['event_date DESC', 'status ASC'])
918+
)
919+
920+
select * from raw_data
921+
```
922+
923+
would be rendered as:
924+
925+
```sql linenums="1"
926+
WITH "raw_data" AS (
927+
SELECT
928+
*
929+
FROM "my_table" AS "my_table"
930+
QUALIFY
931+
ROW_NUMBER() OVER (PARTITION BY "id", CAST("event_date" AS DATE) ORDER BY "event_date" DESC, "status" ASC) = 1
932+
)
933+
SELECT
934+
*
935+
FROM "raw_data" AS "raw_data"
936+
```
937+
904938
### @AND
905939

906940
`@AND` combines a sequence of operands using the `AND` operator, filtering out any NULL expressions.

sqlmesh/core/macros.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,63 @@ def var(
11001100
return exp.convert(evaluator.var(var_name.this, default))
11011101

11021102

1103+
@macro()
1104+
def deduplicate(
1105+
evaluator: MacroEvaluator,
1106+
relation: exp.Expression,
1107+
partition_by: t.List[exp.Expression],
1108+
order_by: t.List[str],
1109+
) -> exp.Query:
1110+
"""Returns a QUERY to deduplicate rows within a table
1111+
1112+
Args:
1113+
relation: table or CTE name to deduplicate
1114+
partition_by: column names, or expressions to use to identify a window of rows out of which to select one as the deduplicated row
1115+
order_by: A list of strings representing the ORDER BY clause
1116+
1117+
Example:
1118+
>>> from sqlglot import parse_one
1119+
>>> from sqlglot.schema import MappingSchema
1120+
>>> from sqlmesh.core.macros import MacroEvaluator
1121+
>>> sql = "@deduplicate(demo.table, [user_id, cast(timestamp as date)], ['timestamp desc', 'status asc'])"
1122+
>>> MacroEvaluator().transform(parse_one(sql)).sql()
1123+
'SELECT * FROM demo.table QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, CAST(timestamp AS DATE) ORDER BY timestamp DESC, status ASC) = 1'
1124+
"""
1125+
if not isinstance(partition_by, list):
1126+
raise SQLMeshError(
1127+
"partition_by must be a list of columns: [<column>, cast(<column> as <type>)]"
1128+
)
1129+
1130+
if not isinstance(order_by, list):
1131+
raise SQLMeshError(
1132+
"order_by must be a list of strings, optional - nulls ordering: ['<column> <asc|desc> nulls <first|last>']"
1133+
)
1134+
1135+
partition_clause = exp.tuple_(*partition_by)
1136+
1137+
order_expressions = [
1138+
evaluator.transform(parse_one(order_item, into=exp.Ordered, dialect=evaluator.dialect))
1139+
for order_item in order_by
1140+
]
1141+
1142+
if not order_expressions:
1143+
raise SQLMeshError(
1144+
"order_by must be a list of strings, optional - nulls ordering: ['<column> <asc|desc> nulls <first|last>']"
1145+
)
1146+
1147+
order_clause = exp.Order(expressions=order_expressions)
1148+
1149+
window_function = exp.Window(
1150+
this=exp.RowNumber(), partition_by=partition_clause, order=order_clause
1151+
)
1152+
1153+
first_unique_row = window_function.eq(1)
1154+
1155+
query = exp.select("*").from_(relation).qualify(first_unique_row)
1156+
1157+
return query
1158+
1159+
11031160
def normalize_macro_name(name: str) -> str:
11041161
"""Prefix macro name with @ and upcase"""
11051162
return f"@{name.upper()}"

tests/core/test_macros.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,3 +663,94 @@ def test_macro_first_value_ignore_respect_nulls(assert_exp_eq) -> None:
663663
"SELECT FIRST_VALUE(@test(x) RESPECT NULLS) OVER (ORDER BY y) AS column_test"
664664
)
665665
assert_exp_eq(evaluator.transform(actual_expr), expected_sql, dialect="duckdb")
666+
667+
668+
DEDUPLICATE_SQL = """
669+
@deduplicate(
670+
my_table,
671+
[user_id, CAST(timestamp AS DATE)],
672+
['timestamp DESC', 'status ASC nulls last']
673+
)
674+
"""
675+
676+
677+
@pytest.mark.parametrize(
678+
"dialect, sql, expected_sql",
679+
[
680+
*[
681+
(
682+
dialect,
683+
DEDUPLICATE_SQL,
684+
"""
685+
SELECT *
686+
FROM my_table
687+
QUALIFY ROW_NUMBER() OVER (
688+
PARTITION BY user_id, CAST(timestamp AS DATE)
689+
ORDER BY timestamp DESC, status ASC NULLS LAST
690+
) = 1
691+
""",
692+
)
693+
for dialect in ["bigquery", "databricks", "snowflake", "duckdb"]
694+
],
695+
(
696+
"redshift",
697+
DEDUPLICATE_SQL,
698+
"""
699+
SELECT *
700+
FROM my_table
701+
QUALIFY ROW_NUMBER() OVER (
702+
PARTITION BY user_id, CAST("timestamp" AS DATE)
703+
ORDER BY "timestamp" DESC, status ASC NULLS LAST
704+
) = 1
705+
""",
706+
),
707+
*[
708+
(
709+
dialect,
710+
DEDUPLICATE_SQL,
711+
"""
712+
SELECT *
713+
FROM (
714+
SELECT *, ROW_NUMBER() OVER (
715+
PARTITION BY user_id, CAST(timestamp AS DATE)
716+
ORDER BY timestamp DESC, status ASC NULLS LAST
717+
) AS _w
718+
FROM my_table
719+
) as _t
720+
WHERE _w = 1
721+
""",
722+
)
723+
for dialect in ["trino", "postgres"]
724+
],
725+
],
726+
)
727+
def test_deduplicate(assert_exp_eq, dialect, sql, expected_sql):
728+
schema = MappingSchema({}, dialect=dialect)
729+
evaluator = MacroEvaluator(schema=schema, dialect=dialect)
730+
assert_exp_eq(evaluator.transform(parse_one(sql)), expected_sql, dialect=dialect)
731+
732+
733+
def test_deduplicate_error_handling(macro_evaluator):
734+
# Test error handling: non-list partition_by
735+
with pytest.raises(SQLMeshError) as e:
736+
macro_evaluator.evaluate(parse_one("@deduplicate(my_table, user_id, ['timestamp DESC'])"))
737+
assert (
738+
str(e.value.__cause__)
739+
== "partition_by must be a list of columns: [<column>, cast(<column> as <type>)]"
740+
)
741+
742+
# Test error handling: non-list order_by
743+
with pytest.raises(SQLMeshError) as e:
744+
macro_evaluator.evaluate(parse_one("@deduplicate(my_table, [user_id], 'timestamp DESC')"))
745+
assert (
746+
str(e.value.__cause__)
747+
== "order_by must be a list of strings, optional - nulls ordering: ['<column> <asc|desc> nulls <first|last>']"
748+
)
749+
750+
# Test error handling: empty order_by
751+
with pytest.raises(SQLMeshError) as e:
752+
macro_evaluator.evaluate(parse_one("@deduplicate(my_table, [user_id], [])"))
753+
assert (
754+
str(e.value.__cause__)
755+
== "order_by must be a list of strings, optional - nulls ordering: ['<column> <asc|desc> nulls <first|last>']"
756+
)

0 commit comments

Comments
 (0)