Skip to content

Commit 98c858c

Browse files
Feat(dbt): Add dbt graph context variable support (#5159)
1 parent e574577 commit 98c858c

File tree

9 files changed

+124
-7
lines changed

9 files changed

+124
-7
lines changed

sqlmesh/dbt/adapter.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot, to_table_mapping
1212
from sqlmesh.utils.errors import ConfigError, ParsetimeAdapterCallError
1313
from sqlmesh.utils.jinja import JinjaMacroRegistry
14+
from sqlmesh.utils import AttributeDict
1415

1516
if t.TYPE_CHECKING:
1617
import agate
@@ -158,6 +159,20 @@ def compare_dbr_version(self, major: int, minor: int) -> int:
158159
# Always return -1 to fallback to Spark macro implementations.
159160
return -1
160161

162+
@property
163+
def graph(self) -> t.Any:
164+
return AttributeDict(
165+
{
166+
"exposures": {},
167+
"groups": {},
168+
"metrics": {},
169+
"nodes": {},
170+
"sources": {},
171+
"semantic_models": {},
172+
"saved_queries": {},
173+
}
174+
)
175+
161176

162177
class ParsetimeAdapter(BaseAdapter):
163178
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:
@@ -246,6 +261,10 @@ def __init__(
246261
**table_mapping,
247262
}
248263

264+
@property
265+
def graph(self) -> t.Any:
266+
return self.jinja_globals.get("flat_graph", super().graph)
267+
249268
def get_relation(
250269
self, database: t.Optional[str], schema: str, identifier: str
251270
) -> t.Optional[BaseRelation]:

sqlmesh/dbt/builtin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def create_builtin_globals(
452452
"load_result": sql_execution.load_result,
453453
"run_query": sql_execution.run_query,
454454
"statement": sql_execution.statement,
455+
"graph": adapter.graph,
455456
}
456457
)
457458

sqlmesh/dbt/context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ def jinja_globals(self) -> t.Dict[str, JinjaGlobalAttribute]:
242242
# pass user-specified default dialect if we have already loaded the config
243243
if self.sqlmesh_config.dialect:
244244
output["dialect"] = self.sqlmesh_config.dialect
245+
# Pass flat graph structure like dbt
246+
if self._manifest is not None:
247+
output["flat_graph"] = AttributeDict(self.manifest.flat_graph)
245248
return output
246249

247250
def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext:

sqlmesh/dbt/manifest.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from dbt import constants as dbt_constants, flags
1414

15+
from sqlmesh.utils.conversions import make_serializable
16+
1517
# Override the file name to prevent dbt commands from invalidating the cache.
1618
dbt_constants.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack"
1719

@@ -155,6 +157,39 @@ def all_macros(self) -> t.Dict[str, t.Dict[str, MacroInfo]]:
155157
result[package_name][macro_name] = macro_config.info
156158
return result
157159

160+
@property
161+
def flat_graph(self) -> t.Dict[str, t.Any]:
162+
return {
163+
"exposures": {
164+
k: make_serializable(v.to_dict(omit_none=False))
165+
for k, v in getattr(self._manifest, "exposures", {}).items()
166+
},
167+
"groups": {
168+
k: make_serializable(v.to_dict(omit_none=False))
169+
for k, v in getattr(self._manifest, "groups", {}).items()
170+
},
171+
"metrics": {
172+
k: make_serializable(v.to_dict(omit_none=False))
173+
for k, v in getattr(self._manifest, "metrics", {}).items()
174+
},
175+
"nodes": {
176+
k: make_serializable(v.to_dict(omit_none=False))
177+
for k, v in self._manifest.nodes.items()
178+
},
179+
"sources": {
180+
k: make_serializable(v.to_dict(omit_none=False))
181+
for k, v in self._manifest.sources.items()
182+
},
183+
"semantic_models": {
184+
k: make_serializable(v.to_dict(omit_none=False))
185+
for k, v in getattr(self._manifest, "semantic_models", {}).items()
186+
},
187+
"saved_queries": {
188+
k: make_serializable(v.to_dict(omit_none=False))
189+
for k, v in getattr(self._manifest, "saved_queries", {}).items()
190+
},
191+
}
192+
158193
def _load_all(self) -> None:
159194
if self._is_loaded:
160195
return

sqlmesh/utils/conversions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import typing as t
4+
from datetime import date, datetime
45

56

67
def ensure_bool(val: t.Any) -> bool:
@@ -19,3 +20,13 @@ def try_str_to_bool(val: str) -> t.Union[str, bool]:
1920
return maybe_bool == "true"
2021

2122
return val
23+
24+
25+
def make_serializable(obj: t.Any) -> t.Any:
26+
if isinstance(obj, (date, datetime)):
27+
return obj.isoformat()
28+
if isinstance(obj, dict):
29+
return {k: make_serializable(v) for k, v in obj.items()}
30+
if isinstance(obj, list):
31+
return [make_serializable(item) for item in obj]
32+
return obj

