Skip to content

Commit fd4ec21

Browse files
committed
Fix!: depend on all attributes of dbt model when passed to a macro
1 parent 61a65ac commit fd4ec21

File tree

8 files changed

+134
-36
lines changed

8 files changed

+134
-36
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{%- macro check_model_is_table(model) -%}
2+
{%- if model.config.materialized != 'table' -%}
3+
{%- do exceptions.raise_compiler_error(
4+
"Model must use the table materialization. Please check any model overrides."
5+
) -%}
6+
{%- endif -%}
7+
{%- endmacro -%}
8+
9+
{%- macro check_model_is_table_alt(foo) -%}
10+
{%- if foo.config.materialized != 'table' -%}
11+
{%- do exceptions.raise_compiler_error(
12+
"Model must use the table materialization. Please check any model overrides."
13+
) -%}
14+
{%- endif -%}
15+
{%- endmacro -%}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
{{ check_model_is_table(model) }}
2+
3+
{% if 'DISTINCT' in model.raw_code %}
4+
{{ check_model_is_table_alt(model) }}
5+
{% endif %}
6+
17
SELECT DISTINCT
28
customer_id::INT AS customer_id
39
FROM {{ ref('orders') }} as o

sqlmesh/core/renderer.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -188,30 +188,32 @@ def _resolve_table(table: str | exp.Table) -> str:
188188
}
189189

190190
variables = kwargs.pop("variables", {})
191-
jinja_env_kwargs = {
192-
**{
193-
**render_kwargs,
194-
**_prepare_python_env_for_jinja(macro_evaluator, self._python_env),
195-
**variables,
196-
},
197-
"snapshots": snapshots or {},
198-
"table_mapping": table_mapping,
199-
"deployability_index": deployability_index,
200-
"default_catalog": self._default_catalog,
201-
"runtime_stage": runtime_stage.value,
202-
"resolve_table": _resolve_table,
203-
}
204-
if this_model:
205-
render_kwargs["this_model"] = this_model
206-
jinja_env_kwargs["this_model"] = this_model.sql(
207-
dialect=self._dialect, identify=True, comments=False
208-
)
209-
210-
jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs)
211191

212192
expressions = [self._expression]
213193
if isinstance(self._expression, d.Jinja):
214194
try:
195+
jinja_env_kwargs = {
196+
**{
197+
**render_kwargs,
198+
**_prepare_python_env_for_jinja(macro_evaluator, self._python_env),
199+
**variables,
200+
},
201+
"snapshots": snapshots or {},
202+
"table_mapping": table_mapping,
203+
"deployability_index": deployability_index,
204+
"default_catalog": self._default_catalog,
205+
"runtime_stage": runtime_stage.value,
206+
"resolve_table": _resolve_table,
207+
"raw_code": self._expression.name,
208+
}
209+
210+
if this_model:
211+
jinja_env_kwargs["this_model"] = this_model.sql(
212+
dialect=self._dialect, identify=True, comments=False
213+
)
214+
215+
jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs)
216+
215217
expressions = []
216218
rendered_expression = jinja_env.from_string(self._expression.name).render()
217219
logger.debug(
@@ -229,6 +231,9 @@ def _resolve_table(table: str | exp.Table) -> str:
229231
f"Could not render or parse jinja at '{self._path}'.\n{ex}"
230232
) from ex
231233

234+
if this_model:
235+
render_kwargs["this_model"] = this_model
236+
232237
macro_evaluator.locals.update(render_kwargs)
233238

234239
if variables:

