55import unittest
66from collections import Counter
77from contextlib import AbstractContextManager , nullcontext
8+ from itertools import chain
89from pathlib import Path
910from unittest .mock import patch
1011
@@ -126,25 +127,25 @@ def setUp(self) -> None:
126127
127128 for name , values in self .body .get ("inputs" , {}).items ():
128129 all_types_are_known = False
129- known_columns_to_types : t .Dict [str , exp .DataType ] = {}
130+ columns_to_known_types : t .Dict [str , exp .DataType ] = {}
130131
131132 model = self .models .get (name )
132133 if model :
133134 inferred_columns_to_types = model .columns_to_types or {}
134- known_columns_to_types = {
135+ columns_to_known_types = {
135136 c : t for c , t in inferred_columns_to_types .items () if type_is_known (t )
136137 }
137138 all_types_are_known = bool (inferred_columns_to_types ) and (
138- len (known_columns_to_types ) == len (inferred_columns_to_types )
139+ len (columns_to_known_types ) == len (inferred_columns_to_types )
139140 )
140141
141142 # Types specified in the test will override the corresponding inferred ones
142- known_columns_to_types .update (values .get ("columns" , {}))
143+ columns_to_known_types .update (values .get ("columns" , {}))
143144
144145 rows = values .get ("rows" )
145146 if not all_types_are_known and rows :
146147 for col , value in rows [0 ].items ():
147- if col not in known_columns_to_types :
148+ if col not in columns_to_known_types :
148149 v_type = annotate_types (exp .convert (value )).type or type (value ).__name__
149150 v_type = exp .maybe_parse (
150151 v_type , into = exp .DataType , dialect = self ._test_adapter_dialect
@@ -159,21 +160,21 @@ def setUp(self) -> None:
159160 self .path ,
160161 )
161162
162- known_columns_to_types [col ] = v_type
163+ columns_to_known_types [col ] = v_type
163164
164165 if rows is None :
165166 query_or_df : exp .Query | pd .DataFrame = self ._add_missing_columns (
166- values ["query" ], known_columns_to_types
167+ values ["query" ], columns_to_known_types
167168 )
168- if known_columns_to_types :
169- known_columns_to_types = {
170- col : known_columns_to_types [col ] for col in query_or_df .named_selects
169+ if columns_to_known_types :
170+ columns_to_known_types = {
171+ col : columns_to_known_types [col ] for col in query_or_df .named_selects
171172 }
172173 else :
173- query_or_df = self ._create_df (values , columns = known_columns_to_types )
174+ query_or_df = self ._create_df (values , columns = columns_to_known_types )
174175
175176 self .engine_adapter .create_view (
176- self ._test_fixture_table (name ), query_or_df , known_columns_to_types
177+ self ._test_fixture_table (name ), query_or_df , columns_to_known_types
177178 )
178179
179180 def tearDown (self ) -> None :
@@ -525,7 +526,7 @@ def _add_missing_columns(
525526
526527
527528class SqlModelTest (ModelTest ):
528- def test_ctes (self , ctes : t .Dict [str , exp .Expression ]) -> None :
529+ def test_ctes (self , ctes : t .Dict [str , exp .Expression ], recursive : bool = False ) -> None :
529530 """Run CTE queries and compare output to expected output"""
530531 for cte_name , values in self .body ["outputs" ].get ("ctes" , {}).items ():
531532 with self .subTest (cte = cte_name ):
@@ -535,11 +536,13 @@ def test_ctes(self, ctes: t.Dict[str, exp.Expression]) -> None:
535536 )
536537
537538 cte_query = ctes [cte_name ].this
538- for alias , cte in ctes .items ():
539- cte_query = cte_query .with_ (alias , cte .this )
540539
541- partial = values .get ("partial" )
542540 sort = cte_query .args .get ("order" ) is None
541+ partial = values .get ("partial" )
542+
543+ cte_query = exp .select (* _projection_identifiers (cte_query )).from_ (cte_name )
544+ for alias , cte in ctes .items ():
545+ cte_query = cte_query .with_ (alias , cte .this , recursive = recursive )
543546
544547 actual = self ._execute (cte_query )
545548 expected = self ._create_df (values , columns = cte_query .named_selects , partial = partial )
@@ -548,13 +551,16 @@ def test_ctes(self, ctes: t.Dict[str, exp.Expression]) -> None:
548551
549552 def runTest (self ) -> None :
550553 query = self ._render_model_query ()
551-
552- self .test_ctes (
553- {
554- self ._normalize_model_name (cte .alias , with_default_catalog = False ): cte
555- for cte in query .ctes
556- }
557- )
554+ with_clause = query .args .get ("with" )
555+
556+ if with_clause :
557+ self .test_ctes (
558+ {
559+ self ._normalize_model_name (cte .alias , with_default_catalog = False ): cte
560+ for cte in query .ctes
561+ },
562+ recursive = with_clause .recursive ,
563+ )
558564
559565 values = self .body ["outputs" ].get ("query" )
560566 if values is not None :
@@ -732,14 +738,23 @@ def generate_test(
732738 if isinstance (model , SqlModel ):
733739 assert isinstance (test , SqlModelTest )
734740 model_query = test ._render_model_query ()
741+ with_clause = model_query .args .get ("with" )
735742
736- if include_ctes :
743+ if with_clause and include_ctes :
737744 ctes = {}
745+ recursive = with_clause .recursive
738746 previous_ctes : t .List [exp .CTE ] = []
747+
739748 for cte in model_query .ctes :
740749 cte_query = cte .this
741- for prev in previous_ctes :
742- cte_query = cte_query .with_ (prev .alias , prev .this )
750+ cte_identifier = cte .args ["alias" ].this
751+
752+ cte_query = exp .select (* _projection_identifiers (cte_query )).from_ (cte_identifier )
753+
754+ for prev in chain (previous_ctes , [cte ]):
755+ cte_query = cte_query .with_ (
756+ prev .args ["alias" ].this , prev .this , recursive = recursive
757+ )
743758
744759 cte_output = test ._execute (cte_query )
745760 ctes [cte .alias ] = (
@@ -775,6 +790,19 @@ def generate_test(
775790 yaml .dump ({test_name : test_body }, file )
776791
777792
793+ def _projection_identifiers (query : exp .Query ) -> t .List [str | exp .Identifier ]:
794+ identifiers : t .List [str | exp .Identifier ] = []
795+ for select in query .selects :
796+ if isinstance (select , exp .Alias ):
797+ identifiers .append (select .args ["alias" ])
798+ elif isinstance (select , exp .Column ):
799+ identifiers .append (select .this )
800+ else :
801+ identifiers .append (select .output_name )
802+
803+ return identifiers
804+
805+
778806def _raise_if_unexpected_columns (
779807 expected_cols : t .Collection [str ], actual_cols : t .Collection [str ]
780808) -> None :
0 commit comments