diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 73e0252332..f1e1dbed03 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -313,33 +313,21 @@ def sqlmesh_model_kwargs( """Get common sqlmesh model parameters""" self.remove_tests_with_invalid_refs(context) self.check_for_circular_test_refs(context) + + dependencies = self.dependencies.copy() + if dependencies.has_dynamic_var_names: + # Include ALL variables as dependencies since we couldn't determine + # precisely which variables are referenced in the model + dependencies.variables |= set(context.variables) + model_dialect = self.dialect(context) model_context = context.context_for_dependencies( - self.dependencies.union(self.tests_ref_source_dependencies) + dependencies.union(self.tests_ref_source_dependencies) ) jinja_macros = model_context.jinja_macros.trim( - self.dependencies.macros, package=self.package_name - ) - - model_node: AttributeDict[str, t.Any] = AttributeDict( - { - k: v - for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items() - if k in self.dependencies.model_attrs - } - if context._manifest and self.node_name in context._manifest._manifest.nodes - else {} - ) - - jinja_macros.add_globals( - { - "this": self.relation_info, - "model": model_node, - "schema": self.table_schema, - "config": self.config_attribute_dict, - **model_context.jinja_globals, # type: ignore - } + dependencies.macros, package=self.package_name ) + jinja_macros.add_globals(self._model_jinja_context(model_context, dependencies)) return { "audits": [(test.name, {}) for test in self.tests], "columns": column_types_to_sqlmesh( @@ -369,3 +357,23 @@ def to_sqlmesh( virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default, ) -> Model: """Convert DBT model into sqlmesh Model""" + + def _model_jinja_context( + self, context: DbtContext, dependencies: Dependencies + ) -> t.Dict[str, t.Any]: + model_node: AttributeDict[str, t.Any] = AttributeDict( + { + k: v + for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items() + if k in dependencies.model_attrs + } + if context._manifest and self.node_name in context._manifest._manifest.nodes + else {} + ) + return { + "this": self.relation_info, + "model": model_node, + "schema": self.table_schema, + "config": self.config_attribute_dict, + **context.jinja_globals, + } diff --git a/sqlmesh/dbt/common.py b/sqlmesh/dbt/common.py index d9db5a472c..ec928576ed 100644 --- a/sqlmesh/dbt/common.py +++ b/sqlmesh/dbt/common.py @@ -184,6 +184,8 @@ class Dependencies(PydanticModel): variables: t.Set[str] = set() model_attrs: t.Set[str] = set() + has_dynamic_var_names: bool = False + def union(self, other: Dependencies) -> Dependencies: return Dependencies( macros=list(set(self.macros) | set(other.macros)), @@ -191,6 +193,7 @@ def union(self, other: Dependencies) -> Dependencies: refs=self.refs | other.refs, variables=self.variables | other.variables, model_attrs=self.model_attrs | other.model_attrs, + has_dynamic_var_names=self.has_dynamic_var_names or other.has_dynamic_var_names, ) @field_validator("macros", mode="after") diff --git a/sqlmesh/dbt/context.py b/sqlmesh/dbt/context.py index 307dea6477..d76cccbce7 100644 --- a/sqlmesh/dbt/context.py +++ b/sqlmesh/dbt/context.py @@ -8,6 +8,7 @@ from sqlmesh.core.config import Config as SQLMeshConfig from sqlmesh.dbt.builtin import _relation_info_to_relation +from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.manifest import ManifestHelper from sqlmesh.dbt.target import TargetConfig from sqlmesh.utils import AttributeDict @@ -22,7 +23,6 @@ if t.TYPE_CHECKING: from jinja2 import Environment - from sqlmesh.dbt.basemodel import Dependencies from sqlmesh.dbt.model import ModelConfig from sqlmesh.dbt.relation import Policy from sqlmesh.dbt.seed import SeedConfig @@ -101,8 +101,6 @@ def add_variables(self, variables: t.Dict[str, t.Any]) -> None: self._jinja_environment = None def set_and_render_variables(self, variables: t.Dict[str, t.Any], package: str) -> None: - self.variables = variables - jinja_environment = self.jinja_macros.build_environment(**self.jinja_globals) def _render_var(value: t.Any) -> t.Any: diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index be0ff59aa4..d321246896 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -124,8 +124,6 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model: ) for project in self._load_projects(): - context = project.context.copy() - macros_max_mtime = self._macros_max_mtime yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder( project.context.project_root @@ -135,12 +133,13 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model: logger.debug("Converting models to sqlmesh") # Now that config is rendered, create the sqlmesh models for package in project.packages.values(): - context.set_and_render_variables(package.variables, package.name) + package_context = project.context.copy() + package_context.set_and_render_variables(package.variables, package.name) package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds} for model in package_models.values(): sqlmesh_model = cache.get_or_load_models( - model.path, loader=lambda: [_to_sqlmesh(model, context)] + model.path, loader=lambda: [_to_sqlmesh(model, package_context)] )[0] models[sqlmesh_model.fqn] = sqlmesh_model @@ -155,15 +154,14 @@ def _load_audits( audits: UniqueKeyDict = UniqueKeyDict("audits") for project in self._load_projects(): - context = project.context - logger.debug("Converting audits to sqlmesh") for package in project.packages.values(): - context.set_and_render_variables(package.variables, package.name) + package_context = project.context.copy() + package_context.set_and_render_variables(package.variables, package.name) for test in package.tests.values(): logger.debug("Converting '%s' to sqlmesh format", test.name) try: - audits[test.name] = test.to_sqlmesh(context) + audits[test.name] = test.to_sqlmesh(package_context) except MissingModelError as e: logger.warning( "Skipping audit '%s' because model '%s' is not a valid ref", @@ -244,9 +242,9 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm project_names: t.Set[str] = set() dialect = self.config.dialect for project in self._load_projects(): - context = project.context for package_name, package in project.packages.items(): - context.set_and_render_variables(package.variables, package_name) + package_context = project.context.copy() + package_context.set_and_render_variables(package.variables, package_name) on_run_start: t.List[str] = [ on_run_hook.sql for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index) @@ -261,7 +259,7 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm for hook in [*package.on_run_start.values(), *package.on_run_end.values()]: dependencies = dependencies.union(hook.dependencies) - statements_context = context.context_for_dependencies(dependencies) + statements_context = package_context.context_for_dependencies(dependencies) jinja_registry = make_jinja_registry( statements_context.jinja_macros, package_name, set(dependencies.macros) ) diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index 125b204270..7414325902 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -554,6 +554,9 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies: args = [jinja_call_arg_name(arg) for arg in node.args] if args and args[0]: dependencies.variables.add(args[0]) + else: + # We couldn't determine the var name statically + dependencies.has_dynamic_var_names = True dependencies.macros.append(MacroReference(name="var")) elif len(call_name) == 1: macro_name = call_name[0] diff --git a/sqlmesh/dbt/project.py b/sqlmesh/dbt/project.py index d37c9cc6c4..581660943a 100644 --- a/sqlmesh/dbt/project.py +++ b/sqlmesh/dbt/project.py @@ -55,9 +55,6 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N raise ConfigError(f"Could not find {PROJECT_FILENAME} in {context.project_root}") project_yaml = load_yaml(project_file_path) - variable_overrides = variables - variables = {**project_yaml.get("vars", {}), **(variables or {})} - project_name = context.render(project_yaml.get("name", "")) context.project_name = project_name if not context.project_name: @@ -69,6 +66,7 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N profile = Profile.load(context, context.target_name) context.target = profile.target + variable_overrides = variables or {} context.manifest = ManifestHelper( project_file_path.parent, profile.path.parent, @@ -101,13 +99,17 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N package = package_loader.load(path.parent) packages[package.name] = package + all_project_variables = {**project_yaml.get("vars", {}), **(variable_overrides or {})} for name, package in packages.items(): - package_vars = variables.get(name) + package_vars = all_project_variables.get(name) if isinstance(package_vars, dict): package.variables.update(package_vars) - package.variables.update(variables) + if name == context.project_name: + package.variables.update(all_project_variables) + else: + package.variables.update(variable_overrides) return Project(context, profile, packages) diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index bc6f878801..99426ebb97 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -362,6 +362,7 @@ def test_variables(assert_exp_eq, sushi_test_project): "nested_vars": { "some_nested_var": 2, }, + "dynamic_test_var": 3, "list_var": [ {"name": "item1", "value": 1}, {"name": "item2", "value": 2}, @@ -375,25 +376,10 @@ def test_variables(assert_exp_eq, sushi_test_project): expected_customer_variables = { "some_var": ["foo", "bar"], "some_other_var": 5, - "yet_another_var": 1, + "yet_another_var": 5, "customers:bla": False, "customers:customer_id": "customer_id", "start": "Jan 1 2022", - "top_waiters:limit": 10, - "top_waiters:revenue": "revenue", - "customers:boo": ["a", "b"], - "nested_vars": { - "some_nested_var": 2, - }, - "list_var": [ - {"name": "item1", "value": 1}, - {"name": "item2", "value": 2}, - ], - "customers": { - "customers:bla": False, - "customers:customer_id": "customer_id", - "some_var": ["foo", "bar"], - }, } assert sushi_test_project.packages["sushi"].variables == expected_sushi_variables @@ -406,7 +392,9 @@ def test_nested_variables(sushi_test_project): sql="SELECT {{ var('nested_vars')['some_nested_var'] }}", dependencies=Dependencies(variables=["nested_vars"]), ) - sqlmesh_model = model_config.to_sqlmesh(sushi_test_project.context) + context = sushi_test_project.context.copy() + context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi") + sqlmesh_model = model_config.to_sqlmesh(context) assert sqlmesh_model.jinja_macros.global_objs["vars"]["nested_vars"] == {"some_nested_var": 2} diff --git a/tests/dbt/test_manifest.py b/tests/dbt/test_manifest.py index bf64e4b8b3..2bed6acb55 100644 --- a/tests/dbt/test_manifest.py +++ b/tests/dbt/test_manifest.py @@ -79,6 +79,7 @@ def test_manifest_helper(caplog): waiter_revenue_by_day_config = models["waiter_revenue_by_day_v2"] assert waiter_revenue_by_day_config.dependencies == Dependencies( macros={ + MacroReference(name="dynamic_var_name_dependency"), MacroReference(name="log_value"), MacroReference(name="test_dependencies"), MacroReference(package="customers", name="duckdb__current_engine"), @@ -87,6 +88,7 @@ def test_manifest_helper(caplog): }, sources={"streaming.items", "streaming.orders", "streaming.order_items"}, variables={"yet_another_var", "nested_vars"}, + has_dynamic_var_names=True, ) assert waiter_revenue_by_day_config.materialized == "incremental" assert waiter_revenue_by_day_config.incremental_strategy == "delete+insert" diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index fdb8345398..1bcc3081f7 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -37,6 +37,7 @@ ) from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json from sqlmesh.dbt.builtin import _relation_info_to_relation +from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.column import ( ColumnConfig, column_descriptions_to_sqlmesh, @@ -50,6 +51,7 @@ from sqlmesh.dbt.target import BigQueryConfig, DuckDbConfig, SnowflakeConfig, ClickhouseConfig from sqlmesh.dbt.test import TestConfig from sqlmesh.utils.errors import ConfigError, MacroEvalError, SQLMeshError +from sqlmesh.utils.jinja import MacroReference pytestmark = [pytest.mark.dbt, pytest.mark.slow] @@ -1530,6 +1532,9 @@ def test_dbt_package_macros(sushi_test_project: Project): @pytest.mark.xdist_group("dbt_manifest") def test_dbt_vars(sushi_test_project: Project): context = sushi_test_project.context + context.set_and_render_variables( + sushi_test_project.packages["customers"].variables, "customers" + ) assert context.render("{{ var('some_other_var') }}") == "5" assert context.render("{{ var('some_other_var', 0) }}") == "5" @@ -1854,3 +1859,65 @@ def test_on_run_start_end(): "CREATE OR REPLACE TABLE schema_table_sushi__dev_nested_package AS SELECT 'sushi__dev' AS schema", ] ) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_dynamic_var_names(sushi_test_project: Project, sushi_test_dbt_context: Context): + context = sushi_test_project.context + context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi") + context.target = BigQueryConfig(name="production", database="main", schema="sushi") + model_config = ModelConfig( + name="model", + alias="model", + schema="test", + package_name="package", + materialized="table", + unique_key="ds", + partition_by={"field": "ds", "granularity": "month"}, + sql=""" + {% set var_name = "yet_" + "another_" + "var" %} + {% set results = run_query('select 1 as one') %} + {% if results %} + SELECT {{ results.columns[0].values()[0] }} AS one {{ var(var_name) }} AS var FROM {{ this.identifier }} + {% else %} + SELECT NULL AS one {{ var(var_name) }} AS var FROM {{ this.identifier }} + {% endif %} + """, + dependencies=Dependencies(has_dynamic_var_names=True), + ) + converted_model = model_config.to_sqlmesh(context) + assert "yet_another_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore + + # Test the existing model in the sushi project + assert ( + "dynamic_test_var" # type: ignore + in sushi_test_dbt_context.get_model( + "sushi.waiter_revenue_by_day_v2" + ).jinja_macros.global_objs["vars"] + ) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_dynamic_var_names_in_macro(sushi_test_project: Project): + context = sushi_test_project.context + context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi") + context.target = BigQueryConfig(name="production", database="main", schema="sushi") + model_config = ModelConfig( + name="model", + alias="model", + schema="test", + package_name="package", + materialized="table", + unique_key="ds", + partition_by={"field": "ds", "granularity": "month"}, + sql=""" + {% set var_name = "dynamic_" + "test_" + "var" %} + SELECT {{ sushi.dynamic_var_name_dependency(var_name) }} AS var + """, + dependencies=Dependencies( + macros=[MacroReference(package="sushi", name="dynamic_var_name_dependency")], + has_dynamic_var_names=True, + ), + ) + converted_model = model_config.to_sqlmesh(context) + assert "dynamic_test_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore diff --git a/tests/fixtures/dbt/sushi_test/dbt_project.yml b/tests/fixtures/dbt/sushi_test/dbt_project.yml index 073d85b4d4..c86057c928 100644 --- a/tests/fixtures/dbt/sushi_test/dbt_project.yml +++ b/tests/fixtures/dbt/sushi_test/dbt_project.yml @@ -47,6 +47,7 @@ vars: customers:boo: ["a", "b"] yet_another_var: 1 + dynamic_test_var: 3 customers: some_var: ["foo", "bar"] diff --git a/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql b/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql index 931ce88a84..88518df380 100644 --- a/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql +++ b/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql @@ -6,3 +6,9 @@ {{ log(var("yet_another_var", 2)) }} {{ log(var("nested_vars")['some_nested_var']) }} {% endmacro %} + + +{% macro dynamic_var_name_dependency(var_name) %} + {% set results = run_query('select 1 as one') %} + {{ return(var(var_name)) }} +{% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql b/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql index 335e7ab799..317cc87e68 100644 --- a/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql +++ b/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql @@ -13,7 +13,8 @@ {{ test_dependencies() }} -{% set results = run_query('select 1 as constant') %} +{% set var_name = "dynamic_" + "test_" + "var" %} +{% set results = run_query('select ' ~ dynamic_var_name_dependency(var_name) ~ ' as constant') %} SELECT o.waiter_id::INT AS waiter_id, /* Waiter id */