Skip to content

Commit fe6f377

Browse files
committed
Fix: Support dbt macro dispatch search order
1 parent 71f3eb7 commit fe6f377

File tree

14 files changed

+71
-30
lines changed

14 files changed

+71
-30
lines changed

sqlmesh/dbt/adapter.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -122,48 +122,30 @@ def dispatch(
122122
) -> t.Callable:
123123
"""Returns a dialect-specific version of a macro with the given name."""
124124
target_type = self.jinja_globals["target"]["type"]
125-
macro_suffix = f"__{macro_name}"
126-
127-
def _relevance(package_name_pair: t.Tuple[t.Optional[str], str]) -> t.Tuple[int, int]:
128-
"""Lower scores more relevant."""
129-
macro_package, name = package_name_pair
130-
131-
package_score = 0 if macro_package == macro_namespace else 1
132-
name_score = 1
133-
134-
if name.startswith("default"):
135-
name_score = 2
136-
elif name.startswith(target_type):
137-
name_score = 0
138-
139-
return name_score, package_score
140125

141126
jinja_env = self.jinja_macros.build_environment(**self.jinja_globals).globals
142127

143128
packages_to_check: t.List[t.Optional[str]] = [None]
144129
if macro_namespace is not None:
145-
if macro_namespace in jinja_env:
130+
if dispatch := self.jinja_globals.get("dispatch"):
131+
for entry in dispatch.get(self.jinja_macros.root_package_name, []):
132+
if entry.get("macro_namespace") == macro_namespace:
133+
packages_to_check = entry.get("search_order")
134+
break
135+
if packages_to_check == [None] and macro_namespace in jinja_env:
146136
packages_to_check = [self.jinja_macros.root_package_name, macro_namespace]
147137

148138
# Add dbt packages as fallback
149139
packages_to_check.extend(k for k in jinja_env if k.startswith("dbt"))
150140

151-
candidates = {}
152141
for macro_package in packages_to_check:
153142
macros = jinja_env.get(macro_package, {}) if macro_package else jinja_env
154143
if not isinstance(macros, dict):
155144
continue
156-
candidates.update(
157-
{
158-
(macro_package, macro_name): macro_callable
159-
for macro_name, macro_callable in macros.items()
160-
if macro_name.endswith(macro_suffix)
161-
}
162-
)
163145

164-
if candidates:
165-
sorted_candidates = sorted(candidates, key=_relevance)
166-
return candidates[sorted_candidates[0]]
146+
for prefix in (f"{target_type}__", "default__", ""):
147+
if macro := macros.get(f"{prefix}{macro_name}"):
148+
return macro
167149

168150
raise ConfigError(f"Macro '{macro_name}', package '{macro_namespace}' was not found.")
169151

sqlmesh/dbt/context.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class DbtContext:
5454
_seeds: t.Dict[str, SeedConfig] = field(default_factory=dict)
5555
_sources: t.Dict[str, SourceConfig] = field(default_factory=dict)
5656
_refs: t.Dict[str, t.Union[ModelConfig, SeedConfig]] = field(default_factory=dict)
57+
_dispatch: t.Dict[str, t.List[t.Dict[str, t.Any]]] = field(default_factory=dict)
5758

5859
_target: t.Optional[TargetConfig] = None
5960

@@ -136,6 +137,14 @@ def add_macros(self, macros: t.Dict[str, MacroInfo], package: str) -> None:
136137
self.jinja_macros.add_macros(macros, package=package)
137138
self._jinja_environment = None
138139

140+
@property
141+
def dispatch(self) -> t.Dict[str, t.List[t.Dict[str, t.Any]]]:
142+
return self._dispatch
143+
144+
def add_dispatch(self, dispatch: t.List[t.Dict[str, t.Any]], package: str) -> None:
145+
self._dispatch[package] = dispatch
146+
self._jinja_environment = None
147+
139148
@property
140149
def models(self) -> t.Dict[str, ModelConfig]:
141150
return self._models
@@ -249,6 +258,8 @@ def jinja_globals(self) -> t.Dict[str, JinjaGlobalAttribute]:
249258
# Pass flat graph structure like dbt
250259
if self._manifest is not None:
251260
output["flat_graph"] = AttributeDict(self.manifest.flat_graph)
261+
if self._dispatch is not None:
262+
output["dispatch"] = AttributeDict(self._dispatch)
252263
return output
253264

