Skip to content

Commit 6474484

Browse files
committed
Fix macro func variable extraction & add tests
1 parent 89d459f commit 6474484

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

sqlmesh/core/model/common.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,22 @@ def make_python_env(
160160
def _extract_macro_func_variable_references(macro_func: exp.Expression) -> t.Set[str]:
161161
references = set()
162162

163-
for n in macro_func.walk():
164-
if n is macro_func:
165-
continue
163+
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
164+
# they will be handled in a separate call of _extract_macro_func_variable_references.
165+
def _prune_nested_macro_func(expression: exp.Expression) -> bool:
166+
return (
167+
type(n) is d.MacroFunc
168+
and n is not macro_func
169+
and n.this.name.lower() not in (c.VAR, c.BLUEPRINT_VAR)
170+
)
166171

167-
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
168-
# they will be handled in a separate call of _extract_macro_func_variable_references.
169-
if isinstance(n, d.MacroFunc):
172+
for n in macro_func.walk(prune=_prune_nested_macro_func):
173+
if type(n) is d.MacroFunc:
170174
this = n.this
171-
if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and this.expressions:
172-
references.add(this.expressions[0].this.lower())
175+
args = this.expressions
176+
177+
if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and args and args[0].is_string:
178+
references.add(args[0].this.lower())
173179
elif isinstance(n, d.MacroVar):
174180
references.add(n.name.lower())
175181
elif isinstance(n, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) and "@" in n.name:

tests/core/test_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11149,3 +11149,22 @@ def entrypoint(context, **kwargs):
1114911149

1115011150
assert model_daily is not None
1115111151
assert model_daily.cron == "@daily"
11152+
11153+
11154+
@pytest.mark.parametrize(
11155+
"macro_func, variables",
11156+
[
11157+
("@M(@v1)", {"v1"}),
11158+
("@M(@{v1})", {"v1"}),
11159+
("@M(@SQL('@v1'))", {"v1"}),
11160+
("@M(@'@{v1}_foo')", {"v1"}),
11161+
("@M1(@VAR('v1'))", {"v1"}),
11162+
("@M1(@v1, @M2(@v2), @BLUEPRINT_VAR('v3'))", {"v1", "v3"}),
11163+
("@M1(@BLUEPRINT_VAR(@VAR('v1')))", {"v1"}),
11164+
],
11165+
)
11166+
def test_extract_macro_func_variable_references(macro_func: str, variables: t.Set[str]) -> None:
11167+
from sqlmesh.core.model.common import _extract_macro_func_variable_references
11168+
11169+
macro_func_ast = parse_one(macro_func)
11170+
assert _extract_macro_func_variable_references(macro_func_ast) == variables

0 commit comments

Comments
 (0)