Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions sqlmesh/dbt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot, to_table_mapping
from sqlmesh.utils.errors import ConfigError, ParsetimeAdapterCallError
from sqlmesh.utils.jinja import JinjaMacroRegistry
from sqlmesh.utils import AttributeDict

if t.TYPE_CHECKING:
import agate
Expand Down Expand Up @@ -158,6 +159,20 @@ def compare_dbr_version(self, major: int, minor: int) -> int:
# Always return -1 to fallback to Spark macro implementations.
return -1

@property
def graph(self) -> t.Any:
return AttributeDict(
{
"exposures": {},
"groups": {},
"metrics": {},
"nodes": {},
"sources": {},
"semantic_models": {},
"saved_queries": {},
}
)


class ParsetimeAdapter(BaseAdapter):
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:
Expand Down Expand Up @@ -246,6 +261,10 @@ def __init__(
**table_mapping,
}

@property
def graph(self) -> t.Any:
return self.jinja_globals.get("flat_graph", super().graph)

def get_relation(
self, database: t.Optional[str], schema: str, identifier: str
) -> t.Optional[BaseRelation]:
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/dbt/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def create_builtin_globals(
"load_result": sql_execution.load_result,
"run_query": sql_execution.run_query,
"statement": sql_execution.statement,
"graph": adapter.graph,
}
)

Expand Down
3 changes: 3 additions & 0 deletions sqlmesh/dbt/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def jinja_globals(self) -> t.Dict[str, JinjaGlobalAttribute]:
# pass user-specified default dialect if we have already loaded the config
if self.sqlmesh_config.dialect:
output["dialect"] = self.sqlmesh_config.dialect
# Pass flat graph structure like dbt
if self._manifest is not None:
output["flat_graph"] = AttributeDict(self.manifest.flat_graph)
return output

def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext:
Expand Down
35 changes: 35 additions & 0 deletions sqlmesh/dbt/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from dbt import constants as dbt_constants, flags

from sqlmesh.utils.conversions import make_serializable

# Override the file name to prevent dbt commands from invalidating the cache.
dbt_constants.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack"

Expand Down Expand Up @@ -155,6 +157,39 @@ def all_macros(self) -> t.Dict[str, t.Dict[str, MacroInfo]]:
result[package_name][macro_name] = macro_config.info
return result

@property
def flat_graph(self) -> t.Dict[str, t.Any]:
return {
"exposures": {
k: make_serializable(v.to_dict(omit_none=False))
for k, v in getattr(self._manifest, "exposures", {}).items()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FMI: Is this a shortcut (along with serializable()) to just reach into the Manifest object returned by the dbt core lib and pull out fields, vs wrapping them in *Config objects and exposing them like we have done for self.models, self.seeds etc?

If so, this will probably need to be refactored in future if we want to do something else with this info

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of this originally as we already convert the node objects for example into Config ones (models, seeds, tests, hooks etc.) but we then add extra sqlmesh-specific fields and map dbt terms to native ones in them. If we used those objects to construct the flat graph, we’d need to strip out the SQLMesh-specific attributes or the result wouldn't match what dbt returns. So since this is a user facing dictionary to be used in macros and not something we internally use I thought it better not to overcomplicate it and to construct the flat graph similar to dbt and use the already available dictionaries. That way when a user accesses it in a macro they get exactly what they would in dbt

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense.

I dont think we will ever get rid of the DBT manifest and the dependency on dbt-core :)

},
"groups": {
k: make_serializable(v.to_dict(omit_none=False))
for k, v in getattr(self._manifest, "groups", {}).items()
},
"metrics": {
k: make_serializable(v.to_dict(omit_none=False))
for k, v in getattr(self._manifest, "metrics", {}).items()
},
"nodes": {
k: make_serializable(v.to_dict(omit_none=False))
for k, v in self._manifest.nodes.items()
},
"sources": {
k: make_serializable(v.to_dict(omit_none=False))
for k, v in self._manifest.sources.items()
},
"semantic_models": {
k: make_serializable(v.to_dict(omit_none=False))
for k, v in getattr(self._manifest, "semantic_models", {}).items()
},
"saved_queries": {
k: make_serializable(v.to_dict(omit_none=False))
for k, v in getattr(self._manifest, "saved_queries", {}).items()
},
}

def _load_all(self) -> None:
if self._is_loaded:
return
Expand Down
11 changes: 11 additions & 0 deletions sqlmesh/utils/conversions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing as t
from datetime import date, datetime


def ensure_bool(val: t.Any) -> bool:
Expand All @@ -19,3 +20,13 @@ def try_str_to_bool(val: str) -> t.Union[str, bool]:
return maybe_bool == "true"