sqlmesh/dbt/basemodel.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
column_types_to_sqlmesh,
2020
)
2121
from sqlmesh.dbt.common import (
22+
DBT_ALL_MODEL_ATTRS,
2223
DbtConfig,
2324
Dependencies,
2425
GeneralConfig,
@@ -27,6 +28,7 @@
2728
)
2829
from sqlmesh.dbt.relation import Policy, RelationType
2930
from sqlmesh.dbt.test import TestConfig
31+
from sqlmesh.dbt.util import DBT_VERSION
3032
from sqlmesh.utils import AttributeDict
3133
from sqlmesh.utils.errors import ConfigError
3234
from sqlmesh.utils.pydantic import field_validator
@@ -375,15 +377,23 @@ def to_sqlmesh(
375377
def _model_jinja_context(
376378
self, context: DbtContext, dependencies: Dependencies
377379
) -> t.Dict[str, t.Any]:
378-
model_node: AttributeDict[str, t.Any] = AttributeDict(
379-
{
380-
k: v
381-
for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items()
382-
if k in dependencies.model_attrs
383-
}
384-
if context._manifest and self.node_name in context._manifest._manifest.nodes
385-
else {}
386-
)
380+
if context._manifest and self.node_name in context._manifest._manifest.nodes:
381+
attributes = context._manifest._manifest.nodes[self.node_name].to_dict()
382+
if DBT_ALL_MODEL_ATTRS in dependencies.model_attrs:
383+
model_node: AttributeDict[str, t.Any] = AttributeDict(attributes)
384+
else:
385+
model_node = AttributeDict(
386+
filter(lambda kv: kv[0] in dependencies.model_attrs, attributes.items())
387+
)
388+
389+
raw_code_key = "raw_code" if DBT_VERSION >= (1, 3, 0) else "raw_sql" # type: ignore
390+
391+
# We exclude the raw SQL code to reduce the payload size. It's still accessible through
392+
# the JinjaQuery instance stored in the resulting SQLMesh model's `query` field.
393+
model_node.pop(raw_code_key, None)
394+
else:
395+
model_node = AttributeDict({})
396+
387397
return {
388398
"this": self.relation_info,
389399
"model": model_node,

sqlmesh/dbt/builtin.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
from sqlmesh.utils.errors import ConfigError, MacroEvalError
2727
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReference, MacroReturnVal
2828

29+
if t.TYPE_CHECKING:
30+
from typing import Protocol
31+
32+
class Model(Protocol):
33+
def __getattr__(self, key: str) -> t.Any: ...
34+
35+
2936
logger = logging.getLogger(__name__)
3037

3138

@@ -249,6 +256,21 @@ def source(package: str, name: str) -> t.Optional[BaseRelation]:
249256
return source
250257

251258

259+
def generate_model(model: AttributeDict, raw_code: str) -> Model:
260+
class Model:
261+
def __init__(self, model: AttributeDict) -> None:
262+
self._model = model
263+
self._raw_code_key = "raw_code" if DBT_VERSION >= (1, 3, 0) else "raw_sql" # type: ignore
264+
265+
def __getattr__(self, key: str) -> t.Any:
266+
if key == self._raw_code_key:
267+
return raw_code
268+
269+
return getattr(self._model, key)
270+
271+
return Model(model)
272+
273+
252274
def return_val(val: t.Any) -> None:
253275
raise MacroReturnVal(val)
254276

@@ -415,12 +437,16 @@ def create_builtin_globals(
415437
is_incremental &= snapshot_table_exists
416438
else:
417439
is_incremental = False
440+
418441
builtin_globals["is_incremental"] = lambda: is_incremental
419442

420443
builtin_globals["builtins"] = AttributeDict(
421444
{k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")}
422445
)
423446

447+
if (model := jinja_globals.pop("model", None)) is not None:
448+
builtin_globals["model"] = generate_model(model, jinja_globals.pop("model", ""))
449+
424450
if engine_adapter is not None:
425451
builtin_globals["flags"] = Flags(which="run")
426452
adapter: BaseAdapter = RuntimeAdapter(

sqlmesh/dbt/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
T = t.TypeVar("T", bound="GeneralConfig")
2020

2121
PROJECT_FILENAME = DBT_PROJECT_FILENAME
22+
DBT_ALL_MODEL_ATTRS = "__DBT_ALL_MODEL_ATTRS__"
2223

2324
JINJA_ONLY = {
2425
"adapter",

sqlmesh/dbt/manifest.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
from sqlmesh.core import constants as c
4545
from sqlmesh.utils.errors import SQLMeshError
4646
from sqlmesh.core.config import ModelDefaultsConfig
47-
from sqlmesh.dbt.basemodel import Dependencies
4847
from sqlmesh.dbt.builtin import BUILTIN_FILTERS, BUILTIN_GLOBALS, OVERRIDDEN_MACROS
48+
from sqlmesh.dbt.common import DBT_ALL_MODEL_ATTRS, Dependencies
4949
from sqlmesh.dbt.model import ModelConfig
5050
from sqlmesh.dbt.package import HookConfig, MacroConfig
5151
from sqlmesh.dbt.seed import SeedConfig
@@ -354,7 +354,9 @@ def _load_models_and_seeds(self) -> None:
354354
dependencies = Dependencies(
355355
macros=macro_references, refs=_refs(node), sources=_sources(node)
356356
)
357-
dependencies = dependencies.union(self._extra_dependencies(sql, node.package_name))
357+
dependencies = dependencies.union(
358+
self._extra_dependencies(sql, node.package_name, track_all_model_attrs=True)
359+
)
358360
dependencies = dependencies.union(
359361
self._flatten_dependencies_from_macros(dependencies.macros, node.package_name)
360362
)
@@ -548,15 +550,35 @@ def _flatten_dependencies_from_macros(
548550
dependencies = dependencies.union(macro_dependencies)
549551
return dependencies
550552

551-
def _extra_dependencies(self, target: str, package: str) -> Dependencies:
552-
# We sometimes observe that the manifest doesn't capture all macros, refs, and sources within a macro.
553-
# This behavior has been observed with macros like dbt.current_timestamp(), dbt_utils.slugify(), and source().
554-
# Here we apply our custom extractor to make a best effort to supplement references captured in the manifest.
553+
def _extra_dependencies(
554+
self,
555+
target: str,
556+
package: str,
557+
track_all_model_attrs: bool = False,
558+
) -> Dependencies:
559+
"""
560+
We sometimes observe that the manifest doesn't capture all macros, refs, and sources within a macro.
561+
This behavior has been observed with macros like dbt.current_timestamp(), dbt_utils.slugify(), and source().
562+
Here we apply our custom extractor to make a best effort to supplement references captured in the manifest.
563+
"""
555564
dependencies = Dependencies()
565+
566+
# Whether all `model` attributes (e.g., `model.config`) should be included in the dependencies
567+
all_model_attrs = False
568+
556569
for call_name, node in extract_call_names(target, cache=self._calls):
557570
if call_name[0] == "config":
558571
continue
559-
elif isinstance(node, jinja2.nodes.Getattr):
572+
573+
if (
574+
track_all_model_attrs
575+
and not all_model_attrs
576+
and isinstance(node, jinja2.nodes.Call)
577+
and any(isinstance(a, jinja2.nodes.Name) and a.name == "model" for a in node.args)
578+
):
579+
all_model_attrs = True
580+
581+
if isinstance(node, jinja2.nodes.Getattr):
560582
if call_name[0] == "model":
561583
dependencies.model_attrs.add(call_name[1])
562584
elif call_name[0] == "source":
@@ -602,6 +624,14 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies:
602624
call_name[0], call_name[1], dependencies.macros.append
603625
)
604626

627+
# When `model` is referenced as-is, e.g. it's passed as an argument to a macro call like
628+
# `{{ foo(model) }}`, we can't easily track the attributes that are actually used, because
629+
# it may be aliased and hence tracking actual uses of `model` requires a proper data flow
630+
# analysis. We conservatively deal with this by including all of its supported attributes
631+
# if a standalone reference is found.
632+
if all_model_attrs:
633+
dependencies.model_attrs = {DBT_ALL_MODEL_ATTRS}
634+
605635
return dependencies
606636

607637

tests/core/test_context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,11 @@ def test_plan_enable_preview_default(sushi_context: Context, sushi_dbt_context:
15331533
assert sushi_dbt_context._plan_preview_enabled
15341534

15351535

1536+
def test_raw_code_missing_from_model_attributes(sushi_dbt_context: Context):
1537+
customers_model = sushi_dbt_context.models['"memory"."sushi"."customers"']
1538+
assert "raw_code" not in customers_model.jinja_macros.global_objs["model"]
1539+
1540+
15361541
def test_catalog_name_needs_to_be_quoted():
15371542
config = Config(
15381543
model_defaults=ModelDefaultsConfig(dialect="duckdb"),

0 commit comments

Comments
 (0)