From 5b8f62af70cc0a3a72af3cbce30d9f165c4ed96e Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:45:38 +0300 Subject: [PATCH 1/3] Feat(dbt): Add dbt graph context variable support --- sqlmesh/dbt/adapter.py | 19 ++++++++ sqlmesh/dbt/builtin.py | 1 + sqlmesh/dbt/context.py | 47 +++++++++++++++++++ sqlmesh/utils/conversions.py | 11 +++++ tests/dbt/test_transformation.py | 18 ++++--- tests/fixtures/dbt/sushi_test/dbt_project.yml | 3 +- .../dbt/sushi_test/macros/graph_usage.sql | 18 +++++++ 7 files changed, 109 insertions(+), 8 deletions(-) create mode 100644 tests/fixtures/dbt/sushi_test/macros/graph_usage.sql diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 00a1d86ba2..b524acc160 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -11,6 +11,7 @@ from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot, to_table_mapping from sqlmesh.utils.errors import ConfigError, ParsetimeAdapterCallError from sqlmesh.utils.jinja import JinjaMacroRegistry +from sqlmesh.utils import AttributeDict if t.TYPE_CHECKING: import agate @@ -158,6 +159,20 @@ def compare_dbr_version(self, major: int, minor: int) -> int: # Always return -1 to fallback to Spark macro implementations. return -1 + @property + def graph(self) -> t.Any: + return AttributeDict( + { + "exposures": {}, + "groups": {}, + "metrics": {}, + "nodes": {}, + "sources": {}, + "semantic_models": {}, + "saved_queries": {}, + } + ) + class ParsetimeAdapter(BaseAdapter): def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]: @@ -246,6 +261,10 @@ def __init__( **table_mapping, } + @property + def graph(self) -> t.Any: + return self.jinja_globals.get("flat_graph", super().graph) + def get_relation( self, database: t.Optional[str], schema: str, identifier: str ) -> t.Optional[BaseRelation]: diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 07edeefa2e..70e1b10099 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -452,6 +452,7 @@ def create_builtin_globals( "load_result": sql_execution.load_result, "run_query": sql_execution.run_query, "statement": sql_execution.statement, + "graph": adapter.graph, } ) diff --git a/sqlmesh/dbt/context.py b/sqlmesh/dbt/context.py index d29dc43574..d7b9901623 100644 --- a/sqlmesh/dbt/context.py +++ b/sqlmesh/dbt/context.py @@ -11,6 +11,7 @@ from sqlmesh.dbt.manifest import ManifestHelper from sqlmesh.dbt.target import TargetConfig from sqlmesh.utils import AttributeDict +from sqlmesh.utils.conversions import serializable from sqlmesh.utils.errors import ConfigError, SQLMeshError from sqlmesh.utils.jinja import ( JinjaGlobalAttribute, @@ -195,6 +196,49 @@ def refs(self) -> t.Dict[str, t.Union[ModelConfig, SeedConfig]]: self._refs[f"{config_name}_v{model.version}"] = model return self._refs + @property + def flat_graph(self) -> t.Dict[str, t.Any]: + if self._manifest is None: + return { + "exposures": {}, + "groups": {}, + "metrics": {}, + "nodes": {}, + "sources": {}, + "semantic_models": {}, + "saved_queries": {}, + } + + manifest = self._manifest._manifest + return { + "exposures": { + k: serializable(v.to_dict(omit_none=False)) + for k, v in getattr(manifest, "exposures", {}).items() + }, + "groups": { + k: serializable(v.to_dict(omit_none=False)) + for k, v in getattr(manifest, "groups", {}).items() + }, + "metrics": { + k: serializable(v.to_dict(omit_none=False)) + for k, v in getattr(manifest, "metrics", {}).items() + }, + "nodes": { + k: serializable(v.to_dict(omit_none=False)) for k, v in manifest.nodes.items() + }, + "sources": { + k: serializable(v.to_dict(omit_none=False)) for k, v in manifest.sources.items() + }, + "semantic_models": { + k: serializable(v.to_dict(omit_none=False)) + for k, v in getattr(manifest, "semantic_models", {}).items() + }, + "saved_queries": { + k: serializable(v.to_dict(omit_none=False)) + for k, v in getattr(manifest, "saved_queries", {}).items() + }, + } + @property def target(self) -> TargetConfig: if not self._target: @@ -242,6 +286,9 @@ def jinja_globals(self) -> t.Dict[str, JinjaGlobalAttribute]: # pass user-specified default dialect if we have already loaded the config if self.sqlmesh_config.dialect: output["dialect"] = self.sqlmesh_config.dialect + # Pass flat graph structure like dbt + if self._manifest is not None: + output["flat_graph"] = AttributeDict(self.flat_graph) return output def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext: diff --git a/sqlmesh/utils/conversions.py b/sqlmesh/utils/conversions.py index 2b92772022..f2a7ea9f49 100644 --- a/sqlmesh/utils/conversions.py +++ b/sqlmesh/utils/conversions.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +from datetime import date, datetime def ensure_bool(val: t.Any) -> bool: @@ -19,3 +20,13 @@ def try_str_to_bool(val: str) -> t.Union[str, bool]: return maybe_bool == "true" return val + + +def serializable(obj: t.Any) -> t.Any: + if isinstance(obj, (date, datetime)): + return obj.isoformat() + if isinstance(obj, dict): + return {k: serializable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [serializable(item) for item in obj] + return obj diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index a16cc16f43..ef5eb65a6e 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -1606,6 +1606,7 @@ def test_on_run_start_end(): assert root_environment_statements.after_all == [ "JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;", "JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_last;\nJINJA_END;", + "JINJA_STATEMENT_BEGIN;\n{{ graph_usage() }}\nJINJA_END;", ] assert root_environment_statements.jinja_macros.root_package_name == "sushi" @@ -1626,6 +1627,7 @@ def test_on_run_start_end(): snapshots=sushi_context.snapshots, runtime_stage=RuntimeStage.AFTER_ALL, environment_naming_info=EnvironmentNamingInfo(name="dev"), + engine_adapter=sushi_context.engine_adapter, ) assert rendered_before_all == [ @@ -1635,13 +1637,14 @@ def test_on_run_start_end(): ] # The jinja macro should have resolved the schemas for this environment and generated corresponding statements - assert sorted(rendered_after_all) == sorted( - [ - "CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema", - "CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema", - "DROP TABLE to_be_executed_last", - ] - ) + expected_statements = [ + "CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema", + "CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema", + "DROP TABLE to_be_executed_last", + "CREATE OR REPLACE TABLE graph_table AS SELECT 'model.sushi.simple_model_a' AS unique_id, 'table' AS materialized UNION ALL SELECT 'model.sushi.waiters' AS unique_id, 'ephemeral' AS materialized UNION ALL SELECT 'model.sushi.simple_model_b' AS unique_id, 'table' AS materialized UNION ALL SELECT 'model.sushi.waiter_as_customer_by_day' AS unique_id, 'incremental' AS materialized UNION ALL SELECT 'model.sushi.top_waiters' AS unique_id, 'view' AS materialized UNION ALL SELECT 'model.customers.customers' AS unique_id, 'view' AS materialized UNION ALL SELECT 'model.customers.customer_revenue_by_day' AS unique_id, 'incremental' AS materialized UNION ALL SELECT 'model.sushi.waiter_revenue_by_day.v1' AS unique_id, 'incremental' AS materialized UNION ALL SELECT 'model.sushi.waiter_revenue_by_day.v2' AS unique_id, 'incremental' AS materialized", + ] + + assert sorted(rendered_after_all) == sorted(expected_statements) # Nested dbt_packages on run start / on run end packaged_environment_statements = sushi_context._environment_statements[1] @@ -1675,6 +1678,7 @@ def test_on_run_start_end(): snapshots=sushi_context.snapshots, runtime_stage=RuntimeStage.AFTER_ALL, environment_naming_info=EnvironmentNamingInfo(name="dev"), + engine_adapter=sushi_context.engine_adapter, ) # Validate order of execution to match dbt's diff --git a/tests/fixtures/dbt/sushi_test/dbt_project.yml b/tests/fixtures/dbt/sushi_test/dbt_project.yml index 1afa7dd2c6..073d85b4d4 100644 --- a/tests/fixtures/dbt/sushi_test/dbt_project.yml +++ b/tests/fixtures/dbt/sushi_test/dbt_project.yml @@ -70,4 +70,5 @@ on-run-start: - "{{ log_value('on-run-start') }}" on-run-end: - '{{ create_tables(schemas) }}' - - 'DROP TABLE to_be_executed_last;' \ No newline at end of file + - 'DROP TABLE to_be_executed_last;' + - '{{ graph_usage() }}' \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/macros/graph_usage.sql b/tests/fixtures/dbt/sushi_test/macros/graph_usage.sql new file mode 100644 index 0000000000..8b133ec280 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/graph_usage.sql @@ -0,0 +1,18 @@ +{% macro graph_usage() %} +{% if execute %} + {% set model_nodes = graph.nodes.values() + | selectattr("resource_type", "equalto", "model") + | list %} + + {% set out = [] %} + {% for node in model_nodes %} + {% set line = "select '" ~ node.unique_id ~ "' as unique_id, '" ~ node.config.materialized ~ "' as materialized" %} + {% do out.append(line) %} + {% endfor %} + + {% if out %} + {% set sql_statement = "create or replace table graph_table as\n" ~ (out | join('\nunion all\n')) %} + {{ return(sql_statement) }} + {% endif %} +{% endif %} +{% endmacro %} From a50873694ddb5b2fd3dffe20d84fd9d0e4866845 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Fri, 15 Aug 2025 17:50:48 +0300 Subject: [PATCH 2/3] refactor graph to the manifest; make test more deterministic --- sqlmesh/dbt/context.py | 46 +------------------------------- sqlmesh/dbt/manifest.py | 34 +++++++++++++++++++++++ sqlmesh/utils/jinja.py | 3 +++ tests/dbt/test_transformation.py | 26 ++++++++++++++++-- 4 files changed, 62 insertions(+), 47 deletions(-) diff --git a/sqlmesh/dbt/context.py b/sqlmesh/dbt/context.py index d7b9901623..2eceb005a7 100644 --- a/sqlmesh/dbt/context.py +++ b/sqlmesh/dbt/context.py @@ -11,7 +11,6 @@ from sqlmesh.dbt.manifest import ManifestHelper from sqlmesh.dbt.target import TargetConfig from sqlmesh.utils import AttributeDict -from sqlmesh.utils.conversions import serializable from sqlmesh.utils.errors import ConfigError, SQLMeshError from sqlmesh.utils.jinja import ( JinjaGlobalAttribute, @@ -196,49 +195,6 @@ def refs(self) -> t.Dict[str, t.Union[ModelConfig, SeedConfig]]: self._refs[f"{config_name}_v{model.version}"] = model return self._refs - @property - def flat_graph(self) -> t.Dict[str, t.Any]: - if self._manifest is None: - return { - "exposures": {}, - "groups": {}, - "metrics": {}, - "nodes": {}, - "sources": {}, - "semantic_models": {}, - "saved_queries": {}, - } - - manifest = self._manifest._manifest - return { - "exposures": { - k: serializable(v.to_dict(omit_none=False)) - for k, v in getattr(manifest, "exposures", {}).items() - }, - "groups": { - k: serializable(v.to_dict(omit_none=False)) - for k, v in getattr(manifest, "groups", {}).items() - }, - "metrics": { - k: serializable(v.to_dict(omit_none=False)) - for k, v in getattr(manifest, "metrics", {}).items() - }, - "nodes": { - k: serializable(v.to_dict(omit_none=False)) for k, v in manifest.nodes.items() - }, - "sources": { - k: serializable(v.to_dict(omit_none=False)) for k, v in manifest.sources.items() - }, - "semantic_models": { - k: serializable(v.to_dict(omit_none=False)) - for k, v in getattr(manifest, "semantic_models", {}).items() - }, - "saved_queries": { - k: serializable(v.to_dict(omit_none=False)) - for k, v in getattr(manifest, "saved_queries", {}).items() - }, - } - @property def target(self) -> TargetConfig: if not self._target: @@ -288,7 +244,7 @@ def jinja_globals(self) -> t.Dict[str, JinjaGlobalAttribute]: output["dialect"] = self.sqlmesh_config.dialect # Pass flat graph structure like dbt if self._manifest is not None: - output["flat_graph"] = AttributeDict(self.flat_graph) + output["flat_graph"] = AttributeDict(self.manifest.flat_graph) return output def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext: diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index 4f839b9c9b..f08d0aa828 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -12,6 +12,8 @@ from dbt import constants as dbt_constants, flags +from sqlmesh.utils.conversions import serializable + # Override the file name to prevent dbt commands from invalidating the cache. dbt_constants.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack" @@ -155,6 +157,38 @@ def all_macros(self) -> t.Dict[str, t.Dict[str, MacroInfo]]: result[package_name][macro_name] = macro_config.info return result + @property + def flat_graph(self) -> t.Dict[str, t.Any]: + return { + "exposures": { + k: serializable(v.to_dict(omit_none=False)) + for k, v in getattr(self._manifest, "exposures", {}).items() + }, + "groups": { + k: serializable(v.to_dict(omit_none=False)) + for k, v in getattr(self._manifest, "groups", {}).items() + }, + "metrics": { + k: serializable(v.to_dict(omit_none=False)) + for k, v in getattr(self._manifest, "metrics", {}).items() + }, + "nodes": { + k: serializable(v.to_dict(omit_none=False)) for k, v in self._manifest.nodes.items() + }, + "sources": { + k: serializable(v.to_dict(omit_none=False)) + for k, v in self._manifest.sources.items() + }, + "semantic_models": { + k: serializable(v.to_dict(omit_none=False)) + for k, v in getattr(self._manifest, "semantic_models", {}).items() + }, + "saved_queries": { + k: serializable(v.to_dict(omit_none=False)) + for k, v in getattr(self._manifest, "saved_queries", {}).items() + }, + } + def _load_all(self) -> None: if self._is_loaded: return diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index 6720c24581..fc9d898159 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -363,6 +363,9 @@ def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None: Args: globals: The global objects that should be added. """ + # Keep the registry lightweight when the graph is not needed + if not "graph" in self.packages: + globals.pop("flat_graph", None) self.global_objs.update(**self._validate_global_objs(globals)) def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index ef5eb65a6e..cefedd6814 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -1641,10 +1641,32 @@ def test_on_run_start_end(): "CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema", "CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema", "DROP TABLE to_be_executed_last", - "CREATE OR REPLACE TABLE graph_table AS SELECT 'model.sushi.simple_model_a' AS unique_id, 'table' AS materialized UNION ALL SELECT 'model.sushi.waiters' AS unique_id, 'ephemeral' AS materialized UNION ALL SELECT 'model.sushi.simple_model_b' AS unique_id, 'table' AS materialized UNION ALL SELECT 'model.sushi.waiter_as_customer_by_day' AS unique_id, 'incremental' AS materialized UNION ALL SELECT 'model.sushi.top_waiters' AS unique_id, 'view' AS materialized UNION ALL SELECT 'model.customers.customers' AS unique_id, 'view' AS materialized UNION ALL SELECT 'model.customers.customer_revenue_by_day' AS unique_id, 'incremental' AS materialized UNION ALL SELECT 'model.sushi.waiter_revenue_by_day.v1' AS unique_id, 'incremental' AS materialized UNION ALL SELECT 'model.sushi.waiter_revenue_by_day.v2' AS unique_id, 'incremental' AS materialized", ] + assert sorted(rendered_after_all[:-1]) == sorted(expected_statements) - assert sorted(rendered_after_all) == sorted(expected_statements) + # Assert the models with their materialisations are present in the rendered graph_table statement + graph_table_stmt = rendered_after_all[-1] + assert "'model.sushi.simple_model_a' AS unique_id, 'table' AS materialized" in graph_table_stmt + assert "'model.sushi.waiters' AS unique_id, 'ephemeral' AS materialized" in graph_table_stmt + assert "'model.sushi.simple_model_b' AS unique_id, 'table' AS materialized" in graph_table_stmt + assert ( + "'model.sushi.waiter_as_customer_by_day' AS unique_id, 'incremental' AS materialized" + in graph_table_stmt + ) + assert "'model.sushi.top_waiters' AS unique_id, 'view' AS materialized" in graph_table_stmt + assert "'model.customers.customers' AS unique_id, 'view' AS materialized" in graph_table_stmt + assert ( + "'model.customers.customer_revenue_by_day' AS unique_id, 'incremental' AS materialized" + in graph_table_stmt + ) + assert ( + "'model.sushi.waiter_revenue_by_day.v1' AS unique_id, 'incremental' AS materialized" + in graph_table_stmt + ) + assert ( + "'model.sushi.waiter_revenue_by_day.v2' AS unique_id, 'incremental' AS materialized" + in graph_table_stmt + ) # Nested dbt_packages on run start / on run end packaged_environment_statements = sushi_context._environment_statements[1] From d577edd59a6adfbf5c9bd68f505b5f656790abbd Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Mon, 18 Aug 2025 16:39:21 +0300 Subject: [PATCH 3/3] rename function --- sqlmesh/dbt/manifest.py | 17 +++++++++-------- sqlmesh/utils/conversions.py | 6 +++--- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index f08d0aa828..91c87f413e 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -12,7 +12,7 @@ from dbt import constants as dbt_constants, flags -from sqlmesh.utils.conversions import serializable +from sqlmesh.utils.conversions import make_serializable # Override the file name to prevent dbt commands from invalidating the cache. dbt_constants.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack" @@ -161,30 +161,31 @@ def all_macros(self) -> t.Dict[str, t.Dict[str, MacroInfo]]: def flat_graph(self) -> t.Dict[str, t.Any]: return { "exposures": { - k: serializable(v.to_dict(omit_none=False)) + k: make_serializable(v.to_dict(omit_none=False)) for k, v in getattr(self._manifest, "exposures", {}).items() }, "groups": { - k: serializable(v.to_dict(omit_none=False)) + k: make_serializable(v.to_dict(omit_none=False)) for k, v in getattr(self._manifest, "groups", {}).items() }, "metrics": { - k: serializable(v.to_dict(omit_none=False)) + k: make_serializable(v.to_dict(omit_none=False)) for k, v in getattr(self._manifest, "metrics", {}).items() }, "nodes": { - k: serializable(v.to_dict(omit_none=False)) for k, v in self._manifest.nodes.items() + k: make_serializable(v.to_dict(omit_none=False)) + for k, v in self._manifest.nodes.items() }, "sources": { - k: serializable(v.to_dict(omit_none=False)) + k: make_serializable(v.to_dict(omit_none=False)) for k, v in self._manifest.sources.items() }, "semantic_models": { - k: serializable(v.to_dict(omit_none=False)) + k: make_serializable(v.to_dict(omit_none=False)) for k, v in getattr(self._manifest, "semantic_models", {}).items() }, "saved_queries": { - k: serializable(v.to_dict(omit_none=False)) + k: make_serializable(v.to_dict(omit_none=False)) for k, v in getattr(self._manifest, "saved_queries", {}).items() }, } diff --git a/sqlmesh/utils/conversions.py b/sqlmesh/utils/conversions.py index f2a7ea9f49..411f3c8ab1 100644 --- a/sqlmesh/utils/conversions.py +++ b/sqlmesh/utils/conversions.py @@ -22,11 +22,11 @@ def try_str_to_bool(val: str) -> t.Union[str, bool]: return val -def serializable(obj: t.Any) -> t.Any: +def make_serializable(obj: t.Any) -> t.Any: if isinstance(obj, (date, datetime)): return obj.isoformat() if isinstance(obj, dict): - return {k: serializable(v) for k, v in obj.items()} + return {k: make_serializable(v) for k, v in obj.items()} if isinstance(obj, list): - return [serializable(item) for item in obj] + return [make_serializable(item) for item in obj] return obj