Skip to content

Commit 69d658b

Browse files
Feat(dbt): Add dbt graph context variable support
1 parent 21f06dd commit 69d658b

File tree

7 files changed

+109
-8
lines changed

7 files changed

+109
-8
lines changed

sqlmesh/dbt/adapter.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot, to_table_mapping
1212
from sqlmesh.utils.errors import ConfigError, ParsetimeAdapterCallError
1313
from sqlmesh.utils.jinja import JinjaMacroRegistry
14+
from sqlmesh.utils import AttributeDict
1415

1516
if t.TYPE_CHECKING:
1617
import agate
@@ -158,6 +159,20 @@ def compare_dbr_version(self, major: int, minor: int) -> int:
158159
# Always return -1 to fallback to Spark macro implementations.
159160
return -1
160161

162+
@property
163+
def graph(self) -> t.Any:
164+
return AttributeDict(
165+
{
166+
"exposures": {},
167+
"groups": {},
168+
"metrics": {},
169+
"nodes": {},
170+
"sources": {},
171+
"semantic_models": {},
172+
"saved_queries": {},
173+
}
174+
)
175+
161176

162177
class ParsetimeAdapter(BaseAdapter):
163178
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:
@@ -246,6 +261,10 @@ def __init__(
246261
**table_mapping,
247262
}
248263

264+
@property
265+
def graph(self) -> t.Any:
266+
return self.jinja_globals.get("flat_graph", super().graph)
267+
249268
def get_relation(
250269
self, database: t.Optional[str], schema: str, identifier: str
251270
) -> t.Optional[BaseRelation]:

sqlmesh/dbt/builtin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def create_builtin_globals(
452452
"load_result": sql_execution.load_result,
453453
"run_query": sql_execution.run_query,
454454
"statement": sql_execution.statement,
455+
"graph": adapter.graph,
455456
}
456457
)
457458

sqlmesh/dbt/context.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlmesh.dbt.manifest import ManifestHelper
1212
from sqlmesh.dbt.target import TargetConfig
1313
from sqlmesh.utils import AttributeDict
14+
from sqlmesh.utils.conversions import serializable
1415
from sqlmesh.utils.errors import ConfigError, SQLMeshError
1516
from sqlmesh.utils.jinja import (
1617
JinjaGlobalAttribute,
@@ -195,6 +196,49 @@ def refs(self) -> t.Dict[str, t.Union[ModelConfig, SeedConfig]]:
195196
self._refs[f"{config_name}_v{model.version}"] = model
196197
return self._refs
197198

199+
@property
200+
def flat_graph(self) -> t.Dict[str, t.Any]:
201+
if self._manifest is None:
202+
return {
203+
"exposures": {},
204+
"groups": {},
205+
"metrics": {},
206+
"nodes": {},
207+
"sources": {},
208+
"semantic_models": {},
209+
"saved_queries": {},
210+
}
211+
212+
manifest = self._manifest._manifest
213+
return {
214+
"exposures": {
215+
k: serializable(v.to_dict(omit_none=False))
216+
for k, v in getattr(manifest, "exposures", {}).items()
217+
},
218+
"groups": {
219+
k: serializable(v.to_dict(omit_none=False))
220+
for k, v in getattr(manifest, "groups", {}).items()
221+
},
222+
"metrics": {
223+
k: serializable(v.to_dict(omit_none=False))
224+
for k, v in getattr(manifest, "metrics", {}).items()
225+
},
226+
"nodes": {
227+
k: serializable(v.to_dict(omit_none=False)) for k, v in manifest.nodes.items()
228+
},
229+
"sources": {
230+
k: serializable(v.to_dict(omit_none=False)) for k, v in manifest.sources.items()
231+
},
232+
"semantic_models": {
233+
k: serializable(v.to_dict(omit_none=False))
234+
for k, v in getattr(manifest, "semantic_models", {}).items()
235+
},
236+
"saved_queries": {
237+
k: serializable(v.to_dict(omit_none=False))
238+
for k, v in getattr(manifest, "saved_queries", {}).items()
239+
},
240+
}
241+
198242
@property
199243
def target(self) -> TargetConfig:
200244
if not self._target:
@@ -242,6 +286,9 @@ def jinja_globals(self) -> t.Dict[str, JinjaGlobalAttribute]:
242286
# pass user-specified default dialect if we have already loaded the config
243287
if self.sqlmesh_config.dialect:
244288
output["dialect"] = self.sqlmesh_config.dialect
289+
# Pass flat graph structure like dbt
290+
if self._manifest is not None:
291+
output["flat_graph"] = AttributeDict(self.flat_graph)
245292
return output
246293

247294
def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext:

sqlmesh/utils/conversions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import typing as t
4+
from datetime import date, datetime
45

56

67
def ensure_bool(val: t.Any) -> bool:
@@ -19,3 +20,13 @@ def try_str_to_bool(val: str) -> t.Union[str, bool]:
1920
return maybe_bool == "true"
2021

2122
return val
23+
24+
25+
def serializable(obj: t.Any) -> t.Any:
26+
if isinstance(obj, (date, datetime)):
27+
return obj.isoformat()
28+
if isinstance(obj, dict):
29+
return {k: serializable(v) for k, v in obj.items()}
30+
if isinstance(obj, list):
31+
return [serializable(item) for item in obj]
32+
return obj

tests/dbt/test_transformation.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,7 @@ def test_on_run_start_end():
16061606
assert root_environment_statements.after_all == [
16071607
"JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;",
16081608
"JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_last;\nJINJA_END;",
1609+
"JINJA_STATEMENT_BEGIN;\n{{ graph_usage() }}\nJINJA_END;",
16091610
]
16101611

16111612
assert root_environment_statements.jinja_macros.root_package_name == "sushi"
@@ -1626,6 +1627,7 @@ def test_on_run_start_end():
16261627
snapshots=sushi_context.snapshots,
16271628
runtime_stage=RuntimeStage.AFTER_ALL,
16281629
environment_naming_info=EnvironmentNamingInfo(name="dev"),
1630+
engine_adapter=sushi_context.engine_adapter,
16291631
)
16301632

