Skip to content

Commit 442450b

Browse files
committed
Fix macro func variable extraction & add tests
1 parent 53abda3 commit 442450b

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

sqlmesh/core/model/common.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,22 @@ def make_python_env(
160160
def _extract_macro_func_variable_references(macro_func: exp.Expression) -> t.Set[str]:
161161
references = set()
162162

163-
for n in macro_func.walk():
164-
if n is macro_func:
165-
continue
163+
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
164+
# they will be handled in a separate call of _extract_macro_func_variable_references.
165+
def _prune_nested_macro_func(expression: exp.Expression) -> bool:
166+
return (
167+
type(n) is d.MacroFunc
168+
and n is not macro_func
169+
and n.this.name.lower() not in (c.VAR, c.BLUEPRINT_VAR)
170+
)
166171

167-
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
168-
# they will be handled in a separate call of _extract_macro_func_variable_references.
169-
if isinstance(n, d.MacroFunc):
172+
for n in macro_func.walk(prune=_prune_nested_macro_func):
173+
if type(n) is d.MacroFunc:
170174
this = n.this
171-
if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and this.expressions:
172-
references.add(this.expressions[0].this.lower())
175+
args = this.expressions
176+
177+
if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and args and args[0].is_string:
178+
references.add(args[0].this.lower())
173179
elif isinstance(n, d.MacroVar):
174180
references.add(n.name.lower())
175181
elif isinstance(n, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) and "@" in n.name:

tests/core/test_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11117,3 +11117,22 @@ def entrypoint(context, **kwargs):
1111711117

1111811118
assert model_daily is not None
1111911119
assert model_daily.cron == "@daily"
11120+
11121+
11122+
@pytest.mark.parametrize(
11123+
"macro_func, variables",
11124+
[
11125+
("@M(@v1)", {"v1"}),
11126+
("@M(@{v1})", {"v1"}),
11127+
("@M(@SQL('@v1'))", {"v1"}),
11128+
("@M(@'@{v1}_foo')", {"v1"}),
11129+
("@M1(@VAR('v1'))", {"v1"}),
11130+
("@M1(@v1, @M2(@v2), @BLUEPRINT_VAR('v3'))", {"v1", "v3"}),
11131+
("@M1(@BLUEPRINT_VAR(@VAR('v1')))", {"v1"}),
11132+
],
11133+
)
11134+
def test_extract_macro_func_variable_references(macro_func: str, variables: t.Set[str]) -> None:
11135+
from sqlmesh.core.model.common import _extract_macro_func_variable_references
11136+
11137+
macro_func_ast = parse_one(macro_func)
11138+
assert _extract_macro_func_variable_references(macro_func_ast) == variables

0 commit comments

Comments
 (0)