diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index fc9d898159..d2d830c521 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -200,6 +200,14 @@ def _extract(node: nodes.Node, parent: t.Optional[nodes.Node] = None) -> None: return extracted +def is_variable_node(n: nodes.Node) -> bool: + return ( + isinstance(n, nodes.Call) + and isinstance(n.node, nodes.Name) + and n.node.name in (c.VAR, c.BLUEPRINT_VAR) + ) + + def extract_macro_references_and_variables( *jinja_strs: str, dbt_target_name: t.Optional[str] = None ) -> t.Tuple[t.Set[MacroReference], t.Set[str]]: @@ -230,7 +238,15 @@ def extract_macro_references_and_variables( for call_name, node in extract_call_names(jinja_str): if call_name[0] in (c.VAR, c.BLUEPRINT_VAR): - assert isinstance(node, nodes.Call) + if not is_variable_node(node): + # Find the variable node which could be nested + for n in node.find_all(nodes.Call): + if is_variable_node(n): + node = n + break + else: + raise ValueError(f"Could not find variable name in {jinja_str}") + node = t.cast(nodes.Call, node) args = [jinja_call_arg_name(arg) for arg in node.args] if args and args[0]: variable_name = args[0].lower() diff --git a/tests/dbt/converter/test_jinja.py b/tests/dbt/converter/test_jinja.py index 5d9e8f3d73..5d3e4508d3 100644 --- a/tests/dbt/converter/test_jinja.py +++ b/tests/dbt/converter/test_jinja.py @@ -1,5 +1,9 @@ import pytest -from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor +from sqlmesh.utils.jinja import ( + JinjaMacroRegistry, + MacroExtractor, + extract_macro_references_and_variables, +) from sqlmesh.dbt.converter.jinja import JinjaGenerator, convert_jinja_query, convert_jinja_macro import sqlmesh.dbt.converter.jinja_transforms as jt from pathlib import Path @@ -437,3 +441,10 @@ def test_convert_jinja_macro(input: str, expected: str, sushi_dbt_context: Conte result = convert_jinja_macro(sushi_dbt_context, input.strip()) assert " ".join(result.split()) == " ".join(expected.strip().split()) + + +def test_extract_macro_references_and_variables() -> None: + input = """JINJA_QUERY('{%- set something = "'"~var("variable").split("|") -%}""" + _, variables = extract_macro_references_and_variables(input) + assert len(variables) == 1 + assert variables == {"variable"}