Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def is_custom(self) -> bool:
def is_managed(self) -> bool:
return self.model_kind_name == ModelKindName.MANAGED

@property
def is_dbt_custom(self) -> bool:
return self.model_kind_name == ModelKindName.DBT_CUSTOM

@property
def is_symbolic(self) -> bool:
"""A symbolic model is one that doesn't execute at all."""
Expand Down Expand Up @@ -170,6 +174,7 @@ class ModelKindName(str, ModelKindMixin, Enum):
EXTERNAL = "EXTERNAL"
CUSTOM = "CUSTOM"
MANAGED = "MANAGED"
DBT_CUSTOM = "DBT_CUSTOM"

@property
def model_kind_name(self) -> t.Optional[ModelKindName]:
Expand Down Expand Up @@ -887,6 +892,46 @@ def supports_python_models(self) -> bool:
return False


class DbtCustomKind(_ModelKind):
name: t.Literal[ModelKindName.DBT_CUSTOM] = ModelKindName.DBT_CUSTOM
materialization: str
adapter: str = "default"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this adapter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's to allow users to define different materialisation tied to different adapters and for dispatch then to appropriately match the correct one used

{% materialization my_materialization_name, default %}
 -- cross-adapter materialization... assume Redshift is not supported
{% endmaterialization %}

{% materialization my_materialization_name, adapter='redshift' %}
-- override the materialization for Redshift
{% endmaterialization %}

test for this example: test_adapter_specific_materialization_override

definition: str
dialect: t.Optional[str] = Field(None, validate_default=True)

_dialect_validator = kind_dialect_validator

@field_validator("materialization", "adapter", "definition", mode="before")
@classmethod
def _validate_fields(cls, v: t.Any) -> str:
return validate_string(v)

@property
def data_hash_values(self) -> t.List[t.Optional[str]]:
return [
*super().data_hash_values,
self.materialization,
self.definition,
self.adapter,
self.dialect,
]

def to_expression(
self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any
) -> d.ModelKind:
return super().to_expression(
expressions=[
*(expressions or []),
*_properties(
{
"materialization": exp.Literal.string(self.materialization),
"adapter": exp.Literal.string(self.adapter),
}
),
],
)


class EmbeddedKind(_ModelKind):
name: t.Literal[ModelKindName.EMBEDDED] = ModelKindName.EMBEDDED

Expand Down Expand Up @@ -992,6 +1037,7 @@ def to_expression(
SCDType2ByColumnKind,
CustomKind,
ManagedKind,
DbtCustomKind,
],
Field(discriminator="name"),
]
Expand All @@ -1011,6 +1057,7 @@ def to_expression(
ModelKindName.SCD_TYPE_2_BY_COLUMN: SCDType2ByColumnKind,
ModelKindName.CUSTOM: CustomKind,
ModelKindName.MANAGED: ManagedKind,
ModelKindName.DBT_CUSTOM: DbtCustomKind,
}


Expand Down
205 changes: 199 additions & 6 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
ViewKind,
CustomKind,
)
from sqlmesh.core.model.kind import _Incremental
from sqlmesh.core.model.kind import _Incremental, DbtCustomKind
from sqlmesh.utils import CompletionStatus, columns_to_types_all_known
from sqlmesh.core.schema_diff import (
has_drop_alteration,
Expand All @@ -67,7 +67,7 @@
SnapshotTableCleanupTask,
)
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
from sqlmesh.utils import random_id, CorrelationId
from sqlmesh.utils import random_id, CorrelationId, AttributeDict
from sqlmesh.utils.concurrency import (
concurrent_apply_to_snapshots,
concurrent_apply_to_values,
Expand All @@ -83,6 +83,7 @@
format_additive_change_msg,
AdditiveChangeError,
)
from sqlmesh.utils.jinja import MacroReturnVal

if sys.version_info >= (3, 12):
from importlib import metadata
Expand Down Expand Up @@ -747,7 +748,10 @@ def _evaluate_snapshot(
adapter.transaction(),
adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)),
):
adapter.execute(model.render_pre_statements(**render_statements_kwargs))
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
evaluation_strategy.run_pre_statements(
snapshot=snapshot, render_kwargs=render_statements_kwargs
)

if not target_table_exists or (model.is_seed and not snapshot.intervals):
# Only create the empty table if the columns were provided explicitly by the user
Expand Down Expand Up @@ -817,7 +821,9 @@ def _evaluate_snapshot(
batch_index=batch_index,
)

adapter.execute(model.render_post_statements(**render_statements_kwargs))
evaluation_strategy.run_post_statements(
snapshot=snapshot, render_kwargs=render_statements_kwargs
)

return wap_id

Expand Down Expand Up @@ -1433,7 +1439,9 @@ def _execute_create(
"table_mapping": {snapshot.name: table_name},
}
if run_pre_post_statements:
adapter.execute(snapshot.model.render_pre_statements(**create_render_kwargs))
evaluation_strategy.run_pre_statements(
snapshot=snapshot, render_kwargs=create_render_kwargs
)
evaluation_strategy.create(
table_name=table_name,
model=snapshot.model,
Expand All @@ -1445,7 +1453,9 @@ def _execute_create(
physical_properties=rendered_physical_properties,
)
if run_pre_post_statements:
adapter.execute(snapshot.model.render_post_statements(**create_render_kwargs))
evaluation_strategy.run_post_statements(
snapshot=snapshot, render_kwargs=create_render_kwargs
)

def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex) -> bool:
adapter = self.get_adapter(snapshot.model.gateway)
Expand All @@ -1456,6 +1466,7 @@ def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex
and adapter.SUPPORTS_CLONING
# managed models cannot have their schema mutated because theyre based on queries, so clone + alter wont work
and not snapshot.is_managed
and not snapshot.is_dbt_custom
and not deployability_index.is_deployable(snapshot)
# If the deployable table is missing we can't clone it
and adapter.table_exists(snapshot.table_name())
Expand Down Expand Up @@ -1540,6 +1551,19 @@ def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) ->
klass = ViewStrategy
elif snapshot.is_scd_type_2:
klass = SCDType2Strategy
elif snapshot.is_dbt_custom:
if hasattr(snapshot, "model") and isinstance(
(model_kind := snapshot.model.kind), DbtCustomKind
):
return DbtCustomMaterializationStrategy(
adapter=adapter,
materialization_name=model_kind.materialization,
materialization_template=model_kind.definition,
)