254265
def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext:

sqlmesh/dbt/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def _load_projects(self) -> t.List[Project]:
218218
context.add_sources(package.sources)
219219
context.add_seeds(package.seeds)
220220
context.add_models(package.models)
221+
context.add_dispatch(package.dispatch, package_name)
221222
macros_mtimes.extend(
222223
[
223224
self._path_mtimes[m.path]

sqlmesh/dbt/package.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Package(PydanticModel):
5050
on_run_start: t.Dict[str, HookConfig]
5151
on_run_end: t.Dict[str, HookConfig]
5252
files: t.Set[Path]
53+
dispatch: t.List[t.Dict[str, t.Any]]
5354

5455
@property
5556
def macro_infos(self) -> t.Dict[str, MacroInfo]:
@@ -90,6 +91,8 @@ def load(self, package_root: Path) -> Package:
9091
var: value for var, value in all_variables.items() if not isinstance(value, dict)
9192
}
9293

94+
dispatch = project_yaml.get("dispatch") or []
95+
9396
tests = _fix_paths(self._context.manifest.tests(package_name), package_root)
9497
models = _fix_paths(self._context.manifest.models(package_name), package_root)
9598
seeds = _fix_paths(self._context.manifest.seeds(package_name), package_root)
@@ -113,6 +116,7 @@ def load(self, package_root: Path) -> Package:
113116
sources=sources,
114117
seeds=seeds,
115118
variables=package_variables,
119+
dispatch=dispatch,
116120
macros=macros,
117121
files=config_paths,
118122
on_run_start=on_run_start,

tests/dbt/test_adapter.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,11 @@ def test_adapter_dispatch(sushi_test_project: Project, runtime_renderer: t.Calla
242242
assert renderer("{{ adapter.dispatch('current_engine', 'customers')() }}") == "duckdb"
243243
assert renderer("{{ adapter.dispatch('current_timestamp')() }}") == "now()"
244244
assert renderer("{{ adapter.dispatch('current_timestamp', 'dbt')() }}") == "now()"
245-
assert renderer("{{ adapter.dispatch('select_distinct', 'customers')() }}") == "distinct"
245+
246+
# Macros in root project overrides macros in dependent packages
247+
assert (
248+
renderer("{{ adapter.dispatch('hello_world', 'my_helpers')() }}") == "hello from sushi_test"
249+
)
246250

247251
# test with keyword arguments
248252
assert (
@@ -276,6 +280,14 @@ def test_adapter_dispatch(sushi_test_project: Project, runtime_renderer: t.Calla
276280
renderer("{{ adapter.dispatch('current_engine')() }}")
277281

278282

283+
@pytest.mark.slow
284+
def test_adapter_dispatch_search_order(sushi_test_project: Project, runtime_renderer: t.Callable):
285+
context = sushi_test_project.context
286+
renderer = runtime_renderer(context)
287+
assert renderer("{{ adapter.dispatch('current_package', 'my_helpers')() }}") == "my_helpers"
288+
assert renderer("{{ adapter.dispatch('current_package', 'customers')() }}") == "my_helpers"
289+
290+
279291
@pytest.mark.parametrize("project_dialect", ["duckdb", "bigquery"])
280292
@pytest.mark.slow
281293
def test_adapter_map_snapshot_tables(
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../packages/my_helpers

tests/fixtures/dbt/sushi_test/dbt_project.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,8 @@ on-run-end:
7979
- '{{ create_tables(schemas) }}'
8080
- 'DROP TABLE to_be_executed_last;'
8181
- '{{ graph_usage() }}'
82+
83+
84+
dispatch:
85+
- macro_namespace: customers
86+
search_order: ["my_helpers", "customers"]

tests/fixtures/dbt/sushi_test/macros/distinct.sql

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
{% macro hello_world() %}hello from sushi_test{% endmacro %}
2+
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
packages:
22
- local: packages/customers
3-
3+
- local: dbt_packages/my_helpers

0 commit comments

Comments
 (0)