sqlmesh/utils/jinja.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,9 @@ def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None:
363363
Args:
364364
globals: The global objects that should be added.
365365
"""
366+
# Keep the registry lightweight when the graph is not needed
367+
if not "graph" in self.packages:
368+
globals.pop("flat_graph", None)
366369
self.global_objs.update(**self._validate_global_objs(globals))
367370

368371
def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]:

tests/dbt/test_transformation.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,7 @@ def test_on_run_start_end():
16061606
assert root_environment_statements.after_all == [
16071607
"JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;",
16081608
"JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_last;\nJINJA_END;",
1609+
"JINJA_STATEMENT_BEGIN;\n{{ graph_usage() }}\nJINJA_END;",
16091610
]
16101611

16111612
assert root_environment_statements.jinja_macros.root_package_name == "sushi"
@@ -1626,6 +1627,7 @@ def test_on_run_start_end():
16261627
snapshots=sushi_context.snapshots,
16271628
runtime_stage=RuntimeStage.AFTER_ALL,
16281629
environment_naming_info=EnvironmentNamingInfo(name="dev"),
1630+
engine_adapter=sushi_context.engine_adapter,
16291631
)
16301632

16311633
assert rendered_before_all == [
@@ -1635,12 +1637,35 @@ def test_on_run_start_end():
16351637
]
16361638

16371639
# The jinja macro should have resolved the schemas for this environment and generated corresponding statements
1638-
assert sorted(rendered_after_all) == sorted(
1639-
[
1640-
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
1641-
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
1642-
"DROP TABLE to_be_executed_last",
1643-
]
1640+
expected_statements = [
1641+
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
1642+
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
1643+
"DROP TABLE to_be_executed_last",
1644+
]
1645+
assert sorted(rendered_after_all[:-1]) == sorted(expected_statements)
1646+
1647+
# Assert the models with their materialisations are present in the rendered graph_table statement
1648+
graph_table_stmt = rendered_after_all[-1]
1649+
assert "'model.sushi.simple_model_a' AS unique_id, 'table' AS materialized" in graph_table_stmt
1650+
assert "'model.sushi.waiters' AS unique_id, 'ephemeral' AS materialized" in graph_table_stmt
1651+
assert "'model.sushi.simple_model_b' AS unique_id, 'table' AS materialized" in graph_table_stmt
1652+
assert (
1653+
"'model.sushi.waiter_as_customer_by_day' AS unique_id, 'incremental' AS materialized"
1654+
in graph_table_stmt
1655+
)
1656+
assert "'model.sushi.top_waiters' AS unique_id, 'view' AS materialized" in graph_table_stmt
1657+
assert "'model.customers.customers' AS unique_id, 'view' AS materialized" in graph_table_stmt
1658+
assert (
1659+
"'model.customers.customer_revenue_by_day' AS unique_id, 'incremental' AS materialized"
1660+
in graph_table_stmt
1661+
)
1662+
assert (
1663+
"'model.sushi.waiter_revenue_by_day.v1' AS unique_id, 'incremental' AS materialized"
1664+
in graph_table_stmt
1665+
)
1666+
assert (
1667+
"'model.sushi.waiter_revenue_by_day.v2' AS unique_id, 'incremental' AS materialized"
1668+
in graph_table_stmt
16441669
)
16451670

16461671
# Nested dbt_packages on run start / on run end
@@ -1675,6 +1700,7 @@ def test_on_run_start_end():
16751700
snapshots=sushi_context.snapshots,
16761701
runtime_stage=RuntimeStage.AFTER_ALL,
16771702
environment_naming_info=EnvironmentNamingInfo(name="dev"),
1703+
engine_adapter=sushi_context.engine_adapter,
16781704
)
16791705

16801706
# Validate order of execution to match dbt's

tests/fixtures/dbt/sushi_test/dbt_project.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,5 @@ on-run-start:
7070
- "{{ log_value('on-run-start') }}"
7171
on-run-end:
7272
- '{{ create_tables(schemas) }}'
73-
- 'DROP TABLE to_be_executed_last;'
73+
- 'DROP TABLE to_be_executed_last;'
74+
- '{{ graph_usage() }}'
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{% macro graph_usage() %}
2+
{% if execute %}
3+
{% set model_nodes = graph.nodes.values()
4+
| selectattr("resource_type", "equalto", "model")
5+
| list %}
6+
7+
{% set out = [] %}
8+
{% for node in model_nodes %}
9+
{% set line = "select '" ~ node.unique_id ~ "' as unique_id, '" ~ node.config.materialized ~ "' as materialized" %}
10+
{% do out.append(line) %}
11+
{% endfor %}
12+
13+
{% if out %}
14+
{% set sql_statement = "create or replace table graph_table as\n" ~ (out | join('\nunion all\n')) %}
15+
{{ return(sql_statement) }}
16+
{% endif %}
17+
{% endif %}
18+
{% endmacro %}

0 commit comments

Comments
 (0)