Skip to content

Commit 0bea61e

Browse files
authored
fix: find variable node if nested (#5238)
1 parent 95d20fb commit 0bea61e

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

sqlmesh/utils/jinja.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,14 @@ def _extract(node: nodes.Node, parent: t.Optional[nodes.Node] = None) -> None:
200200
return extracted
201201

202202

203+
def is_variable_node(n: nodes.Node) -> bool:
204+
return (
205+
isinstance(n, nodes.Call)
206+
and isinstance(n.node, nodes.Name)
207+
and n.node.name in (c.VAR, c.BLUEPRINT_VAR)
208+
)
209+
210+
203211
def extract_macro_references_and_variables(
204212
*jinja_strs: str, dbt_target_name: t.Optional[str] = None
205213
) -> t.Tuple[t.Set[MacroReference], t.Set[str]]:
@@ -230,7 +238,15 @@ def extract_macro_references_and_variables(
230238

231239
for call_name, node in extract_call_names(jinja_str):
232240
if call_name[0] in (c.VAR, c.BLUEPRINT_VAR):
233-
assert isinstance(node, nodes.Call)
241+
if not is_variable_node(node):
242+
# Find the variable node which could be nested
243+
for n in node.find_all(nodes.Call):
244+
if is_variable_node(n):
245+
node = n
246+
break
247+
else:
248+
raise ValueError(f"Could not find variable name in {jinja_str}")
249+
node = t.cast(nodes.Call, node)
234250
args = [jinja_call_arg_name(arg) for arg in node.args]
235251
if args and args[0]:
236252
variable_name = args[0].lower()

tests/dbt/converter/test_jinja.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import pytest
2-
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
2+
from sqlmesh.utils.jinja import (
3+
JinjaMacroRegistry,
4+
MacroExtractor,
5+
extract_macro_references_and_variables,
6+
)
37
from sqlmesh.dbt.converter.jinja import JinjaGenerator, convert_jinja_query, convert_jinja_macro
48
import sqlmesh.dbt.converter.jinja_transforms as jt
59
from pathlib import Path
@@ -437,3 +441,10 @@ def test_convert_jinja_macro(input: str, expected: str, sushi_dbt_context: Conte
437441
result = convert_jinja_macro(sushi_dbt_context, input.strip())
438442

439443
assert " ".join(result.split()) == " ".join(expected.strip().split())
444+
445+
446+
def test_extract_macro_references_and_variables() -> None:
447+
input = """JINJA_QUERY('{%- set something = "'"~var("variable").split("|") -%}"""
448+
_, variables = extract_macro_references_and_variables(input)
449+
assert len(variables) == 1
450+
assert variables == {"variable"}

0 commit comments

Comments
 (0)