16311633
assert rendered_before_all == [
@@ -1635,13 +1637,14 @@ def test_on_run_start_end():
16351637
]
16361638

16371639
# The jinja macro should have resolved the schemas for this environment and generated corresponding statements
1638-
assert sorted(rendered_after_all) == sorted(
1639-
[
1640-
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
1641-
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
1642-
"DROP TABLE to_be_executed_last",
1643-
]
1644-
)
1640+
expected_statements = [
1641+
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
1642+
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
1643+
"DROP TABLE to_be_executed_last",
1644+
"CREATE OR REPLACE TABLE graph_table AS SELECT 'model.sushi.simple_model_a' AS unique_id, 'table' AS materialized UNION ALL SELECT 'model.sushi.waiters' AS unique_id, 'ephemeral' AS materialized UNION ALL SELECT 'model.sushi.simple_model_b' AS unique_id, 'table' AS materialized UNION ALL SELECT 'model.sushi.waiter_as_customer_by_day' AS unique_id, 'incremental' AS materialized UNION ALL SELECT 'model.sushi.top_waiters' AS unique_id, 'view' AS materialized UNION ALL SELECT 'model.customers.customers' AS unique_id, 'view' AS materialized UNION ALL SELECT 'model.customers.customer_revenue_by_day' AS unique_id, 'incremental' AS materialized UNION ALL SELECT 'model.sushi.waiter_revenue_by_day.v1' AS unique_id, 'incremental' AS materialized UNION ALL SELECT 'model.sushi.waiter_revenue_by_day.v2' AS unique_id, 'incremental' AS materialized",
1645+
]
1646+
1647+
assert sorted(rendered_after_all) == sorted(expected_statements)
16451648

16461649
# Nested dbt_packages on run start / on run end
16471650
packaged_environment_statements = sushi_context._environment_statements[1]
@@ -1675,6 +1678,7 @@ def test_on_run_start_end():
16751678
snapshots=sushi_context.snapshots,
16761679
runtime_stage=RuntimeStage.AFTER_ALL,
16771680
environment_naming_info=EnvironmentNamingInfo(name="dev"),
1681+
engine_adapter=sushi_context.engine_adapter,
16781682
)
16791683

16801684
# Validate order of execution to match dbt's

tests/fixtures/dbt/sushi_test/dbt_project.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,5 @@ on-run-start:
7070
- "{{ log_value('on-run-start') }}"
7171
on-run-end:
7272
- '{{ create_tables(schemas) }}'
73-
- 'DROP TABLE to_be_executed_last;'
73+
- 'DROP TABLE to_be_executed_last;'
74+
- '{{ graph_usage() }}'
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{% macro graph_usage() %}
2+
{% if execute %}
3+
{% set model_nodes = graph.nodes.values()
4+
| selectattr("resource_type", "equalto", "model")
5+
| list %}
6+
7+
{% set out = [] %}
8+
{% for node in model_nodes %}
9+
{% set line = "select '" ~ node.unique_id ~ "' as unique_id, '" ~ node.config.materialized ~ "' as materialized" %}
10+
{% do out.append(line) %}
11+
{% endfor %}
12+
13+
{% if out %}
14+
{% set sql_statement = "create or replace table graph_table as\n" ~ (out | join('\nunion all\n')) %}
15+
{{ return(sql_statement) }}
16+
{% endif %}
17+
{% endif %}
18+
{% endmacro %}

0 commit comments

Comments
 (0)