Skip to content

Commit a255e17

Browse files
Feat(dbt): Add support for dbt custom materializations (#5435)
Co-authored-by: Iaroslav Zeigerman <zeigerman.ia@gmail.com>
1 parent cbcb6d2 commit a255e17

File tree

16 files changed

+1321
-11
lines changed

16 files changed

+1321
-11
lines changed

sqlmesh/core/model/kind.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def is_custom(self) -> bool:
119119
def is_managed(self) -> bool:
120120
return self.model_kind_name == ModelKindName.MANAGED
121121

122+
@property
123+
def is_dbt_custom(self) -> bool:
124+
return self.model_kind_name == ModelKindName.DBT_CUSTOM
125+
122126
@property
123127
def is_symbolic(self) -> bool:
124128
"""A symbolic model is one that doesn't execute at all."""
@@ -170,6 +174,7 @@ class ModelKindName(str, ModelKindMixin, Enum):
170174
EXTERNAL = "EXTERNAL"
171175
CUSTOM = "CUSTOM"
172176
MANAGED = "MANAGED"
177+
DBT_CUSTOM = "DBT_CUSTOM"
173178

174179
@property
175180
def model_kind_name(self) -> t.Optional[ModelKindName]:
@@ -887,6 +892,46 @@ def supports_python_models(self) -> bool:
887892
return False
888893

889894

895+
class DbtCustomKind(_ModelKind):
896+
name: t.Literal[ModelKindName.DBT_CUSTOM] = ModelKindName.DBT_CUSTOM
897+
materialization: str
898+
adapter: str = "default"
899+
definition: str
900+
dialect: t.Optional[str] = Field(None, validate_default=True)
901+
902+
_dialect_validator = kind_dialect_validator
903+
904+
@field_validator("materialization", "adapter", "definition", mode="before")
905+
@classmethod
906+
def _validate_fields(cls, v: t.Any) -> str:
907+
return validate_string(v)
908+
909+
@property
910+
def data_hash_values(self) -> t.List[t.Optional[str]]:
911+
return [
912+
*super().data_hash_values,
913+
self.materialization,
914+
self.definition,
915+
self.adapter,
916+
self.dialect,
917+
]
918+
919+
def to_expression(
920+
self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
921+
) -> d.ModelKind:
922+
return super().to_expression(
923+
expressions=[
924+
*(expressions or []),
925+
*_properties(
926+
{
927+
"materialization": exp.Literal.string(self.materialization),
928+
"adapter": exp.Literal.string(self.adapter),
929+
}
930+
),
931+
],
932+
)
933+
934+
890935
class EmbeddedKind(_ModelKind):
891936
name: t.Literal[ModelKindName.EMBEDDED] = ModelKindName.EMBEDDED
892937

@@ -992,6 +1037,7 @@ def to_expression(
9921037
SCDType2ByColumnKind,
9931038
CustomKind,
9941039
ManagedKind,
1040+
DbtCustomKind,
9951041
],
9961042
Field(discriminator="name"),
9971043
]
@@ -1011,6 +1057,7 @@ def to_expression(
10111057
ModelKindName.SCD_TYPE_2_BY_COLUMN: SCDType2ByColumnKind,
10121058
ModelKindName.CUSTOM: CustomKind,
10131059
ModelKindName.MANAGED: ManagedKind,
1060+
ModelKindName.DBT_CUSTOM: DbtCustomKind,
10141061
}
10151062

10161063

sqlmesh/core/snapshot/evaluator.py

Lines changed: 199 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
ViewKind,
5151
CustomKind,
5252
)
53-
from sqlmesh.core.model.kind import _Incremental
53+
from sqlmesh.core.model.kind import _Incremental, DbtCustomKind
5454
from sqlmesh.utils import CompletionStatus, columns_to_types_all_known
5555
from sqlmesh.core.schema_diff import (
5656
has_drop_alteration,
@@ -67,7 +67,7 @@
6767
SnapshotTableCleanupTask,
6868
)
6969
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
70-
from sqlmesh.utils import random_id, CorrelationId
70+
from sqlmesh.utils import random_id, CorrelationId, AttributeDict
7171
from sqlmesh.utils.concurrency import (
7272
concurrent_apply_to_snapshots,
7373
concurrent_apply_to_values,
@@ -83,6 +83,7 @@
8383
format_additive_change_msg,
8484
AdditiveChangeError,
8585
)
86+
from sqlmesh.utils.jinja import MacroReturnVal
8687

8788
if sys.version_info >= (3, 12):
8889
from importlib import metadata
@@ -747,7 +748,10 @@ def _evaluate_snapshot(
747748
adapter.transaction(),
748749
adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)),
749750
):
750-
adapter.execute(model.render_pre_statements(**render_statements_kwargs))
751+
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
752+
evaluation_strategy.run_pre_statements(
753+
snapshot=snapshot, render_kwargs=render_statements_kwargs
754+
)
751755

