Skip to content

Commit 4fa0d07

Browse files
committed
Fix!: mark vars referenced in metadata macros as metadata
1 parent 3555e73 commit 4fa0d07

File tree

7 files changed

+211
-45
lines changed

7 files changed

+211
-45
lines changed

sqlmesh/core/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@
8080
DEFAULT_SCHEMA = "default"
8181

8282
SQLMESH_VARS = "__sqlmesh__vars__"
83+
SQLMESH_VARS_METADATA = "__sqlmesh__vars__metadata__"
8384
SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__"
85+
SQLMESH_BLUEPRINT_VARS_METADATA = "__sqlmesh__blueprint__vars__metadata__"
8486

8587
VAR = "var"
8688
BLUEPRINT_VAR = "blueprint_var"

sqlmesh/core/macros.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,12 @@ def __init__(
210210
self.macros[normalize_macro_name(k)] = self.env[k]
211211
elif v.is_value:
212212
value = self.env[k]
213-
if k in (c.SQLMESH_VARS, c.SQLMESH_BLUEPRINT_VARS):
213+
if k in (
214+
c.SQLMESH_VARS,
215+
c.SQLMESH_VARS_METADATA,
216+
c.SQLMESH_BLUEPRINT_VARS,
217+
c.SQLMESH_BLUEPRINT_VARS_METADATA,
218+
):
214219
value = {
215220
var_name: (
216221
self.parse_one(var_value.sql)
@@ -557,17 +562,25 @@ def views(self) -> t.List[str]:
557562

558563
def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
559564
"""Returns the value of the specified variable, or the default value if it doesn't exist."""
560-
return (self.locals.get(c.SQLMESH_VARS) or {}).get(var_name.lower(), default)
565+
return (
566+
self.locals.get(c.SQLMESH_VARS) or self.locals.get(c.SQLMESH_VARS_METADATA) or {}
567+
).get(var_name.lower(), default)
561568

562569
def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
563570
"""Returns the value of the specified blueprint variable, or the default value if it doesn't exist."""
564-
return (self.locals.get(c.SQLMESH_BLUEPRINT_VARS) or {}).get(var_name.lower(), default)
571+
return (
572+
self.locals.get(c.SQLMESH_BLUEPRINT_VARS)
573+
or self.locals.get(c.SQLMESH_BLUEPRINT_VARS_METADATA)
574+
or {}
575+
).get(var_name.lower(), default)
565576

566577
@property
567578
def variables(self) -> t.Dict[str, t.Any]:
568579
return {
569580
**self.locals.get(c.SQLMESH_VARS, {}),
581+
**self.locals.get(c.SQLMESH_VARS_METADATA, {}),
570582
**self.locals.get(c.SQLMESH_BLUEPRINT_VARS, {}),
583+
**self.locals.get(c.SQLMESH_BLUEPRINT_VARS_METADATA, {}),
571584
}
572585

573586
def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any:

sqlmesh/core/model/common.py

Lines changed: 124 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66

77
from astor import to_source
8+
from collections import defaultdict
89
from difflib import get_close_matches
910
from sqlglot import exp
1011
from sqlglot.helper import ensure_list
@@ -28,7 +29,7 @@
2829
from sqlmesh.utils import registry_decorator
2930
from sqlmesh.utils.jinja import MacroReference
3031

31-
MacroCallable = registry_decorator
32+
MacroCallable = t.Union[Executable, registry_decorator]
3233

3334

3435
def make_python_env(
@@ -48,13 +49,17 @@ def make_python_env(
4849
dialect: DialectType = None,
4950
) -> t.Dict[str, Executable]:
5051
python_env = {} if python_env is None else python_env
51-
variables = variables or {}
5252
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
53-
used_macros: t.Dict[
54-
str,
55-
t.Tuple[t.Union[Executable | MacroCallable], t.Optional[bool]],
56-
] = {}
57-
used_variables = (used_variables or set()).copy()
53+
54+
variables = variables or {}
55+
blueprint_variables = blueprint_variables or {}
56+
57+
used_macros: t.Dict[str, t.Tuple[MacroCallable, t.Optional[bool]]] = {}
58+
used_variable_referenced_in_metadata_expression = dict.fromkeys(used_variables or set(), False)
59+
60+
# For an expression like @foo(@v1, @bar(@v1, @v2), @v3), the following mapping would be:
61+
# v1 -> {"foo", "bar"}, v2 -> {"bar"}, v3 -> "foo"
62+
macro_funcs_by_used_var: t.DefaultDict[str, t.Set[str]] = defaultdict(set)
5863

5964
expressions = ensure_list(expressions)
6065
for expression_metadata in expressions:
@@ -77,16 +82,27 @@ def make_python_env(
7782
macros[name],
7883
used_macros.get(name, (None, is_metadata))[1] and is_metadata,
7984
)
80-
if name == c.VAR:
85+
if name in (c.VAR, c.BLUEPRINT_VAR):
8186
args = macro_func_or_var.this.expressions
8287
if len(args) < 1:
83-
raise_config_error("Macro VAR requires at least one argument", path)
88+
raise_config_error(
89+
f"Macro {name.upper()} requires at least one argument", path
90+
)
91+
8492
if not args[0].is_string:
8593
raise_config_error(
8694
f"The variable name must be a string literal, '{args[0].sql()}' was given instead",
8795
path,
8896
)
89-
used_variables.add(args[0].this.lower())
97+
98+
var_name = args[0].this.lower()
99+
used_variable_referenced_in_metadata_expression[var_name] = (
100+
used_variable_referenced_in_metadata_expression.get(var_name, True)
101+
and bool(is_metadata)
102+
)
103+
else:
104+
for var_ref in _extract_macro_func_variable_references(macro_func_or_var):
105+
macro_funcs_by_used_var[var_ref].add(name)
90106
elif macro_func_or_var.__class__ is d.MacroVar:
91107
name = macro_func_or_var.name.lower()
92108
if name in macros:
@@ -95,17 +111,23 @@ def make_python_env(
95111
macros[name],
96112
used_macros.get(name, (None, is_metadata))[1] and is_metadata,
97113
)
98-
elif name in variables:
99-
used_variables.add(name)
114+
elif name in variables or name in blueprint_variables:
115+
used_variable_referenced_in_metadata_expression[name] = (
116+
used_variable_referenced_in_metadata_expression.get(name, True)
117+
and bool(is_metadata)
118+
)
100119
elif (
101120
isinstance(macro_func_or_var, (exp.Identifier, d.MacroStrReplace, d.MacroSQL))
102121
) and "@" in macro_func_or_var.name:
103122
for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(
104123
macro_func_or_var.name
105124
):
106125
var_name = braced_identifier or identifier
107-
if var_name in variables:
108-
used_variables.add(var_name)
126+
if var_name in variables or var_name in blueprint_variables:
127+
used_variable_referenced_in_metadata_expression[var_name] = (
128+
used_variable_referenced_in_metadata_expression.get(var_name, True)
129+
and bool(is_metadata)
130+
)
109131

110132
for macro_ref in jinja_macro_references or set():
111133
if macro_ref.package is None and macro_ref.name in macros:
@@ -126,43 +148,101 @@ def make_python_env(
126148
python_env.update(serialize_env(env, path=module_path))
127149
return _add_variables_to_python_env(
128150
python_env,
129-
used_variables,
151+
used_variable_referenced_in_metadata_expression,
130152
variables,
131153
blueprint_variables=blueprint_variables,
132154
dialect=dialect,
133155
strict_resolution=strict_resolution,
156+
macro_funcs_by_used_var=macro_funcs_by_used_var,
134157
)
135158

136159

160+
def _extract_macro_func_variable_references(macro_func: exp.Expression) -> t.Set[str]:
161+
references = set()
162+
163+
for n in macro_func.walk():
164+
if n is macro_func:
165+
continue
166+
167+
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
168+
# they will be handled in a separate call of _extract_macro_func_variable_references.
169+
if isinstance(n, d.MacroFunc):
170+
this = n.this
171+
if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and this.expressions:
172+
references.add(this.expressions[0].this.lower())
173+
elif isinstance(n, d.MacroVar):
174+
references.add(n.name.lower())
175+
elif isinstance(n, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) and "@" in n.name:
176+
references.update(
177+
(braced_identifier or identifier).lower()
178+
for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(n.name)
179+
)
180+
181+
return references
182+
183+
137184
def _add_variables_to_python_env(
138185
python_env: t.Dict[str, Executable],
139-
used_variables: t.Optional[t.Set[str]],
186+
used_variable_referenced_in_metadata_expression: t.Dict[str, bool],
140187
variables: t.Optional[t.Dict[str, t.Any]],
141188
strict_resolution: bool = True,
142189
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
143190
dialect: DialectType = None,
191+
macro_funcs_by_used_var: t.Optional[t.DefaultDict[str, t.Set[str]]] = None,
144192
) -> t.Dict[str, Executable]:
145-
_, python_used_variables = parse_dependencies(
193+
_, python_used_variable_referenced_in_metadata_expression = parse_dependencies(
146194
python_env,
147195
None,
148196
strict_resolution=strict_resolution,
149197
variables=variables,
150198
blueprint_variables=blueprint_variables,
151199
)
152-
used_variables = (used_variables or set()) | python_used_variables
200+
for var_name, is_metadata in python_used_variable_referenced_in_metadata_expression.items():
201+
used_variable_referenced_in_metadata_expression[var_name] = (
202+
used_variable_referenced_in_metadata_expression.get(var_name, True) and is_metadata
203+
)
204+
205+
metadata_used_variables = set()
206+
for used_var, macro_names in (macro_funcs_by_used_var or {}).items():
207+
if used_variable_referenced_in_metadata_expression.get(used_var) or all(
208+
name in python_env and python_env[name].is_metadata for name in macro_names
209+
):
210+
metadata_used_variables.add(used_var)
211+
212+
used_variables = set(used_variable_referenced_in_metadata_expression)
213+
non_metadata_used_variables = used_variables - metadata_used_variables
214+
215+
metadata_variables = {
216+
k: v for k, v in (variables or {}).items() if k in metadata_used_variables
217+
}
218+
variables = {k: v for k, v in (variables or {}).items() if k in non_metadata_used_variables}
153219

154-
variables = {k: v for k, v in (variables or {}).items() if k in used_variables}
155220
if variables:
156221
python_env[c.SQLMESH_VARS] = Executable.value(variables, sort_root_dict=True)
222+
if metadata_variables:
223+
python_env[c.SQLMESH_VARS_METADATA] = Executable.value(
224+
metadata_variables, sort_root_dict=True, is_metadata=True
225+
)
157226

158227
if blueprint_variables:
228+
metadata_blueprint_variables = {
229+
k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
230+
for k, v in blueprint_variables.items()
231+
if k in metadata_used_variables
232+
}
159233
blueprint_variables = {
160234
k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
161235
for k, v in blueprint_variables.items()
236+
if k in non_metadata_used_variables
162237
}
163-
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(
164-
blueprint_variables, sort_root_dict=True
165-
)
238+
if blueprint_variables:
239+
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(
240+
blueprint_variables, sort_root_dict=True
241+
)
242+
if metadata_blueprint_variables:
243+
python_env[c.SQLMESH_BLUEPRINT_VARS_METADATA] = Executable.value(
244+
blueprint_variables, sort_root_dict=True, is_metadata=True
245+
)
166246

167247
return python_env
168248

@@ -173,7 +253,7 @@ def parse_dependencies(
173253
strict_resolution: bool = True,
174254
variables: t.Optional[t.Dict[str, t.Any]] = None,
175255
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
176-
) -> t.Tuple[t.Set[str], t.Set[str]]:
256+
) -> t.Tuple[t.Set[str], t.Dict[str, bool]]:
177257
"""
178258
Parses the source of a model function and finds upstream table dependencies
179259
and referenced variables based on calls to context / evaluator.
@@ -187,7 +267,8 @@ def parse_dependencies(
187267
blueprint_variables: The blueprint variables available to the python environment.
188268
189269
Returns:
190-
A tuple containing the set of upstream table dependencies and the set of referenced variables.
270+
A tuple containing the set of upstream table dependencies and a mapping of
271+
the referenced variables associated with their metadata status.
191272
"""
192273

193274
class VariableResolutionContext:
@@ -205,12 +286,16 @@ def blueprint_var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optiona
205286
local_env = dict.fromkeys(("context", "evaluator"), VariableResolutionContext)
206287

207288
depends_on = set()
208-
used_variables = set()
289+
used_variable_referenced_in_metadata_expression: t.Dict[str, bool] = {}
209290

210291
for executable in python_env.values():
211292
if not executable.is_definition:
212293
continue
294+
295+
is_metadata = executable.is_metadata
213296
for node in ast.walk(ast.parse(executable.payload)):
297+
next_variables = set()
298+
214299
if isinstance(node, ast.Call):
215300
func = node.func
216301
if not isinstance(func, ast.Attribute) or not isinstance(func.value, ast.Name):
@@ -241,26 +326,35 @@ def get_first_arg(keyword_arg_name: str) -> t.Any:
241326

242327
if func.value.id == "context" and func.attr in ("table", "resolve_table"):
243328
depends_on.add(get_first_arg("model_name"))
244-
elif func.value.id in ("context", "evaluator") and func.attr == c.VAR:
245-
used_variables.add(get_first_arg("var_name").lower())
329+
elif func.value.id in ("context", "evaluator") and func.attr in (
330+
c.VAR,
331+
c.BLUEPRINT_VAR,
332+
):
333+
next_variables.add(get_first_arg("var_name").lower())
246334
elif (
247335
isinstance(node, ast.Attribute)
248336
and isinstance(node.value, ast.Name)
249337
and node.value.id in ("context", "evaluator")
250338
and node.attr == c.GATEWAY
251339
):
252340
# Check whether the gateway attribute is referenced.
253-
used_variables.add(c.GATEWAY)
341+
next_variables.add(c.GATEWAY)
254342
elif isinstance(node, ast.FunctionDef) and node.name == entrypoint:
255-
used_variables.update(
343+
next_variables.update(
256344
[
257345
arg.arg
258346
for arg in [*node.args.args, *node.args.kwonlyargs]
259347
if arg.arg != "context"
260348
]
261349
)
262350

263-
return depends_on, used_variables
351+
for var_name in next_variables:
352+
used_variable_referenced_in_metadata_expression[var_name] = (
353+
used_variable_referenced_in_metadata_expression.get(var_name, True)
354+
and bool(is_metadata)
355+
)
356+
357+
return depends_on, used_variable_referenced_in_metadata_expression
264358

265359

266360
def validate_extra_and_required_fields(

sqlmesh/core/model/definition.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,12 +1782,17 @@ def render(
17821782
start, end = make_inclusive(start or c.EPOCH, end or c.EPOCH, self.dialect)
17831783
execution_time = to_datetime(execution_time or c.EPOCH)
17841784

1785-
variables = env.get(c.SQLMESH_VARS, {})
1786-
variables.update(kwargs.pop("variables", {}))
1787-
1785+
variables = {
1786+
**env.get(c.SQLMESH_VARS, {}),
1787+
**env.get(c.SQLMESH_VARS_METADATA, {}),
1788+
**kwargs.pop("variables", {}),
1789+
}
17881790
blueprint_variables = {
17891791
k: d.parse_one(v.sql, dialect=self.dialect) if isinstance(v, SqlValue) else v
1790-
for k, v in env.get(c.SQLMESH_BLUEPRINT_VARS, {}).items()
1792+
for k, v in {
1793+
**env.get(c.SQLMESH_BLUEPRINT_VARS, {}),
1794+
**env.get(c.SQLMESH_BLUEPRINT_VARS_METADATA, {}),
1795+
}.items()
17911796
}
17921797
try:
17931798
kwargs = {

sqlmesh/core/renderer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def _resolve_table(table: str | exp.Table) -> str:
234234

235235
if variables:
236236
macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
237+
macro_evaluator.locals.setdefault(c.SQLMESH_VARS_METADATA, {})
237238

238239
for definition in self._macro_definitions:
239240
try:

sqlmesh/utils/jinja.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def extract_macro_references_and_variables(
229229
)
230230

231231
for call_name, node in extract_call_names(jinja_str):
232-
if call_name[0] == c.VAR:
232+
if call_name[0] in (c.VAR, c.BLUEPRINT_VAR):
233233
assert isinstance(node, nodes.Call)
234234
args = [jinja_call_arg_name(arg) for arg in node.args]
235235
if args and args[0]:

0 commit comments

Comments
 (0)