From 9dbe718fbbfb28579bc296d15b528d9566b8fef4 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Fri, 29 Aug 2025 21:24:04 +0300 Subject: [PATCH 1/6] Fix!: depend on all attributes of dbt `model` when passed to a macro --- .../sushi_dbt/macros/check_model_is_table.sql | 15 +++++++ examples/sushi_dbt/models/customers.sql | 6 +++ sqlmesh/core/renderer.py | 45 ++++++++++--------- sqlmesh/dbt/basemodel.py | 28 ++++++++---- sqlmesh/dbt/builtin.py | 26 +++++++++++ sqlmesh/dbt/common.py | 1 + sqlmesh/dbt/manifest.py | 44 +++++++++++++++--- tests/core/test_context.py | 5 +++ 8 files changed, 134 insertions(+), 36 deletions(-) create mode 100644 examples/sushi_dbt/macros/check_model_is_table.sql diff --git a/examples/sushi_dbt/macros/check_model_is_table.sql b/examples/sushi_dbt/macros/check_model_is_table.sql new file mode 100644 index 0000000000..42dc5615e4 --- /dev/null +++ b/examples/sushi_dbt/macros/check_model_is_table.sql @@ -0,0 +1,15 @@ +{%- macro check_model_is_table(model) -%} + {%- if model.config.materialized != 'table' -%} + {%- do exceptions.raise_compiler_error( + "Model must use the table materialization. Please check any model overrides." + ) -%} + {%- endif -%} +{%- endmacro -%} + +{%- macro check_model_is_table_alt(foo) -%} + {%- if foo.config.materialized != 'table' -%} + {%- do exceptions.raise_compiler_error( + "Model must use the table materialization. Please check any model overrides." + ) -%} + {%- endif -%} +{%- endmacro -%} diff --git a/examples/sushi_dbt/models/customers.sql b/examples/sushi_dbt/models/customers.sql index ac82126fc7..df602a88fa 100644 --- a/examples/sushi_dbt/models/customers.sql +++ b/examples/sushi_dbt/models/customers.sql @@ -1,3 +1,9 @@ +{{ check_model_is_table(model) }} + +{% if 'DISTINCT' in model.raw_code %} + {{ check_model_is_table_alt(model) }} +{% endif %} + SELECT DISTINCT customer_id::INT AS customer_id FROM {{ ref('orders') }} as o diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 4078d718a6..a1c40b31b6 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -188,30 +188,32 @@ def _resolve_table(table: str | exp.Table) -> str: } variables = kwargs.pop("variables", {}) - jinja_env_kwargs = { - **{ - **render_kwargs, - **_prepare_python_env_for_jinja(macro_evaluator, self._python_env), - **variables, - }, - "snapshots": snapshots or {}, - "table_mapping": table_mapping, - "deployability_index": deployability_index, - "default_catalog": self._default_catalog, - "runtime_stage": runtime_stage.value, - "resolve_table": _resolve_table, - } - if this_model: - render_kwargs["this_model"] = this_model - jinja_env_kwargs["this_model"] = this_model.sql( - dialect=self._dialect, identify=True, comments=False - ) - - jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs) expressions = [self._expression] if isinstance(self._expression, d.Jinja): try: + jinja_env_kwargs = { + **{ + **render_kwargs, + **_prepare_python_env_for_jinja(macro_evaluator, self._python_env), + **variables, + }, + "snapshots": snapshots or {}, + "table_mapping": table_mapping, + "deployability_index": deployability_index, + "default_catalog": self._default_catalog, + "runtime_stage": runtime_stage.value, + "resolve_table": _resolve_table, + "raw_code": self._expression.name, + } + + if this_model: + jinja_env_kwargs["this_model"] = this_model.sql( + dialect=self._dialect, identify=True, comments=False + ) + + jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs) + expressions = [] rendered_expression = jinja_env.from_string(self._expression.name).render() logger.debug( @@ -229,6 +231,9 @@ def _resolve_table(table: str | exp.Table) -> str: f"Could not render or parse jinja at '{self._path}'.\n{ex}" ) from ex + if this_model: + render_kwargs["this_model"] = this_model + macro_evaluator.locals.update(render_kwargs) if variables: diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 212b314997..d28dbcd22e 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -19,6 +19,7 @@ column_types_to_sqlmesh, ) from sqlmesh.dbt.common import ( + DBT_ALL_MODEL_ATTRS, DbtConfig, Dependencies, GeneralConfig, @@ -27,6 +28,7 @@ ) from sqlmesh.dbt.relation import Policy, RelationType from sqlmesh.dbt.test import TestConfig +from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.utils import AttributeDict from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import field_validator @@ -375,15 +377,23 @@ def to_sqlmesh( def _model_jinja_context( self, context: DbtContext, dependencies: Dependencies ) -> t.Dict[str, t.Any]: - model_node: AttributeDict[str, t.Any] = AttributeDict( - { - k: v - for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items() - if k in dependencies.model_attrs - } - if context._manifest and self.node_name in context._manifest._manifest.nodes - else {} - ) + if context._manifest and self.node_name in context._manifest._manifest.nodes: + attributes = context._manifest._manifest.nodes[self.node_name].to_dict() + if DBT_ALL_MODEL_ATTRS in dependencies.model_attrs: + model_node: AttributeDict[str, t.Any] = AttributeDict(attributes) + else: + model_node = AttributeDict( + filter(lambda kv: kv[0] in dependencies.model_attrs, attributes.items()) + ) + + raw_code_key = "raw_code" if DBT_VERSION >= (1, 3, 0) else "raw_sql" # type: ignore + + # We exclude the raw SQL code to reduce the payload size. It's still accessible through + # the JinjaQuery instance stored in the resulting SQLMesh model's `query` field. + model_node.pop(raw_code_key, None) + else: + model_node = AttributeDict({}) + return { "this": self.relation_info, "model": model_node, diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 24669807bb..6e3cd82113 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -26,6 +26,13 @@ from sqlmesh.utils.errors import ConfigError, MacroEvalError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReference, MacroReturnVal +if t.TYPE_CHECKING: + from typing import Protocol + + class Model(Protocol): + def __getattr__(self, key: str) -> t.Any: ... + + logger = logging.getLogger(__name__) @@ -301,6 +308,21 @@ def source(package: str, name: str) -> t.Optional[BaseRelation]: return source +def generate_model(model: AttributeDict, raw_code: str) -> Model: + class Model: + def __init__(self, model: AttributeDict) -> None: + self._model = model + self._raw_code_key = "raw_code" if DBT_VERSION >= (1, 3, 0) else "raw_sql" # type: ignore + + def __getattr__(self, key: str) -> t.Any: + if key == self._raw_code_key: + return raw_code + + return getattr(self._model, key) + + return Model(model) + + def return_val(val: t.Any) -> None: raise MacroReturnVal(val) @@ -469,12 +491,16 @@ def create_builtin_globals( is_incremental &= snapshot_table_exists else: is_incremental = False + builtin_globals["is_incremental"] = lambda: is_incremental builtin_globals["builtins"] = AttributeDict( {k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")} ) + if (model := jinja_globals.pop("model", None)) is not None: + builtin_globals["model"] = generate_model(model, jinja_globals.pop("model", "")) + if engine_adapter is not None: builtin_globals["flags"] = Flags(which="run") adapter: BaseAdapter = RuntimeAdapter( diff --git a/sqlmesh/dbt/common.py b/sqlmesh/dbt/common.py index ba982c2bb2..6a04167409 100644 --- a/sqlmesh/dbt/common.py +++ b/sqlmesh/dbt/common.py @@ -19,6 +19,7 @@ T = t.TypeVar("T", bound="GeneralConfig") PROJECT_FILENAME = DBT_PROJECT_FILENAME +DBT_ALL_MODEL_ATTRS = "__DBT_ALL_MODEL_ATTRS__" JINJA_ONLY = { "adapter", diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index ca20554e3b..182945dd94 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -44,8 +44,8 @@ from sqlmesh.core import constants as c from sqlmesh.utils.errors import SQLMeshError from sqlmesh.core.config import ModelDefaultsConfig -from sqlmesh.dbt.basemodel import Dependencies from sqlmesh.dbt.builtin import BUILTIN_FILTERS, BUILTIN_GLOBALS, OVERRIDDEN_MACROS +from sqlmesh.dbt.common import DBT_ALL_MODEL_ATTRS, Dependencies from sqlmesh.dbt.model import ModelConfig from sqlmesh.dbt.package import HookConfig, MacroConfig from sqlmesh.dbt.seed import SeedConfig @@ -354,7 +354,9 @@ def _load_models_and_seeds(self) -> None: dependencies = Dependencies( macros=macro_references, refs=_refs(node), sources=_sources(node) ) - dependencies = dependencies.union(self._extra_dependencies(sql, node.package_name)) + dependencies = dependencies.union( + self._extra_dependencies(sql, node.package_name, track_all_model_attrs=True) + ) dependencies = dependencies.union( self._flatten_dependencies_from_macros(dependencies.macros, node.package_name) ) @@ -552,15 +554,35 @@ def _flatten_dependencies_from_macros( dependencies = dependencies.union(macro_dependencies) return dependencies - def _extra_dependencies(self, target: str, package: str) -> Dependencies: - # We sometimes observe that the manifest doesn't capture all macros, refs, and sources within a macro. - # This behavior has been observed with macros like dbt.current_timestamp(), dbt_utils.slugify(), and source(). - # Here we apply our custom extractor to make a best effort to supplement references captured in the manifest. + def _extra_dependencies( + self, + target: str, + package: str, + track_all_model_attrs: bool = False, + ) -> Dependencies: + """ + We sometimes observe that the manifest doesn't capture all macros, refs, and sources within a macro. + This behavior has been observed with macros like dbt.current_timestamp(), dbt_utils.slugify(), and source(). + Here we apply our custom extractor to make a best effort to supplement references captured in the manifest. + """ dependencies = Dependencies() + + # Whether all `model` attributes (e.g., `model.config`) should be included in the dependencies + all_model_attrs = False + for call_name, node in extract_call_names(target, cache=self._calls): if call_name[0] == "config": continue - elif isinstance(node, jinja2.nodes.Getattr): + + if ( + track_all_model_attrs + and not all_model_attrs + and isinstance(node, jinja2.nodes.Call) + and any(isinstance(a, jinja2.nodes.Name) and a.name == "model" for a in node.args) + ): + all_model_attrs = True + + if isinstance(node, jinja2.nodes.Getattr): if call_name[0] == "model": dependencies.model_attrs.add(call_name[1]) elif call_name[0] == "source": @@ -606,6 +628,14 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies: call_name[0], call_name[1], dependencies.macros.append ) + # When `model` is referenced as-is, e.g. it's passed as an argument to a macro call like + # `{{ foo(model) }}`, we can't easily track the attributes that are actually used, because + # it may be aliased and hence tracking actual uses of `model` requires a proper data flow + # analysis. We conservatively deal with this by including all of its supported attributes + # if a standalone reference is found. + if all_model_attrs: + dependencies.model_attrs = {DBT_ALL_MODEL_ATTRS} + return dependencies diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 196889a87c..c994c3e888 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1534,6 +1534,11 @@ def test_plan_enable_preview_default(sushi_context: Context, sushi_dbt_context: assert sushi_dbt_context._plan_preview_enabled +def test_raw_code_missing_from_model_attributes(sushi_dbt_context: Context): + customers_model = sushi_dbt_context.models['"memory"."sushi"."customers"'] + assert "raw_code" not in customers_model.jinja_macros.global_objs["model"] # type: ignore + + def test_catalog_name_needs_to_be_quoted(): config = Config( model_defaults=ModelDefaultsConfig(dialect="duckdb"), From 66f385c5c93bf2d0b6ad7e42a69303a985193f5f Mon Sep 17 00:00:00 2001 From: George Sittas Date: Wed, 3 Sep 2025 13:35:31 +0300 Subject: [PATCH 2/6] PR feedback: revert sushi_dbt modifications, use tests/fixtures/dbt instead --- examples/sushi_dbt/models/customers.sql | 6 ------ tests/core/test_context.py | 4 ++-- .../dbt/sushi_test}/macros/check_model_is_table.sql | 0 tests/fixtures/dbt/sushi_test/models/simple_model_a.sql | 5 +++++ 4 files changed, 7 insertions(+), 8 deletions(-) rename {examples/sushi_dbt => tests/fixtures/dbt/sushi_test}/macros/check_model_is_table.sql (100%) diff --git a/examples/sushi_dbt/models/customers.sql b/examples/sushi_dbt/models/customers.sql index df602a88fa..ac82126fc7 100644 --- a/examples/sushi_dbt/models/customers.sql +++ b/examples/sushi_dbt/models/customers.sql @@ -1,9 +1,3 @@ -{{ check_model_is_table(model) }} - -{% if 'DISTINCT' in model.raw_code %} - {{ check_model_is_table_alt(model) }} -{% endif %} - SELECT DISTINCT customer_id::INT AS customer_id FROM {{ ref('orders') }} as o diff --git a/tests/core/test_context.py b/tests/core/test_context.py index c994c3e888..c2771efe92 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1534,8 +1534,8 @@ def test_plan_enable_preview_default(sushi_context: Context, sushi_dbt_context: assert sushi_dbt_context._plan_preview_enabled -def test_raw_code_missing_from_model_attributes(sushi_dbt_context: Context): - customers_model = sushi_dbt_context.models['"memory"."sushi"."customers"'] +def test_raw_code_missing_from_model_attributes(sushi_test_dbt_context: Context): + customers_model = sushi_test_dbt_context.models['"memory"."sushi"."simple_model_a"'] assert "raw_code" not in customers_model.jinja_macros.global_objs["model"] # type: ignore diff --git a/examples/sushi_dbt/macros/check_model_is_table.sql b/tests/fixtures/dbt/sushi_test/macros/check_model_is_table.sql similarity index 100% rename from examples/sushi_dbt/macros/check_model_is_table.sql rename to tests/fixtures/dbt/sushi_test/macros/check_model_is_table.sql diff --git a/tests/fixtures/dbt/sushi_test/models/simple_model_a.sql b/tests/fixtures/dbt/sushi_test/models/simple_model_a.sql index e9441e35da..c5a0b90fa6 100644 --- a/tests/fixtures/dbt/sushi_test/models/simple_model_a.sql +++ b/tests/fixtures/dbt/sushi_test/models/simple_model_a.sql @@ -1,2 +1,7 @@ +{{ check_model_is_table(model) }} + +{% if 'SELECT' in model.raw_code %} + {{ check_model_is_table_alt(model) }} +{% endif %} SELECT 1 AS a From 5f1c19c8d95b6e23528f5f5c8c18f9dce54deae6 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Wed, 3 Sep 2025 13:51:16 +0300 Subject: [PATCH 3/6] PR feedback: remove `DBT_ALL_MODEL_ATTRS` placeholder --- sqlmesh/dbt/basemodel.py | 5 ++--- sqlmesh/dbt/common.py | 15 ++++++++++++--- sqlmesh/dbt/manifest.py | 6 +++--- tests/dbt/test_manifest.py | 3 ++- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index d28dbcd22e..4fe2c24d91 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -19,7 +19,6 @@ column_types_to_sqlmesh, ) from sqlmesh.dbt.common import ( - DBT_ALL_MODEL_ATTRS, DbtConfig, Dependencies, GeneralConfig, @@ -379,11 +378,11 @@ def _model_jinja_context( ) -> t.Dict[str, t.Any]: if context._manifest and self.node_name in context._manifest._manifest.nodes: attributes = context._manifest._manifest.nodes[self.node_name].to_dict() - if DBT_ALL_MODEL_ATTRS in dependencies.model_attrs: + if dependencies.model_attrs.all_attrs: model_node: AttributeDict[str, t.Any] = AttributeDict(attributes) else: model_node = AttributeDict( - filter(lambda kv: kv[0] in dependencies.model_attrs, attributes.items()) + filter(lambda kv: kv[0] in dependencies.model_attrs.attrs, attributes.items()) ) raw_code_key = "raw_code" if DBT_VERSION >= (1, 3, 0) else "raw_sql" # type: ignore diff --git a/sqlmesh/dbt/common.py b/sqlmesh/dbt/common.py index 6a04167409..4cb1897562 100644 --- a/sqlmesh/dbt/common.py +++ b/sqlmesh/dbt/common.py @@ -2,6 +2,7 @@ import re import typing as t +from dataclasses import dataclass from pathlib import Path from ruamel.yaml.constructor import DuplicateKeyError @@ -19,7 +20,6 @@ T = t.TypeVar("T", bound="GeneralConfig") PROJECT_FILENAME = DBT_PROJECT_FILENAME -DBT_ALL_MODEL_ATTRS = "__DBT_ALL_MODEL_ATTRS__" JINJA_ONLY = { "adapter", @@ -173,6 +173,12 @@ def sqlmesh_config_fields(self) -> t.Set[str]: return set() +@dataclass +class ModelAttrs: + attrs: t.Set[str] + all_attrs: bool = False + + class Dependencies(PydanticModel): """ DBT dependencies for a model, macro, etc. @@ -187,7 +193,7 @@ class Dependencies(PydanticModel): sources: t.Set[str] = set() refs: t.Set[str] = set() variables: t.Set[str] = set() - model_attrs: t.Set[str] = set() + model_attrs: ModelAttrs = ModelAttrs(attrs=set()) has_dynamic_var_names: bool = False @@ -197,7 +203,10 @@ def union(self, other: Dependencies) -> Dependencies: sources=self.sources | other.sources, refs=self.refs | other.refs, variables=self.variables | other.variables, - model_attrs=self.model_attrs | other.model_attrs, + model_attrs=ModelAttrs( + attrs=self.model_attrs.attrs | other.model_attrs.attrs, + all_attrs=self.model_attrs.all_attrs or other.model_attrs.all_attrs, + ), has_dynamic_var_names=self.has_dynamic_var_names or other.has_dynamic_var_names, ) diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index 182945dd94..67a1ae7ed3 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -45,7 +45,7 @@ from sqlmesh.utils.errors import SQLMeshError from sqlmesh.core.config import ModelDefaultsConfig from sqlmesh.dbt.builtin import BUILTIN_FILTERS, BUILTIN_GLOBALS, OVERRIDDEN_MACROS -from sqlmesh.dbt.common import DBT_ALL_MODEL_ATTRS, Dependencies +from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.model import ModelConfig from sqlmesh.dbt.package import HookConfig, MacroConfig from sqlmesh.dbt.seed import SeedConfig @@ -584,7 +584,7 @@ def _extra_dependencies( if isinstance(node, jinja2.nodes.Getattr): if call_name[0] == "model": - dependencies.model_attrs.add(call_name[1]) + dependencies.model_attrs.attrs.add(call_name[1]) elif call_name[0] == "source": args = [jinja_call_arg_name(arg) for arg in node.args] if args and all(arg for arg in args): @@ -634,7 +634,7 @@ def _extra_dependencies( # analysis. We conservatively deal with this by including all of its supported attributes # if a standalone reference is found. if all_model_attrs: - dependencies.model_attrs = {DBT_ALL_MODEL_ATTRS} + dependencies.model_attrs.all_attrs = True return dependencies diff --git a/tests/dbt/test_manifest.py b/tests/dbt/test_manifest.py index 5f2f6fb37f..ba8971e9b2 100644 --- a/tests/dbt/test_manifest.py +++ b/tests/dbt/test_manifest.py @@ -6,6 +6,7 @@ from sqlmesh.core.config import ModelDefaultsConfig from sqlmesh.dbt.basemodel import Dependencies +from sqlmesh.dbt.common import ModelAttrs from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.manifest import ManifestHelper, _convert_jinja_test_to_macro from sqlmesh.dbt.profile import Profile @@ -33,7 +34,7 @@ def test_manifest_helper(caplog): assert models["top_waiters"].dependencies == Dependencies( refs={"sushi.waiter_revenue_by_day", "waiter_revenue_by_day"}, variables={"top_waiters:revenue", "top_waiters:limit"}, - model_attrs={"columns", "config"}, + model_attrs=ModelAttrs(attrs={"columns", "config"}), macros=[ MacroReference(name="get_top_waiters_limit"), MacroReference(name="ref"), From 3b3f7b0d59c498650219dc05b62c983d1c9d8fd8 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Wed, 3 Sep 2025 14:23:27 +0300 Subject: [PATCH 4/6] PR feedback: simplify `model` builtin global --- sqlmesh/dbt/basemodel.py | 6 ++---- sqlmesh/dbt/builtin.py | 19 +++---------------- sqlmesh/dbt/common.py | 4 +++- 3 files changed, 8 insertions(+), 21 deletions(-) diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 4fe2c24d91..4c2feed7ad 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -22,12 +22,12 @@ DbtConfig, Dependencies, GeneralConfig, + RAW_CODE_KEY, SqlStr, sql_str_validator, ) from sqlmesh.dbt.relation import Policy, RelationType from sqlmesh.dbt.test import TestConfig -from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.utils import AttributeDict from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import field_validator @@ -385,11 +385,9 @@ def _model_jinja_context( filter(lambda kv: kv[0] in dependencies.model_attrs.attrs, attributes.items()) ) - raw_code_key = "raw_code" if DBT_VERSION >= (1, 3, 0) else "raw_sql" # type: ignore - # We exclude the raw SQL code to reduce the payload size. It's still accessible through # the JinjaQuery instance stored in the resulting SQLMesh model's `query` field. - model_node.pop(raw_code_key, None) + model_node.pop(RAW_CODE_KEY, None) else: model_node = AttributeDict({}) diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 6e3cd82113..a64d0f9852 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -18,6 +18,7 @@ from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.snapshot.definition import DeployabilityIndex from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter +from sqlmesh.dbt.common import RAW_CODE_KEY from sqlmesh.dbt.relation import Policy from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS from sqlmesh.dbt.util import DBT_VERSION @@ -308,21 +309,6 @@ def source(package: str, name: str) -> t.Optional[BaseRelation]: return source -def generate_model(model: AttributeDict, raw_code: str) -> Model: - class Model: - def __init__(self, model: AttributeDict) -> None: - self._model = model - self._raw_code_key = "raw_code" if DBT_VERSION >= (1, 3, 0) else "raw_sql" # type: ignore - - def __getattr__(self, key: str) -> t.Any: - if key == self._raw_code_key: - return raw_code - - return getattr(self._model, key) - - return Model(model) - - def return_val(val: t.Any) -> None: raise MacroReturnVal(val) @@ -499,7 +485,8 @@ def create_builtin_globals( ) if (model := jinja_globals.pop("model", None)) is not None: - builtin_globals["model"] = generate_model(model, jinja_globals.pop("model", "")) + raw_code = jinja_globals.pop("raw_code", "") + builtin_globals["model"] = AttributeDict({**model, RAW_CODE_KEY: raw_code}) if engine_adapter is not None: builtin_globals["flags"] = Flags(which="run") diff --git a/sqlmesh/dbt/common.py b/sqlmesh/dbt/common.py index 4cb1897562..240d59084a 100644 --- a/sqlmesh/dbt/common.py +++ b/sqlmesh/dbt/common.py @@ -8,18 +8,20 @@ from ruamel.yaml.constructor import DuplicateKeyError from sqlglot.helper import ensure_list +from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.core.config.base import BaseConfig, UpdateStrategy +from sqlmesh.core.config.common import DBT_PROJECT_FILENAME from sqlmesh.utils import AttributeDict from sqlmesh.utils.conversions import ensure_bool, try_str_to_bool from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.jinja import MacroReference from sqlmesh.utils.pydantic import PydanticModel, field_validator from sqlmesh.utils.yaml import load -from sqlmesh.core.config.common import DBT_PROJECT_FILENAME T = t.TypeVar("T", bound="GeneralConfig") PROJECT_FILENAME = DBT_PROJECT_FILENAME +RAW_CODE_KEY = "raw_code" if DBT_VERSION >= (1, 3, 0) else "raw_sql" # type: ignore JINJA_ONLY = { "adapter", From 8919c487ac662314c7d1ca55468293586267b1d1 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Wed, 3 Sep 2025 22:34:03 +0300 Subject: [PATCH 5/6] Include config in dbt query payload --- sqlmesh/core/model/definition.py | 11 +++++ sqlmesh/core/renderer.py | 4 +- sqlmesh/dbt/basemodel.py | 8 ---- sqlmesh/dbt/loader.py | 2 +- sqlmesh/dbt/model.py | 27 +---------- sqlmesh/dbt/test.py | 5 +- .../v0095_warn_about_dbt_raw_sql_diff.py | 47 +++++++++++++++++++ tests/core/test_context.py | 21 +++++++-- tests/dbt/test_config.py | 37 --------------- .../sushi_test/models/model_with_raw_code.sql | 11 +++++ .../dbt/sushi_test/models/simple_model_a.sql | 5 -- 11 files changed, 94 insertions(+), 84 deletions(-) create mode 100644 sqlmesh/migrations/v0095_warn_about_dbt_raw_sql_diff.py create mode 100644 tests/fixtures/dbt/sushi_test/models/model_with_raw_code.sql diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index f3ffcde05a..733fd1530a 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -595,6 +595,7 @@ def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer: only_execution_time=False, default_catalog=self.default_catalog, model_fqn=self.fqn, + raw_code=self._raw_code, ) return self._statement_renderer_cache[expression_key] @@ -1305,6 +1306,10 @@ def _is_time_column_in_partitioned_by(self) -> bool: def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: return {} + @property + def _raw_code(self) -> t.Optional[str]: + return None + class SqlModel(_Model): """The model definition which relies on a SQL query to fetch the data. @@ -1581,6 +1586,7 @@ def _query_renderer(self) -> QueryRenderer: default_catalog=self.default_catalog, quote_identifiers=not no_quote_identifiers, optimize_query=self.optimize_query, + raw_code=self._raw_code, ) @property @@ -1606,6 +1612,11 @@ def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: self.render_query() return self._query_renderer._violated_rules + @property + def _raw_code(self) -> t.Optional[str]: + query = self.query + return query.name if isinstance(query, d.JinjaQuery) else None + class SeedModel(_Model): """The model definition which uses a pre-built static dataset to source the data from. diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index a1c40b31b6..86f0ef9890 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -53,6 +53,7 @@ def __init__( model_fqn: t.Optional[str] = None, normalize_identifiers: bool = True, optimize_query: t.Optional[bool] = True, + raw_code: t.Optional[str] = None, ): self._expression = expression self._dialect = dialect @@ -68,6 +69,7 @@ def __init__( self._cache: t.List[t.Optional[exp.Expression]] = [] self._model_fqn = model_fqn self._optimize_query_flag = optimize_query is not False + self._raw_code = raw_code def update_schema(self, schema: t.Dict[str, t.Any]) -> None: self.schema = d.normalize_mapping_schema(schema, dialect=self._dialect) @@ -204,7 +206,7 @@ def _resolve_table(table: str | exp.Table) -> str: "default_catalog": self._default_catalog, "runtime_stage": runtime_stage.value, "resolve_table": _resolve_table, - "raw_code": self._expression.name, + "raw_code": self._raw_code, } if this_model: diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 4c2feed7ad..548718cf89 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -168,14 +168,6 @@ def _validate_grants(cls, v: t.Dict[str, str]) -> t.Dict[str, t.List[str]]: }, } - @property - def sql_no_config(self) -> SqlStr: - return SqlStr("") - - @property - def sql_embedded_config(self) -> SqlStr: - return SqlStr("") - @property def table_schema(self) -> str: """ diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index 4f473a20ab..695aff3c45 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -138,7 +138,7 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model: package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds} for model in package_models.values(): - if isinstance(model, ModelConfig) and not model.sql_no_config: + if isinstance(model, ModelConfig) and not model.sql.strip(): logger.info(f"Skipping empty model '{model.name}' at path '{model.path}'.") continue diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index a941eb8880..20d0f8cd1a 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -27,7 +27,7 @@ ) from sqlmesh.core.model.kind import SCDType2ByTimeKind, OnDestructiveChange, OnAdditiveChange from sqlmesh.dbt.basemodel import BaseModelConfig, Materialization, SnapshotStrategy -from sqlmesh.dbt.common import SqlStr, extract_jinja_config, sql_str_validator +from sqlmesh.dbt.common import SqlStr, sql_str_validator from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import field_validator @@ -138,10 +138,6 @@ class ModelConfig(BaseModelConfig): inserts_only: t.Optional[bool] = None incremental_predicates: t.Optional[t.List[str]] = None - # Private fields - _sql_embedded_config: t.Optional[SqlStr] = None - _sql_no_config: t.Optional[SqlStr] = None - _sql_validator = sql_str_validator @field_validator( @@ -432,25 +428,6 @@ def model_kind(self, context: DbtContext) -> ModelKind: raise ConfigError(f"{materialization.value} materialization not supported.") - @property - def sql_no_config(self) -> SqlStr: - if self._sql_no_config is None: - self._sql_no_config = SqlStr("") - self._extract_sql_config() - return self._sql_no_config - - @property - def sql_embedded_config(self) -> SqlStr: - if self._sql_embedded_config is None: - self._sql_embedded_config = SqlStr("") - self._extract_sql_config() - return self._sql_embedded_config - - def _extract_sql_config(self) -> None: - no_config, embedded_config = extract_jinja_config(self.sql) - self._sql_no_config = SqlStr(no_config) - self._sql_embedded_config = SqlStr(embedded_config) - def _big_query_partition_by_expr(self, context: DbtContext) -> exp.Expression: assert isinstance(self.partition_by, dict) data_type = self.partition_by["data_type"].lower() @@ -508,7 +485,7 @@ def to_sqlmesh( ) -> Model: """Converts the dbt model into a SQLMesh model.""" model_dialect = self.dialect(context) - query = d.jinja_query(self.sql_no_config) + query = d.jinja_query(self.sql) kind = self.model_kind(context) optional_kwargs: t.Dict[str, t.Any] = {} diff --git a/sqlmesh/dbt/test.py b/sqlmesh/dbt/test.py index 035c62acda..b5eec21623 100644 --- a/sqlmesh/dbt/test.py +++ b/sqlmesh/dbt/test.py @@ -12,7 +12,6 @@ Dependencies, GeneralConfig, SqlStr, - extract_jinja_config, sql_str_validator, ) from sqlmesh.utils import AttributeDict @@ -134,9 +133,7 @@ def to_sqlmesh(self, context: DbtContext) -> Audit: } ) - sql_no_config, _sql_config_only = extract_jinja_config(self.sql) - sql_no_config = sql_no_config.replace("**_dbt_generic_test_kwargs", self._kwargs()) - query = d.jinja_query(sql_no_config) + query = d.jinja_query(self.sql.replace("**_dbt_generic_test_kwargs", self._kwargs())) skip = not self.enabled blocking = self.severity == Severity.ERROR diff --git a/sqlmesh/migrations/v0095_warn_about_dbt_raw_sql_diff.py b/sqlmesh/migrations/v0095_warn_about_dbt_raw_sql_diff.py new file mode 100644 index 0000000000..be005de117 --- /dev/null +++ b/sqlmesh/migrations/v0095_warn_about_dbt_raw_sql_diff.py @@ -0,0 +1,47 @@ +""" +Warns dbt users about potential diffs due to inclusion of {{ config(...) }} blocks in model SQL. + +Prior to this fix, SQLMesh wasn't including the {{ config(...) }} block in the model's SQL payload +when processing dbt models. Now these config blocks are properly included in the raw SQL, which +may cause diffs to appear for existing dbt models even though the actual SQL logic hasn't changed. + +This is a one-time diff that will appear after upgrading, and applying a plan will resolve it. +""" + +import json + +from sqlglot import exp + +from sqlmesh.core.console import get_console + +SQLMESH_DBT_PACKAGE = "sqlmesh.dbt" + + +def migrate(state_sync, **kwargs): # type: ignore + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + warning = ( + "SQLMesh now includes dbt's {{ config(...) }} blocks in the model's raw SQL when " + "processing dbt models. This change ensures that all model attributes referenced " + "in macros are properly tracked for fingerprinting. As a result, you may see diffs " + "for existing dbt models even though the actual SQL logic hasn't changed. This is " + "a one-time diff that will be resolved after applying a plan. Run 'sqlmesh diff prod' " + "to review any changes, then apply a plan if the diffs look expected." + ) + + for (snapshot,) in engine_adapter.fetchall( + exp.select("snapshot").from_(snapshots_table), quote_identifiers=True + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + + jinja_macros = node.get("jinja_macros") or {} + create_builtins_module = jinja_macros.get("create_builtins_module") or "" + + if create_builtins_module == SQLMESH_DBT_PACKAGE: + get_console().log_warning(warning) + return diff --git a/tests/core/test_context.py b/tests/core/test_context.py index c2771efe92..f09ff25c33 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1534,9 +1534,24 @@ def test_plan_enable_preview_default(sushi_context: Context, sushi_dbt_context: assert sushi_dbt_context._plan_preview_enabled -def test_raw_code_missing_from_model_attributes(sushi_test_dbt_context: Context): - customers_model = sushi_test_dbt_context.models['"memory"."sushi"."simple_model_a"'] - assert "raw_code" not in customers_model.jinja_macros.global_objs["model"] # type: ignore +@pytest.mark.slow +def test_raw_code_handling(sushi_test_dbt_context: Context): + model = sushi_test_dbt_context.models['"memory"."sushi"."model_with_raw_code"'] + assert "raw_code" not in model.jinja_macros.global_objs["model"] # type: ignore + + # logging "pre-hook" (in dbt_projects.yml) + the actual pre-hook in the model file + assert len(model.pre_statements) == 2 + + original_file_path = model.jinja_macros.global_objs["model"]["original_file_path"] # type: ignore + model_file_path = sushi_test_dbt_context.path / original_file_path + + raw_code_length = len(model_file_path.read_text()) - 1 + + hook = model.render_pre_statements()[0] + assert ( + hook.sql() + == f'''CREATE TABLE "t" AS SELECT 'Length is {raw_code_length}' AS "length_col"''' + ) def test_catalog_name_needs_to_be_quoted(): diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index ae8713d933..ecd95a43c4 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -276,42 +276,6 @@ def test_singular_test_to_standalone_audit(dbt_dummy_postgres_config: PostgresCo assert standalone_audit.dialect == "bigquery" -def test_model_config_sql_no_config(): - assert ( - ModelConfig( - sql="""{{ - config( - materialized='table', - incremental_strategy='delete+"insert' - ) -}} -query""" - ).sql_no_config.strip() - == "query" - ) - - assert ( - ModelConfig( - sql="""{{ - config( - materialized='table', - incremental_strategy='delete+insert', - post_hook=" '{{ var('new') }}' " - ) -}} -query""" - ).sql_no_config.strip() - == "query" - ) - - assert ( - ModelConfig( - sql="""before {{config(materialized='table', post_hook=" {{ var('new') }} ")}} after""" - ).sql_no_config.strip() - == "before after" - ) - - @pytest.mark.slow def test_variables(assert_exp_eq, sushi_test_project): # Case 1: using an undefined variable without a default value @@ -350,7 +314,6 @@ def test_variables(assert_exp_eq, sushi_test_project): # Case 3: using a defined variable with a default value model_config.sql = "SELECT {{ var('foo', 5) }}" - model_config._sql_no_config = None assert_exp_eq(model_config.to_sqlmesh(**kwargs).render_query(), 'SELECT 6 AS "6"') diff --git a/tests/fixtures/dbt/sushi_test/models/model_with_raw_code.sql b/tests/fixtures/dbt/sushi_test/models/model_with_raw_code.sql new file mode 100644 index 0000000000..386e7f40ef --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/models/model_with_raw_code.sql @@ -0,0 +1,11 @@ +{{ + config( + pre_hook=['CREATE TABLE t AS SELECT \'Length is {{ model.raw_code|length }}\' AS length_col'] + ) +}} + +{{ check_model_is_table(model) }} +{{ check_model_is_table_alt(model) }} + +SELECT + 1 AS c diff --git a/tests/fixtures/dbt/sushi_test/models/simple_model_a.sql b/tests/fixtures/dbt/sushi_test/models/simple_model_a.sql index c5a0b90fa6..e9441e35da 100644 --- a/tests/fixtures/dbt/sushi_test/models/simple_model_a.sql +++ b/tests/fixtures/dbt/sushi_test/models/simple_model_a.sql @@ -1,7 +1,2 @@ -{{ check_model_is_table(model) }} - -{% if 'SELECT' in model.raw_code %} - {{ check_model_is_table_alt(model) }} -{% endif %} SELECT 1 AS a From 0c1829946f00743219e292b87abd0040ae39f728 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Thu, 4 Sep 2025 03:02:35 +0300 Subject: [PATCH 6/6] PR feedback --- sqlmesh/core/model/definition.py | 15 ++------------- sqlmesh/core/renderer.py | 10 +++++----- sqlmesh/dbt/builtin.py | 16 +++++++--------- .../v0095_warn_about_dbt_raw_sql_diff.py | 12 ++++++------ 4 files changed, 20 insertions(+), 33 deletions(-) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 733fd1530a..b772a9d9d9 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -594,8 +594,7 @@ def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer: python_env=self.python_env, only_execution_time=False, default_catalog=self.default_catalog, - model_fqn=self.fqn, - raw_code=self._raw_code, + model=self, ) return self._statement_renderer_cache[expression_key] @@ -1306,10 +1305,6 @@ def _is_time_column_in_partitioned_by(self) -> bool: def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: return {} - @property - def _raw_code(self) -> t.Optional[str]: - return None - class SqlModel(_Model): """The model definition which relies on a SQL query to fetch the data. @@ -1578,7 +1573,6 @@ def _query_renderer(self) -> QueryRenderer: self.dialect, self.macro_definitions, schema=self.mapping_schema, - model_fqn=self.fqn, path=self._path, jinja_macro_registry=self.jinja_macros, python_env=self.python_env, @@ -1586,7 +1580,7 @@ def _query_renderer(self) -> QueryRenderer: default_catalog=self.default_catalog, quote_identifiers=not no_quote_identifiers, optimize_query=self.optimize_query, - raw_code=self._raw_code, + model=self, ) @property @@ -1612,11 +1606,6 @@ def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: self.render_query() return self._query_renderer._violated_rules - @property - def _raw_code(self) -> t.Optional[str]: - query = self.query - return query.name if isinstance(query, d.JinjaQuery) else None - class SeedModel(_Model): """The model definition which uses a pre-built static dataset to source the data from. diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 86f0ef9890..49144bf55c 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -31,6 +31,7 @@ from sqlglot.dialects.dialect import DialectType from sqlmesh.core.linter.rule import Rule + from sqlmesh.core.model.definition import _Model from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot @@ -50,10 +51,9 @@ def __init__( schema: t.Optional[t.Dict[str, t.Any]] = None, default_catalog: t.Optional[str] = None, quote_identifiers: bool = True, - model_fqn: t.Optional[str] = None, normalize_identifiers: bool = True, optimize_query: t.Optional[bool] = True, - raw_code: t.Optional[str] = None, + model: t.Optional[_Model] = None, ): self._expression = expression self._dialect = dialect @@ -67,9 +67,9 @@ def __init__( self._quote_identifiers = quote_identifiers self.update_schema({} if schema is None else schema) self._cache: t.List[t.Optional[exp.Expression]] = [] - self._model_fqn = model_fqn + self._model_fqn = model.fqn if model else None self._optimize_query_flag = optimize_query is not False - self._raw_code = raw_code + self._model = model def update_schema(self, schema: t.Dict[str, t.Any]) -> None: self.schema = d.normalize_mapping_schema(schema, dialect=self._dialect) @@ -206,7 +206,7 @@ def _resolve_table(table: str | exp.Table) -> str: "default_catalog": self._default_catalog, "runtime_stage": runtime_stage.value, "resolve_table": _resolve_table, - "raw_code": self._raw_code, + "model_instance": self._model, } if this_model: diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index a64d0f9852..0a2d837c28 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -16,6 +16,7 @@ from sqlmesh.core.console import get_console from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.model.definition import SqlModel from sqlmesh.core.snapshot.definition import DeployabilityIndex from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter from sqlmesh.dbt.common import RAW_CODE_KEY @@ -27,13 +28,6 @@ from sqlmesh.utils.errors import ConfigError, MacroEvalError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReference, MacroReturnVal -if t.TYPE_CHECKING: - from typing import Protocol - - class Model(Protocol): - def __getattr__(self, key: str) -> t.Any: ... - - logger = logging.getLogger(__name__) @@ -485,8 +479,12 @@ def create_builtin_globals( ) if (model := jinja_globals.pop("model", None)) is not None: - raw_code = jinja_globals.pop("raw_code", "") - builtin_globals["model"] = AttributeDict({**model, RAW_CODE_KEY: raw_code}) + if isinstance(model_instance := jinja_globals.pop("model_instance", None), SqlModel): + builtin_globals["model"] = AttributeDict( + {**model, RAW_CODE_KEY: model_instance.query.name} + ) + else: + builtin_globals["model"] = AttributeDict(model.copy()) if engine_adapter is not None: builtin_globals["flags"] = Flags(which="run") diff --git a/sqlmesh/migrations/v0095_warn_about_dbt_raw_sql_diff.py b/sqlmesh/migrations/v0095_warn_about_dbt_raw_sql_diff.py index be005de117..ce39946b0d 100644 --- a/sqlmesh/migrations/v0095_warn_about_dbt_raw_sql_diff.py +++ b/sqlmesh/migrations/v0095_warn_about_dbt_raw_sql_diff.py @@ -25,12 +25,12 @@ def migrate(state_sync, **kwargs): # type: ignore snapshots_table = f"{schema}.{snapshots_table}" warning = ( - "SQLMesh now includes dbt's {{ config(...) }} blocks in the model's raw SQL when " - "processing dbt models. This change ensures that all model attributes referenced " - "in macros are properly tracked for fingerprinting. As a result, you may see diffs " - "for existing dbt models even though the actual SQL logic hasn't changed. This is " - "a one-time diff that will be resolved after applying a plan. Run 'sqlmesh diff prod' " - "to review any changes, then apply a plan if the diffs look expected." + "SQLMesh detected that it may not be able to fully migrate the state database. This should not impact " + "the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` " + "command. Please run `sqlmesh diff prod` after the migration has completed, before making any new " + "changes. If any unexpected changes are reported, consider running a forward-only plan to apply these " + "changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. " + "See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n" ) for (snapshot,) in engine_adapter.fetchall(