diff --git a/sqlmesh/dbt/project.py b/sqlmesh/dbt/project.py index 4af30958f5..355b18630e 100644 --- a/sqlmesh/dbt/project.py +++ b/sqlmesh/dbt/project.py @@ -99,17 +99,20 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N package = package_loader.load(path.parent) packages[package.name] = package + # Variable resolution precedence: + # 1. Variable overrides + # 2. Package-scoped variables in the root project's dbt_project.yml + # 3. Global project variables in the root project's dbt_project.yml + # 4. Variables in the package's dbt_project.yml all_project_variables = {**(project_yaml.get("vars") or {}), **(variable_overrides or {})} for name, package in packages.items(): - package_vars = all_project_variables.get(name) - - if isinstance(package_vars, dict): - package.variables.update(package_vars) - - if name == context.project_name: - package.variables.update(all_project_variables) + if isinstance(all_project_variables.get(name), dict): + project_vars_copy = all_project_variables.copy() + package_scoped_vars = project_vars_copy.pop(name) + package.variables.update(project_vars_copy) + package.variables.update(package_scoped_vars) else: - package.variables.update(variable_overrides) + package.variables.update(all_project_variables) return Project(context, profile, packages) diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index 4e3e78eea9..30dae478a1 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -343,15 +343,26 @@ def test_variables(assert_exp_eq, sushi_test_project): "customers:customer_id": "customer_id", "some_var": ["foo", "bar"], }, + "some_var": "should be overridden in customers package", } expected_customer_variables = { - "some_var": ["foo", "bar"], + "some_var": ["foo", "bar"], # Takes precedence over the root project variable "some_other_var": 5, - "yet_another_var": 5, "customers:bla": False, "customers:customer_id": "customer_id", + "yet_another_var": 1, # Make sure that the project variable takes precedence + "top_waiters:limit": "{{ get_top_waiters_limit() }}", + "top_waiters:revenue": "revenue", + "customers:boo": ["a", "b"], + "nested_vars": { + "some_nested_var": 2, + }, + "dynamic_test_var": 3, + "list_var": [ + {"name": "item1", "value": 1}, + {"name": "item2", "value": 2}, + ], } - assert sushi_test_project.packages["sushi"].variables == expected_sushi_variables assert sushi_test_project.packages["customers"].variables == expected_customer_variables diff --git a/tests/fixtures/dbt/sushi_test/dbt_project.yml b/tests/fixtures/dbt/sushi_test/dbt_project.yml index 2a25389e43..920dea7216 100644 --- a/tests/fixtures/dbt/sushi_test/dbt_project.yml +++ b/tests/fixtures/dbt/sushi_test/dbt_project.yml @@ -50,6 +50,7 @@ vars: yet_another_var: 1 dynamic_test_var: 3 + some_var: 'should be overridden in customers package' customers: some_var: ["foo", "bar"] @@ -74,4 +75,4 @@ on-run-start: on-run-end: - '{{ create_tables(schemas) }}' - 'DROP TABLE to_be_executed_last;' - - '{{ graph_usage() }}' \ No newline at end of file + - '{{ graph_usage() }}'