55from pathlib import Path
66
77from astor import to_source
8+ from collections import defaultdict
89from difflib import get_close_matches
910from sqlglot import exp
1011from sqlglot .helper import ensure_list
2829 from sqlmesh .utils import registry_decorator
2930 from sqlmesh .utils .jinja import MacroReference
3031
31- MacroCallable = registry_decorator
32+ MacroCallable = t . Union [ Executable , registry_decorator ]
3233
3334
3435def make_python_env (
@@ -48,13 +49,17 @@ def make_python_env(
4849 dialect : DialectType = None ,
4950) -> t .Dict [str , Executable ]:
5051 python_env = {} if python_env is None else python_env
51- variables = variables or {}
5252 env : t .Dict [str , t .Tuple [t .Any , t .Optional [bool ]]] = {}
53- used_macros : t .Dict [
54- str ,
55- t .Tuple [t .Union [Executable | MacroCallable ], t .Optional [bool ]],
56- ] = {}
57- used_variables = (used_variables or set ()).copy ()
53+
54+ variables = variables or {}
55+ blueprint_variables = blueprint_variables or {}
56+
57+ used_macros : t .Dict [str , t .Tuple [MacroCallable , t .Optional [bool ]]] = {}
58+ used_variable_referenced_in_metadata_expression = dict .fromkeys (used_variables or set (), False )
59+
60+ # For an expression like @foo(@v1, @bar(@v1, @v2), @v3), the following mapping would be:
61+ # v1 -> {"foo", "bar"}, v2 -> {"bar"}, v3 -> "foo"
62+ macro_funcs_by_used_var : t .DefaultDict [str , t .Set [str ]] = defaultdict (set )
5863
5964 expressions = ensure_list (expressions )
6065 for expression_metadata in expressions :
@@ -77,16 +82,27 @@ def make_python_env(
7782 macros [name ],
7883 used_macros .get (name , (None , is_metadata ))[1 ] and is_metadata ,
7984 )
80- if name == c .VAR :
85+ if name in ( c .VAR , c . BLUEPRINT_VAR ) :
8186 args = macro_func_or_var .this .expressions
8287 if len (args ) < 1 :
83- raise_config_error ("Macro VAR requires at least one argument" , path )
88+ raise_config_error (
89+ f"Macro { name .upper ()} requires at least one argument" , path
90+ )
91+
8492 if not args [0 ].is_string :
8593 raise_config_error (
8694 f"The variable name must be a string literal, '{ args [0 ].sql ()} ' was given instead" ,
8795 path ,
8896 )
89- used_variables .add (args [0 ].this .lower ())
97+
98+ var_name = args [0 ].this .lower ()
99+ used_variable_referenced_in_metadata_expression [var_name ] = (
100+ used_variable_referenced_in_metadata_expression .get (var_name , True )
101+ and bool (is_metadata )
102+ )
103+ else :
104+ for var_ref in _extract_macro_func_variable_references (macro_func_or_var ):
105+ macro_funcs_by_used_var [var_ref ].add (name )
90106 elif macro_func_or_var .__class__ is d .MacroVar :
91107 name = macro_func_or_var .name .lower ()
92108 if name in macros :
@@ -95,17 +111,23 @@ def make_python_env(
95111 macros [name ],
96112 used_macros .get (name , (None , is_metadata ))[1 ] and is_metadata ,
97113 )
98- elif name in variables :
99- used_variables .add (name )
114+ elif name in variables or name in blueprint_variables :
115+ used_variable_referenced_in_metadata_expression [name ] = (
116+ used_variable_referenced_in_metadata_expression .get (name , True )
117+ and bool (is_metadata )
118+ )
100119 elif (
101120 isinstance (macro_func_or_var , (exp .Identifier , d .MacroStrReplace , d .MacroSQL ))
102121 ) and "@" in macro_func_or_var .name :
103122 for _ , identifier , braced_identifier , _ in MacroStrTemplate .pattern .findall (
104123 macro_func_or_var .name
105124 ):
106125 var_name = braced_identifier or identifier
107- if var_name in variables :
108- used_variables .add (var_name )
126+ if var_name in variables or var_name in blueprint_variables :
127+ used_variable_referenced_in_metadata_expression [var_name ] = (
128+ used_variable_referenced_in_metadata_expression .get (var_name , True )
129+ and bool (is_metadata )
130+ )
109131
110132 for macro_ref in jinja_macro_references or set ():
111133 if macro_ref .package is None and macro_ref .name in macros :
@@ -126,41 +148,97 @@ def make_python_env(
126148 python_env .update (serialize_env (env , path = module_path ))
127149 return _add_variables_to_python_env (
128150 python_env ,
129- used_variables ,
151+ used_variable_referenced_in_metadata_expression ,
130152 variables ,
131153 blueprint_variables = blueprint_variables ,
132154 dialect = dialect ,
133155 strict_resolution = strict_resolution ,
156+ macro_funcs_by_used_var = macro_funcs_by_used_var ,
134157 )
135158
136159
160+ def _extract_macro_func_variable_references (macro_func : exp .Expression ) -> t .Set [str ]:
161+ references = set ()
162+
163+ for n in macro_func .walk ():
164+ if n is macro_func :
165+ continue
166+
167+ # Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
168+ # they will be handled in a separate call of _extract_macro_func_variable_references.
169+ if isinstance (n , d .MacroFunc ):
170+ this = n .this
171+ if this .name .lower () in (c .VAR , c .BLUEPRINT_VAR ) and this .expressions :
172+ references .add (this .expressions [0 ].this .lower ())
173+ elif isinstance (n , d .MacroVar ):
174+ references .add (n .name .lower ())
175+ elif isinstance (n , (exp .Identifier , d .MacroStrReplace , d .MacroSQL )) and "@" in n .name :
176+ references .update (
177+ (braced_identifier or identifier ).lower ()
178+ for _ , identifier , braced_identifier , _ in MacroStrTemplate .pattern .findall (n .name )
179+ )
180+
181+ return references
182+
183+
137184def _add_variables_to_python_env (
138185 python_env : t .Dict [str , Executable ],
139- used_variables : t .Optional [ t . Set [ str ] ],
186+ used_variable_referenced_in_metadata_expression : t .Dict [ str , bool ],
140187 variables : t .Optional [t .Dict [str , t .Any ]],
141188 strict_resolution : bool = True ,
142189 blueprint_variables : t .Optional [t .Dict [str , t .Any ]] = None ,
143190 dialect : DialectType = None ,
191+ macro_funcs_by_used_var : t .Optional [t .DefaultDict [str , t .Set [str ]]] = None ,
144192) -> t .Dict [str , Executable ]:
145- _ , python_used_variables = parse_dependencies (
193+ _ , python_used_variable_referenced_in_metadata_expression = parse_dependencies (
146194 python_env ,
147195 None ,
148196 strict_resolution = strict_resolution ,
149197 variables = variables ,
150198 blueprint_variables = blueprint_variables ,
151199 )
152- used_variables = (used_variables or set ()) | python_used_variables
200+ for var_name , is_metadata in python_used_variable_referenced_in_metadata_expression .items ():
201+ used_variable_referenced_in_metadata_expression [var_name ] = (
202+ used_variable_referenced_in_metadata_expression .get (var_name , True ) and is_metadata
203+ )
204+
205+ metadata_used_variables = set ()
206+ for used_var , macro_names in (macro_funcs_by_used_var or {}).items ():
207+ if used_variable_referenced_in_metadata_expression .get (used_var ) or all (
208+ name in python_env and python_env [name ].is_metadata for name in macro_names
209+ ):
210+ metadata_used_variables .add (used_var )
211+
212+ used_variables = set (used_variable_referenced_in_metadata_expression )
213+ non_metadata_used_variables = used_variables - metadata_used_variables
214+
215+ metadata_variables = {
216+ k : v for k , v in (variables or {}).items () if k in metadata_used_variables
217+ }
218+ variables = {k : v for k , v in (variables or {}).items () if k in non_metadata_used_variables }
153219
154- variables = {k : v for k , v in (variables or {}).items () if k in used_variables }
155220 if variables :
156221 python_env [c .SQLMESH_VARS ] = Executable .value (variables )
222+ if metadata_variables :
223+ python_env [c .SQLMESH_VARS_METADATA ] = Executable .value (metadata_variables , is_metadata = True )
157224
158225 if blueprint_variables :
226+ metadata_blueprint_variables = {
227+ k : SqlValue (sql = v .sql (dialect = dialect )) if isinstance (v , exp .Expression ) else v
228+ for k , v in blueprint_variables .items ()
229+ if k in metadata_used_variables
230+ }
159231 blueprint_variables = {
160232 k : SqlValue (sql = v .sql (dialect = dialect )) if isinstance (v , exp .Expression ) else v
161233 for k , v in blueprint_variables .items ()
234+ if k in non_metadata_used_variables
162235 }
163- python_env [c .SQLMESH_BLUEPRINT_VARS ] = Executable .value (blueprint_variables )
236+ if blueprint_variables :
237+ python_env [c .SQLMESH_BLUEPRINT_VARS ] = Executable .value (blueprint_variables )
238+ if metadata_blueprint_variables :
239+ python_env [c .SQLMESH_BLUEPRINT_VARS_METADATA ] = Executable .value (
240+ blueprint_variables , is_metadata = True
241+ )
164242
165243 return python_env
166244
@@ -171,7 +249,7 @@ def parse_dependencies(
171249 strict_resolution : bool = True ,
172250 variables : t .Optional [t .Dict [str , t .Any ]] = None ,
173251 blueprint_variables : t .Optional [t .Dict [str , t .Any ]] = None ,
174- ) -> t .Tuple [t .Set [str ], t .Set [str ]]:
252+ ) -> t .Tuple [t .Set [str ], t .Dict [str , bool ]]:
175253 """
176254 Parses the source of a model function and finds upstream table dependencies
177255 and referenced variables based on calls to context / evaluator.
@@ -185,7 +263,8 @@ def parse_dependencies(
185263 blueprint_variables: The blueprint variables available to the python environment.
186264
187265 Returns:
188- A tuple containing the set of upstream table dependencies and the set of referenced variables.
266+ A tuple containing the set of upstream table dependencies and a mapping of
267+ the referenced variables associated with their metadata status.
189268 """
190269
191270 class VariableResolutionContext :
@@ -203,12 +282,16 @@ def blueprint_var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optiona
203282 local_env = dict .fromkeys (("context" , "evaluator" ), VariableResolutionContext )
204283
205284 depends_on = set ()
206- used_variables = set ()
285+ used_variable_referenced_in_metadata_expression : t . Dict [ str , bool ] = {}
207286
208287 for executable in python_env .values ():
209288 if not executable .is_definition :
210289 continue
290+
291+ is_metadata = executable .is_metadata
211292 for node in ast .walk (ast .parse (executable .payload )):
293+ next_variables = set ()
294+
212295 if isinstance (node , ast .Call ):
213296 func = node .func
214297 if not isinstance (func , ast .Attribute ) or not isinstance (func .value , ast .Name ):
@@ -239,26 +322,35 @@ def get_first_arg(keyword_arg_name: str) -> t.Any:
239322
240323 if func .value .id == "context" and func .attr in ("table" , "resolve_table" ):
241324 depends_on .add (get_first_arg ("model_name" ))
242- elif func .value .id in ("context" , "evaluator" ) and func .attr == c .VAR :
243- used_variables .add (get_first_arg ("var_name" ).lower ())
325+ elif func .value .id in ("context" , "evaluator" ) and func .attr in (
326+ c .VAR ,
327+ c .BLUEPRINT_VAR ,
328+ ):
329+ next_variables .add (get_first_arg ("var_name" ).lower ())
244330 elif (
245331 isinstance (node , ast .Attribute )
246332 and isinstance (node .value , ast .Name )
247333 and node .value .id in ("context" , "evaluator" )
248334 and node .attr == c .GATEWAY
249335 ):
250336 # Check whether the gateway attribute is referenced.
251- used_variables .add (c .GATEWAY )
337+ next_variables .add (c .GATEWAY )
252338 elif isinstance (node , ast .FunctionDef ) and node .name == entrypoint :
253- used_variables .update (
339+ next_variables .update (
254340 [
255341 arg .arg
256342 for arg in [* node .args .args , * node .args .kwonlyargs ]
257343 if arg .arg != "context"
258344 ]
259345 )
260346
261- return depends_on , used_variables
347+ for var_name in next_variables :
348+ used_variable_referenced_in_metadata_expression [var_name ] = (
349+ used_variable_referenced_in_metadata_expression .get (var_name , True )
350+ and bool (is_metadata )
351+ )
352+
353+ return depends_on , used_variable_referenced_in_metadata_expression
262354
263355
264356def validate_extra_and_required_fields (
0 commit comments