From ed36d5e77800a78f7e1e15f282c86022dbbd5c12 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Wed, 20 Aug 2025 15:15:42 -0700 Subject: [PATCH 1/2] Fix: Improve tracking of var dependencies in dbt models --- sqlmesh/dbt/adapter.py | 50 +++++++++++++++ sqlmesh/dbt/basemodel.py | 56 +++++++++------- sqlmesh/dbt/builtin.py | 17 +++-- sqlmesh/dbt/common.py | 3 + sqlmesh/dbt/context.py | 36 ++++++++++- sqlmesh/dbt/manifest.py | 3 + sqlmesh/dbt/model.py | 27 +++++++- tests/dbt/test_config.py | 2 + tests/dbt/test_manifest.py | 2 + tests/dbt/test_transformation.py | 64 +++++++++++++++++++ tests/fixtures/dbt/sushi_test/dbt_project.yml | 1 + .../sushi_test/macros/test_dependencies.sql | 6 ++ .../models/waiter_revenue_by_day.sql | 3 +- 13 files changed, 238 insertions(+), 32 deletions(-) diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 2dc9890ca4..83f8a368ad 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -245,6 +245,56 @@ def _raise_parsetime_adapter_call_error(action: str) -> None: raise ParsetimeAdapterCallError(f"Can't {action} at parse time.") +class StubParsetimeAdapter(BaseAdapter): + """Same as ParsetimeAdapter, but returns stub / empty values instead of raising an error.""" + + def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]: + return None + + def load_relation(self, relation: BaseRelation) -> t.Optional[BaseRelation]: + return None + + def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseRelation]: + return [] + + def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.List[BaseRelation]: + return [] + + def get_columns_in_relation(self, relation: BaseRelation) -> t.List[Column]: + return [] + + def get_missing_columns( + self, from_relation: BaseRelation, to_relation: BaseRelation + ) -> t.List[Column]: + return [] + + def create_schema(self, relation: BaseRelation) -> None: + pass + + def drop_schema(self, relation: BaseRelation) -> None: + pass + + def drop_relation(self, relation: BaseRelation) -> None: + pass + + def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None: + pass + + def execute( + self, sql: str, auto_begin: bool = False, fetch: bool = False + ) -> t.Tuple[AdapterResponse, agate.Table]: + from dbt.adapters.base.impl import AdapterResponse + from sqlmesh.dbt.util import empty_table + + return AdapterResponse(""), empty_table() + + def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]: + return relation.schema + + def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]: + return relation.identifier + + class RuntimeAdapter(BaseAdapter): def __init__( self, diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 73e0252332..9e23fba245 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -309,37 +309,24 @@ def sqlmesh_model_kwargs( self, context: DbtContext, column_types_override: t.Optional[t.Dict[str, ColumnConfig]] = None, + extra_dependencies: t.Optional[Dependencies] = None, ) -> t.Dict[str, t.Any]: """Get common sqlmesh model parameters""" self.remove_tests_with_invalid_refs(context) self.check_for_circular_test_refs(context) + + dependencies = self.dependencies + if extra_dependencies: + dependencies = dependencies.union(extra_dependencies) + model_dialect = self.dialect(context) model_context = context.context_for_dependencies( - self.dependencies.union(self.tests_ref_source_dependencies) + dependencies.union(self.tests_ref_source_dependencies) ) jinja_macros = model_context.jinja_macros.trim( - self.dependencies.macros, package=self.package_name - ) - - model_node: AttributeDict[str, t.Any] = AttributeDict( - { - k: v - for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items() - if k in self.dependencies.model_attrs - } - if context._manifest and self.node_name in context._manifest._manifest.nodes - else {} - ) - - jinja_macros.add_globals( - { - "this": self.relation_info, - "model": model_node, - "schema": self.table_schema, - "config": self.config_attribute_dict, - **model_context.jinja_globals, # type: ignore - } + dependencies.macros, package=self.package_name ) + jinja_macros.add_globals(self._model_jinja_context(model_context, dependencies)) return { "audits": [(test.name, {}) for test in self.tests], "columns": column_types_to_sqlmesh( @@ -369,3 +356,28 @@ def to_sqlmesh( virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default, ) -> Model: """Convert DBT model into sqlmesh Model""" + + def _model_jinja_context( + self, context: DbtContext, dependencies: Dependencies + ) -> t.Dict[str, t.Any]: + model_node: AttributeDict[str, t.Any] = AttributeDict( + { + k: v + for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items() + if k in dependencies.model_attrs + } + if context._manifest and self.node_name in context._manifest._manifest.nodes + else {} + ) + return { + "this": self.relation_info, + "model": model_node, + "schema": self.table_schema, + "config": self.config_attribute_dict, + **context.jinja_globals, + } + + def _track_dependencies_on_render(self, input: str, context: DbtContext) -> Dependencies: + return context.track_dependencies_on_render( + input, self._model_jinja_context(context, self.dependencies), self.package_name + ) diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 70e1b10099..037d5a03ed 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -17,7 +17,7 @@ from sqlmesh.core.console import get_console from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.snapshot.definition import DeployabilityIndex -from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter +from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter, StubParsetimeAdapter from sqlmesh.dbt.relation import Policy from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS from sqlmesh.dbt.util import DBT_VERSION @@ -384,15 +384,15 @@ def create_builtin_globals( builtin_globals["this"] = this sources = jinja_globals.pop("sources", None) - if sources is not None: + if sources is not None and "source" not in jinja_globals: builtin_globals["source"] = generate_source(sources, api) refs = jinja_globals.pop("refs", None) - if refs is not None: + if refs is not None and "ref" not in jinja_globals: builtin_globals["ref"] = generate_ref(refs, api) variables = jinja_globals.pop("vars", None) - if variables is not None: + if variables is not None and "var" not in jinja_globals: builtin_globals["var"] = Var(variables) deployability_index = ( @@ -415,6 +415,7 @@ def create_builtin_globals( {k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")} ) + execute = True if engine_adapter is not None: builtin_globals["flags"] = Flags(which="run") adapter: BaseAdapter = RuntimeAdapter( @@ -435,7 +436,11 @@ def create_builtin_globals( ) else: builtin_globals["flags"] = Flags(which="parse") - adapter = ParsetimeAdapter( + adapter_class: t.Type[BaseAdapter] = ParsetimeAdapter + if jinja_globals.get("use_stub_adapter", False): + adapter_class = StubParsetimeAdapter + execute = False + adapter = adapter_class( jinja_macros, jinja_globals={**builtin_globals, **jinja_globals}, project_dialect=project_dialect, @@ -446,7 +451,7 @@ def create_builtin_globals( builtin_globals.update( { "adapter": adapter, - "execute": True, + "execute": execute, "load_relation": adapter.load_relation, "store_result": sql_execution.store_result, "load_result": sql_execution.load_result, diff --git a/sqlmesh/dbt/common.py b/sqlmesh/dbt/common.py index d9db5a472c..ec928576ed 100644 --- a/sqlmesh/dbt/common.py +++ b/sqlmesh/dbt/common.py @@ -184,6 +184,8 @@ class Dependencies(PydanticModel): variables: t.Set[str] = set() model_attrs: t.Set[str] = set() + has_dynamic_var_names: bool = False + def union(self, other: Dependencies) -> Dependencies: return Dependencies( macros=list(set(self.macros) | set(other.macros)), @@ -191,6 +193,7 @@ def union(self, other: Dependencies) -> Dependencies: refs=self.refs | other.refs, variables=self.variables | other.variables, model_attrs=self.model_attrs | other.model_attrs, + has_dynamic_var_names=self.has_dynamic_var_names or other.has_dynamic_var_names, ) @field_validator("macros", mode="after") diff --git a/sqlmesh/dbt/context.py b/sqlmesh/dbt/context.py index 307dea6477..c3e965ebdb 100644 --- a/sqlmesh/dbt/context.py +++ b/sqlmesh/dbt/context.py @@ -7,7 +7,8 @@ from dbt.adapters.base import BaseRelation from sqlmesh.core.config import Config as SQLMeshConfig -from sqlmesh.dbt.builtin import _relation_info_to_relation +from sqlmesh.dbt.builtin import _relation_info_to_relation, Var +from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.manifest import ManifestHelper from sqlmesh.dbt.target import TargetConfig from sqlmesh.utils import AttributeDict @@ -22,7 +23,6 @@ if t.TYPE_CHECKING: from jinja2 import Environment - from sqlmesh.dbt.basemodel import Dependencies from sqlmesh.dbt.model import ModelConfig from sqlmesh.dbt.relation import Policy from sqlmesh.dbt.seed import SeedConfig @@ -212,6 +212,38 @@ def target(self, value: TargetConfig) -> None: def render(self, source: str, **kwargs: t.Any) -> str: return self.jinja_environment.from_string(source).render(**kwargs) + def track_dependencies_on_render( + self, input: str, jinja_context: t.Dict[str, t.Any], package_name: t.Optional[str] = None + ) -> Dependencies: + dependencies_on_render = Dependencies() + + class TrackingVar(Var): + def __call__( + self, name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any + ) -> t.Any: + dependencies_on_render.variables.add(name) + return super().__call__(name, default, **kwargs) + + def has_var(self, name: str) -> bool: + dependencies_on_render.variables.add(name) + return super().has_var(name) + + if package_name: + top_level_packages = [*self.jinja_macros.top_level_packages, package_name] + jinja_macros = self.jinja_macros.copy(update={"top_level_packages": top_level_packages}) + else: + jinja_macros = self.jinja_macros + + jinja_environment = jinja_macros.build_environment( + **{ + **jinja_context, + "var": TrackingVar(self.variables), + "use_stub_adapter": True, + } + ) + jinja_environment.from_string(input).render() + return dependencies_on_render + def get_callable_macro( self, name: str, package: t.Optional[str] = None ) -> t.Optional[t.Callable]: diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index 125b204270..7414325902 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -554,6 +554,9 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies: args = [jinja_call_arg_name(arg) for arg in node.args] if args and args[0]: dependencies.variables.add(args[0]) + else: + # We couldn't determine the var name statically + dependencies.has_dynamic_var_names = True dependencies.macros.append(MacroReference(name="var")) elif len(call_name) == 1: macro_name = call_name[0] diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index c646392368..51bd32f2fb 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -25,7 +25,8 @@ ) from sqlmesh.core.model.kind import SCDType2ByTimeKind, OnDestructiveChange, OnAdditiveChange from sqlmesh.dbt.basemodel import BaseModelConfig, Materialization, SnapshotStrategy -from sqlmesh.dbt.common import SqlStr, extract_jinja_config, sql_str_validator +from sqlmesh.dbt.column import ColumnConfig +from sqlmesh.dbt.common import SqlStr, extract_jinja_config, sql_str_validator, Dependencies from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import field_validator @@ -436,6 +437,30 @@ def sqlmesh_config_fields(self) -> t.Set[str]: "physical_version", } + def sqlmesh_model_kwargs( + self, + context: DbtContext, + column_types_override: t.Optional[t.Dict[str, ColumnConfig]] = None, + extra_dependencies: t.Optional[Dependencies] = None, + ) -> t.Dict[str, t.Any]: + if not self.dependencies.has_dynamic_var_names: + return super().sqlmesh_model_kwargs(context, column_types_override, extra_dependencies) + + extra_dependencies = extra_dependencies or Dependencies() + extra_dependencies = extra_dependencies.union( + self._track_dependencies_on_render(self.sql_no_config, context) + ) + for pre_hook in self.pre_hook: + extra_dependencies = extra_dependencies.union( + self._track_dependencies_on_render(pre_hook.sql, context) + ) + for post_hook in self.post_hook: + extra_dependencies = extra_dependencies.union( + self._track_dependencies_on_render(post_hook.sql, context) + ) + + return super().sqlmesh_model_kwargs(context, column_types_override, extra_dependencies) + def to_sqlmesh( self, context: DbtContext, diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index bc6f878801..becbfc7ac2 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -362,6 +362,7 @@ def test_variables(assert_exp_eq, sushi_test_project): "nested_vars": { "some_nested_var": 2, }, + "dynamic_test_var": 3, "list_var": [ {"name": "item1", "value": 1}, {"name": "item2", "value": 2}, @@ -385,6 +386,7 @@ def test_variables(assert_exp_eq, sushi_test_project): "nested_vars": { "some_nested_var": 2, }, + "dynamic_test_var": 3, "list_var": [ {"name": "item1", "value": 1}, {"name": "item2", "value": 2}, diff --git a/tests/dbt/test_manifest.py b/tests/dbt/test_manifest.py index bf64e4b8b3..2bed6acb55 100644 --- a/tests/dbt/test_manifest.py +++ b/tests/dbt/test_manifest.py @@ -79,6 +79,7 @@ def test_manifest_helper(caplog): waiter_revenue_by_day_config = models["waiter_revenue_by_day_v2"] assert waiter_revenue_by_day_config.dependencies == Dependencies( macros={ + MacroReference(name="dynamic_var_name_dependency"), MacroReference(name="log_value"), MacroReference(name="test_dependencies"), MacroReference(package="customers", name="duckdb__current_engine"), @@ -87,6 +88,7 @@ def test_manifest_helper(caplog): }, sources={"streaming.items", "streaming.orders", "streaming.order_items"}, variables={"yet_another_var", "nested_vars"}, + has_dynamic_var_names=True, ) assert waiter_revenue_by_day_config.materialized == "incremental" assert waiter_revenue_by_day_config.incremental_strategy == "delete+insert" diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index fdb8345398..31b464f859 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -37,6 +37,7 @@ ) from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json from sqlmesh.dbt.builtin import _relation_info_to_relation +from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.column import ( ColumnConfig, column_descriptions_to_sqlmesh, @@ -50,6 +51,7 @@ from sqlmesh.dbt.target import BigQueryConfig, DuckDbConfig, SnowflakeConfig, ClickhouseConfig from sqlmesh.dbt.test import TestConfig from sqlmesh.utils.errors import ConfigError, MacroEvalError, SQLMeshError +from sqlmesh.utils.jinja import MacroReference pytestmark = [pytest.mark.dbt, pytest.mark.slow] @@ -1854,3 +1856,65 @@ def test_on_run_start_end(): "CREATE OR REPLACE TABLE schema_table_sushi__dev_nested_package AS SELECT 'sushi__dev' AS schema", ] ) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_track_dynamic_var_names_on_render( + sushi_test_project: Project, sushi_test_dbt_context: Context +): + context = sushi_test_project.context + context.target = BigQueryConfig(name="production", database="main", schema="sushi") + model_config = ModelConfig( + name="model", + alias="model", + schema="test", + package_name="package", + materialized="table", + unique_key="ds", + partition_by={"field": "ds", "granularity": "month"}, + sql=""" + {% set var_name = "yet_" + "another_" + "var" %} + {% set results = run_query('select 1 as one') %} + {% if results %} + SELECT {{ results.columns[0].values()[0] }} AS one {{ var(var_name) }} AS var FROM {{ this.identifier }} + {% else %} + SELECT NULL AS one {{ var(var_name) }} AS var FROM {{ this.identifier }} + {% endif %} + """, + dependencies=Dependencies(has_dynamic_var_names=True), + ) + converted_model = model_config.to_sqlmesh(context) + assert converted_model.jinja_macros.global_objs["vars"] == {"yet_another_var": 1} + + # Test the existing model in the sushi project + assert ( + "dynamic_test_var" # type: ignore + in sushi_test_dbt_context.get_model( + "sushi.waiter_revenue_by_day_v2" + ).jinja_macros.global_objs["vars"] + ) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_track_dynamic_var_names_on_render_in_macro(sushi_test_project: Project): + context = sushi_test_project.context + context.target = BigQueryConfig(name="production", database="main", schema="sushi") + model_config = ModelConfig( + name="model", + alias="model", + schema="test", + package_name="package", + materialized="table", + unique_key="ds", + partition_by={"field": "ds", "granularity": "month"}, + sql=""" + {% set var_name = "yet_" + "another_" + "var" %} + SELECT {{ sushi.dynamic_var_name_dependency(var_name) }} AS var + """, + dependencies=Dependencies( + macros=[MacroReference(package="sushi", name="dynamic_var_name_dependency")], + has_dynamic_var_names=True, + ), + ) + converted_model = model_config.to_sqlmesh(context) + assert converted_model.jinja_macros.global_objs["vars"] == {"yet_another_var": 1} diff --git a/tests/fixtures/dbt/sushi_test/dbt_project.yml b/tests/fixtures/dbt/sushi_test/dbt_project.yml index 073d85b4d4..c86057c928 100644 --- a/tests/fixtures/dbt/sushi_test/dbt_project.yml +++ b/tests/fixtures/dbt/sushi_test/dbt_project.yml @@ -47,6 +47,7 @@ vars: customers:boo: ["a", "b"] yet_another_var: 1 + dynamic_test_var: 3 customers: some_var: ["foo", "bar"] diff --git a/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql b/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql index 931ce88a84..88518df380 100644 --- a/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql +++ b/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql @@ -6,3 +6,9 @@ {{ log(var("yet_another_var", 2)) }} {{ log(var("nested_vars")['some_nested_var']) }} {% endmacro %} + + +{% macro dynamic_var_name_dependency(var_name) %} + {% set results = run_query('select 1 as one') %} + {{ return(var(var_name)) }} +{% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql b/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql index 335e7ab799..317cc87e68 100644 --- a/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql +++ b/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql @@ -13,7 +13,8 @@ {{ test_dependencies() }} -{% set results = run_query('select 1 as constant') %} +{% set var_name = "dynamic_" + "test_" + "var" %} +{% set results = run_query('select ' ~ dynamic_var_name_dependency(var_name) ~ ' as constant') %} SELECT o.waiter_id::INT AS waiter_id, /* Waiter id */ From af8f31b2a13ffc2127c6ea91c4f0837b2df5488a Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Fri, 22 Aug 2025 10:26:38 -0700 Subject: [PATCH 2/2] use all vars when dynamic vars are detected --- sqlmesh/dbt/adapter.py | 50 -------------------------------- sqlmesh/dbt/basemodel.py | 14 ++++----- sqlmesh/dbt/builtin.py | 17 ++++------- sqlmesh/dbt/context.py | 36 +---------------------- sqlmesh/dbt/loader.py | 20 ++++++------- sqlmesh/dbt/model.py | 27 +---------------- sqlmesh/dbt/project.py | 12 ++++---- tests/dbt/test_config.py | 22 +++----------- tests/dbt/test_transformation.py | 17 ++++++----- 9 files changed, 43 insertions(+), 172 deletions(-) diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 83f8a368ad..2dc9890ca4 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -245,56 +245,6 @@ def _raise_parsetime_adapter_call_error(action: str) -> None: raise ParsetimeAdapterCallError(f"Can't {action} at parse time.") -class StubParsetimeAdapter(BaseAdapter): - """Same as ParsetimeAdapter, but returns stub / empty values instead of raising an error.""" - - def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]: - return None - - def load_relation(self, relation: BaseRelation) -> t.Optional[BaseRelation]: - return None - - def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseRelation]: - return [] - - def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.List[BaseRelation]: - return [] - - def get_columns_in_relation(self, relation: BaseRelation) -> t.List[Column]: - return [] - - def get_missing_columns( - self, from_relation: BaseRelation, to_relation: BaseRelation - ) -> t.List[Column]: - return [] - - def create_schema(self, relation: BaseRelation) -> None: - pass - - def drop_schema(self, relation: BaseRelation) -> None: - pass - - def drop_relation(self, relation: BaseRelation) -> None: - pass - - def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None: - pass - - def execute( - self, sql: str, auto_begin: bool = False, fetch: bool = False - ) -> t.Tuple[AdapterResponse, agate.Table]: - from dbt.adapters.base.impl import AdapterResponse - from sqlmesh.dbt.util import empty_table - - return AdapterResponse(""), empty_table() - - def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]: - return relation.schema - - def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]: - return relation.identifier - - class RuntimeAdapter(BaseAdapter): def __init__( self, diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 9e23fba245..f1e1dbed03 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -309,15 +309,16 @@ def sqlmesh_model_kwargs( self, context: DbtContext, column_types_override: t.Optional[t.Dict[str, ColumnConfig]] = None, - extra_dependencies: t.Optional[Dependencies] = None, ) -> t.Dict[str, t.Any]: """Get common sqlmesh model parameters""" self.remove_tests_with_invalid_refs(context) self.check_for_circular_test_refs(context) - dependencies = self.dependencies - if extra_dependencies: - dependencies = dependencies.union(extra_dependencies) + dependencies = self.dependencies.copy() + if dependencies.has_dynamic_var_names: + # Include ALL variables as dependencies since we couldn't determine + # precisely which variables are referenced in the model + dependencies.variables |= set(context.variables) model_dialect = self.dialect(context) model_context = context.context_for_dependencies( @@ -376,8 +377,3 @@ def _model_jinja_context( "config": self.config_attribute_dict, **context.jinja_globals, } - - def _track_dependencies_on_render(self, input: str, context: DbtContext) -> Dependencies: - return context.track_dependencies_on_render( - input, self._model_jinja_context(context, self.dependencies), self.package_name - ) diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 037d5a03ed..70e1b10099 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -17,7 +17,7 @@ from sqlmesh.core.console import get_console from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.snapshot.definition import DeployabilityIndex -from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter, StubParsetimeAdapter +from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter from sqlmesh.dbt.relation import Policy from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS from sqlmesh.dbt.util import DBT_VERSION @@ -384,15 +384,15 @@ def create_builtin_globals( builtin_globals["this"] = this sources = jinja_globals.pop("sources", None) - if sources is not None and "source" not in jinja_globals: + if sources is not None: builtin_globals["source"] = generate_source(sources, api) refs = jinja_globals.pop("refs", None) - if refs is not None and "ref" not in jinja_globals: + if refs is not None: builtin_globals["ref"] = generate_ref(refs, api) variables = jinja_globals.pop("vars", None) - if variables is not None and "var" not in jinja_globals: + if variables is not None: builtin_globals["var"] = Var(variables) deployability_index = ( @@ -415,7 +415,6 @@ def create_builtin_globals( {k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")} ) - execute = True if engine_adapter is not None: builtin_globals["flags"] = Flags(which="run") adapter: BaseAdapter = RuntimeAdapter( @@ -436,11 +435,7 @@ def create_builtin_globals( ) else: builtin_globals["flags"] = Flags(which="parse") - adapter_class: t.Type[BaseAdapter] = ParsetimeAdapter - if jinja_globals.get("use_stub_adapter", False): - adapter_class = StubParsetimeAdapter - execute = False - adapter = adapter_class( + adapter = ParsetimeAdapter( jinja_macros, jinja_globals={**builtin_globals, **jinja_globals}, project_dialect=project_dialect, @@ -451,7 +446,7 @@ def create_builtin_globals( builtin_globals.update( { "adapter": adapter, - "execute": execute, + "execute": True, "load_relation": adapter.load_relation, "store_result": sql_execution.store_result, "load_result": sql_execution.load_result, diff --git a/sqlmesh/dbt/context.py b/sqlmesh/dbt/context.py index c3e965ebdb..d76cccbce7 100644 --- a/sqlmesh/dbt/context.py +++ b/sqlmesh/dbt/context.py @@ -7,7 +7,7 @@ from dbt.adapters.base import BaseRelation from sqlmesh.core.config import Config as SQLMeshConfig -from sqlmesh.dbt.builtin import _relation_info_to_relation, Var +from sqlmesh.dbt.builtin import _relation_info_to_relation from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.manifest import ManifestHelper from sqlmesh.dbt.target import TargetConfig @@ -101,8 +101,6 @@ def add_variables(self, variables: t.Dict[str, t.Any]) -> None: self._jinja_environment = None def set_and_render_variables(self, variables: t.Dict[str, t.Any], package: str) -> None: - self.variables = variables - jinja_environment = self.jinja_macros.build_environment(**self.jinja_globals) def _render_var(value: t.Any) -> t.Any: @@ -212,38 +210,6 @@ def target(self, value: TargetConfig) -> None: def render(self, source: str, **kwargs: t.Any) -> str: return self.jinja_environment.from_string(source).render(**kwargs) - def track_dependencies_on_render( - self, input: str, jinja_context: t.Dict[str, t.Any], package_name: t.Optional[str] = None - ) -> Dependencies: - dependencies_on_render = Dependencies() - - class TrackingVar(Var): - def __call__( - self, name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any - ) -> t.Any: - dependencies_on_render.variables.add(name) - return super().__call__(name, default, **kwargs) - - def has_var(self, name: str) -> bool: - dependencies_on_render.variables.add(name) - return super().has_var(name) - - if package_name: - top_level_packages = [*self.jinja_macros.top_level_packages, package_name] - jinja_macros = self.jinja_macros.copy(update={"top_level_packages": top_level_packages}) - else: - jinja_macros = self.jinja_macros - - jinja_environment = jinja_macros.build_environment( - **{ - **jinja_context, - "var": TrackingVar(self.variables), - "use_stub_adapter": True, - } - ) - jinja_environment.from_string(input).render() - return dependencies_on_render - def get_callable_macro( self, name: str, package: t.Optional[str] = None ) -> t.Optional[t.Callable]: diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index be0ff59aa4..d321246896 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -124,8 +124,6 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model: ) for project in self._load_projects(): - context = project.context.copy() - macros_max_mtime = self._macros_max_mtime yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder( project.context.project_root @@ -135,12 +133,13 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model: logger.debug("Converting models to sqlmesh") # Now that config is rendered, create the sqlmesh models for package in project.packages.values(): - context.set_and_render_variables(package.variables, package.name) + package_context = project.context.copy() + package_context.set_and_render_variables(package.variables, package.name) package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds} for model in package_models.values(): sqlmesh_model = cache.get_or_load_models( - model.path, loader=lambda: [_to_sqlmesh(model, context)] + model.path, loader=lambda: [_to_sqlmesh(model, package_context)] )[0] models[sqlmesh_model.fqn] = sqlmesh_model @@ -155,15 +154,14 @@ def _load_audits( audits: UniqueKeyDict = UniqueKeyDict("audits") for project in self._load_projects(): - context = project.context - logger.debug("Converting audits to sqlmesh") for package in project.packages.values(): - context.set_and_render_variables(package.variables, package.name) + package_context = project.context.copy() + package_context.set_and_render_variables(package.variables, package.name) for test in package.tests.values(): logger.debug("Converting '%s' to sqlmesh format", test.name) try: - audits[test.name] = test.to_sqlmesh(context) + audits[test.name] = test.to_sqlmesh(package_context) except MissingModelError as e: logger.warning( "Skipping audit '%s' because model '%s' is not a valid ref", @@ -244,9 +242,9 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm project_names: t.Set[str] = set() dialect = self.config.dialect for project in self._load_projects(): - context = project.context for package_name, package in project.packages.items(): - context.set_and_render_variables(package.variables, package_name) + package_context = project.context.copy() + package_context.set_and_render_variables(package.variables, package_name) on_run_start: t.List[str] = [ on_run_hook.sql for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index) @@ -261,7 +259,7 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm for hook in [*package.on_run_start.values(), *package.on_run_end.values()]: dependencies = dependencies.union(hook.dependencies) - statements_context = context.context_for_dependencies(dependencies) + statements_context = package_context.context_for_dependencies(dependencies) jinja_registry = make_jinja_registry( statements_context.jinja_macros, package_name, set(dependencies.macros) ) diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index 51bd32f2fb..c646392368 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -25,8 +25,7 @@ ) from sqlmesh.core.model.kind import SCDType2ByTimeKind, OnDestructiveChange, OnAdditiveChange from sqlmesh.dbt.basemodel import BaseModelConfig, Materialization, SnapshotStrategy -from sqlmesh.dbt.column import ColumnConfig -from sqlmesh.dbt.common import SqlStr, extract_jinja_config, sql_str_validator, Dependencies +from sqlmesh.dbt.common import SqlStr, extract_jinja_config, sql_str_validator from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import field_validator @@ -437,30 +436,6 @@ def sqlmesh_config_fields(self) -> t.Set[str]: "physical_version", } - def sqlmesh_model_kwargs( - self, - context: DbtContext, - column_types_override: t.Optional[t.Dict[str, ColumnConfig]] = None, - extra_dependencies: t.Optional[Dependencies] = None, - ) -> t.Dict[str, t.Any]: - if not self.dependencies.has_dynamic_var_names: - return super().sqlmesh_model_kwargs(context, column_types_override, extra_dependencies) - - extra_dependencies = extra_dependencies or Dependencies() - extra_dependencies = extra_dependencies.union( - self._track_dependencies_on_render(self.sql_no_config, context) - ) - for pre_hook in self.pre_hook: - extra_dependencies = extra_dependencies.union( - self._track_dependencies_on_render(pre_hook.sql, context) - ) - for post_hook in self.post_hook: - extra_dependencies = extra_dependencies.union( - self._track_dependencies_on_render(post_hook.sql, context) - ) - - return super().sqlmesh_model_kwargs(context, column_types_override, extra_dependencies) - def to_sqlmesh( self, context: DbtContext, diff --git a/sqlmesh/dbt/project.py b/sqlmesh/dbt/project.py index d37c9cc6c4..581660943a 100644 --- a/sqlmesh/dbt/project.py +++ b/sqlmesh/dbt/project.py @@ -55,9 +55,6 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N raise ConfigError(f"Could not find {PROJECT_FILENAME} in {context.project_root}") project_yaml = load_yaml(project_file_path) - variable_overrides = variables - variables = {**project_yaml.get("vars", {}), **(variables or {})} - project_name = context.render(project_yaml.get("name", "")) context.project_name = project_name if not context.project_name: @@ -69,6 +66,7 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N profile = Profile.load(context, context.target_name) context.target = profile.target + variable_overrides = variables or {} context.manifest = ManifestHelper( project_file_path.parent, profile.path.parent, @@ -101,13 +99,17 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N package = package_loader.load(path.parent) packages[package.name] = package + all_project_variables = {**project_yaml.get("vars", {}), **(variable_overrides or {})} for name, package in packages.items(): - package_vars = variables.get(name) + package_vars = all_project_variables.get(name) if isinstance(package_vars, dict): package.variables.update(package_vars) - package.variables.update(variables) + if name == context.project_name: + package.variables.update(all_project_variables) + else: + package.variables.update(variable_overrides) return Project(context, profile, packages) diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index becbfc7ac2..99426ebb97 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -376,26 +376,10 @@ def test_variables(assert_exp_eq, sushi_test_project): expected_customer_variables = { "some_var": ["foo", "bar"], "some_other_var": 5, - "yet_another_var": 1, + "yet_another_var": 5, "customers:bla": False, "customers:customer_id": "customer_id", "start": "Jan 1 2022", - "top_waiters:limit": 10, - "top_waiters:revenue": "revenue", - "customers:boo": ["a", "b"], - "nested_vars": { - "some_nested_var": 2, - }, - "dynamic_test_var": 3, - "list_var": [ - {"name": "item1", "value": 1}, - {"name": "item2", "value": 2}, - ], - "customers": { - "customers:bla": False, - "customers:customer_id": "customer_id", - "some_var": ["foo", "bar"], - }, } assert sushi_test_project.packages["sushi"].variables == expected_sushi_variables @@ -408,7 +392,9 @@ def test_nested_variables(sushi_test_project): sql="SELECT {{ var('nested_vars')['some_nested_var'] }}", dependencies=Dependencies(variables=["nested_vars"]), ) - sqlmesh_model = model_config.to_sqlmesh(sushi_test_project.context) + context = sushi_test_project.context.copy() + context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi") + sqlmesh_model = model_config.to_sqlmesh(context) assert sqlmesh_model.jinja_macros.global_objs["vars"]["nested_vars"] == {"some_nested_var": 2} diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 31b464f859..1bcc3081f7 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -1532,6 +1532,9 @@ def test_dbt_package_macros(sushi_test_project: Project): @pytest.mark.xdist_group("dbt_manifest") def test_dbt_vars(sushi_test_project: Project): context = sushi_test_project.context + context.set_and_render_variables( + sushi_test_project.packages["customers"].variables, "customers" + ) assert context.render("{{ var('some_other_var') }}") == "5" assert context.render("{{ var('some_other_var', 0) }}") == "5" @@ -1859,10 +1862,9 @@ def test_on_run_start_end(): @pytest.mark.xdist_group("dbt_manifest") -def test_track_dynamic_var_names_on_render( - sushi_test_project: Project, sushi_test_dbt_context: Context -): +def test_dynamic_var_names(sushi_test_project: Project, sushi_test_dbt_context: Context): context = sushi_test_project.context + context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi") context.target = BigQueryConfig(name="production", database="main", schema="sushi") model_config = ModelConfig( name="model", @@ -1884,7 +1886,7 @@ def test_track_dynamic_var_names_on_render( dependencies=Dependencies(has_dynamic_var_names=True), ) converted_model = model_config.to_sqlmesh(context) - assert converted_model.jinja_macros.global_objs["vars"] == {"yet_another_var": 1} + assert "yet_another_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore # Test the existing model in the sushi project assert ( @@ -1896,8 +1898,9 @@ def test_track_dynamic_var_names_on_render( @pytest.mark.xdist_group("dbt_manifest") -def test_track_dynamic_var_names_on_render_in_macro(sushi_test_project: Project): +def test_dynamic_var_names_in_macro(sushi_test_project: Project): context = sushi_test_project.context + context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi") context.target = BigQueryConfig(name="production", database="main", schema="sushi") model_config = ModelConfig( name="model", @@ -1908,7 +1911,7 @@ def test_track_dynamic_var_names_on_render_in_macro(sushi_test_project: Project) unique_key="ds", partition_by={"field": "ds", "granularity": "month"}, sql=""" - {% set var_name = "yet_" + "another_" + "var" %} + {% set var_name = "dynamic_" + "test_" + "var" %} SELECT {{ sushi.dynamic_var_name_dependency(var_name) }} AS var """, dependencies=Dependencies( @@ -1917,4 +1920,4 @@ def test_track_dynamic_var_names_on_render_in_macro(sushi_test_project: Project) ), ) converted_model = model_config.to_sqlmesh(context) - assert converted_model.jinja_macros.global_objs["vars"] == {"yet_another_var": 1} + assert "dynamic_test_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore