From cf49aba0482fc27520ebd89d19d1e85cff5d7f11 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Wed, 24 Sep 2025 17:47:05 +0200 Subject: [PATCH 1/7] Feat(dbt): Add support for dbt custom materializations --- sqlmesh/core/model/kind.py | 53 ++ sqlmesh/core/snapshot/evaluator.py | 160 +++- sqlmesh/dbt/adapter.py | 18 + sqlmesh/dbt/basemodel.py | 14 + sqlmesh/dbt/builtin.py | 1 + sqlmesh/dbt/manifest.py | 66 +- sqlmesh/dbt/model.py | 27 + sqlmesh/dbt/package.py | 19 +- tests/dbt/test_custom_materializations.py | 721 ++++++++++++++++++ tests/dbt/test_model.py | 11 + tests/dbt/test_transformation.py | 125 ++- .../materializations/custom_incremental.sql | 58 ++ .../models/custom_incremental_model.sql | 20 + .../models/custom_incremental_with_filter.sql | 9 + 14 files changed, 1293 insertions(+), 9 deletions(-) create mode 100644 tests/dbt/test_custom_materializations.py create mode 100644 tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql create mode 100644 tests/fixtures/dbt/sushi_test/models/custom_incremental_model.sql create mode 100644 tests/fixtures/dbt/sushi_test/models/custom_incremental_with_filter.sql diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index dc5f533c21..68ea8bf523 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -119,6 +119,10 @@ def is_custom(self) -> bool: def is_managed(self) -> bool: return self.model_kind_name == ModelKindName.MANAGED + @property + def is_dbt_custom(self) -> bool: + return self.model_kind_name == ModelKindName.DBT_CUSTOM + @property def is_symbolic(self) -> bool: """A symbolic model is one that doesn't execute at all.""" @@ -170,6 +174,7 @@ class ModelKindName(str, ModelKindMixin, Enum): EXTERNAL = "EXTERNAL" CUSTOM = "CUSTOM" MANAGED = "MANAGED" + DBT_CUSTOM = "DBT_CUSTOM" @property def model_kind_name(self) -> t.Optional[ModelKindName]: @@ -887,6 +892,52 @@ def supports_python_models(self) -> bool: return False +class DbtCustomKind(_ModelKind): + name: t.Literal[ModelKindName.DBT_CUSTOM] = ModelKindName.DBT_CUSTOM + materialization: str + adapter: str = "default" + definition: str + dialect: t.Optional[str] = Field(None, validate_default=True) + + _dialect_validator = kind_dialect_validator + + @field_validator("materialization", "adapter", "definition", mode="before") + @classmethod + def _validate_fields(cls, v: t.Any) -> str: + return validate_string(v) + + @property + def data_hash_values(self) -> t.List[t.Optional[str]]: + return [ + *super().data_hash_values, + self.materialization, + self.definition, + self.adapter, + self.dialect, + ] + + @property + def metadata_hash_values(self) -> t.List[t.Optional[str]]: + return [ + *super().metadata_hash_values, + ] + + def to_expression( + self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + ) -> d.ModelKind: + return super().to_expression( + expressions=[ + *(expressions or []), + *_properties( + { + "materialization": exp.Literal.string(self.materialization), + "adapter": exp.Literal.string(self.adapter), + } + ), + ], + ) + + class EmbeddedKind(_ModelKind): name: t.Literal[ModelKindName.EMBEDDED] = ModelKindName.EMBEDDED @@ -992,6 +1043,7 @@ def to_expression( SCDType2ByColumnKind, CustomKind, ManagedKind, + DbtCustomKind, ], Field(discriminator="name"), ] @@ -1011,6 +1063,7 @@ def to_expression( ModelKindName.SCD_TYPE_2_BY_COLUMN: SCDType2ByColumnKind, ModelKindName.CUSTOM: CustomKind, ModelKindName.MANAGED: ManagedKind, + ModelKindName.DBT_CUSTOM: DbtCustomKind, } diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 70cc31b0a4..75e32463c3 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -50,7 +50,7 @@ ViewKind, CustomKind, ) -from sqlmesh.core.model.kind import _Incremental +from sqlmesh.core.model.kind import _Incremental, DbtCustomKind from sqlmesh.utils import CompletionStatus, columns_to_types_all_known from sqlmesh.core.schema_diff import ( has_drop_alteration, @@ -83,6 +83,7 @@ format_additive_change_msg, AdditiveChangeError, ) +from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReturnVal if sys.version_info >= (3, 12): from importlib import metadata @@ -747,7 +748,8 @@ def _evaluate_snapshot( adapter.transaction(), adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)), ): - adapter.execute(model.render_pre_statements(**render_statements_kwargs)) + if not snapshot.is_dbt_custom: + adapter.execute(model.render_pre_statements(**render_statements_kwargs)) if not target_table_exists or (model.is_seed and not snapshot.intervals): # Only create the empty table if the columns were provided explicitly by the user @@ -817,7 +819,8 @@ def _evaluate_snapshot( batch_index=batch_index, ) - adapter.execute(model.render_post_statements(**render_statements_kwargs)) + if not snapshot.is_dbt_custom: + adapter.execute(model.render_post_statements(**render_statements_kwargs)) return wap_id @@ -1432,7 +1435,7 @@ def _execute_create( **create_render_kwargs, "table_mapping": {snapshot.name: table_name}, } - if run_pre_post_statements: + if run_pre_post_statements and not snapshot.is_dbt_custom: adapter.execute(snapshot.model.render_pre_statements(**create_render_kwargs)) evaluation_strategy.create( table_name=table_name, @@ -1444,7 +1447,7 @@ def _execute_create( dry_run=dry_run, physical_properties=rendered_physical_properties, ) - if run_pre_post_statements: + if run_pre_post_statements and not snapshot.is_dbt_custom: adapter.execute(snapshot.model.render_post_statements(**create_render_kwargs)) def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex) -> bool: @@ -1456,6 +1459,7 @@ def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex and adapter.SUPPORTS_CLONING # managed models cannot have their schema mutated because theyre based on queries, so clone + alter wont work and not snapshot.is_managed + and not snapshot.is_dbt_custom and not deployability_index.is_deployable(snapshot) # If the deployable table is missing we can't clone it and adapter.table_exists(snapshot.table_name()) @@ -1540,6 +1544,19 @@ def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> klass = ViewStrategy elif snapshot.is_scd_type_2: klass = SCDType2Strategy + elif snapshot.is_dbt_custom: + if hasattr(snapshot, "model") and isinstance( + (model_kind := snapshot.model.kind), DbtCustomKind + ): + return DbtCustomMaterialization( + adapter=adapter, + materialization_name=model_kind.materialization, + materialization_template=model_kind.definition, + ) + + raise SQLMeshError( + f"Expected DbtCustomKind for dbt custom materialization in model '{snapshot.name}'" + ) elif snapshot.is_custom: if snapshot.custom_materialization is None: raise SQLMeshError( @@ -2593,6 +2610,139 @@ def get_custom_materialization_type_or_raise( raise SQLMeshError(f"Custom materialization '{name}' not present in the Python environment") +class DbtCustomMaterialization(MaterializableStrategy): + def __init__( + self, + adapter: EngineAdapter, + materialization_name: str, + materialization_template: str, + ): + super().__init__(adapter) + self.materialization_name = materialization_name + self.materialization_template = materialization_template + + def create( + self, + table_name: str, + model: Model, + is_table_deployable: bool, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + original_query = model.render_query_or_raise(**render_kwargs) + self._execute_materialization( + table_name=table_name, + query_or_df=original_query.limit(0), + model=model, + is_first_insert=True, + render_kwargs=render_kwargs, + create_only=True, + **kwargs, + ) + + def insert( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + self._execute_materialization( + table_name=table_name, + query_or_df=query_or_df, + model=model, + is_first_insert=is_first_insert, + render_kwargs=render_kwargs, + **kwargs, + ) + + def append( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + return self.insert( + table_name, + query_or_df, + model, + is_first_insert=False, + render_kwargs=render_kwargs, + **kwargs, + ) + + def _execute_materialization( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + create_only: bool = False, + **kwargs: t.Any, + ) -> None: + from sqlmesh.dbt.builtin import create_builtin_globals + + jinja_macros = getattr(model, "jinja_macros", JinjaMacroRegistry()) + existing_globals = jinja_macros.global_objs.copy() + + # For vdes we need to use the table, since we don't know the schema/table at parse time + parts = exp.to_table(table_name, dialect=self.adapter.dialect) + + relation_info = existing_globals.pop("this") + if isinstance(relation_info, dict): + relation_info["database"] = parts.catalog + relation_info["identifier"] = parts.name + relation_info["name"] = parts.name + + jinja_globals = { + **existing_globals, + "this": relation_info, + "database": parts.catalog, + "schema": parts.db, + "identifier": parts.name, + "target": existing_globals.get("target", {"type": self.adapter.dialect}), + "execution_dt": kwargs.get("execution_time"), + } + + context = create_builtin_globals( + jinja_macros=jinja_macros, jinja_globals=jinja_globals, engine_adapter=self.adapter + ) + + context.update( + { + "sql": str(query_or_df), + "is_first_insert": is_first_insert, + "create_only": create_only, + "pre_hooks": model.render_pre_statements(**render_kwargs), + "post_hooks": model.render_post_statements(**render_kwargs), + **kwargs, + } + ) + + try: + jinja_env = jinja_macros.build_environment(**context) + template = jinja_env.from_string(self.materialization_template) + + try: + template.render(**context) + except MacroReturnVal as ret: + # this is a succesful return from a macro call (dbt uses this list of Relations to update their relation cache) + returned_relations = ret.value.get("relations", []) + logger.info( + f"Materialization {self.materialization_name} returned relations: {returned_relations}" + ) + + except Exception as e: + raise SQLMeshError( + f"Failed to execute dbt materialization '{self.materialization_name}': {e}" + ) from e + + class EngineManagedStrategy(MaterializableStrategy): def create( self, diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 7f7c7eb4fb..a8b2b9af72 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -99,6 +99,12 @@ def execute( ) -> t.Tuple[AdapterResponse, agate.Table]: """Executes the given SQL statement and returns the results as an agate table.""" + @abc.abstractmethod + def run_hooks( + self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True + ) -> None: + """Executes the given hooks.""" + @abc.abstractmethod def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]: """Resolves the relation's schema to its physical schema.""" @@ -241,6 +247,12 @@ def execute( self._raise_parsetime_adapter_call_error("execute SQL") raise + def run_hooks( + self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True + ) -> None: + self._raise_parsetime_adapter_call_error("run hooks") + raise + def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]: return relation.schema @@ -451,6 +463,12 @@ def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]: identifier = self._map_table_name(self._normalize(self._relation_to_table(relation))).name return identifier if identifier else None + def run_hooks( + self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True + ) -> None: + # inside_transaction not yet supported similarly to transaction + self.engine_adapter.execute([exp.maybe_parse(hook) for hook in hooks]) + def _map_table_name(self, table: exp.Table) -> exp.Table: # Use the default dialect since this is the dialect used to normalize and quote keys in the # mapping table. diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 4dcf44a0af..0b75955129 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -57,6 +57,12 @@ class Materialization(str, Enum): # Snowflake, https://docs.getdbt.com/reference/resource-configs/snowflake-configs#dynamic-tables DYNAMIC_TABLE = "dynamic_table" + CUSTOM = "custom" + + @classmethod + def _missing_(cls, value): # type: ignore + return cls.CUSTOM + class SnapshotStrategy(str, Enum): """DBT snapshot strategies""" @@ -295,6 +301,14 @@ def sqlmesh_model_kwargs( # precisely which variables are referenced in the model dependencies.variables |= set(context.variables) + if ( + getattr(self, "model_materialization", None) == Materialization.CUSTOM + and hasattr(self, "_get_custom_materialization") + and (custom_mat := self._get_custom_materialization(context)) + ): + # include custom materialization dependencies as they might use macros + dependencies = dependencies.union(custom_mat.dependencies) + model_dialect = self.dialect(context) model_context = context.context_for_dependencies( dependencies.union(self.tests_ref_source_dependencies) diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index e284c11797..8690eb91fa 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -544,6 +544,7 @@ def create_builtin_globals( "load_result": sql_execution.load_result, "run_query": sql_execution.run_query, "statement": sql_execution.statement, + "run_hooks": adapter.run_hooks, "graph": adapter.graph, "selected_resources": list(jinja_globals.get("selected_models") or []), } diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index 0e33569888..7e12147e03 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -47,7 +47,7 @@ from sqlmesh.dbt.builtin import BUILTIN_FILTERS, BUILTIN_GLOBALS, OVERRIDDEN_MACROS from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.model import ModelConfig -from sqlmesh.dbt.package import HookConfig, MacroConfig +from sqlmesh.dbt.package import HookConfig, MacroConfig, MaterializationConfig from sqlmesh.dbt.seed import SeedConfig from sqlmesh.dbt.source import SourceConfig from sqlmesh.dbt.target import TargetConfig @@ -75,6 +75,7 @@ SourceConfigs = t.Dict[str, SourceConfig] MacroConfigs = t.Dict[str, MacroConfig] HookConfigs = t.Dict[str, HookConfig] +MaterializationConfigs = t.Dict[str, MaterializationConfig] IGNORED_PACKAGES = {"elementary"} @@ -135,6 +136,7 @@ def __init__( self._on_run_start_per_package: t.Dict[str, HookConfigs] = defaultdict(dict) self._on_run_end_per_package: t.Dict[str, HookConfigs] = defaultdict(dict) + self._materializations_per_package: t.Dict[str, MaterializationConfigs] = defaultdict(dict) def tests(self, package_name: t.Optional[str] = None) -> TestConfigs: self._load_all() @@ -164,6 +166,10 @@ def on_run_end(self, package_name: t.Optional[str] = None) -> HookConfigs: self._load_all() return self._on_run_end_per_package[package_name or self._project_name] + def materializations(self, package_name: t.Optional[str] = None) -> MaterializationConfigs: + self._load_all() + return self._materializations_per_package[package_name or self._project_name] + @property def all_macros(self) -> t.Dict[str, t.Dict[str, MacroInfo]]: self._load_all() @@ -213,6 +219,7 @@ def _load_all(self) -> None: self._calls = {k: (v, False) for k, v in (self._call_cache.get("") or {}).items()} self._load_macros() + self._load_materializations() self._load_sources() self._load_tests() self._load_models_and_seeds() @@ -250,11 +257,14 @@ def _load_sources(self) -> None: def _load_macros(self) -> None: for macro in self._manifest.macros.values(): + if macro.name.startswith("materialization_"): + continue + if macro.name.startswith("test_"): macro.macro_sql = _convert_jinja_test_to_macro(macro.macro_sql) dependencies = Dependencies(macros=_macro_references(self._manifest, macro)) - if not macro.name.startswith("materialization_") and not macro.name.startswith("test_"): + if not macro.name.startswith("test_"): dependencies = dependencies.union( self._extra_dependencies(macro.macro_sql, macro.package_name) ) @@ -281,6 +291,34 @@ def _load_macros(self) -> None: if pos > 0 and name[pos + 2 :] in adapter_macro_names: macro_config.info.is_top_level = True + def _load_materializations(self) -> None: + for macro in self._manifest.macros.values(): + if macro.name.startswith("materialization_"): + # Extract name and adapter ( "materialization_{name}_{adapter}" or "materialization_{name}_default") + name_parts = macro.name.split("_") + if len(name_parts) >= 3: + mat_name = "_".join(name_parts[1:-1]) + adapter = name_parts[-1] + + dependencies = Dependencies(macros=_macro_references(self._manifest, macro)) + macro.macro_sql = _strip_jinja_materialization_tags(macro.macro_sql) + dependencies = dependencies.union( + self._extra_dependencies(macro.macro_sql, macro.package_name) + ) + + materialization_config = MaterializationConfig( + name=mat_name, + adapter=adapter, + definition=macro.macro_sql, + dependencies=dependencies, + path=Path(macro.original_file_path), + ) + + key = f"{mat_name}_{adapter}" + self._materializations_per_package[macro.package_name][key] = ( + materialization_config + ) + def _load_tests(self) -> None: for node in self._manifest.nodes.values(): if node.resource_type != "test": @@ -732,3 +770,27 @@ def _convert_jinja_test_to_macro(test_jinja: str) -> str: macro = macro_tag + test_jinja[match.span()[-1] :] return re.sub(ENDTEST_REGEX, lambda m: m.group(0).replace("endtest", "endmacro"), macro) + + +def _strip_jinja_materialization_tags(materialization_jinja: str) -> str: + MATERIALIZATION_TAG_REGEX = r"{%-?\s*materialization\s+[^%]*%}\s*\n?" + ENDMATERIALIZATION_REGEX = r"{%-?\s*endmaterialization\s*-?%}\s*\n?" + + if not re.match(MATERIALIZATION_TAG_REGEX, materialization_jinja): + return materialization_jinja + + materialization_jinja = re.sub( + MATERIALIZATION_TAG_REGEX, + "", + materialization_jinja, + flags=re.IGNORECASE, + ) + + materialization_jinja = re.sub( + ENDMATERIALIZATION_REGEX, + "", + materialization_jinja, + flags=re.IGNORECASE, + ) + + return materialization_jinja.strip() diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index f6cb81f30f..f47283d06e 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -31,6 +31,7 @@ OnAdditiveChange, on_destructive_change_validator, on_additive_change_validator, + DbtCustomKind, ) from sqlmesh.dbt.basemodel import BaseModelConfig, Materialization, SnapshotStrategy from sqlmesh.dbt.common import SqlStr, sql_str_validator @@ -40,6 +41,7 @@ if t.TYPE_CHECKING: from sqlmesh.core.audit.definition import ModelAudit from sqlmesh.dbt.context import DbtContext + from sqlmesh.dbt.package import MaterializationConfig logger = logging.getLogger(__name__) @@ -444,6 +446,19 @@ def model_kind(self, context: DbtContext) -> ModelKind: if materialization == Materialization.DYNAMIC_TABLE: return ManagedKind() + if materialization == Materialization.CUSTOM: + if custom_materialization := self._get_custom_materialization(context): + return DbtCustomKind( + materialization=self.materialized, + adapter=custom_materialization.adapter, + dialect=self.dialect(context), + definition=custom_materialization.definition, + ) + + raise ConfigError( + f"Unknown materialization '{self.materialized}'. Custom materializations must be defined in your dbt project." + ) + raise ConfigError(f"{materialization.value} materialization not supported.") def _big_query_partition_by_expr(self, context: DbtContext) -> exp.Expression: @@ -483,6 +498,18 @@ def _big_query_partition_by_expr(self, context: DbtContext) -> exp.Expression: dialect="bigquery", ) + def _get_custom_materialization(self, context: DbtContext) -> t.Optional[MaterializationConfig]: + materializations = context.manifest.materializations() + name, target_adapter = self.materialized, context.target.dialect + + adapter_specific_key = f"{name}_{target_adapter}" + default_key = f"{name}_default" + if adapter_specific_key in materializations: + return materializations[adapter_specific_key] + if default_key in materializations: + return materializations[default_key] + return None + @property def sqlmesh_config_fields(self) -> t.Set[str]: return super().sqlmesh_config_fields | { diff --git a/sqlmesh/dbt/package.py b/sqlmesh/dbt/package.py index 420cf3cb73..dd6425ea83 100644 --- a/sqlmesh/dbt/package.py +++ b/sqlmesh/dbt/package.py @@ -37,6 +37,16 @@ class HookConfig(PydanticModel): dependencies: Dependencies +class MaterializationConfig(PydanticModel): + """Class to contain custom materialization configuration.""" + + name: str + adapter: str + definition: str + dependencies: Dependencies + path: Path + + class Package(PydanticModel): """Class to contain package configuration""" @@ -47,6 +57,7 @@ class Package(PydanticModel): models: t.Dict[str, ModelConfig] variables: t.Dict[str, t.Any] macros: t.Dict[str, MacroConfig] + materializations: t.Dict[str, MaterializationConfig] on_run_start: t.Dict[str, HookConfig] on_run_end: t.Dict[str, HookConfig] files: t.Set[Path] @@ -94,6 +105,9 @@ def load(self, package_root: Path) -> Package: models = _fix_paths(self._context.manifest.models(package_name), package_root) seeds = _fix_paths(self._context.manifest.seeds(package_name), package_root) macros = _fix_paths(self._context.manifest.macros(package_name), package_root) + materializations = _fix_paths( + self._context.manifest.materializations(package_name), package_root + ) on_run_start = _fix_paths(self._context.manifest.on_run_start(package_name), package_root) on_run_end = _fix_paths(self._context.manifest.on_run_end(package_name), package_root) sources = self._context.manifest.sources(package_name) @@ -114,13 +128,16 @@ def load(self, package_root: Path) -> Package: seeds=seeds, variables=package_variables, macros=macros, + materializations=materializations, files=config_paths, on_run_start=on_run_start, on_run_end=on_run_end, ) -T = t.TypeVar("T", TestConfig, ModelConfig, MacroConfig, SeedConfig, HookConfig) +T = t.TypeVar( + "T", TestConfig, ModelConfig, MacroConfig, MaterializationConfig, SeedConfig, HookConfig +) def _fix_paths(configs: t.Dict[str, T], package_root: Path) -> t.Dict[str, T]: diff --git a/tests/dbt/test_custom_materializations.py b/tests/dbt/test_custom_materializations.py new file mode 100644 index 0000000000..bd961136d2 --- /dev/null +++ b/tests/dbt/test_custom_materializations.py @@ -0,0 +1,721 @@ +from __future__ import annotations + +import typing as t +from pathlib import Path + +import pytest + +from sqlmesh import Context +from sqlmesh.core.config import ModelDefaultsConfig +from sqlmesh.core.model.kind import DbtCustomKind +from sqlmesh.dbt.context import DbtContext +from sqlmesh.dbt.manifest import ManifestHelper +from sqlmesh.dbt.model import ModelConfig +from sqlmesh.dbt.profile import Profile +from sqlmesh.dbt.basemodel import Materialization + +pytestmark = pytest.mark.dbt + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_manifest_loading(): + project_path = Path("tests/fixtures/dbt/sushi_test") + profile = Profile.load(DbtContext(project_path)) + + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + materializations = helper.materializations() + + # custom materialization should have loaded from the manifest + assert "custom_incremental_default" in materializations + custom_incremental = materializations["custom_incremental_default"] + assert custom_incremental.name == "custom_incremental" + assert custom_incremental.adapter == "default" + assert "make_temp_relation(new_relation)" in custom_incremental.definition + assert "run_hooks(pre_hooks, inside_transaction=False)" in custom_incremental.definition + assert " {{ return({'relations': [new_relation]}) }}" in custom_incremental.definition + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_model_config(): + project_path = Path("tests/fixtures/dbt/sushi_test") + profile = Profile.load(DbtContext(project_path)) + + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + models = helper.models() + + custom_model = models["custom_incremental_model"] + assert isinstance(custom_model, ModelConfig) + assert custom_model.materialized == "custom_incremental" + assert custom_model.model_materialization == Materialization.CUSTOM + + # pre and post hooks should also be handled in custom materializations + assert len(custom_model.pre_hook) == 2 + assert ( + custom_model.pre_hook[1].sql + == "CREATE TABLE IF NOT EXISTS hook_table (id INTEGER, length_col TEXT, updated_at TIMESTAMP)" + ) + assert len(custom_model.post_hook) == 2 + assert "COALESCE(MAX(id), 0)" in custom_model.post_hook[1].sql + + custom_filter_model = models["custom_incremental_with_filter"] + assert isinstance(custom_filter_model, ModelConfig) + assert custom_filter_model.materialized == "custom_incremental" + assert custom_filter_model.model_materialization == Materialization.CUSTOM + assert custom_filter_model.interval == "2 day" + assert custom_filter_model.time_column == "created_at" + + # verify also that the global hooks are inherited in the model without + assert len(custom_filter_model.pre_hook) == 1 + assert len(custom_filter_model.post_hook) == 1 + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_model_kind(): + project_path = Path("tests/fixtures/dbt/sushi_test") + context = DbtContext(project_path) + profile = Profile.load(DbtContext(project_path)) + + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + context._target = profile.target + context._manifest = helper + models = helper.models() + + # custom materialization models get DbtCustomKind populated + custom_model = models["custom_incremental_model"] + kind = custom_model.model_kind(context) + assert isinstance(kind, DbtCustomKind) + assert kind.materialization == "custom_incremental" + assert kind.adapter == "default" + assert "create_table_as" in kind.definition + + custom_filter_model = models["custom_incremental_with_filter"] + kind = custom_filter_model.model_kind(context) + assert isinstance(kind, DbtCustomKind) + assert kind.materialization == "custom_incremental" + assert kind.adapter == "default" + assert "run_hooks" in kind.definition + + # the DbtCustomKind shouldnt be set for normal strategies + regular_model = models["simple_model_a"] + regular_kind = regular_model.model_kind(context) + assert not isinstance(regular_kind, DbtCustomKind) + + # verify in sqlmesh as well + sqlmesh_context = Context( + paths=["tests/fixtures/dbt/sushi_test"], + config=None, + ) + + custom_incremental = sqlmesh_context.get_model("sushi.custom_incremental_model") + assert isinstance(custom_incremental.kind, DbtCustomKind) + assert custom_incremental.kind.materialization == "custom_incremental" + + custom_with_filter = sqlmesh_context.get_model("sushi.custom_incremental_with_filter") + assert isinstance(custom_with_filter.kind, DbtCustomKind) + assert custom_with_filter.kind.materialization == "custom_incremental" + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_dependencies(): + project_path = Path("tests/fixtures/dbt/sushi_test") + context = DbtContext(project_path) + profile = Profile.load(DbtContext(project_path)) + + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + context._target = profile.target + context._manifest = helper + models = helper.models() + + # custom materialization uses macros that should appear in dependencies + for model_name in ["custom_incremental_model", "custom_incremental_with_filter"]: + materialization_deps = models[model_name]._get_custom_materialization(context) + assert materialization_deps is not None + assert len(materialization_deps.dependencies.macros) > 0 + macro_names = [macro.name for macro in materialization_deps.dependencies.macros] + expected_macros = [ + "build_incremental_filter_sql", + "Relation", + "create_table_as", + "make_temp_relation", + "run_hooks", + "statement", + ] + assert any(macro in macro_names for macro in expected_macros) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_adapter_specific_materialization_override(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + macros_dir = temp_project / "macros" / "materializations" + macros_dir.mkdir(parents=True, exist_ok=True) + + adapter_mat_content = """ +{%- materialization custom_adapter_test, default -%} + {%- set new_relation = api.Relation.create(database=database, schema=schema, identifier=identifier) -%} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + {%- call statement('main') -%} + CREATE TABLE {{ new_relation }} AS ( + SELECT 'default_adapter' as adapter_type, * FROM ({{ sql }}) AS subquery + ) + {%- endcall -%} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {{ return({'relations': [new_relation]}) }} +{%- endmaterialization -%} + +{%- materialization custom_adapter_test, adapter='postgres' -%} + {%- set new_relation = api.Relation.create(database=database, schema=schema, identifier=identifier) -%} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + {%- call statement('main') -%} + CREATE TABLE {{ new_relation }} AS ( + SELECT 'postgres_adapter'::text as adapter_type, * FROM ({{ sql }}) AS subquery + ) + {%- endcall -%} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {{ return({'relations': [new_relation]}) }} +{%- endmaterialization -%} + +{%- materialization custom_adapter_test, adapter='duckdb' -%} + {%- set new_relation = api.Relation.create(database=database, schema=schema, identifier=identifier) -%} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + {%- call statement('main') -%} + CREATE TABLE {{ new_relation }} AS ( + SELECT 'duckdb_adapter' as adapter_type, * FROM ({{ sql }}) AS subquery + ) + {%- endcall -%} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {{ return({'relations': [new_relation]}) }} +{%- endmaterialization -%} +""".strip() + + (macros_dir / "custom_adapter_test.sql").write_text(adapter_mat_content) + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + test_model_content = """ +{{ config( + materialized='custom_adapter_test', +) }} + +SELECT + 1 as id, + 'test' as name +""".strip() + + (models_dir / "test_adapter_specific.sql").write_text(test_model_content) + + context = DbtContext(temp_project) + profile = Profile.load(context) + + helper = ManifestHelper( + temp_project, + temp_project, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + materializations = helper.materializations() + assert "custom_adapter_test_default" in materializations + assert "custom_adapter_test_duckdb" in materializations + assert "custom_adapter_test_postgres" in materializations + + default_mat = materializations["custom_adapter_test_default"] + assert "default_adapter" in default_mat.definition + assert default_mat.adapter == "default" + + duckdb_mat = materializations["custom_adapter_test_duckdb"] + assert "duckdb_adapter" in duckdb_mat.definition + assert duckdb_mat.adapter == "duckdb" + + postgres_mat = materializations["custom_adapter_test_postgres"] + assert "postgres_adapter" in postgres_mat.definition + assert postgres_mat.adapter == "postgres" + + # verify that the correct adapter is selected based on target + context._target = profile.target + context._manifest = helper + models = helper.models() + + test_model = models["test_adapter_specific"] + + kind = test_model.model_kind(context) + assert isinstance(kind, DbtCustomKind) + assert kind.materialization == "custom_adapter_test" + # Should use duckdb adapter since that's the default target + assert "duckdb_adapter" in kind.definition or "default_adapter" in kind.definition + + # test also that adapter-specific materializations execute with correct adapter + sushi_context = Context(paths=path) + + plan = sushi_context.plan(select_models=["sushi.test_adapter_specific"]) + sushi_context.apply(plan) + + # check that the table was created with the correct adapter type + result = sushi_context.engine_adapter.fetchdf("SELECT * FROM sushi.test_adapter_specific") + assert len(result) == 1 + assert "adapter_type" in result.columns + assert result["adapter_type"][0] == "duckdb_adapter" + assert result["id"][0] == 1 + assert result["name"][0] == "test" + + +@pytest.mark.xdist_group("dbt_manifest") +def test_missing_custom_materialization_error(): + from sqlmesh.utils.errors import ConfigError + + project_path = Path("tests/fixtures/dbt/sushi_test") + context = DbtContext(project_path) + profile = Profile.load(context) + + # the materialization is non-existent + fake_model_config = ModelConfig( + name="test_model", + path=project_path / "models" / "fake_model.sql", + raw_code="SELECT 1 as id", + materialized="non_existent_custom", + schema="test_schema", + ) + + context._target = profile.target + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + context._manifest = helper + + # Should raise ConfigError when trying to get the model kind + with pytest.raises(ConfigError) as e: + fake_model_config.model_kind(context) + + assert "Unknown materialization 'non_existent_custom'" in str(e.value) + assert "Custom materializations must be defined" in str(e.value) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_broken_jinja_materialization_error(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + macros_dir = temp_project / "macros" / "materializations" + macros_dir.mkdir(parents=True, exist_ok=True) + + # Create broken Jinja materialization + broken_mat_content = """ +{%- materialization broken_jinja, default -%} + {%- set new_relation = api.Relation.create(database=database, schema=schema, identifier=identifier) -%} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + {# An intentional undefined variable that will cause runtime error #} + {%- set broken_var = undefined_variable_that_does_not_exist + 10 -%} + + {%- call statement('main') -%} + CREATE TABLE {{ new_relation }} AS ( + SELECT * FROM ({{ sql }}) AS subquery + WHERE 1 = {{ broken_var }} + ) + {%- endcall -%} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {{ return({'relations': [new_relation]}) }} +{%- endmaterialization -%} +""".strip() + + (macros_dir / "broken_jinja.sql").write_text(broken_mat_content) + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + test_model_content = """ +{{ config( + materialized='broken_jinja', +) }} + +SELECT + 1 as id, + 'This should fail with Jinja error' as error_msg +""".strip() + + (models_dir / "test_broken_jinja.sql").write_text(test_model_content) + + sushi_context = Context(paths=path) + + # The model will load fine jinja won't fail at parse time + model = sushi_context.get_model("sushi.test_broken_jinja") + assert isinstance(model.kind, DbtCustomKind) + assert model.kind.materialization == "broken_jinja" + + # but execution should fail + with pytest.raises(Exception) as e: + plan = sushi_context.plan(select_models=["sushi.test_broken_jinja"]) + sushi_context.apply(plan) + + assert "plan application failed" in str(e.value).lower() + + +@pytest.mark.xdist_group("dbt_manifest") +def test_failing_hooks_in_materialization(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + test_model_content = """ +{{ config( + materialized='custom_incremental', + pre_hook="CREATE TABLE will_fail_due_to_intentional_syntax_error (", + post_hook="DROP TABLE non_existent_table_that_will_fail", +) }} + +SELECT + 1 as id, + 'Testing hook failures' as test_msg +""".strip() + + (models_dir / "test_failing_hooks.sql").write_text(test_model_content) + + sushi_context = Context(paths=[str(temp_project)]) + + # in this case the pre_hook has invalid syntax + with pytest.raises(Exception) as e: + plan = sushi_context.plan(select_models=["sushi.test_failing_hooks"]) + sushi_context.apply(plan) + + assert "plan application failed" in str(e.value).lower() + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_virtual_environments(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + test_model_content = """ +{{ config( + materialized='custom_incremental', + time_column='created_at', +) }} + +SELECT + CURRENT_TIMESTAMP as created_at, + 1 as id, + 'venv_test' as test_type +""".strip() + + (models_dir / "test_venv_model.sql").write_text(test_model_content) + + sushi_context = Context(paths=path) + prod_plan = sushi_context.plan(select_models=["sushi.test_venv_model"]) + sushi_context.apply(prod_plan) + prod_result = sushi_context.engine_adapter.fetchdf( + "SELECT * FROM sushi.test_venv_model ORDER BY id" + ) + assert len(prod_result) == 1 + assert prod_result["id"][0] == 1 + assert prod_result["test_type"][0] == "venv_test" + + # Create dev environment and check the dev table was created with proper naming + dev_plan = sushi_context.plan("dev", select_models=["sushi.test_venv_model"]) + sushi_context.apply(dev_plan) + dev_result = sushi_context.engine_adapter.fetchdf( + "SELECT * FROM sushi__dev.test_venv_model ORDER BY id" + ) + assert len(dev_result) == 1 + assert dev_result["id"][0] == 1 + assert dev_result["test_type"][0] == "venv_test" + + dev_tables = sushi_context.engine_adapter.fetchdf(""" + SELECT table_name, table_schema + FROM system.information_schema.tables + WHERE table_schema LIKE 'sushi%dev%' + AND table_name LIKE '%test_venv_model%' + """) + + prod_tables = sushi_context.engine_adapter.fetchdf(""" + SELECT table_name, table_schema + FROM system.information_schema.tables + WHERE table_schema = 'sushi' + AND table_name LIKE '%test_venv_model%' + """) + + # Verify both environments have their own tables + assert len(dev_tables) >= 1 + assert len(prod_tables) >= 1 + + +@pytest.mark.xdist_group("dbt_manifest") +def test_virtual_environment_schema_names(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + test_model_content = """ +{{ config( + materialized='custom_incremental', + time_column='created_at', +) }} + +SELECT + CURRENT_TIMESTAMP as created_at, + 1 as id, + 'schema_naming_test' as test_type +""".strip() + + (models_dir / "test_schema_naming.sql").write_text(test_model_content) + + context = Context(paths=path) + prod_plan = context.plan(select_models=["sushi.test_schema_naming"]) + context.apply(prod_plan) + + dev_plan = context.plan("dev", select_models=["sushi.test_schema_naming"]) + context.apply(dev_plan) + + prod_result = context.engine_adapter.fetchdf( + "SELECT * FROM sushi.test_schema_naming ORDER BY id" + ) + assert len(prod_result) == 1 + assert prod_result["test_type"][0] == "schema_naming_test" + + dev_result = context.engine_adapter.fetchdf( + "SELECT * FROM sushi__dev.test_schema_naming ORDER BY id" + ) + assert len(dev_result) == 1 + assert dev_result["test_type"][0] == "schema_naming_test" + + # to examine the schema structure + all_schemas_query = """ + SELECT DISTINCT table_schema, COUNT(*) as table_count + FROM system.information_schema.tables + WHERE table_schema LIKE '%sushi%' + AND table_name LIKE '%test_schema_naming%' + GROUP BY table_schema + ORDER BY table_schema + """ + + schema_info = context.engine_adapter.fetchdf(all_schemas_query) + + schema_names = schema_info["table_schema"].tolist() + + # - virtual schemas: sushi, sushi__dev (for views) + view_schemas = [s for s in schema_names if not s.startswith("sqlmesh__")] + + # - physical schema: sqlmesh__sushi (for actual data tables) + physical_schemas = [s for s in schema_names if s.startswith("sqlmesh__")] + + # verify we got both of them + assert len(view_schemas) >= 2 + assert len(physical_schemas) >= 1 + assert "sushi" in view_schemas + assert "sushi__dev" in view_schemas + assert any("sqlmesh__sushi" in s for s in physical_schemas) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_lineage_tracking(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + # create a custom materialization model that depends on simple_model_a and waiter_names seed + lineage_model_content = """ +{{ config( + materialized='custom_incremental', + time_column='created_at', +) }} + +SELECT + CURRENT_TIMESTAMP as created_at, + w.id as waiter_id, + w.name as waiter_name, + s.a as simple_value, + w.id * s.a as computed_value, + 'lineage_test' as model_type +FROM {{ ref('waiter_names') }} w +CROSS JOIN {{ ref('simple_model_a') }} s +""".strip() + + (models_dir / "enhanced_waiter_data.sql").write_text(lineage_model_content) + + # Create another custom materialization model that depends on the first one and simple_model_b + downstream_model_content = """ +{{ config( + materialized='custom_incremental', + time_column='analysis_date', +) }} + +SELECT + CURRENT_TIMESTAMP as analysis_date, + e.waiter_name, + e.simple_value, + e.computed_value, + b.a as model_b_value, + e.computed_value + b.a as final_computation, + CASE + WHEN e.computed_value >= 5 THEN 'High' + WHEN e.computed_value >= 2 THEN 'Medium' + ELSE 'Low' + END as category, + 'downstream_lineage_test' as model_type +FROM {{ ref('enhanced_waiter_data') }} e +CROSS JOIN {{ ref('simple_model_b') }} b +WHERE e.computed_value >= 0 +""".strip() + + (models_dir / "waiter_analytics_summary.sql").write_text(downstream_model_content) + + context = Context(paths=path) + enhanced_data_model = context.get_model("sushi.enhanced_waiter_data") + analytics_summary_model = context.get_model("sushi.waiter_analytics_summary") + + # Verify that custom materialization models have proper model kinds + assert isinstance(enhanced_data_model.kind, DbtCustomKind) + assert enhanced_data_model.kind.materialization == "custom_incremental" + + assert isinstance(analytics_summary_model.kind, DbtCustomKind) + assert analytics_summary_model.kind.materialization == "custom_incremental" + + # - enhanced_waiter_data should depend on waiter_names and simple_model_a + enhanced_data_deps = enhanced_data_model.depends_on + assert '"memory"."sushi"."simple_model_a"' in enhanced_data_deps + assert '"memory"."sushi"."waiter_names"' in enhanced_data_deps + + # - waiter_analytics_summary should depend on enhanced_waiter_data and simple_model_b + analytics_deps = analytics_summary_model.depends_on + assert '"memory"."sushi"."enhanced_waiter_data"' in analytics_deps + assert '"memory"."sushi"."simple_model_b"' in analytics_deps + + # build only the models that have dependences + plan = context.plan( + select_models=[ + "sushi.waiter_names", + "sushi.simple_model_a", + "sushi.simple_model_b", + "sushi.enhanced_waiter_data", + "sushi.waiter_analytics_summary", + ] + ) + context.apply(plan) + + # Verify that all δοwnstream models were built and contain expected data + waiter_names_result = context.engine_adapter.fetchdf( + "SELECT COUNT(*) as count FROM sushi.waiter_names" + ) + assert waiter_names_result["count"][0] > 0 + + simple_a_result = context.engine_adapter.fetchdf("SELECT a FROM sushi.simple_model_a") + assert len(simple_a_result) > 0 + assert simple_a_result["a"][0] == 1 + + simple_b_result = context.engine_adapter.fetchdf("SELECT a FROM sushi.simple_model_b") + assert len(simple_b_result) > 0 + assert simple_b_result["a"][0] == 1 + + # Check intermediate custom materialization model + enhanced_data_result = context.engine_adapter.fetchdf(""" + SELECT + waiter_name, + simple_value, + computed_value, + model_type + FROM sushi.enhanced_waiter_data + ORDER BY waiter_id + LIMIT 5 + """) + + assert len(enhanced_data_result) > 0 + assert enhanced_data_result["model_type"][0] == "lineage_test" + assert all(val == 1 for val in enhanced_data_result["simple_value"]) + assert all(val >= 0 for val in enhanced_data_result["computed_value"]) + assert any(val == "Ryan" for val in enhanced_data_result["waiter_name"]) + + # Check final downstream custom materialization model + analytics_summary_result = context.engine_adapter.fetchdf(""" + SELECT + waiter_name, + category, + model_type, + final_computation + FROM sushi.waiter_analytics_summary + ORDER BY waiter_name + LIMIT 5 + """) + + assert len(analytics_summary_result) > 0 + assert analytics_summary_result["model_type"][0] == "downstream_lineage_test" + assert all(cat in ["High", "Medium", "Low"] for cat in analytics_summary_result["category"]) + assert all(val >= 0 for val in analytics_summary_result["final_computation"]) + + # Test that lineage information is preserved in dev environments + dev_plan = context.plan("dev", select_models=["sushi.waiter_analytics_summary"]) + context.apply(dev_plan) + + dev_analytics_result = context.engine_adapter.fetchdf(""" + SELECT + COUNT(*) as count, + COUNT(DISTINCT waiter_name) as unique_waiters + FROM sushi__dev.waiter_analytics_summary + """) + + prod_analytics_result = context.engine_adapter.fetchdf(""" + SELECT + COUNT(*) as count, + COUNT(DISTINCT waiter_name) as unique_waiters + FROM sushi.waiter_analytics_summary + """) + + # Dev and prod should have the same data as they share physical data + assert dev_analytics_result["count"][0] == prod_analytics_result["count"][0] + assert dev_analytics_result["unique_waiters"][0] == prod_analytics_result["unique_waiters"][0] diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index d212872cb7..e29c6768bf 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -842,3 +842,14 @@ def test_jinja_config_no_query(create_empty_project): # loads without error and contains empty query (which will error at runtime) assert not context.snapshots['"local"."main"."comment_config_model"'].model.render_query() + + +@pytest.mark.slow +def test_load_custom_materialisations(sushi_test_dbt_context: Context) -> None: + context = sushi_test_dbt_context + assert context.get_model("sushi.custom_incremental_model") + assert context.get_model("sushi.custom_incremental_with_filter") + + context.load() + assert context.get_model("sushi.custom_incremental_model") + assert context.get_model("sushi.custom_incremental_with_filter") diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index a640d620b7..9a9ce8f906 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -1,5 +1,5 @@ import agate -from datetime import datetime +from datetime import datetime, timedelta import json import logging import typing as t @@ -113,6 +113,129 @@ def test_materialization(): ModelConfig(name="model", alias="model", schema="schema", materialized="dictionary") +def test_dbt_custom_materialization(): + sushi_context = Context(paths=["tests/fixtures/dbt/sushi_test"]) + + plan_builder = sushi_context.plan_builder(select_models=["sushi.custom_incremental_model"]) + plan = plan_builder.build() + assert len(plan.selected_models) == 1 + selected_model = list(plan.selected_models)[0] + assert selected_model == "model.sushi.custom_incremental_model" + + qoery = "SELECT * FROM sushi.custom_incremental_model ORDER BY created_at" + hook_table = "SELECT * FROM hook_table ORDER BY id" + sushi_context.apply(plan) + result = sushi_context.engine_adapter.fetchdf(qoery) + assert len(result) == 1 + assert {"created_at", "id"}.issubset(result.columns) + + # assert the pre/post hooks executed as well as part of the custom materialization + hook_result = sushi_context.engine_adapter.fetchdf(hook_table) + assert len(hook_result) == 1 + assert {"length_col", "id", "updated_at"}.issubset(hook_result.columns) + assert int(hook_result["length_col"][0]) >= 519 + assert hook_result["id"][0] == 1 + + # running with execution time one day in the future to simulate an incremental insert + tomorrow = datetime.now() + timedelta(days=1) + sushi_context.run(select_models=["sushi.custom_incremental_model"], execution_time=tomorrow) + + result_after_run = sushi_context.engine_adapter.fetchdf(qoery) + assert {"created_at", "id"}.issubset(result_after_run.columns) + + # this should have added new unique values for the new row + assert len(result_after_run) == 2 + assert result_after_run["id"].is_unique + assert result_after_run["created_at"].is_unique + + # validate the hooks executed as part of the run as well + hook_result = sushi_context.engine_adapter.fetchdf(hook_table) + assert len(hook_result) == 2 + assert hook_result["id"][1] == 2 + assert int(hook_result["length_col"][1]) >= 519 + assert hook_result["id"].is_monotonic_increasing + assert hook_result["updated_at"].is_unique + assert not hook_result["length_col"].is_unique + + +def test_dbt_custom_materialization_with_time_filter_and_macro(): + sushi_context = Context(paths=["tests/fixtures/dbt/sushi_test"]) + today = datetime.now() + + # select both custom materialiasation models with the wildcard + selector = ["sushi.custom_incremental*"] + plan_builder = sushi_context.plan_builder(select_models=selector, execution_time=today) + plan = plan_builder.build() + + assert len(plan.selected_models) == 2 + assert { + "model.sushi.custom_incremental_model", + "model.sushi.custom_incremental_with_filter", + }.issubset(plan.selected_models) + + # the model that daily (default cron) populates with data + select_daily = "SELECT * FROM sushi.custom_incremental_model ORDER BY created_at" + + # this model uses `run_started_at` as a filter (which we populate with execution time) with 2 day interval + select_filter = "SELECT * FROM sushi.custom_incremental_with_filter ORDER BY created_at" + + sushi_context.apply(plan) + result = sushi_context.engine_adapter.fetchdf(select_daily) + assert len(result) == 1 + assert {"created_at", "id"}.issubset(result.columns) + + result = sushi_context.engine_adapter.fetchdf(select_filter) + assert len(result) == 1 + assert {"created_at", "id"}.issubset(result.columns) + + # - run ONE DAY LATER + a_day_later = today + timedelta(days=1) + sushi_context.run(select_models=selector, execution_time=a_day_later) + result_after_run = sushi_context.engine_adapter.fetchdf(select_daily) + + # the new row is inserted in the normal incremental model + assert len(result_after_run) == 2 + assert {"created_at", "id"}.issubset(result_after_run.columns) + assert result_after_run["id"].is_unique + assert result_after_run["created_at"].is_unique + + # this model due to the filter shouldn't populate with any new data + result_after_run_filter = sushi_context.engine_adapter.fetchdf(select_filter) + assert len(result_after_run_filter) == 1 + assert {"created_at", "id"}.issubset(result_after_run_filter.columns) + assert result.equals(result_after_run_filter) + assert result_after_run_filter["id"].is_unique + assert result_after_run_filter["created_at"][0].date() == today.date() + + # - run TWO DAYS LATER + two_days_later = a_day_later + timedelta(days=1) + sushi_context.run(select_models=selector, execution_time=two_days_later) + result_after_run = sushi_context.engine_adapter.fetchdf(select_daily) + + # again a new row is inserted in the normal model + assert len(result_after_run) == 3 + assert {"created_at", "id"}.issubset(result_after_run.columns) + assert result_after_run["id"].is_unique + assert result_after_run["created_at"].is_unique + + # the model with the filter now should populate as well + result_after_run_filter = sushi_context.engine_adapter.fetchdf(select_filter) + assert len(result_after_run_filter) == 2 + assert {"created_at", "id"}.issubset(result_after_run_filter.columns) + assert result_after_run_filter["id"].is_unique + assert result_after_run_filter["created_at"][0].date() == today.date() + assert result_after_run_filter["created_at"][1].date() == two_days_later.date() + + # assert hooks have executed for both plan and incremental runs + hook_result = sushi_context.engine_adapter.fetchdf("SELECT * FROM hook_table ORDER BY id") + assert len(hook_result) == 3 + hook_result["id"][0] == 1 + assert hook_result["id"].is_monotonic_increasing + assert hook_result["updated_at"].is_unique + assert int(hook_result["length_col"][1]) >= 519 + assert not hook_result["length_col"].is_unique + + def test_model_kind(): context = DbtContext() context.project_name = "Test" diff --git a/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql b/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql new file mode 100644 index 0000000000..afb53bc1c6 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql @@ -0,0 +1,58 @@ +{%- macro build_incremental_filter_sql(sql, time_column, existing_relation, interval_config) -%} + {# macro to build the filter and also test use of macro inside materialisation #} + WITH source_data AS ( + {{ sql }} + ) + SELECT * FROM source_data + WHERE {{ time_column }} >= ( + SELECT COALESCE(MAX({{ time_column }}), '1900-01-01') + {%- if interval_config %} + INTERVAL {{ interval_config }} {%- endif %} + FROM {{ existing_relation }} + ) +{%- endmacro -%} + +{%- materialization custom_incremental, default -%} + {%- set existing_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} + {%- set new_relation = api.Relation.create(database=database, schema=schema, identifier=identifier) -%} + {%- set temp_relation = make_temp_relation(new_relation) -%} + + {%- set time_column = config.get('time_column') -%} + {%- set interval_config = config.get('interval') -%} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + {%- if existing_relation is none -%} + {# The first insert creates new table if it doesn't exist #} + {%- call statement('main') -%} + {{ create_table_as(False, new_relation, sql) }} + {%- endcall -%} + {%- else -%} + {# Incremental load, appending new data with optional time filtering #} + {%- if time_column is not none -%} + {%- set filtered_sql -%} + {{ build_incremental_filter_sql(sql, time_column, existing_relation, interval_config) }} + {%- endset -%} + {%- else -%} + {%- set filtered_sql = sql -%} + {%- endif -%} + + {{log(filtered_sql, info=true)}} + + {%- call statement('create_temp') -%} + {{ create_table_as(True, temp_relation, filtered_sql) }} + {%- endcall -%} + + {%- call statement('insert') -%} + INSERT INTO {{ new_relation }} + SELECT * FROM {{ temp_relation }} + {%- endcall -%} + + {%- call statement('drop_temp') -%} + DROP TABLE {{ temp_relation }} + {%- endcall -%} + {%- endif -%} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {{ return({'relations': [new_relation]}) }} +{%- endmaterialization -%} \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/models/custom_incremental_model.sql b/tests/fixtures/dbt/sushi_test/models/custom_incremental_model.sql new file mode 100644 index 0000000000..c7e9a8f7ea --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/models/custom_incremental_model.sql @@ -0,0 +1,20 @@ +{{ config( + materialized='custom_incremental', + pre_hook=[ + "CREATE TABLE IF NOT EXISTS hook_table (id INTEGER, length_col TEXT, updated_at TIMESTAMP)" + ], + post_hook=[ + """ + INSERT INTO hook_table + SELECT + COALESCE(MAX(id), 0) + 1 AS id, + '{{ model.raw_code | length }}' AS length_col, + CURRENT_TIMESTAMP AS updated_at + FROM hook_table + """ + ] +) }} + +SELECT + current_timestamp as created_at, + hash(current_timestamp) as id, \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/models/custom_incremental_with_filter.sql b/tests/fixtures/dbt/sushi_test/models/custom_incremental_with_filter.sql new file mode 100644 index 0000000000..94cbdc9333 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/models/custom_incremental_with_filter.sql @@ -0,0 +1,9 @@ +{{ config( + materialized='custom_incremental', + time_column='created_at', + interval='2 day' +) }} + +SELECT + CAST('{{ run_started_at }}' AS TIMESTAMP) as created_at, + hash('{{ run_started_at }}') as id, \ No newline at end of file From 2a1442c40867db178a51dbe35775c268e78dc5de Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Thu, 25 Sep 2025 15:02:18 +0200 Subject: [PATCH 2/7] pr feedback --- sqlmesh/core/model/kind.py | 6 ------ sqlmesh/core/snapshot/evaluator.py | 12 +++++------- sqlmesh/dbt/manifest.py | 10 ++++------ sqlmesh/dbt/package.py | 4 +--- 4 files changed, 10 insertions(+), 22 deletions(-) diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index 68ea8bf523..7b8e88ac17 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -916,12 +916,6 @@ def data_hash_values(self) -> t.List[t.Optional[str]]: self.dialect, ] - @property - def metadata_hash_values(self) -> t.List[t.Optional[str]]: - return [ - *super().metadata_hash_values, - ] - def to_expression( self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any ) -> d.ModelKind: diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 75e32463c3..89cc0540dd 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -83,7 +83,7 @@ format_additive_change_msg, AdditiveChangeError, ) -from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReturnVal +from sqlmesh.utils.jinja import MacroReturnVal if sys.version_info >= (3, 12): from importlib import metadata @@ -2685,9 +2685,7 @@ def _execute_materialization( create_only: bool = False, **kwargs: t.Any, ) -> None: - from sqlmesh.dbt.builtin import create_builtin_globals - - jinja_macros = getattr(model, "jinja_macros", JinjaMacroRegistry()) + jinja_macros = model.jinja_macros existing_globals = jinja_macros.global_objs.copy() # For vdes we need to use the table, since we don't know the schema/table at parse time @@ -2709,8 +2707,8 @@ def _execute_materialization( "execution_dt": kwargs.get("execution_time"), } - context = create_builtin_globals( - jinja_macros=jinja_macros, jinja_globals=jinja_globals, engine_adapter=self.adapter + context = jinja_macros._create_builtin_globals( + {"engine_adapter": self.adapter, **jinja_globals} ) context.update( @@ -2731,7 +2729,7 @@ def _execute_materialization( try: template.render(**context) except MacroReturnVal as ret: - # this is a succesful return from a macro call (dbt uses this list of Relations to update their relation cache) + # this is a successful return from a macro call (dbt uses this list of Relations to update their relation cache) returned_relations = ret.value.get("relations", []) logger.info( f"Materialization {self.materialization_name} returned relations: {returned_relations}" diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index 7e12147e03..2eb215cb1b 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -136,7 +136,7 @@ def __init__( self._on_run_start_per_package: t.Dict[str, HookConfigs] = defaultdict(dict) self._on_run_end_per_package: t.Dict[str, HookConfigs] = defaultdict(dict) - self._materializations_per_package: t.Dict[str, MaterializationConfigs] = defaultdict(dict) + self._materializations: MaterializationConfigs = {} def tests(self, package_name: t.Optional[str] = None) -> TestConfigs: self._load_all() @@ -166,9 +166,9 @@ def on_run_end(self, package_name: t.Optional[str] = None) -> HookConfigs: self._load_all() return self._on_run_end_per_package[package_name or self._project_name] - def materializations(self, package_name: t.Optional[str] = None) -> MaterializationConfigs: + def materializations(self) -> MaterializationConfigs: self._load_all() - return self._materializations_per_package[package_name or self._project_name] + return self._materializations @property def all_macros(self) -> t.Dict[str, t.Dict[str, MacroInfo]]: @@ -315,9 +315,7 @@ def _load_materializations(self) -> None: ) key = f"{mat_name}_{adapter}" - self._materializations_per_package[macro.package_name][key] = ( - materialization_config - ) + self._materializations[key] = materialization_config def _load_tests(self) -> None: for node in self._manifest.nodes.values(): diff --git a/sqlmesh/dbt/package.py b/sqlmesh/dbt/package.py index dd6425ea83..dbaa832c22 100644 --- a/sqlmesh/dbt/package.py +++ b/sqlmesh/dbt/package.py @@ -105,9 +105,7 @@ def load(self, package_root: Path) -> Package: models = _fix_paths(self._context.manifest.models(package_name), package_root) seeds = _fix_paths(self._context.manifest.seeds(package_name), package_root) macros = _fix_paths(self._context.manifest.macros(package_name), package_root) - materializations = _fix_paths( - self._context.manifest.materializations(package_name), package_root - ) + materializations = _fix_paths(self._context.manifest.materializations(), package_root) on_run_start = _fix_paths(self._context.manifest.on_run_start(package_name), package_root) on_run_end = _fix_paths(self._context.manifest.on_run_end(package_name), package_root) sources = self._context.manifest.sources(package_name) From 0693a6339b151aadd0d092990290bc8fe7d58400 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Fri, 26 Sep 2025 17:21:53 +0200 Subject: [PATCH 3/7] refactor to run pre post statements methods --- sqlmesh/core/snapshot/evaluator.py | 65 +++++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 10 deletions(-) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 89cc0540dd..4770a9f650 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -748,8 +748,10 @@ def _evaluate_snapshot( adapter.transaction(), adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)), ): - if not snapshot.is_dbt_custom: - adapter.execute(model.render_pre_statements(**render_statements_kwargs)) + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + evaluation_strategy.run_pre_statements( + snapshot=snapshot, render_kwargs=render_statements_kwargs + ) if not target_table_exists or (model.is_seed and not snapshot.intervals): # Only create the empty table if the columns were provided explicitly by the user @@ -819,8 +821,9 @@ def _evaluate_snapshot( batch_index=batch_index, ) - if not snapshot.is_dbt_custom: - adapter.execute(model.render_post_statements(**render_statements_kwargs)) + evaluation_strategy.run_post_statements( + snapshot=snapshot, render_kwargs=render_statements_kwargs + ) return wap_id @@ -1435,8 +1438,10 @@ def _execute_create( **create_render_kwargs, "table_mapping": {snapshot.name: table_name}, } - if run_pre_post_statements and not snapshot.is_dbt_custom: - adapter.execute(snapshot.model.render_pre_statements(**create_render_kwargs)) + if run_pre_post_statements: + evaluation_strategy.run_pre_statements( + snapshot=snapshot, render_kwargs=create_render_kwargs + ) evaluation_strategy.create( table_name=table_name, model=snapshot.model, @@ -1447,8 +1452,10 @@ def _execute_create( dry_run=dry_run, physical_properties=rendered_physical_properties, ) - if run_pre_post_statements and not snapshot.is_dbt_custom: - adapter.execute(snapshot.model.render_post_statements(**create_render_kwargs)) + if run_pre_post_statements: + evaluation_strategy.run_post_statements( + snapshot=snapshot, render_kwargs=create_render_kwargs + ) def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex) -> bool: adapter = self.get_adapter(snapshot.model.gateway) @@ -1548,7 +1555,7 @@ def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> if hasattr(snapshot, "model") and isinstance( (model_kind := snapshot.model.kind), DbtCustomKind ): - return DbtCustomMaterialization( + return DbtCustomMaterializationStrategy( adapter=adapter, materialization_name=model_kind.materialization, materialization_template=model_kind.definition, @@ -1696,6 +1703,24 @@ def demote(self, view_name: str, **kwargs: t.Any) -> None: view_name: The name of the target view in the virtual layer. """ + @abc.abstractmethod + def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + """Executes the snapshot's pre statements. + + Args: + snapshot: The target snapshot. + render_kwargs: Additional key-value arguments to pass when rendering the statements. + """ + + @abc.abstractmethod + def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + """Executes the snapshot's post statements. + + Args: + snapshot: The target snapshot. + render_kwargs: Additional key-value arguments to pass when rendering the statements. + """ + class SymbolicStrategy(EvaluationStrategy): def insert( @@ -1757,6 +1782,12 @@ def promote( def demote(self, view_name: str, **kwargs: t.Any) -> None: pass + def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Dict[str, t.Any]) -> None: + pass + + def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Dict[str, t.Any]) -> None: + pass + class EmbeddedStrategy(SymbolicStrategy): def promote( @@ -1804,6 +1835,12 @@ def demote(self, view_name: str, **kwargs: t.Any) -> None: logger.info("Dropping view '%s'", view_name) self.adapter.drop_view(view_name, cascade=False) + def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + self.adapter.execute(snapshot.model.render_pre_statements(**render_kwargs)) + + def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + self.adapter.execute(snapshot.model.render_post_statements(**render_kwargs)) + class MaterializableStrategy(PromotableStrategy, abc.ABC): def create( @@ -2610,7 +2647,7 @@ def get_custom_materialization_type_or_raise( raise SQLMeshError(f"Custom materialization '{name}' not present in the Python environment") -class DbtCustomMaterialization(MaterializableStrategy): +class DbtCustomMaterializationStrategy(MaterializableStrategy): def __init__( self, adapter: EngineAdapter, @@ -2675,6 +2712,14 @@ def append( **kwargs, ) + def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + # in dbt custom materialisations it's up to the user when to run the pre hooks + pass + + def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + # in dbt custom materialisations it's up to the user when to run the post hooks + pass + def _execute_materialization( self, table_name: str, From bf3e3bf374d87c2c71db5617ecf815f0e89c5ff8 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Mon, 29 Sep 2025 10:15:51 +0300 Subject: [PATCH 4/7] to make tests work for < 1.5 versions --- tests/fixtures/dbt/sushi_test/profiles.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/fixtures/dbt/sushi_test/profiles.yml b/tests/fixtures/dbt/sushi_test/profiles.yml index 056c3c2b91..f49ad8ea0f 100644 --- a/tests/fixtures/dbt/sushi_test/profiles.yml +++ b/tests/fixtures/dbt/sushi_test/profiles.yml @@ -3,6 +3,7 @@ sushi: in_memory: type: duckdb schema: sushi + database: memory duckdb: type: duckdb path: 'local.duckdb' From 99d99cc95726bce2380ec3b03cd02d82dbb85c68 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Mon, 29 Sep 2025 11:37:49 +0300 Subject: [PATCH 5/7] adapt tests to work for earlier dbt versions --- tests/dbt/test_manifest.py | 2 +- .../macros/materializations/custom_incremental.sql | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/dbt/test_manifest.py b/tests/dbt/test_manifest.py index e6c02bcb4c..e2e7bc706c 100644 --- a/tests/dbt/test_manifest.py +++ b/tests/dbt/test_manifest.py @@ -232,7 +232,7 @@ def test_source_meta_external_location(): expected = ( "read_parquet('path/to/external/items.parquet')" if DBT_VERSION >= (1, 4, 0) - else '"main"."parquet_file".items' + else '"memory"."parquet_file".items' ) assert relation.render() == expected diff --git a/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql b/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql index afb53bc1c6..d39453f1c6 100644 --- a/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql +++ b/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql @@ -24,7 +24,8 @@ {%- if existing_relation is none -%} {# The first insert creates new table if it doesn't exist #} {%- call statement('main') -%} - {{ create_table_as(False, new_relation, sql) }} + CREATE TABLE {{ new_relation }} + AS {{ sql }} {%- endcall -%} {%- else -%} {# Incremental load, appending new data with optional time filtering #} @@ -40,6 +41,8 @@ {%- call statement('create_temp') -%} {{ create_table_as(True, temp_relation, filtered_sql) }} + CREATE TABLE {{ temp_relation }} + AS {{ filtered_sql }} {%- endcall -%} {%- call statement('insert') -%} From 4c367cc7b173d94f4c409bab097d88dd4d15b44c Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Mon, 29 Sep 2025 14:16:59 +0300 Subject: [PATCH 6/7] handle leading whitespace --- sqlmesh/dbt/manifest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index 2eb215cb1b..f8e6e01fc4 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -771,7 +771,7 @@ def _convert_jinja_test_to_macro(test_jinja: str) -> str: def _strip_jinja_materialization_tags(materialization_jinja: str) -> str: - MATERIALIZATION_TAG_REGEX = r"{%-?\s*materialization\s+[^%]*%}\s*\n?" + MATERIALIZATION_TAG_REGEX = r"\s*{%-?\s*materialization\s+[^%]*%}\s*\n?" ENDMATERIALIZATION_REGEX = r"{%-?\s*endmaterialization\s*-?%}\s*\n?" if not re.match(MATERIALIZATION_TAG_REGEX, materialization_jinja): From 7de824704b7fa797f43514d9beb131138f9470ce Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Mon, 29 Sep 2025 13:14:12 -0700 Subject: [PATCH 7/7] Use native Macro Registry APIs (#5452) --- sqlmesh/core/snapshot/evaluator.py | 40 +++++++++---------- sqlmesh/dbt/adapter.py | 18 --------- sqlmesh/dbt/builtin.py | 2 +- sqlmesh/dbt/manifest.py | 6 +++ sqlmesh/utils/jinja.py | 1 + tests/dbt/test_custom_materializations.py | 2 +- .../materializations/custom_incremental.sql | 6 +-- 7 files changed, 32 insertions(+), 43 deletions(-) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 4770a9f650..4ac87199c6 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -67,7 +67,7 @@ SnapshotTableCleanupTask, ) from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker -from sqlmesh.utils import random_id, CorrelationId +from sqlmesh.utils import random_id, CorrelationId, AttributeDict from sqlmesh.utils.concurrency import ( concurrent_apply_to_snapshots, concurrent_apply_to_values, @@ -2731,12 +2731,12 @@ def _execute_materialization( **kwargs: t.Any, ) -> None: jinja_macros = model.jinja_macros - existing_globals = jinja_macros.global_objs.copy() # For vdes we need to use the table, since we don't know the schema/table at parse time parts = exp.to_table(table_name, dialect=self.adapter.dialect) - relation_info = existing_globals.pop("this") + existing_globals = jinja_macros.global_objs + relation_info = existing_globals.get("this") if isinstance(relation_info, dict): relation_info["database"] = parts.catalog relation_info["identifier"] = parts.name @@ -2750,29 +2750,29 @@ def _execute_materialization( "identifier": parts.name, "target": existing_globals.get("target", {"type": self.adapter.dialect}), "execution_dt": kwargs.get("execution_time"), + "engine_adapter": self.adapter, + "sql": str(query_or_df), + "is_first_insert": is_first_insert, + "create_only": create_only, + # FIXME: Add support for transaction=False + "pre_hooks": [ + AttributeDict({"sql": s.this.this, "transaction": True}) + for s in model.pre_statements + ], + "post_hooks": [ + AttributeDict({"sql": s.this.this, "transaction": True}) + for s in model.post_statements + ], + "model_instance": model, + **kwargs, } - context = jinja_macros._create_builtin_globals( - {"engine_adapter": self.adapter, **jinja_globals} - ) - - context.update( - { - "sql": str(query_or_df), - "is_first_insert": is_first_insert, - "create_only": create_only, - "pre_hooks": model.render_pre_statements(**render_kwargs), - "post_hooks": model.render_post_statements(**render_kwargs), - **kwargs, - } - ) - try: - jinja_env = jinja_macros.build_environment(**context) + jinja_env = jinja_macros.build_environment(**jinja_globals) template = jinja_env.from_string(self.materialization_template) try: - template.render(**context) + template.render() except MacroReturnVal as ret: # this is a successful return from a macro call (dbt uses this list of Relations to update their relation cache) returned_relations = ret.value.get("relations", []) diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index a8b2b9af72..7f7c7eb4fb 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -99,12 +99,6 @@ def execute( ) -> t.Tuple[AdapterResponse, agate.Table]: """Executes the given SQL statement and returns the results as an agate table.""" - @abc.abstractmethod - def run_hooks( - self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True - ) -> None: - """Executes the given hooks.""" - @abc.abstractmethod def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]: """Resolves the relation's schema to its physical schema.""" @@ -247,12 +241,6 @@ def execute( self._raise_parsetime_adapter_call_error("execute SQL") raise - def run_hooks( - self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True - ) -> None: - self._raise_parsetime_adapter_call_error("run hooks") - raise - def resolve_schema(self, relation: BaseRelation) -> t.Optional[str]: return relation.schema @@ -463,12 +451,6 @@ def resolve_identifier(self, relation: BaseRelation) -> t.Optional[str]: identifier = self._map_table_name(self._normalize(self._relation_to_table(relation))).name return identifier if identifier else None - def run_hooks( - self, hooks: t.List[str | exp.Expression], inside_transaction: bool = True - ) -> None: - # inside_transaction not yet supported similarly to transaction - self.engine_adapter.execute([exp.maybe_parse(hook) for hook in hooks]) - def _map_table_name(self, table: exp.Table) -> exp.Table: # Use the default dialect since this is the dialect used to normalize and quote keys in the # mapping table. diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 8690eb91fa..b8180bc011 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -544,9 +544,9 @@ def create_builtin_globals( "load_result": sql_execution.load_result, "run_query": sql_execution.run_query, "statement": sql_execution.statement, - "run_hooks": adapter.run_hooks, "graph": adapter.graph, "selected_resources": list(jinja_globals.get("selected_models") or []), + "write": lambda input: None, # We don't support writing yet } ) diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index f8e6e01fc4..17c5e91700 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -395,6 +395,12 @@ def _load_models_and_seeds(self) -> None: dependencies = dependencies.union( self._extra_dependencies(sql, node.package_name, track_all_model_attrs=True) ) + for hook in [*node_config.get("pre-hook", []), *node_config.get("post-hook", [])]: + dependencies = dependencies.union( + self._extra_dependencies( + hook["sql"], node.package_name, track_all_model_attrs=True + ) + ) dependencies = dependencies.union( self._flatten_dependencies_from_macros(dependencies.macros, node.package_name) ) diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index 508c6dce2d..59e9f6dd2f 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -369,6 +369,7 @@ def build_environment(self, **kwargs: t.Any) -> Environment: context.update(builtin_globals) context.update(root_macros) context.update(package_macros) + context["render"] = lambda input: env.from_string(input).render() env.globals.update(context) env.filters.update(self._environment.filters) diff --git a/tests/dbt/test_custom_materializations.py b/tests/dbt/test_custom_materializations.py index bd961136d2..9e7a94315c 100644 --- a/tests/dbt/test_custom_materializations.py +++ b/tests/dbt/test_custom_materializations.py @@ -37,7 +37,7 @@ def test_custom_materialization_manifest_loading(): assert custom_incremental.name == "custom_incremental" assert custom_incremental.adapter == "default" assert "make_temp_relation(new_relation)" in custom_incremental.definition - assert "run_hooks(pre_hooks, inside_transaction=False)" in custom_incremental.definition + assert "run_hooks(pre_hooks)" in custom_incremental.definition assert " {{ return({'relations': [new_relation]}) }}" in custom_incremental.definition diff --git a/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql b/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql index d39453f1c6..c61899c8ff 100644 --- a/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql +++ b/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql @@ -19,7 +19,7 @@ {%- set time_column = config.get('time_column') -%} {%- set interval_config = config.get('interval') -%} - {{ run_hooks(pre_hooks, inside_transaction=False) }} + {{ run_hooks(pre_hooks) }} {%- if existing_relation is none -%} {# The first insert creates new table if it doesn't exist #} @@ -55,7 +55,7 @@ {%- endcall -%} {%- endif -%} - {{ run_hooks(post_hooks, inside_transaction=False) }} + {{ run_hooks(post_hooks) }} {{ return({'relations': [new_relation]}) }} -{%- endmaterialization -%} \ No newline at end of file +{%- endmaterialization -%}