752756
if not target_table_exists or (model.is_seed and not snapshot.intervals):
753757
# Only create the empty table if the columns were provided explicitly by the user
@@ -817,7 +821,9 @@ def _evaluate_snapshot(
817821
batch_index=batch_index,
818822
)
819823

820-
adapter.execute(model.render_post_statements(**render_statements_kwargs))
824+
evaluation_strategy.run_post_statements(
825+
snapshot=snapshot, render_kwargs=render_statements_kwargs
826+
)
821827

822828
return wap_id
823829

@@ -1433,7 +1439,9 @@ def _execute_create(
14331439
"table_mapping": {snapshot.name: table_name},
14341440
}
14351441
if run_pre_post_statements:
1436-
adapter.execute(snapshot.model.render_pre_statements(**create_render_kwargs))
1442+
evaluation_strategy.run_pre_statements(
1443+
snapshot=snapshot, render_kwargs=create_render_kwargs
1444+
)
14371445
evaluation_strategy.create(
14381446
table_name=table_name,
14391447
model=snapshot.model,
@@ -1445,7 +1453,9 @@ def _execute_create(
14451453
physical_properties=rendered_physical_properties,
14461454
)
14471455
if run_pre_post_statements:
1448-
adapter.execute(snapshot.model.render_post_statements(**create_render_kwargs))
1456+
evaluation_strategy.run_post_statements(
1457+
snapshot=snapshot, render_kwargs=create_render_kwargs
1458+
)
14491459

14501460
def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex) -> bool:
14511461
adapter = self.get_adapter(snapshot.model.gateway)
@@ -1456,6 +1466,7 @@ def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex
14561466
and adapter.SUPPORTS_CLONING
14571467
# managed models cannot have their schema mutated because theyre based on queries, so clone + alter wont work
14581468
and not snapshot.is_managed
1469+
and not snapshot.is_dbt_custom
14591470
and not deployability_index.is_deployable(snapshot)
14601471
# If the deployable table is missing we can't clone it
14611472
and adapter.table_exists(snapshot.table_name())
@@ -1540,6 +1551,19 @@ def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) ->
15401551
klass = ViewStrategy
15411552
elif snapshot.is_scd_type_2:
15421553
klass = SCDType2Strategy
1554+
elif snapshot.is_dbt_custom:
1555+
if hasattr(snapshot, "model") and isinstance(
1556+
(model_kind := snapshot.model.kind), DbtCustomKind
1557+
):
1558+
return DbtCustomMaterializationStrategy(
1559+
adapter=adapter,
1560+
materialization_name=model_kind.materialization,
1561+
materialization_template=model_kind.definition,
1562+
)
1563+
1564+
raise SQLMeshError(
1565+
f"Expected DbtCustomKind for dbt custom materialization in model '{snapshot.name}'"
1566+
)
15431567
elif snapshot.is_custom:
15441568
if snapshot.custom_materialization is None:
15451569
raise SQLMeshError(
@@ -1679,6 +1703,24 @@ def demote(self, view_name: str, **kwargs: t.Any) -> None:
16791703
view_name: The name of the target view in the virtual layer.
16801704
"""
16811705

1706+
@abc.abstractmethod
1707+
def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
1708+
"""Executes the snapshot's pre statements.
1709+
1710+
Args:
1711+
snapshot: The target snapshot.
1712+
render_kwargs: Additional key-value arguments to pass when rendering the statements.
1713+
"""
1714+
1715+
@abc.abstractmethod
1716+
def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
1717+
"""Executes the snapshot's post statements.
1718+
1719+
Args:
1720+
snapshot: The target snapshot.
1721+
render_kwargs: Additional key-value arguments to pass when rendering the statements.
1722+
"""
1723+
16821724

16831725
class SymbolicStrategy(EvaluationStrategy):
16841726
def insert(
@@ -1740,6 +1782,12 @@ def promote(
17401782
def demote(self, view_name: str, **kwargs: t.Any) -> None:
17411783
pass
17421784

1785+
def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Dict[str, t.Any]) -> None:
1786+
pass
1787+
1788+
def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Dict[str, t.Any]) -> None:
1789+
pass
1790+
17431791

17441792
class EmbeddedStrategy(SymbolicStrategy):
17451793
def promote(
@@ -1787,6 +1835,12 @@ def demote(self, view_name: str, **kwargs: t.Any) -> None:
17871835
logger.info("Dropping view '%s'", view_name)
17881836
self.adapter.drop_view(view_name, cascade=False)
17891837

1838+
def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
1839+
self.adapter.execute(snapshot.model.render_pre_statements(**render_kwargs))
1840+
1841+
def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
1842+
self.adapter.execute(snapshot.model.render_post_statements(**render_kwargs))
1843+
17901844

17911845
class MaterializableStrategy(PromotableStrategy, abc.ABC):
17921846
def create(
@@ -2593,6 +2647,145 @@ def get_custom_materialization_type_or_raise(
25932647
raise SQLMeshError(f"Custom materialization '{name}' not present in the Python environment")
25942648

25952649

2650+
class DbtCustomMaterializationStrategy(MaterializableStrategy):
2651+
def __init__(
2652+
self,
2653+
adapter: EngineAdapter,
2654+
materialization_name: str,
2655+
materialization_template: str,
2656+
):
2657+
super().__init__(adapter)
2658+
self.materialization_name = materialization_name
2659+
self.materialization_template = materialization_template
2660+
2661+
def create(
2662+
self,
2663+
table_name: str,
2664+
model: Model,
2665+
is_table_deployable: bool,
2666+
render_kwargs: t.Dict[str, t.Any],
2667+
**kwargs: t.Any,
2668+
) -> None:
2669+
original_query = model.render_query_or_raise(**render_kwargs)
2670+
self._execute_materialization(
2671+
table_name=table_name,
2672+
query_or_df=original_query.limit(0),
2673+
model=model,
2674+
is_first_insert=True,
2675+
render_kwargs=render_kwargs,
2676+
create_only=True,
2677+
**kwargs,
2678+
)
2679+
2680+
def insert(
2681+
self,
2682+
table_name: str,
2683+
query_or_df: QueryOrDF,
2684+
model: Model,
2685+
is_first_insert: bool,
2686+
render_kwargs: t.Dict[str, t.Any],
2687+
**kwargs: t.Any,
2688+
) -> None:
2689+
self._execute_materialization(
2690+
table_name=table_name,
2691+
query_or_df=query_or_df,
2692+
model=model,
2693+
is_first_insert=is_first_insert,
2694+
render_kwargs=render_kwargs,
2695+
**kwargs,
2696+
)
2697+
2698+
def append(
2699+
self,
2700+
table_name: str,
2701+
query_or_df: QueryOrDF,
2702+
model: Model,
2703+
render_kwargs: t.Dict[str, t.Any],
2704+
**kwargs: t.Any,
2705+
) -> None:
2706+
return self.insert(
2707+
table_name,
2708+
query_or_df,
2709+
model,
2710+
is_first_insert=False,
2711+
render_kwargs=render_kwargs,
2712+
**kwargs,
2713+
)
2714+
2715+
def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
2716+
# in dbt custom materialisations it's up to the user when to run the pre hooks
2717+
pass
2718+
2719+
def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
2720+
# in dbt custom materialisations it's up to the user when to run the post hooks
2721+
pass
2722+
2723+
def _execute_materialization(
2724+
self,
2725+
table_name: str,
2726+
query_or_df: QueryOrDF,
2727+
model: Model,
2728+
is_first_insert: bool,
2729+
render_kwargs: t.Dict[str, t.Any],
2730+
create_only: bool = False,
2731+
**kwargs: t.Any,
2732+
) -> None:
2733+
jinja_macros = model.jinja_macros
2734+
2735+
# For vdes we need to use the table, since we don't know the schema/table at parse time
2736+
parts = exp.to_table(table_name, dialect=self.adapter.dialect)
2737+
2738+
existing_globals = jinja_macros.global_objs
2739+
relation_info = existing_globals.get("this")
2740+
if isinstance(relation_info, dict):
2741+
relation_info["database"] = parts.catalog
2742+
relation_info["identifier"] = parts.name
2743+
relation_info["name"] = parts.name
2744+
2745+
jinja_globals = {
2746+
**existing_globals,
2747+
"this": relation_info,
2748+
"database": parts.catalog,
2749+
"schema": parts.db,
2750+
"identifier": parts.name,
2751+
"target": existing_globals.get("target", {"type": self.adapter.dialect}),
2752+
"execution_dt": kwargs.get("execution_time"),
2753+
"engine_adapter": self.adapter,
2754+
"sql": str(query_or_df),
2755+
"is_first_insert": is_first_insert,
2756+
"create_only": create_only,
2757+
# FIXME: Add support for transaction=False
2758+
"pre_hooks": [
2759+
AttributeDict({"sql": s.this.this, "transaction": True})
2760+
for s in model.pre_statements
2761+
],
2762+
"post_hooks": [
2763+
AttributeDict({"sql": s.this.this, "transaction": True})
2764+
for s in model.post_statements
2765+
],
2766+
"model_instance": model,
2767+
**kwargs,
2768+
}
2769+
2770+
try:
2771+
jinja_env = jinja_macros.build_environment(**jinja_globals)
2772+
template = jinja_env.from_string(self.materialization_template)
2773+
2774+
try:
2775+
template.render()
2776+
except MacroReturnVal as ret:
2777+
# this is a successful return from a macro call (dbt uses this list of Relations to update their relation cache)
2778+
returned_relations = ret.value.get("relations", [])
2779+
logger.info(
2780+
f"Materialization {self.materialization_name} returned relations: {returned_relations}"
2781+
)
2782+
2783+
except Exception as e:
2784+
raise SQLMeshError(
2785+
f"Failed to execute dbt materialization '{self.materialization_name}': {e}"
2786+
) from e
2787+
2788+
25962789
class EngineManagedStrategy(MaterializableStrategy):
25972790
def create(
25982791
self,

0 commit comments

Comments
 (0)