Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions sqlmesh/dbt/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,33 +313,21 @@ def sqlmesh_model_kwargs(
"""Get common sqlmesh model parameters"""
self.remove_tests_with_invalid_refs(context)
self.check_for_circular_test_refs(context)

dependencies = self.dependencies.copy()
if dependencies.has_dynamic_var_names:
# Include ALL variables as dependencies since we couldn't determine
# precisely which variables are referenced in the model
dependencies.variables |= set(context.variables)

model_dialect = self.dialect(context)
model_context = context.context_for_dependencies(
self.dependencies.union(self.tests_ref_source_dependencies)
dependencies.union(self.tests_ref_source_dependencies)
)
jinja_macros = model_context.jinja_macros.trim(
self.dependencies.macros, package=self.package_name
)

model_node: AttributeDict[str, t.Any] = AttributeDict(
{
k: v
for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items()
if k in self.dependencies.model_attrs
}
if context._manifest and self.node_name in context._manifest._manifest.nodes
else {}
)

jinja_macros.add_globals(
{
"this": self.relation_info,
"model": model_node,
"schema": self.table_schema,
"config": self.config_attribute_dict,
**model_context.jinja_globals, # type: ignore
}
dependencies.macros, package=self.package_name
)
jinja_macros.add_globals(self._model_jinja_context(model_context, dependencies))
return {
"audits": [(test.name, {}) for test in self.tests],
"columns": column_types_to_sqlmesh(
Expand Down Expand Up @@ -369,3 +357,23 @@ def to_sqlmesh(
virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default,
) -> Model:
"""Convert DBT model into sqlmesh Model"""

def _model_jinja_context(
self, context: DbtContext, dependencies: Dependencies
) -> t.Dict[str, t.Any]:
model_node: AttributeDict[str, t.Any] = AttributeDict(
{
k: v
for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items()
if k in dependencies.model_attrs
}
if context._manifest and self.node_name in context._manifest._manifest.nodes
else {}
)
return {
"this": self.relation_info,
"model": model_node,
"schema": self.table_schema,
"config": self.config_attribute_dict,
**context.jinja_globals,
}
3 changes: 3 additions & 0 deletions sqlmesh/dbt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,16 @@ class Dependencies(PydanticModel):
variables: t.Set[str] = set()
model_attrs: t.Set[str] = set()

has_dynamic_var_names: bool = False

def union(self, other: Dependencies) -> Dependencies:
return Dependencies(
macros=list(set(self.macros) | set(other.macros)),
sources=self.sources | other.sources,
refs=self.refs | other.refs,
variables=self.variables | other.variables,
model_attrs=self.model_attrs | other.model_attrs,
has_dynamic_var_names=self.has_dynamic_var_names or other.has_dynamic_var_names,
)

@field_validator("macros", mode="after")
Expand Down
4 changes: 1 addition & 3 deletions sqlmesh/dbt/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sqlmesh.core.config import Config as SQLMeshConfig
from sqlmesh.dbt.builtin import _relation_info_to_relation
from sqlmesh.dbt.common import Dependencies
from sqlmesh.dbt.manifest import ManifestHelper
from sqlmesh.dbt.target import TargetConfig
from sqlmesh.utils import AttributeDict
Expand All @@ -22,7 +23,6 @@
if t.TYPE_CHECKING:
from jinja2 import Environment

from sqlmesh.dbt.basemodel import Dependencies
from sqlmesh.dbt.model import ModelConfig
from sqlmesh.dbt.relation import Policy
from sqlmesh.dbt.seed import SeedConfig
Expand Down Expand Up @@ -101,8 +101,6 @@ def add_variables(self, variables: t.Dict[str, t.Any]) -> None:
self._jinja_environment = None

def set_and_render_variables(self, variables: t.Dict[str, t.Any], package: str) -> None:
self.variables = variables

jinja_environment = self.jinja_macros.build_environment(**self.jinja_globals)

def _render_var(value: t.Any) -> t.Any:
Expand Down
20 changes: 9 additions & 11 deletions sqlmesh/dbt/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
)

for project in self._load_projects():
context = project.context.copy()

macros_max_mtime = self._macros_max_mtime
yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder(
project.context.project_root
Expand All @@ -135,12 +133,13 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
logger.debug("Converting models to sqlmesh")
# Now that config is rendered, create the sqlmesh models
for package in project.packages.values():
context.set_and_render_variables(package.variables, package.name)
package_context = project.context.copy()
package_context.set_and_render_variables(package.variables, package.name)
package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds}

for model in package_models.values():
sqlmesh_model = cache.get_or_load_models(
model.path, loader=lambda: [_to_sqlmesh(model, context)]
model.path, loader=lambda: [_to_sqlmesh(model, package_context)]
)[0]

models[sqlmesh_model.fqn] = sqlmesh_model
Expand All @@ -155,15 +154,14 @@ def _load_audits(
audits: UniqueKeyDict = UniqueKeyDict("audits")

for project in self._load_projects():
context = project.context

logger.debug("Converting audits to sqlmesh")
for package in project.packages.values():
context.set_and_render_variables(package.variables, package.name)
package_context = project.context.copy()
package_context.set_and_render_variables(package.variables, package.name)
for test in package.tests.values():
logger.debug("Converting '%s' to sqlmesh format", test.name)
try:
audits[test.name] = test.to_sqlmesh(context)
audits[test.name] = test.to_sqlmesh(package_context)
except MissingModelError as e:
logger.warning(
"Skipping audit '%s' because model '%s' is not a valid ref",
Expand Down Expand Up @@ -244,9 +242,9 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
project_names: t.Set[str] = set()
dialect = self.config.dialect
for project in self._load_projects():
context = project.context
for package_name, package in project.packages.items():
context.set_and_render_variables(package.variables, package_name)
package_context = project.context.copy()
package_context.set_and_render_variables(package.variables, package_name)
on_run_start: t.List[str] = [
on_run_hook.sql
for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index)
Expand All @@ -261,7 +259,7 @@ def _load_environment_statements(self, macros: MacroRegistry) -> t.List[Environm
for hook in [*package.on_run_start.values(), *package.on_run_end.values()]:
dependencies = dependencies.union(hook.dependencies)

statements_context = context.context_for_dependencies(dependencies)
statements_context = package_context.context_for_dependencies(dependencies)
jinja_registry = make_jinja_registry(
statements_context.jinja_macros, package_name, set(dependencies.macros)
)
Expand Down
3 changes: 3 additions & 0 deletions sqlmesh/dbt/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,9 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies:
args = [jinja_call_arg_name(arg) for arg in node.args]
if args and args[0]:
dependencies.variables.add(args[0])
else:
# We couldn't determine the var name statically
dependencies.has_dynamic_var_names = True
dependencies.macros.append(MacroReference(name="var"))
elif len(call_name) == 1:
macro_name = call_name[0]
Expand Down
12 changes: 7 additions & 5 deletions sqlmesh/dbt/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N
raise ConfigError(f"Could not find {PROJECT_FILENAME} in {context.project_root}")
project_yaml = load_yaml(project_file_path)

variable_overrides = variables
variables = {**project_yaml.get("vars", {}), **(variables or {})}

project_name = context.render(project_yaml.get("name", ""))
context.project_name = project_name
if not context.project_name:
Expand All @@ -69,6 +66,7 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N
profile = Profile.load(context, context.target_name)
context.target = profile.target

variable_overrides = variables or {}
context.manifest = ManifestHelper(
project_file_path.parent,
profile.path.parent,
Expand Down Expand Up @@ -101,13 +99,17 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N
package = package_loader.load(path.parent)
packages[package.name] = package

all_project_variables = {**project_yaml.get("vars", {}), **(variable_overrides or {})}
for name, package in packages.items():
package_vars = variables.get(name)
package_vars = all_project_variables.get(name)

if isinstance(package_vars, dict):
package.variables.update(package_vars)

package.variables.update(variables)
if name == context.project_name:
package.variables.update(all_project_variables)
else:
package.variables.update(variable_overrides)

return Project(context, profile, packages)

Expand Down
22 changes: 5 additions & 17 deletions tests/dbt/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def test_variables(assert_exp_eq, sushi_test_project):
"nested_vars": {
"some_nested_var": 2,
},
"dynamic_test_var": 3,
"list_var": [
{"name": "item1", "value": 1},
{"name": "item2", "value": 2},
Expand All @@ -375,25 +376,10 @@ def test_variables(assert_exp_eq, sushi_test_project):
expected_customer_variables = {
"some_var": ["foo", "bar"],
"some_other_var": 5,
"yet_another_var": 1,
"yet_another_var": 5,
"customers:bla": False,
"customers:customer_id": "customer_id",
"start": "Jan 1 2022",
"top_waiters:limit": 10,
"top_waiters:revenue": "revenue",
"customers:boo": ["a", "b"],
"nested_vars": {
"some_nested_var": 2,
},
"list_var": [
{"name": "item1", "value": 1},
{"name": "item2", "value": 2},
],
"customers": {
"customers:bla": False,
"customers:customer_id": "customer_id",
"some_var": ["foo", "bar"],
},
}

assert sushi_test_project.packages["sushi"].variables == expected_sushi_variables
Expand All @@ -406,7 +392,9 @@ def test_nested_variables(sushi_test_project):
sql="SELECT {{ var('nested_vars')['some_nested_var'] }}",
dependencies=Dependencies(variables=["nested_vars"]),
)
sqlmesh_model = model_config.to_sqlmesh(sushi_test_project.context)
context = sushi_test_project.context.copy()
context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi")
sqlmesh_model = model_config.to_sqlmesh(context)
assert sqlmesh_model.jinja_macros.global_objs["vars"]["nested_vars"] == {"some_nested_var": 2}


Expand Down
2 changes: 2 additions & 0 deletions tests/dbt/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_manifest_helper(caplog):
waiter_revenue_by_day_config = models["waiter_revenue_by_day_v2"]
assert waiter_revenue_by_day_config.dependencies == Dependencies(
macros={
MacroReference(name="dynamic_var_name_dependency"),
MacroReference(name="log_value"),
MacroReference(name="test_dependencies"),
MacroReference(package="customers", name="duckdb__current_engine"),
Expand All @@ -87,6 +88,7 @@ def test_manifest_helper(caplog):
},
sources={"streaming.items", "streaming.orders", "streaming.order_items"},
variables={"yet_another_var", "nested_vars"},
has_dynamic_var_names=True,
)
assert waiter_revenue_by_day_config.materialized == "incremental"
assert waiter_revenue_by_day_config.incremental_strategy == "delete+insert"
Expand Down
67 changes: 67 additions & 0 deletions tests/dbt/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json
from sqlmesh.dbt.builtin import _relation_info_to_relation
from sqlmesh.dbt.common import Dependencies
from sqlmesh.dbt.column import (
ColumnConfig,
column_descriptions_to_sqlmesh,
Expand All @@ -50,6 +51,7 @@
from sqlmesh.dbt.target import BigQueryConfig, DuckDbConfig, SnowflakeConfig, ClickhouseConfig
from sqlmesh.dbt.test import TestConfig
from sqlmesh.utils.errors import ConfigError, MacroEvalError, SQLMeshError
from sqlmesh.utils.jinja import MacroReference

pytestmark = [pytest.mark.dbt, pytest.mark.slow]

Expand Down Expand Up @@ -1530,6 +1532,9 @@ def test_dbt_package_macros(sushi_test_project: Project):
@pytest.mark.xdist_group("dbt_manifest")
def test_dbt_vars(sushi_test_project: Project):
context = sushi_test_project.context
context.set_and_render_variables(
sushi_test_project.packages["customers"].variables, "customers"
)

assert context.render("{{ var('some_other_var') }}") == "5"
assert context.render("{{ var('some_other_var', 0) }}") == "5"
Expand Down Expand Up @@ -1854,3 +1859,65 @@ def test_on_run_start_end():
"CREATE OR REPLACE TABLE schema_table_sushi__dev_nested_package AS SELECT 'sushi__dev' AS schema",
]
)


@pytest.mark.xdist_group("dbt_manifest")
def test_dynamic_var_names(sushi_test_project: Project, sushi_test_dbt_context: Context):
context = sushi_test_project.context
context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi")
context.target = BigQueryConfig(name="production", database="main", schema="sushi")
model_config = ModelConfig(
name="model",
alias="model",
schema="test",
package_name="package",
materialized="table",
unique_key="ds",
partition_by={"field": "ds", "granularity": "month"},
sql="""
{% set var_name = "yet_" + "another_" + "var" %}
{% set results = run_query('select 1 as one') %}
{% if results %}
SELECT {{ results.columns[0].values()[0] }} AS one {{ var(var_name) }} AS var FROM {{ this.identifier }}
{% else %}
SELECT NULL AS one {{ var(var_name) }} AS var FROM {{ this.identifier }}
{% endif %}
""",
dependencies=Dependencies(has_dynamic_var_names=True),
)
converted_model = model_config.to_sqlmesh(context)
assert "yet_another_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore

# Test the existing model in the sushi project
assert (
"dynamic_test_var" # type: ignore
in sushi_test_dbt_context.get_model(
"sushi.waiter_revenue_by_day_v2"
).jinja_macros.global_objs["vars"]
)


@pytest.mark.xdist_group("dbt_manifest")
def test_dynamic_var_names_in_macro(sushi_test_project: Project):
context = sushi_test_project.context
context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi")
context.target = BigQueryConfig(name="production", database="main", schema="sushi")
model_config = ModelConfig(
name="model",
alias="model",
schema="test",
package_name="package",
materialized="table",
unique_key="ds",
partition_by={"field": "ds", "granularity": "month"},
sql="""
{% set var_name = "dynamic_" + "test_" + "var" %}
SELECT {{ sushi.dynamic_var_name_dependency(var_name) }} AS var
""",
dependencies=Dependencies(
macros=[MacroReference(package="sushi", name="dynamic_var_name_dependency")],
has_dynamic_var_names=True,
),
)
converted_model = model_config.to_sqlmesh(context)
assert "dynamic_test_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore
1 change: 1 addition & 0 deletions tests/fixtures/dbt/sushi_test/dbt_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ vars:
customers:boo: ["a", "b"]

yet_another_var: 1
dynamic_test_var: 3

customers:
some_var: ["foo", "bar"]
Expand Down
6 changes: 6 additions & 0 deletions tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,9 @@
{{ log(var("yet_another_var", 2)) }}
{{ log(var("nested_vars")['some_nested_var']) }}
{% endmacro %}


{% macro dynamic_var_name_dependency(var_name) %}
{% set results = run_query('select 1 as one') %}
{{ return(var(var_name)) }}
{% endmacro %}
Loading