raise SQLMeshError(
f"Expected DbtCustomKind for dbt custom materialization in model '{snapshot.name}'"
)
elif snapshot.is_custom:
if snapshot.custom_materialization is None:
raise SQLMeshError(
Expand Down Expand Up @@ -1679,6 +1703,24 @@ def demote(self, view_name: str, **kwargs: t.Any) -> None:
view_name: The name of the target view in the virtual layer.
"""

@abc.abstractmethod
def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
"""Executes the snapshot's pre statements.

Args:
snapshot: The target snapshot.
render_kwargs: Additional key-value arguments to pass when rendering the statements.
"""

@abc.abstractmethod
def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
"""Executes the snapshot's post statements.

Args:
snapshot: The target snapshot.
render_kwargs: Additional key-value arguments to pass when rendering the statements.
"""


class SymbolicStrategy(EvaluationStrategy):
def insert(
Expand Down Expand Up @@ -1740,6 +1782,12 @@ def promote(
def demote(self, view_name: str, **kwargs: t.Any) -> None:
pass

def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Dict[str, t.Any]) -> None:
pass

def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Dict[str, t.Any]) -> None:
pass


class EmbeddedStrategy(SymbolicStrategy):
def promote(
Expand Down Expand Up @@ -1787,6 +1835,12 @@ def demote(self, view_name: str, **kwargs: t.Any) -> None:
logger.info("Dropping view '%s'", view_name)
self.adapter.drop_view(view_name, cascade=False)

def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
self.adapter.execute(snapshot.model.render_pre_statements(**render_kwargs))

def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
self.adapter.execute(snapshot.model.render_post_statements(**render_kwargs))


class MaterializableStrategy(PromotableStrategy, abc.ABC):
def create(
Expand Down Expand Up @@ -2593,6 +2647,145 @@ def get_custom_materialization_type_or_raise(
raise SQLMeshError(f"Custom materialization '{name}' not present in the Python environment")


class DbtCustomMaterializationStrategy(MaterializableStrategy):
def __init__(
self,
adapter: EngineAdapter,
materialization_name: str,
materialization_template: str,
):
super().__init__(adapter)
self.materialization_name = materialization_name
self.materialization_template = materialization_template

def create(
self,
table_name: str,
model: Model,
is_table_deployable: bool,
render_kwargs: t.Dict[str, t.Any],
**kwargs: t.Any,
) -> None:
original_query = model.render_query_or_raise(**render_kwargs)
self._execute_materialization(
table_name=table_name,
query_or_df=original_query.limit(0),
model=model,
is_first_insert=True,
render_kwargs=render_kwargs,
create_only=True,
**kwargs,
)

def insert(
self,
table_name: str,
query_or_df: QueryOrDF,
model: Model,
is_first_insert: bool,
render_kwargs: t.Dict[str, t.Any],
**kwargs: t.Any,
) -> None:
self._execute_materialization(
table_name=table_name,
query_or_df=query_or_df,
model=model,
is_first_insert=is_first_insert,
render_kwargs=render_kwargs,
**kwargs,
)

def append(
self,
table_name: str,
query_or_df: QueryOrDF,
model: Model,
render_kwargs: t.Dict[str, t.Any],
**kwargs: t.Any,
) -> None:
return self.insert(
table_name,
query_or_df,
model,
is_first_insert=False,
render_kwargs=render_kwargs,
**kwargs,
)

def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
# in dbt custom materialisations it's up to the user when to run the pre hooks
pass

def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None:
# in dbt custom materialisations it's up to the user when to run the post hooks
pass

def _execute_materialization(
self,
table_name: str,
query_or_df: QueryOrDF,
model: Model,
is_first_insert: bool,
render_kwargs: t.Dict[str, t.Any],
create_only: bool = False,
**kwargs: t.Any,
) -> None:
jinja_macros = model.jinja_macros

# For vdes we need to use the table, since we don't know the schema/table at parse time
parts = exp.to_table(table_name, dialect=self.adapter.dialect)

existing_globals = jinja_macros.global_objs
relation_info = existing_globals.get("this")
if isinstance(relation_info, dict):
relation_info["database"] = parts.catalog
relation_info["identifier"] = parts.name
relation_info["name"] = parts.name

jinja_globals = {
**existing_globals,
"this": relation_info,
"database": parts.catalog,
"schema": parts.db,
"identifier": parts.name,
"target": existing_globals.get("target", {"type": self.adapter.dialect}),
"execution_dt": kwargs.get("execution_time"),
"engine_adapter": self.adapter,
"sql": str(query_or_df),
"is_first_insert": is_first_insert,
"create_only": create_only,
# FIXME: Add support for transaction=False
"pre_hooks": [
AttributeDict({"sql": s.this.this, "transaction": True})
for s in model.pre_statements
],
"post_hooks": [
AttributeDict({"sql": s.this.this, "transaction": True})
for s in model.post_statements
],
"model_instance": model,
**kwargs,
}

try:
jinja_env = jinja_macros.build_environment(**jinja_globals)
template = jinja_env.from_string(self.materialization_template)

try:
template.render()
except MacroReturnVal as ret:
# this is a successful return from a macro call (dbt uses this list of Relations to update their relation cache)
returned_relations = ret.value.get("relations", [])
logger.info(
f"Materialization {self.materialization_name} returned relations: {returned_relations}"
)

except Exception as e:
raise SQLMeshError(
f"Failed to execute dbt materialization '{self.materialization_name}': {e}"
) from e


class EngineManagedStrategy(MaterializableStrategy):
def create(
self,
Expand Down
Loading