From cd11ba5b81adb259d04209904ee2926c7702832b Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Wed, 27 Aug 2025 08:45:51 -0700 Subject: [PATCH 1/2] fix: find variable node if nested --- sqlmesh/utils/jinja.py | 17 ++++++++++++++++- tests/dbt/converter/test_jinja.py | 13 ++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index fc9d898159..dfc4ebe3c6 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -228,9 +228,24 @@ def extract_macro_references_and_variables( ) ) + def is_variable_node(n: nodes.Node) -> bool: + return ( + isinstance(n, nodes.Call) + and isinstance(n.node, nodes.Name) + and n.node.name in (c.VAR, c.BLUEPRINT_VAR) + ) + for call_name, node in extract_call_names(jinja_str): if call_name[0] in (c.VAR, c.BLUEPRINT_VAR): - assert isinstance(node, nodes.Call) + if not is_variable_node(node): + # Find the variable node which could be nested + for n in node.find_all(nodes.Call): + if is_variable_node(n): + node = n + break + else: + raise ValueError(f"Could not find variable name in {jinja_str}") + node = t.cast(nodes.Call, node) args = [jinja_call_arg_name(arg) for arg in node.args] if args and args[0]: variable_name = args[0].lower() diff --git a/tests/dbt/converter/test_jinja.py b/tests/dbt/converter/test_jinja.py index 5d9e8f3d73..5d3e4508d3 100644 --- a/tests/dbt/converter/test_jinja.py +++ b/tests/dbt/converter/test_jinja.py @@ -1,5 +1,9 @@ import pytest -from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor +from sqlmesh.utils.jinja import ( + JinjaMacroRegistry, + MacroExtractor, + extract_macro_references_and_variables, +) from sqlmesh.dbt.converter.jinja import JinjaGenerator, convert_jinja_query, convert_jinja_macro import sqlmesh.dbt.converter.jinja_transforms as jt from pathlib import Path @@ -437,3 +441,10 @@ def test_convert_jinja_macro(input: str, expected: str, sushi_dbt_context: Conte result = convert_jinja_macro(sushi_dbt_context, input.strip()) assert " ".join(result.split()) == " ".join(expected.strip().split()) + + +def test_extract_macro_references_and_variables() -> None: + input = """JINJA_QUERY('{%- set something = "'"~var("variable").split("|") -%}""" + _, variables = extract_macro_references_and_variables(input) + assert len(variables) == 1 + assert variables == {"variable"} From 4d30445970c9a0daeabf711e159615a3658e88c1 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Wed, 27 Aug 2025 13:27:55 -0700 Subject: [PATCH 2/2] feedback --- sqlmesh/utils/jinja.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index dfc4ebe3c6..d2d830c521 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -200,6 +200,14 @@ def _extract(node: nodes.Node, parent: t.Optional[nodes.Node] = None) -> None: return extracted +def is_variable_node(n: nodes.Node) -> bool: + return ( + isinstance(n, nodes.Call) + and isinstance(n.node, nodes.Name) + and n.node.name in (c.VAR, c.BLUEPRINT_VAR) + ) + + def extract_macro_references_and_variables( *jinja_strs: str, dbt_target_name: t.Optional[str] = None ) -> t.Tuple[t.Set[MacroReference], t.Set[str]]: @@ -228,13 +236,6 @@ def extract_macro_references_and_variables( ) ) - def is_variable_node(n: nodes.Node) -> bool: - return ( - isinstance(n, nodes.Call) - and isinstance(n.node, nodes.Name) - and n.node.name in (c.VAR, c.BLUEPRINT_VAR) - ) - for call_name, node in extract_call_names(jinja_str): if call_name[0] in (c.VAR, c.BLUEPRINT_VAR): if not is_variable_node(node):