Skip to content

Commit 855688d

Browse files
authored
Feat: add support for recursive CTEs in unit tests (#3351)
1 parent 5315e2c commit 855688d

File tree

6 files changed

+188
-33
lines changed

6 files changed

+188
-33
lines changed

docs/reference/cli.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ Usage: sqlmesh create_test [OPTIONS] MODEL
9292
9393
Options:
9494
-q, --query <TEXT TEXT>... Queries that will be used to generate data for
95-
the model's dependencies. [required]
95+
the model's dependencies.
9696
-o, --overwrite When true, the fixture file will be overwritten
9797
in case it already exists.
9898
-v, --var <TEXT TEXT>... Key-value pairs that will define variables

docs/reference/notebook.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,9 @@ options:
366366

367367
#### create_test
368368
```
369-
%create_test --query QUERY [QUERY ...] [--overwrite]
369+
%create_test [--query QUERY [QUERY ...]] [--overwrite]
370370
[--var VAR [VAR ...]] [--path PATH] [--name NAME]
371+
[--include-ctes]
371372
model
372373
373374
Generate a unit test fixture for a given model.

sqlmesh/cli/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def dag(ctx: click.Context, file: str, select_model: t.List[str]) -> None:
530530
"queries",
531531
type=(str, str),
532532
multiple=True,
533-
required=True,
533+
default=[],
534534
help="Queries that will be used to generate data for the model's dependencies.",
535535
)
536536
@click.option(

sqlmesh/core/test/definition.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import unittest
66
from collections import Counter
77
from contextlib import AbstractContextManager, nullcontext
8+
from itertools import chain
89
from pathlib import Path
910
from unittest.mock import patch
1011

@@ -126,25 +127,25 @@ def setUp(self) -> None:
126127

127128
for name, values in self.body.get("inputs", {}).items():
128129
all_types_are_known = False
129-
known_columns_to_types: t.Dict[str, exp.DataType] = {}
130+
columns_to_known_types: t.Dict[str, exp.DataType] = {}
130131

131132
model = self.models.get(name)
132133
if model:
133134
inferred_columns_to_types = model.columns_to_types or {}
134-
known_columns_to_types = {
135+
columns_to_known_types = {
135136
c: t for c, t in inferred_columns_to_types.items() if type_is_known(t)
136137
}
137138
all_types_are_known = bool(inferred_columns_to_types) and (
138-
len(known_columns_to_types) == len(inferred_columns_to_types)
139+
len(columns_to_known_types) == len(inferred_columns_to_types)
139140
)
140141

141142
# Types specified in the test will override the corresponding inferred ones
142-
known_columns_to_types.update(values.get("columns", {}))
143+
columns_to_known_types.update(values.get("columns", {}))
143144

144145
rows = values.get("rows")
145146
if not all_types_are_known and rows:
146147
for col, value in rows[0].items():
147-
if col not in known_columns_to_types:
148+
if col not in columns_to_known_types:
148149
v_type = annotate_types(exp.convert(value)).type or type(value).__name__
149150
v_type = exp.maybe_parse(
150151
v_type, into=exp.DataType, dialect=self._test_adapter_dialect
@@ -159,21 +160,21 @@ def setUp(self) -> None:
159160
self.path,
160161
)
161162

162-
known_columns_to_types[col] = v_type
163+
columns_to_known_types[col] = v_type
163164

164165
if rows is None:
165166
query_or_df: exp.Query | pd.DataFrame = self._add_missing_columns(
166-
values["query"], known_columns_to_types
167+
values["query"], columns_to_known_types
167168
)
168-
if known_columns_to_types:
169-
known_columns_to_types = {
170-
col: known_columns_to_types[col] for col in query_or_df.named_selects
169+
if columns_to_known_types:
170+
columns_to_known_types = {
171+
col: columns_to_known_types[col] for col in query_or_df.named_selects
171172
}
172173
else:
173-
query_or_df = self._create_df(values, columns=known_columns_to_types)
174+
query_or_df = self._create_df(values, columns=columns_to_known_types)
174175

175176
self.engine_adapter.create_view(
176-
self._test_fixture_table(name), query_or_df, known_columns_to_types
177+
self._test_fixture_table(name), query_or_df, columns_to_known_types
177178
)
178179

179180
def tearDown(self) -> None:
@@ -525,7 +526,7 @@ def _add_missing_columns(
525526

526527

527528
class SqlModelTest(ModelTest):
528-
def test_ctes(self, ctes: t.Dict[str, exp.Expression]) -> None:
529+
def test_ctes(self, ctes: t.Dict[str, exp.Expression], recursive: bool = False) -> None:
529530
"""Run CTE queries and compare output to expected output"""
530531
for cte_name, values in self.body["outputs"].get("ctes", {}).items():
531532
with self.subTest(cte=cte_name):
@@ -535,11 +536,13 @@ def test_ctes(self, ctes: t.Dict[str, exp.Expression]) -> None:
535536
)
536537

537538
cte_query = ctes[cte_name].this
538-
for alias, cte in ctes.items():
539-
cte_query = cte_query.with_(alias, cte.this)
540539

541-
partial = values.get("partial")
542540
sort = cte_query.args.get("order") is None
541+
partial = values.get("partial")
542+
543+
cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_name)
544+
for alias, cte in ctes.items():
545+
cte_query = cte_query.with_(alias, cte.this, recursive=recursive)
543546

544547
actual = self._execute(cte_query)
545548
expected = self._create_df(values, columns=cte_query.named_selects, partial=partial)
@@ -548,13 +551,16 @@ def test_ctes(self, ctes: t.Dict[str, exp.Expression]) -> None:
548551

549552
def runTest(self) -> None:
550553
query = self._render_model_query()
551-
552-
self.test_ctes(
553-
{
554-
self._normalize_model_name(cte.alias, with_default_catalog=False): cte
555-
for cte in query.ctes
556-
}
557-
)
554+
with_clause = query.args.get("with")
555+
556+
if with_clause:
557+
self.test_ctes(
558+
{
559+
self._normalize_model_name(cte.alias, with_default_catalog=False): cte
560+
for cte in query.ctes
561+
},
562+
recursive=with_clause.recursive,
563+
)
558564

559565
values = self.body["outputs"].get("query")
560566
if values is not None:
@@ -732,14 +738,23 @@ def generate_test(
732738
if isinstance(model, SqlModel):
733739
assert isinstance(test, SqlModelTest)
734740
model_query = test._render_model_query()
741+
with_clause = model_query.args.get("with")
735742

736-
if include_ctes:
743+
if with_clause and include_ctes:
737744
ctes = {}
745+
recursive = with_clause.recursive
738746
previous_ctes: t.List[exp.CTE] = []
747+
739748
for cte in model_query.ctes:
740749
cte_query = cte.this
741-
for prev in previous_ctes:
742-
cte_query = cte_query.with_(prev.alias, prev.this)
750+
cte_identifier = cte.args["alias"].this
751+
752+
cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_identifier)
753+
754+
for prev in chain(previous_ctes, [cte]):
755+
cte_query = cte_query.with_(
756+
prev.args["alias"].this, prev.this, recursive=recursive
757+
)
743758

744759
cte_output = test._execute(cte_query)
745760
ctes[cte.alias] = (
@@ -775,6 +790,19 @@ def generate_test(
775790
yaml.dump({test_name: test_body}, file)
776791

777792

793+
def _projection_identifiers(query: exp.Query) -> t.List[str | exp.Identifier]:
794+
identifiers: t.List[str | exp.Identifier] = []
795+
for select in query.selects:
796+
if isinstance(select, exp.Alias):
797+
identifiers.append(select.args["alias"])
798+
elif isinstance(select, exp.Column):
799+
identifiers.append(select.this)
800+
else:
801+
identifiers.append(select.output_name)
802+
803+
return identifiers
804+
805+
778806
def _raise_if_unexpected_columns(
779807
expected_cols: t.Collection[str], actual_cols: t.Collection[str]
780808
) -> None:

sqlmesh/magics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ def janitor(self, context: Context, line: str) -> None:
858858
"-q",
859859
type=str,
860860
nargs="+",
861-
required=True,
861+
default=[],
862862
help="Queries that will be used to generate data for the model's dependencies.",
863863
)
864864
@argument(

tests/core/test_test.py

Lines changed: 129 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,12 +1450,12 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) ->
14501450
)
14511451

14521452

1453-
def test_pyspark_python_model() -> None:
1453+
def test_pyspark_python_model(tmp_path: Path) -> None:
14541454
spark_connection_config = SparkConnectionConfig(
14551455
config={
14561456
"spark.master": "local",
1457-
"spark.sql.warehouse.dir": "/tmp/data_dir",
1458-
"spark.driver.extraJavaOptions": "-Dderby.system.home=/tmp/derby_dir",
1457+
"spark.sql.warehouse.dir": f"{tmp_path}/data_dir",
1458+
"spark.driver.extraJavaOptions": f"-Dderby.system.home={tmp_path}/derby_dir",
14591459
},
14601460
)
14611461
config = Config(
@@ -1572,6 +1572,100 @@ def test_custom_testing_schema(mocker: MockerFixture) -> None:
15721572
)
15731573

15741574

1575+
def test_complicated_recursive_cte() -> None:
1576+
model_sql = """
1577+
WITH
1578+
RECURSIVE
1579+
chained_contacts AS (
1580+
-- Start with the initial set of contacts and their immediate nodes
1581+
SELECT
1582+
id_contact_a,
1583+
id_contact_b
1584+
FROM
1585+
source
1586+
1587+
UNION ALL
1588+
1589+
-- Recursive step to find further connected nodes
1590+
SELECT
1591+
chained_contacts.id_contact_a,
1592+
unfactorized_duplicates.id_contact_b
1593+
FROM
1594+
chained_contacts
1595+
JOIN source AS unfactorized_duplicates
1596+
ON chained_contacts.id_contact_b = unfactorized_duplicates.id_contact_a
1597+
),
1598+
id_contact_a_with_aggregated_id_contact_bs AS (
1599+
SELECT
1600+
id_contact_a,
1601+
ARRAY_AGG(DISTINCT id_contact_b ORDER BY id_contact_b) AS aggregated_id_contact_bs
1602+
FROM
1603+
chained_contacts
1604+
GROUP BY
1605+
id_contact_a
1606+
)
1607+
SELECT
1608+
ARRAY_CONCAT([id_contact_a], aggregated_id_contact_bs) AS aggregated_duplicates
1609+
FROM
1610+
id_contact_a_with_aggregated_id_contact_bs
1611+
WHERE
1612+
id_contact_a NOT IN (
1613+
SELECT DISTINCT
1614+
id_contact_b
1615+
FROM
1616+
source
1617+
)
1618+
ORDER BY
1619+
id_contact_a
1620+
"""
1621+
1622+
_check_successful_or_raise(
1623+
_create_test(
1624+
body=load_yaml(
1625+
"""
1626+
test_recursive_ctes:
1627+
model: test
1628+
inputs:
1629+
source:
1630+
rows:
1631+
- id_contact_a: "a"
1632+
id_contact_b: "b"
1633+
- id_contact_a: "b"
1634+
id_contact_b: "c"
1635+
- id_contact_a: "c"
1636+
id_contact_b: "d"
1637+
- id_contact_a: "a"
1638+
id_contact_b: "g"
1639+
- id_contact_a: "b"
1640+
id_contact_b: "e"
1641+
- id_contact_a: "c"
1642+
id_contact_b: "f"
1643+
- id_contact_a: "x"
1644+
id_contact_b: "y"
1645+
outputs:
1646+
ctes:
1647+
id_contact_a_with_aggregated_id_contact_bs:
1648+
- id_contact_a: a
1649+
aggregated_id_contact_bs: [b, c, d, e, f, g]
1650+
- id_contact_a: x
1651+
aggregated_id_contact_bs: [y]
1652+
- id_contact_a: b
1653+
aggregated_id_contact_bs: [c, d, e, f]
1654+
- id_contact_a: c
1655+
aggregated_id_contact_bs: [d, f]
1656+
query:
1657+
rows:
1658+
- aggregated_duplicates: [a, b, c, d, e, f, g]
1659+
- aggregated_duplicates: [x, y]
1660+
"""
1661+
),
1662+
test_name="test_recursive_ctes",
1663+
model=_create_model(model_sql),
1664+
context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))),
1665+
).run()
1666+
)
1667+
1668+
15751669
def test_test_generation(tmp_path: Path) -> None:
15761670
init_example_project(tmp_path, dialect="duckdb")
15771671

@@ -1789,3 +1883,35 @@ def test_test_generation_with_decimal(tmp_path: Path, mocker: MockerFixture) ->
17891883
assert "test_foo" in test
17901884
assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": [{"dec_col": "1.23"}]}
17911885
assert test["test_foo"]["outputs"] == {"query": [{"dec_col": "1.23"}]}
1886+
1887+
1888+
def test_test_generation_with_recursive_ctes(tmp_path: Path) -> None:
1889+
init_example_project(tmp_path, dialect="duckdb")
1890+
1891+
config = Config(
1892+
default_connection=DuckDBConnectionConfig(),
1893+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
1894+
)
1895+
foo_sql_file = tmp_path / "models" / "foo.sql"
1896+
foo_sql_file.write_text(
1897+
"MODEL (name sqlmesh_example.foo);"
1898+
"WITH RECURSIVE t AS (SELECT 1 AS c UNION ALL SELECT c + 1 FROM t WHERE c < 3) SELECT c FROM t"
1899+
)
1900+
1901+
context = Context(paths=tmp_path, config=config)
1902+
context.plan(auto_apply=True)
1903+
1904+
context.create_test("sqlmesh_example.foo", input_queries={}, overwrite=True, include_ctes=True)
1905+
1906+
test = load_yaml(context.path / c.TESTS / "test_foo.yaml")
1907+
assert len(test) == 1
1908+
assert "test_foo" in test
1909+
assert test["test_foo"]["inputs"] == {}
1910+
assert test["test_foo"]["outputs"] == {
1911+
"query": [{"c": 1}, {"c": 2}, {"c": 3}],
1912+
"ctes": {
1913+
"t": [{"c": 1}, {"c": 2}, {"c": 3}],
1914+
},
1915+
}
1916+
1917+
_check_successful_or_raise(context.test())

0 commit comments

Comments
 (0)