Skip to content

Commit 9450d29

Browse files
committed
Fix: Improve tracking of var dependencies in dbt models
1 parent 4d8e831 commit 9450d29

File tree

13 files changed

+238
-32
lines changed

13 files changed

+238
-32
lines changed

sqlmesh/dbt/adapter.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,56 @@ def _raise_parsetime_adapter_call_error(action: str) -> None:
233233
raise ParsetimeAdapterCallError(f"Can't {action} at parse time.")
234234

235235

236+
class StubParsetimeAdapter(BaseAdapter):
237+
"""Same as ParsetimeAdapter, but returns stub / empty values instead of raising an error."""
238+
239+
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:
240+
return None
241+
242+
def load_relation(self, relation: BaseRelation) -> t.Optional[BaseRelation]:
243+
return None
244+
245+
def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseRelation]:
246+
return []
247+
248+
def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.List[BaseRelation]:
249+
return []
250+
251+
def get_columns_in_relation(self, relation: BaseRelation) -> t.List[Column]:
252+
return []
253+
254+
def get_missing_columns(
255+
self, from_relation: BaseRelation, to_relation: BaseRelation
256+
) -> t.List[Column]:
257+
return []
258+
259+
def create_schema(self, relation: BaseRelation) -> None:
260+
pass
261+
262+
def drop_schema(self, relation: BaseRelation) -> None:
263+
pass
264+
265+
def drop_relation(self, relation: BaseRelation) -> None:
266+
pass
267+
268+
def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None:
269+
pass
270+
271+
def execute(
272+
self, sql: str, auto_begin: bool = False, fetch: bool = False
273+
) -> t.Tuple[AdapterResponse, agate.Table]:
274+
from dbt.adapters.base.impl import AdapterResponse
275+
from sqlmesh.dbt.util import empty_table
276+
277+
return AdapterResponse(""), empty_table()
278+
279+
def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]:
280+
return relation.schema
281+
282+
def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]:
283+
return relation.identifier
284+
285+
236286
class RuntimeAdapter(BaseAdapter):
237287
def __init__(
238288
self,

sqlmesh/dbt/basemodel.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -293,36 +293,23 @@ def sqlmesh_model_kwargs(
293293
self,
294294
context: DbtContext,
295295
column_types_override: t.Optional[t.Dict[str, ColumnConfig]] = None,
296+
extra_dependencies: t.Optional[Dependencies] = None,
296297
) -> t.Dict[str, t.Any]:
297298
"""Get common sqlmesh model parameters"""
298299
self.check_for_circular_test_refs(context)
300+
301+
dependencies = self.dependencies
302+
if extra_dependencies:
303+
dependencies = dependencies.union(extra_dependencies)
304+
299305
model_dialect = self.dialect(context)
300306
model_context = context.context_for_dependencies(
301-
self.dependencies.union(self.tests_ref_source_dependencies)
307+
dependencies.union(self.tests_ref_source_dependencies)
302308
)
303309
jinja_macros = model_context.jinja_macros.trim(
304-
self.dependencies.macros, package=self.package_name
305-
)
306-
307-
model_node: AttributeDict[str, t.Any] = AttributeDict(
308-
{
309-
k: v
310-
for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items()
311-
if k in self.dependencies.model_attrs
312-
}
313-
if context._manifest and self.node_name in context._manifest._manifest.nodes
314-
else {}
315-
)
316-
317-
jinja_macros.add_globals(
318-
{
319-
"this": self.relation_info,
320-
"model": model_node,
321-
"schema": self.table_schema,
322-
"config": self.config_attribute_dict,
323-
**model_context.jinja_globals, # type: ignore
324-
}
310+
dependencies.macros, package=self.package_name
325311
)
312+
jinja_macros.add_globals(self._model_jinja_context(model_context, dependencies))
326313
return {
327314
"audits": [(test.name, {}) for test in self.tests],
328315
"columns": column_types_to_sqlmesh(
@@ -352,3 +339,28 @@ def to_sqlmesh(
352339
virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default,
353340
) -> Model:
354341
"""Convert DBT model into sqlmesh Model"""
342+
343+
def _model_jinja_context(
344+
self, context: DbtContext, dependencies: Dependencies
345+
) -> t.Dict[str, t.Any]:
346+
model_node: AttributeDict[str, t.Any] = AttributeDict(
347+
{
348+
k: v
349+
for k, v in context._manifest._manifest.nodes[self.node_name].to_dict().items()
350+
if k in dependencies.model_attrs
351+
}
352+
if context._manifest and self.node_name in context._manifest._manifest.nodes
353+
else {}
354+
)
355+
return {
356+
"this": self.relation_info,
357+
"model": model_node,
358+
"schema": self.table_schema,
359+
"config": self.config_attribute_dict,
360+
**context.jinja_globals,
361+
}
362+
363+
def _track_dependencies_on_render(self, input: str, context: DbtContext) -> Dependencies:
364+
return context.track_dependencies_on_render(
365+
input, self._model_jinja_context(context, self.dependencies), self.package_name
366+
)

sqlmesh/dbt/builtin.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sqlmesh.core.console import get_console
1818
from sqlmesh.core.engine_adapter import EngineAdapter
1919
from sqlmesh.core.snapshot.definition import DeployabilityIndex
20-
from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter
20+
from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter, StubParsetimeAdapter
2121
from sqlmesh.dbt.relation import Policy
2222
from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS
2323
from sqlmesh.dbt.util import DBT_VERSION
@@ -384,15 +384,15 @@ def create_builtin_globals(
384384
builtin_globals["this"] = this
385385

386386
sources = jinja_globals.pop("sources", None)
387-
if sources is not None:
387+
if sources is not None and "source" not in jinja_globals:
388388
builtin_globals["source"] = generate_source(sources, api)
389389

390390
refs = jinja_globals.pop("refs", None)
391-
if refs is not None:
391+
if refs is not None and "ref" not in jinja_globals:
392392
builtin_globals["ref"] = generate_ref(refs, api)
393393

394394
variables = jinja_globals.pop("vars", None)
395-
if variables is not None:
395+
if variables is not None and "var" not in jinja_globals:
396396
builtin_globals["var"] = Var(variables)
397397

398398
deployability_index = (
@@ -415,6 +415,7 @@ def create_builtin_globals(
415415
{k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")}
416416
)
417417

418+
execute = True
418419
if engine_adapter is not None:
419420
builtin_globals["flags"] = Flags(which="run")
420421
adapter: BaseAdapter = RuntimeAdapter(
@@ -435,7 +436,11 @@ def create_builtin_globals(
435436
)
436437
else:
437438
builtin_globals["flags"] = Flags(which="parse")
438-
adapter = ParsetimeAdapter(
439+
adapter_class: t.Type[BaseAdapter] = ParsetimeAdapter
440+
if jinja_globals.get("use_stub_adapter", False):
441+
adapter_class = StubParsetimeAdapter
442+
execute = False
443+
adapter = adapter_class(
439444
jinja_macros,
440445
jinja_globals={**builtin_globals, **jinja_globals},
441446
project_dialect=project_dialect,
@@ -446,7 +451,7 @@ def create_builtin_globals(
446451
builtin_globals.update(
447452
{
448453
"adapter": adapter,
449-
"execute": True,
454+
"execute": execute,
450455
"load_relation": adapter.load_relation,
451456
"store_result": sql_execution.store_result,
452457
"load_result": sql_execution.load_result,

sqlmesh/dbt/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,16 @@ class Dependencies(PydanticModel):
184184
variables: t.Set[str] = set()
185185
model_attrs: t.Set[str] = set()
186186

187+
has_dynamic_var_names: bool = False
188+
187189
def union(self, other: Dependencies) -> Dependencies:
188190
return Dependencies(
189191
macros=list(set(self.macros) | set(other.macros)),
190192
sources=self.sources | other.sources,
191193
refs=self.refs | other.refs,
192194
variables=self.variables | other.variables,
193195
model_attrs=self.model_attrs | other.model_attrs,
196+
has_dynamic_var_names=self.has_dynamic_var_names or other.has_dynamic_var_names,
194197
)
195198

196199
@field_validator("macros", mode="after")

sqlmesh/dbt/context.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from dbt.adapters.base import BaseRelation
88

99
from sqlmesh.core.config import Config as SQLMeshConfig
10-
from sqlmesh.dbt.builtin import _relation_info_to_relation
10+
from sqlmesh.dbt.builtin import _relation_info_to_relation, Var
11+
from sqlmesh.dbt.common import Dependencies
1112
from sqlmesh.dbt.manifest import ManifestHelper
1213
from sqlmesh.dbt.target import TargetConfig
1314
from sqlmesh.utils import AttributeDict
@@ -22,7 +23,6 @@
2223
if t.TYPE_CHECKING:
2324
from jinja2 import Environment
2425

25-
from sqlmesh.dbt.basemodel import Dependencies
2626
from sqlmesh.dbt.model import ModelConfig
2727
from sqlmesh.dbt.relation import Policy
2828
from sqlmesh.dbt.seed import SeedConfig
@@ -212,6 +212,38 @@ def target(self, value: TargetConfig) -> None:
212212
def render(self, source: str, **kwargs: t.Any) -> str:
213213
return self.jinja_environment.from_string(source).render(**kwargs)
214214

215+
def track_dependencies_on_render(
216+
self, input: str, jinja_context: t.Dict[str, t.Any], package_name: t.Optional[str] = None
217+
) -> Dependencies:
218+
dependencies_on_render = Dependencies()
219+
220+
class TrackingVar(Var):
221+
def __call__(
222+
self, name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any
223+
) -> t.Any:
224+
dependencies_on_render.variables.add(name)
225+
return super().__call__(name, default, **kwargs)
226+
227+
def has_var(self, name: str) -> bool:
228+
dependencies_on_render.variables.add(name)
229+
return super().has_var(name)
230+
231+
if package_name:
232+
top_level_packages = [*self.jinja_macros.top_level_packages, package_name]
233+
jinja_macros = self.jinja_macros.copy(update={"top_level_packages": top_level_packages})
234+
else:
235+
jinja_macros = self.jinja_macros
236+
237+
jinja_environment = jinja_macros.build_environment(
238+
**{
239+
**jinja_context,
240+
"var": TrackingVar(self.variables),
241+
"use_stub_adapter": True,
242+
}
243+
)
244+
jinja_environment.from_string(input).render()
245+
return dependencies_on_render
246+
215247
def get_callable_macro(
216248
self, name: str, package: t.Optional[str] = None
217249
) -> t.Optional[t.Callable]:

sqlmesh/dbt/manifest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,9 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies:
550550
args = [jinja_call_arg_name(arg) for arg in node.args]
551551
if args and args[0]:
552552
dependencies.variables.add(args[0])
553+
else:
554+
# We couldn't determine the var name statically
555+
dependencies.has_dynamic_var_names = True
553556
dependencies.macros.append(MacroReference(name="var"))
554557
elif len(call_name) == 1:
555558
macro_name = call_name[0]

sqlmesh/dbt/model.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
)
2626
from sqlmesh.core.model.kind import SCDType2ByTimeKind, OnDestructiveChange, OnAdditiveChange
2727
from sqlmesh.dbt.basemodel import BaseModelConfig, Materialization, SnapshotStrategy
28-
from sqlmesh.dbt.common import SqlStr, extract_jinja_config, sql_str_validator
28+
from sqlmesh.dbt.column import ColumnConfig
29+
from sqlmesh.dbt.common import SqlStr, extract_jinja_config, sql_str_validator, Dependencies
2930
from sqlmesh.utils.errors import ConfigError
3031
from sqlmesh.utils.pydantic import field_validator
3132

@@ -436,6 +437,30 @@ def sqlmesh_config_fields(self) -> t.Set[str]:
436437
"physical_version",
437438
}
438439

440+
def sqlmesh_model_kwargs(
441+
self,
442+
context: DbtContext,
443+
column_types_override: t.Optional[t.Dict[str, ColumnConfig]] = None,
444+
extra_dependencies: t.Optional[Dependencies] = None,
445+
) -> t.Dict[str, t.Any]:
446+
if not self.dependencies.has_dynamic_var_names:
447+
return super().sqlmesh_model_kwargs(context, column_types_override, extra_dependencies)
448+
449+
extra_dependencies = extra_dependencies or Dependencies()
450+
extra_dependencies = extra_dependencies.union(
451+
self._track_dependencies_on_render(self.sql_no_config, context)
452+
)
453+
for pre_hook in self.pre_hook:
454+
extra_dependencies = extra_dependencies.union(
455+
self._track_dependencies_on_render(pre_hook.sql, context)
456+
)
457+
for post_hook in self.post_hook:
458+
extra_dependencies = extra_dependencies.union(
459+
self._track_dependencies_on_render(post_hook.sql, context)
460+
)
461+
462+
return super().sqlmesh_model_kwargs(context, column_types_override, extra_dependencies)
463+
439464
def to_sqlmesh(
440465
self,
441466
context: DbtContext,

tests/dbt/test_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def test_variables(assert_exp_eq, sushi_test_project):
362362
"nested_vars": {
363363
"some_nested_var": 2,
364364
},
365+
"dynamic_test_var": 3,
365366
"list_var": [
366367
{"name": "item1", "value": 1},
367368
{"name": "item2", "value": 2},
@@ -385,6 +386,7 @@ def test_variables(assert_exp_eq, sushi_test_project):
385386
"nested_vars": {
386387
"some_nested_var": 2,
387388
},
389+
"dynamic_test_var": 3,
388390
"list_var": [
389391
{"name": "item1", "value": 1},
390392
{"name": "item2", "value": 2},

tests/dbt/test_manifest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def test_manifest_helper(caplog):
7979
waiter_revenue_by_day_config = models["waiter_revenue_by_day_v2"]
8080
assert waiter_revenue_by_day_config.dependencies == Dependencies(
8181
macros={
82+
MacroReference(name="dynamic_var_name_dependency"),
8283
MacroReference(name="log_value"),
8384
MacroReference(name="test_dependencies"),
8485
MacroReference(package="customers", name="duckdb__current_engine"),
@@ -87,6 +88,7 @@ def test_manifest_helper(caplog):
8788
},
8889
sources={"streaming.items", "streaming.orders", "streaming.order_items"},
8990
variables={"yet_another_var", "nested_vars"},
91+
has_dynamic_var_names=True,
9092
)
9193
assert waiter_revenue_by_day_config.materialized == "incremental"
9294
assert waiter_revenue_by_day_config.incremental_strategy == "delete+insert"

0 commit comments

Comments
 (0)