55from pathlib import Path
66
77from astor import to_source
8+ from collections import defaultdict
89from difflib import get_close_matches
910from sqlglot import exp
1011from sqlglot .helper import ensure_list
2829 from sqlmesh .utils import registry_decorator
2930 from sqlmesh .utils .jinja import MacroReference
3031
31- MacroCallable = registry_decorator
32+ MacroCallable = t . Union [ Executable , registry_decorator ]
3233
3334
3435def make_python_env (
@@ -48,13 +49,17 @@ def make_python_env(
4849 dialect : DialectType = None ,
4950) -> t .Dict [str , Executable ]:
5051 python_env = {} if python_env is None else python_env
51- variables = variables or {}
5252 env : t .Dict [str , t .Tuple [t .Any , t .Optional [bool ]]] = {}
53- used_macros : t .Dict [
54- str ,
55- t .Tuple [t .Union [Executable | MacroCallable ], t .Optional [bool ]],
56- ] = {}
57- used_variables = (used_variables or set ()).copy ()
53+
54+ variables = variables or {}
55+ blueprint_variables = blueprint_variables or {}
56+
57+ used_macros : t .Dict [str , t .Tuple [MacroCallable , t .Optional [bool ]]] = {}
58+ used_variable_referenced_in_metadata_expression = dict .fromkeys (used_variables or set (), False )
59+
60+ # For an expression like @foo(@v1, @bar(@v1, @v2), @v3), the following mapping would be:
61+ # v1 -> {"foo", "bar"}, v2 -> {"bar"}, v3 -> "foo"
62+ macro_funcs_by_used_var : t .DefaultDict [str , t .Set [str ]] = defaultdict (set )
5863
5964 expressions = ensure_list (expressions )
6065 for expression_metadata in expressions :
@@ -77,16 +82,27 @@ def make_python_env(
7782 macros [name ],
7883 used_macros .get (name , (None , is_metadata ))[1 ] and is_metadata ,
7984 )
80- if name == c .VAR :
85+ if name in ( c .VAR , c . BLUEPRINT_VAR ) :
8186 args = macro_func_or_var .this .expressions
8287 if len (args ) < 1 :
83- raise_config_error ("Macro VAR requires at least one argument" , path )
88+ raise_config_error (
89+ f"Macro { name .upper ()} requires at least one argument" , path
90+ )
91+
8492 if not args [0 ].is_string :
8593 raise_config_error (
8694 f"The variable name must be a string literal, '{ args [0 ].sql ()} ' was given instead" ,
8795 path ,
8896 )
89- used_variables .add (args [0 ].this .lower ())
97+
98+ var_name = args [0 ].this .lower ()
99+ used_variable_referenced_in_metadata_expression [var_name ] = (
100+ used_variable_referenced_in_metadata_expression .get (var_name , True )
101+ and bool (is_metadata )
102+ )
103+ else :
104+ for var_ref in _extract_macro_func_variable_references (macro_func_or_var ):
105+ macro_funcs_by_used_var [var_ref ].add (name )
90106 elif macro_func_or_var .__class__ is d .MacroVar :
91107 name = macro_func_or_var .name .lower ()
92108 if name in macros :
@@ -95,17 +111,23 @@ def make_python_env(
95111 macros [name ],
96112 used_macros .get (name , (None , is_metadata ))[1 ] and is_metadata ,
97113 )
98- elif name in variables :
99- used_variables .add (name )
114+ elif name in variables or name in blueprint_variables :
115+ used_variable_referenced_in_metadata_expression [name ] = (
116+ used_variable_referenced_in_metadata_expression .get (name , True )
117+ and bool (is_metadata )
118+ )
100119 elif (
101120 isinstance (macro_func_or_var , (exp .Identifier , d .MacroStrReplace , d .MacroSQL ))
102121 ) and "@" in macro_func_or_var .name :
103122 for _ , identifier , braced_identifier , _ in MacroStrTemplate .pattern .findall (
104123 macro_func_or_var .name
105124 ):
106125 var_name = braced_identifier or identifier
107- if var_name in variables :
108- used_variables .add (var_name )
126+ if var_name in variables or var_name in blueprint_variables :
127+ used_variable_referenced_in_metadata_expression [var_name ] = (
128+ used_variable_referenced_in_metadata_expression .get (var_name , True )
129+ and bool (is_metadata )
130+ )
109131
110132 for macro_ref in jinja_macro_references or set ():
111133 if macro_ref .package is None and macro_ref .name in macros :
@@ -126,43 +148,101 @@ def make_python_env(
126148 python_env .update (serialize_env (env , path = module_path ))
127149 return _add_variables_to_python_env (
128150 python_env ,
129- used_variables ,
151+ used_variable_referenced_in_metadata_expression ,
130152 variables ,
131153 blueprint_variables = blueprint_variables ,
132154 dialect = dialect ,
133155 strict_resolution = strict_resolution ,
156+ macro_funcs_by_used_var = macro_funcs_by_used_var ,
134157 )
135158
136159
160+ def _extract_macro_func_variable_references (macro_func : exp .Expression ) -> t .Set [str ]:
161+ references = set ()
162+
163+ for n in macro_func .walk ():
164+ if n is macro_func :
165+ continue
166+
167+ # Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
168+ # they will be handled in a separate call of _extract_macro_func_variable_references.
169+ if isinstance (n , d .MacroFunc ):
170+ this = n .this
171+ if this .name .lower () in (c .VAR , c .BLUEPRINT_VAR ) and this .expressions :
172+ references .add (this .expressions [0 ].this .lower ())
173+ elif isinstance (n , d .MacroVar ):
174+ references .add (n .name .lower ())
175+ elif isinstance (n , (exp .Identifier , d .MacroStrReplace , d .MacroSQL )) and "@" in n .name :
176+ references .update (
177+ (braced_identifier or identifier ).lower ()
178+ for _ , identifier , braced_identifier , _ in MacroStrTemplate .pattern .findall (n .name )
179+ )
180+
181+ return references
182+
183+
137184def _add_variables_to_python_env (
138185 python_env : t .Dict [str , Executable ],
139- used_variables : t .Optional [ t . Set [ str ] ],
186+ used_variable_referenced_in_metadata_expression : t .Dict [ str , bool ],
140187 variables : t .Optional [t .Dict [str , t .Any ]],
141188 strict_resolution : bool = True ,
142189 blueprint_variables : t .Optional [t .Dict [str , t .Any ]] = None ,
143190 dialect : DialectType = None ,
191+ macro_funcs_by_used_var : t .Optional [t .DefaultDict [str , t .Set [str ]]] = None ,
144192) -> t .Dict [str , Executable ]:
145- _ , python_used_variables = parse_dependencies (
193+ _ , python_used_variable_referenced_in_metadata_expression = parse_dependencies (
146194 python_env ,
147195 None ,
148196 strict_resolution = strict_resolution ,
149197 variables = variables ,
150198 blueprint_variables = blueprint_variables ,
151199 )
152- used_variables = (used_variables or set ()) | python_used_variables
200+ for var_name , is_metadata in python_used_variable_referenced_in_metadata_expression .items ():
201+ used_variable_referenced_in_metadata_expression [var_name ] = (
202+ used_variable_referenced_in_metadata_expression .get (var_name , True ) and is_metadata
203+ )
204+
205+ metadata_used_variables = set ()
206+ for used_var , macro_names in (macro_funcs_by_used_var or {}).items ():
207+ if used_variable_referenced_in_metadata_expression .get (used_var ) or all (
208+ name in python_env and python_env [name ].is_metadata for name in macro_names
209+ ):
210+ metadata_used_variables .add (used_var )
211+
212+ used_variables = set (used_variable_referenced_in_metadata_expression )
213+ non_metadata_used_variables = used_variables - metadata_used_variables
214+
215+ metadata_variables = {
216+ k : v for k , v in (variables or {}).items () if k in metadata_used_variables
217+ }
218+ variables = {k : v for k , v in (variables or {}).items () if k in non_metadata_used_variables }
153219
154- variables = {k : v for k , v in (variables or {}).items () if k in used_variables }
155220 if variables :
156221 python_env [c .SQLMESH_VARS ] = Executable .value (variables , sort_root_dict = True )
222+ if metadata_variables :
223+ python_env [c .SQLMESH_VARS_METADATA ] = Executable .value (
224+ metadata_variables , sort_root_dict = True , is_metadata = True
225+ )
157226
158227 if blueprint_variables :
228+ metadata_blueprint_variables = {
229+ k : SqlValue (sql = v .sql (dialect = dialect )) if isinstance (v , exp .Expression ) else v
230+ for k , v in blueprint_variables .items ()
231+ if k in metadata_used_variables
232+ }
159233 blueprint_variables = {
160234 k : SqlValue (sql = v .sql (dialect = dialect )) if isinstance (v , exp .Expression ) else v
161235 for k , v in blueprint_variables .items ()
236+ if k in non_metadata_used_variables
162237 }
163- python_env [c .SQLMESH_BLUEPRINT_VARS ] = Executable .value (
164- blueprint_variables , sort_root_dict = True
165- )
238+ if blueprint_variables :
239+ python_env [c .SQLMESH_BLUEPRINT_VARS ] = Executable .value (
240+ blueprint_variables , sort_root_dict = True
241+ )
242+ if metadata_blueprint_variables :
243+ python_env [c .SQLMESH_BLUEPRINT_VARS_METADATA ] = Executable .value (
244+ blueprint_variables , sort_root_dict = True , is_metadata = True
245+ )
166246
167247 return python_env
168248
@@ -173,7 +253,7 @@ def parse_dependencies(
173253 strict_resolution : bool = True ,
174254 variables : t .Optional [t .Dict [str , t .Any ]] = None ,
175255 blueprint_variables : t .Optional [t .Dict [str , t .Any ]] = None ,
176- ) -> t .Tuple [t .Set [str ], t .Set [str ]]:
256+ ) -> t .Tuple [t .Set [str ], t .Dict [str , bool ]]:
177257 """
178258 Parses the source of a model function and finds upstream table dependencies
179259 and referenced variables based on calls to context / evaluator.
@@ -187,7 +267,8 @@ def parse_dependencies(
187267 blueprint_variables: The blueprint variables available to the python environment.
188268
189269 Returns:
190- A tuple containing the set of upstream table dependencies and the set of referenced variables.
270+ A tuple containing the set of upstream table dependencies and a mapping of
271+ the referenced variables associated with their metadata status.
191272 """
192273
193274 class VariableResolutionContext :
@@ -205,12 +286,16 @@ def blueprint_var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optiona
205286 local_env = dict .fromkeys (("context" , "evaluator" ), VariableResolutionContext )
206287
207288 depends_on = set ()
208- used_variables = set ()
289+ used_variable_referenced_in_metadata_expression : t . Dict [ str , bool ] = {}
209290
210291 for executable in python_env .values ():
211292 if not executable .is_definition :
212293 continue
294+
295+ is_metadata = executable .is_metadata
213296 for node in ast .walk (ast .parse (executable .payload )):
297+ next_variables = set ()
298+
214299 if isinstance (node , ast .Call ):
215300 func = node .func
216301 if not isinstance (func , ast .Attribute ) or not isinstance (func .value , ast .Name ):
@@ -241,26 +326,35 @@ def get_first_arg(keyword_arg_name: str) -> t.Any:
241326
242327 if func .value .id == "context" and func .attr in ("table" , "resolve_table" ):
243328 depends_on .add (get_first_arg ("model_name" ))
244- elif func .value .id in ("context" , "evaluator" ) and func .attr == c .VAR :
245- used_variables .add (get_first_arg ("var_name" ).lower ())
329+ elif func .value .id in ("context" , "evaluator" ) and func .attr in (
330+ c .VAR ,
331+ c .BLUEPRINT_VAR ,
332+ ):
333+ next_variables .add (get_first_arg ("var_name" ).lower ())
246334 elif (
247335 isinstance (node , ast .Attribute )
248336 and isinstance (node .value , ast .Name )
249337 and node .value .id in ("context" , "evaluator" )
250338 and node .attr == c .GATEWAY
251339 ):
252340 # Check whether the gateway attribute is referenced.
253- used_variables .add (c .GATEWAY )
341+ next_variables .add (c .GATEWAY )
254342 elif isinstance (node , ast .FunctionDef ) and node .name == entrypoint :
255- used_variables .update (
343+ next_variables .update (
256344 [
257345 arg .arg
258346 for arg in [* node .args .args , * node .args .kwonlyargs ]
259347 if arg .arg != "context"
260348 ]
261349 )
262350
263- return depends_on , used_variables
351+ for var_name in next_variables :
352+ used_variable_referenced_in_metadata_expression [var_name ] = (
353+ used_variable_referenced_in_metadata_expression .get (var_name , True )
354+ and bool (is_metadata )
355+ )
356+
357+ return depends_on , used_variable_referenced_in_metadata_expression
264358
265359
266360def validate_extra_and_required_fields (
0 commit comments