return val


def make_serializable(obj: t.Any) -> t.Any:
if isinstance(obj, (date, datetime)):
return obj.isoformat()
if isinstance(obj, dict):
return {k: make_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [make_serializable(item) for item in obj]
return obj
3 changes: 3 additions & 0 deletions sqlmesh/utils/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None:
Args:
globals: The global objects that should be added.
"""
# Keep the registry lightweight when the graph is not needed
if not "graph" in self.packages:
globals.pop("flat_graph", None)
self.global_objs.update(**self._validate_global_objs(globals))

def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]:
Expand Down
38 changes: 32 additions & 6 deletions tests/dbt/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,7 @@ def test_on_run_start_end():
assert root_environment_statements.after_all == [
"JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;",
"JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_last;\nJINJA_END;",
"JINJA_STATEMENT_BEGIN;\n{{ graph_usage() }}\nJINJA_END;",
]

assert root_environment_statements.jinja_macros.root_package_name == "sushi"
Expand All @@ -1626,6 +1627,7 @@ def test_on_run_start_end():
snapshots=sushi_context.snapshots,
runtime_stage=RuntimeStage.AFTER_ALL,
environment_naming_info=EnvironmentNamingInfo(name="dev"),
engine_adapter=sushi_context.engine_adapter,
)

assert rendered_before_all == [
Expand All @@ -1635,12 +1637,35 @@ def test_on_run_start_end():
]

# The jinja macro should have resolved the schemas for this environment and generated corresponding statements
assert sorted(rendered_after_all) == sorted(
[
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
"DROP TABLE to_be_executed_last",
]
expected_statements = [
"CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema",
"CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema",
"DROP TABLE to_be_executed_last",
]
assert sorted(rendered_after_all[:-1]) == sorted(expected_statements)

# Assert the models with their materialisations are present in the rendered graph_table statement
graph_table_stmt = rendered_after_all[-1]
assert "'model.sushi.simple_model_a' AS unique_id, 'table' AS materialized" in graph_table_stmt
assert "'model.sushi.waiters' AS unique_id, 'ephemeral' AS materialized" in graph_table_stmt
assert "'model.sushi.simple_model_b' AS unique_id, 'table' AS materialized" in graph_table_stmt
assert (
"'model.sushi.waiter_as_customer_by_day' AS unique_id, 'incremental' AS materialized"
in graph_table_stmt
)
assert "'model.sushi.top_waiters' AS unique_id, 'view' AS materialized" in graph_table_stmt
assert "'model.customers.customers' AS unique_id, 'view' AS materialized" in graph_table_stmt
assert (
"'model.customers.customer_revenue_by_day' AS unique_id, 'incremental' AS materialized"
in graph_table_stmt
)
assert (
"'model.sushi.waiter_revenue_by_day.v1' AS unique_id, 'incremental' AS materialized"
in graph_table_stmt
)
assert (
"'model.sushi.waiter_revenue_by_day.v2' AS unique_id, 'incremental' AS materialized"
in graph_table_stmt
)

# Nested dbt_packages on run start / on run end
Expand Down Expand Up @@ -1675,6 +1700,7 @@ def test_on_run_start_end():
snapshots=sushi_context.snapshots,
runtime_stage=RuntimeStage.AFTER_ALL,
environment_naming_info=EnvironmentNamingInfo(name="dev"),
engine_adapter=sushi_context.engine_adapter,
)

# Validate order of execution to match dbt's
Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures/dbt/sushi_test/dbt_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,5 @@ on-run-start:
- "{{ log_value('on-run-start') }}"
on-run-end:
- '{{ create_tables(schemas) }}'
- 'DROP TABLE to_be_executed_last;'
- 'DROP TABLE to_be_executed_last;'
- '{{ graph_usage() }}'
18 changes: 18 additions & 0 deletions tests/fixtures/dbt/sushi_test/macros/graph_usage.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{% macro graph_usage() %}
{% if execute %}
{% set model_nodes = graph.nodes.values()
| selectattr("resource_type", "equalto", "model")
| list %}

{% set out = [] %}
{% for node in model_nodes %}
{% set line = "select '" ~ node.unique_id ~ "' as unique_id, '" ~ node.config.materialized ~ "' as materialized" %}
{% do out.append(line) %}
{% endfor %}

{% if out %}
{% set sql_statement = "create or replace table graph_table as\n" ~ (out | join('\nunion all\n')) %}
{{ return(sql_statement) }}
{% endif %}
{% endif %}
{% endmacro %}