From 51a58384763fb438f69fd53bd8b033a0fa267090 Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Mon, 1 Sep 2025 13:51:34 -0700 Subject: [PATCH] chore: add dbt microbatch interface --- sqlmesh/dbt/common.py | 8 +- sqlmesh/dbt/model.py | 130 +++++++++++++++++++------ tests/dbt/test_model.py | 209 ++++++++++++++++++++++++++++++++-------- 3 files changed, 278 insertions(+), 69 deletions(-) diff --git a/sqlmesh/dbt/common.py b/sqlmesh/dbt/common.py index c74fd933da..ba982c2bb2 100644 --- a/sqlmesh/dbt/common.py +++ b/sqlmesh/dbt/common.py @@ -132,6 +132,10 @@ def _validate_meta(cls, v: t.Dict[str, t.Union[str, t.Any]]) -> t.Dict[str, t.An def config_attribute_dict(self) -> AttributeDict[str, t.Any]: return AttributeDict(self.dict(exclude=EXCLUDED_CONFIG_ATTRIBUTE_KEYS)) + def _get_field_value(self, field: str) -> t.Optional[t.Any]: + field_val = getattr(self, field, None) + return field_val if field_val is not None else self.meta.get(field, None) + def replace(self, other: T) -> None: """ Replace the contents of this instance with the passed in instance. @@ -152,9 +156,7 @@ def sqlmesh_config_kwargs(self) -> t.Dict[str, t.Any]: """ kwargs = {} for field in self.sqlmesh_config_fields: - field_val = getattr(self, field, None) - if field_val is None: - field_val = self.meta.get(field, None) + field_val = self._get_field_value(field) if field_val is not None: kwargs[field] = field_val return kwargs diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index d2d1a52abc..080900eace 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import typing as t from sqlglot import exp @@ -34,7 +35,7 @@ from sqlmesh.dbt.context import DbtContext -INCREMENTAL_BY_TIME_STRATEGIES = set(["delete+insert", "insert_overwrite"]) +INCREMENTAL_BY_TIME_STRATEGIES = set(["delete+insert", "insert_overwrite", "microbatch"]) INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES = set(["merge"]) @@ -73,9 +74,7 @@ class ModelConfig(BaseModelConfig): time_column: t.Optional[str] = None cron: t.Optional[str] = None interval_unit: t.Optional[str] = None - batch_size: t.Optional[int] = None batch_concurrency: t.Optional[int] = None - lookback: t.Optional[int] = None forward_only: bool = True disable_restatement: t.Optional[bool] = None allow_partials: t.Optional[bool] = None @@ -100,6 +99,15 @@ class ModelConfig(BaseModelConfig): target_schema: t.Optional[str] = None check_cols: t.Optional[t.Union[t.List[str], str]] = None + # Microbatch Fields + event_time: t.Optional[str] = None + begin: t.Optional[datetime.datetime] = None + concurrent_batches: t.Optional[bool] = None + + # Shared SQLMesh and DBT configuration fields + batch_size: t.Optional[t.Union[int, str]] = None + lookback: t.Optional[int] = None + # redshift bind: t.Optional[bool] = None @@ -220,6 +228,17 @@ def snapshot_strategy(self) -> t.Optional[SnapshotStrategy]: def table_schema(self) -> str: return self.target_schema or super().table_schema + def _get_overlapping_field_value( + self, context: DbtContext, dbt_field_name: str, sqlmesh_field_name: str + ) -> t.Optional[t.Any]: + dbt_field = self._get_field_value(dbt_field_name) + sqlmesh_field = getattr(self, sqlmesh_field_name, None) + if dbt_field is not None and sqlmesh_field is not None: + get_console().log_warning( + f"Both '{dbt_field_name}' and '{sqlmesh_field_name}' are set for model '{self.canonical_name(context)}'. '{sqlmesh_field_name}' will be used." + ) + return sqlmesh_field if sqlmesh_field is not None else dbt_field + def model_kind(self, context: DbtContext) -> ModelKind: """ Get the sqlmesh ModelKind @@ -256,12 +275,9 @@ def model_kind(self, context: DbtContext) -> ModelKind: incremental_kind_kwargs["on_destructive_change"] = on_destructive_change incremental_kind_kwargs["on_additive_change"] = on_additive_change - for field in ("forward_only", "auto_restatement_cron"): - field_val = getattr(self, field, None) - if field_val is None: - field_val = self.meta.get(field, None) - if field_val is not None: - incremental_kind_kwargs[field] = field_val + auto_restatement_cron_value = self._get_field_value("auto_restatement_cron") + if auto_restatement_cron_value is not None: + incremental_kind_kwargs["auto_restatement_cron"] = auto_restatement_cron_value if materialization == Materialization.TABLE: return FullKind() @@ -269,14 +285,34 @@ def model_kind(self, context: DbtContext) -> ModelKind: return ViewKind() if materialization == Materialization.INCREMENTAL: incremental_by_kind_kwargs: t.Dict[str, t.Any] = {"dialect": self.dialect(context)} + forward_only_value = self._get_field_value("forward_only") + if forward_only_value is not None: + incremental_kind_kwargs["forward_only"] = forward_only_value + + is_incremental_by_time_range = self.time_column or ( + self.incremental_strategy and self.incremental_strategy == "microbatch" + ) + # Get shared incremental by kwargs for field in ("batch_size", "batch_concurrency", "lookback"): - field_val = getattr(self, field, None) - if field_val is None: - field_val = self.meta.get(field, None) + field_val = self._get_field_value(field) if field_val is not None: + # Check if `batch_size` is representing an interval unit and if so that will be handled at the model level + if field == "batch_size" and isinstance(field_val, str): + continue incremental_by_kind_kwargs[field] = field_val - if self.time_column: + disable_restatement = self.disable_restatement + if disable_restatement is None: + if is_incremental_by_time_range: + disable_restatement = False + else: + disable_restatement = ( + not self.full_refresh if self.full_refresh is not None else False + ) + incremental_by_kind_kwargs["disable_restatement"] = disable_restatement + + # Incremental by time range which includes microbatch + if is_incremental_by_time_range: strategy = self.incremental_strategy or target.default_incremental_strategy( IncrementalByTimeRangeKind ) @@ -287,22 +323,37 @@ def model_kind(self, context: DbtContext) -> ModelKind: f"Supported strategies include {collection_to_str(INCREMENTAL_BY_TIME_STRATEGIES)}." ) + if strategy == "microbatch": + time_column = self._get_overlapping_field_value( + context, "event_time", "time_column" + ) + if not time_column: + raise ConfigError( + f"{self.canonical_name(context)}: 'event_time' is required for microbatch incremental strategy." + ) + concurrent_batches = self._get_field_value("concurrent_batches") + if concurrent_batches is True: + if incremental_by_kind_kwargs.get("batch_size"): + get_console().log_warning( + f"'concurrent_batches' is set to True and 'batch_size' are defined in '{self.canonical_name(context)}'. The batch size will be set to the value of `batch_size`." + ) + incremental_by_kind_kwargs["batch_size"] = incremental_by_kind_kwargs.get( + "batch_size", 1 + ) + else: + if not self.time_column: + raise ConfigError( + f"{self.canonical_name(context)}: 'time_column' is required for incremental by time range models not defined using microbatch." + ) + time_column = self.time_column + return IncrementalByTimeRangeKind( - time_column=self.time_column, - disable_restatement=( - self.disable_restatement if self.disable_restatement is not None else False - ), + time_column=time_column, auto_restatement_intervals=self.auto_restatement_intervals, **incremental_kind_kwargs, **incremental_by_kind_kwargs, ) - disable_restatement = self.disable_restatement - if disable_restatement is None: - disable_restatement = ( - not self.full_refresh if self.full_refresh is not None else False - ) - if self.unique_key: strategy = self.incremental_strategy or target.default_incremental_strategy( IncrementalByUniqueKeyKind @@ -315,11 +366,11 @@ def model_kind(self, context: DbtContext) -> ModelKind: f"Unique key is not compatible with '{strategy}' incremental strategy in model '{self.canonical_name(context)}'. " f"Supported strategies include {collection_to_str(INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES)}. Falling back to 'merge' strategy." ) - strategy = "merge" + merge_filter = None if self.incremental_predicates: dialect = self.dialect(context) - incremental_kind_kwargs["merge_filter"] = exp.and_( + merge_filter = exp.and_( *[ d.parse_one(predicate, dialect=dialect) for predicate in self.incremental_predicates @@ -329,7 +380,7 @@ def model_kind(self, context: DbtContext) -> ModelKind: return IncrementalByUniqueKeyKind( unique_key=self.unique_key, - disable_restatement=disable_restatement, + merge_filter=merge_filter, **incremental_kind_kwargs, **incremental_by_kind_kwargs, ) @@ -339,7 +390,7 @@ def model_kind(self, context: DbtContext) -> ModelKind: ) return IncrementalUnmanagedKind( insert_overwrite=strategy in INCREMENTAL_BY_TIME_STRATEGIES, - disable_restatement=disable_restatement, + disable_restatement=incremental_by_kind_kwargs["disable_restatement"], **incremental_kind_kwargs, ) if materialization == Materialization.EPHEMERAL: @@ -438,6 +489,9 @@ def sqlmesh_config_fields(self) -> t.Set[str]: "interval_unit", "allow_partials", "physical_version", + "start", + # In microbatch models `begin` is the same as `start` + "begin", } def to_sqlmesh( @@ -583,12 +637,32 @@ def to_sqlmesh( # Set allow_partials to True for dbt models to preserve the original semantics. allow_partials = True + if kind.is_incremental: + if self.batch_size and isinstance(self.batch_size, str): + if "interval_unit" in model_kwargs: + get_console().log_warning( + f"Both 'interval_unit' and 'batch_size' are set for model '{self.canonical_name(context)}'. 'interval_unit' will be used." + ) + else: + model_kwargs["interval_unit"] = self.batch_size + self.batch_size = None + if begin := model_kwargs.pop("begin", None): + if "start" in model_kwargs: + get_console().log_warning( + f"Both 'begin' and 'start' are set for model '{self.canonical_name(context)}'. 'start' will be used." + ) + else: + model_kwargs["start"] = begin + + model_kwargs["start"] = model_kwargs.get( + "start", context.sqlmesh_config.model_defaults.start + ) + model = create_sql_model( self.canonical_name(context), query, dialect=model_dialect, kind=kind, - start=self.start or context.sqlmesh_config.model_defaults.start, audit_definitions=audit_definitions, # This ensures that we bypass query rendering that would otherwise be required to extract additional # dependencies from the model's SQL. diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index e33c41e68c..7d4672c512 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -1,8 +1,12 @@ +import datetime +import typing as t import pytest from pathlib import Path +from sqlglot import exp from sqlmesh import Context +from sqlmesh.core.model import TimeColumn, IncrementalByTimeRangeKind from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.model import ModelConfig @@ -13,6 +17,49 @@ pytestmark = pytest.mark.dbt +@pytest.fixture +def create_empty_project(tmp_path: Path) -> t.Callable[[], t.Tuple[Path, Path]]: + def _create_empty_project() -> t.Tuple[Path, Path]: + yaml = YAML() + dbt_project_dir = tmp_path / "dbt" + dbt_project_dir.mkdir() + dbt_model_dir = dbt_project_dir / "models" + dbt_model_dir.mkdir() + dbt_project_config = { + "name": "empty_project", + "version": "1.0.0", + "config-version": 2, + "profile": "test", + "model-paths": ["models"], + } + dbt_project_file = dbt_project_dir / "dbt_project.yml" + with open(dbt_project_file, "w", encoding="utf-8") as f: + YAML().dump(dbt_project_config, f) + sqlmesh_config = { + "model_defaults": { + "start": "2025-01-01", + } + } + sqlmesh_config_file = dbt_project_dir / "sqlmesh.yaml" + with open(sqlmesh_config_file, "w", encoding="utf-8") as f: + YAML().dump(sqlmesh_config, f) + dbt_data_dir = tmp_path / "dbt_data" + dbt_data_dir.mkdir() + dbt_data_file = dbt_data_dir / "local.db" + dbt_profile_config = { + "test": { + "outputs": {"duckdb": {"type": "duckdb", "path": str(dbt_data_file)}}, + "target": "duckdb", + } + } + db_profile_file = dbt_project_dir / "profiles.yml" + with open(db_profile_file, "w", encoding="utf-8") as f: + yaml.dump(dbt_profile_config, f) + return dbt_project_dir, dbt_model_dir + + return _create_empty_project + + def test_model_test_circular_references() -> None: upstream_model = ModelConfig(name="upstream") downstream_model = ModelConfig(name="downstream", dependencies=Dependencies(refs={"upstream"})) @@ -68,16 +115,13 @@ def test_model_test_circular_references() -> None: @pytest.mark.slow def test_load_invalid_ref_audit_constraints( - tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project ) -> None: yaml = YAML() - dbt_project_dir = tmp_path / "dbt" - dbt_project_dir.mkdir() - dbt_model_dir = dbt_project_dir / "models" - dbt_model_dir.mkdir() + project_dir, model_dir = create_empty_project() # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it full_model_contents = """{{ config(tags=["blah"], tests=[{"blah": {"to": "ref('completely_ignored')", "field": "blah2"} }]) }} SELECT 1 as cola""" - full_model_file = dbt_model_dir / "full_model.sql" + full_model_file = model_dir / "full_model.sql" with open(full_model_file, "w", encoding="utf-8") as f: f.write(full_model_contents) model_schema = { @@ -118,41 +162,11 @@ def test_load_invalid_ref_audit_constraints( } ], } - model_schema_file = dbt_model_dir / "schema.yml" + model_schema_file = model_dir / "schema.yml" with open(model_schema_file, "w", encoding="utf-8") as f: yaml.dump(model_schema, f) - dbt_project_config = { - "name": "invalid_ref_audit_constraints", - "version": "1.0.0", - "config-version": 2, - "profile": "test", - "model-paths": ["models"], - } - dbt_project_file = dbt_project_dir / "dbt_project.yml" - with open(dbt_project_file, "w", encoding="utf-8") as f: - yaml.dump(dbt_project_config, f) - sqlmesh_config = { - "model_defaults": { - "start": "2025-01-01", - } - } - sqlmesh_config_file = dbt_project_dir / "sqlmesh.yaml" - with open(sqlmesh_config_file, "w", encoding="utf-8") as f: - yaml.dump(sqlmesh_config, f) - dbt_data_dir = tmp_path / "dbt_data" - dbt_data_dir.mkdir() - dbt_data_file = dbt_data_dir / "local.db" - dbt_profile_config = { - "test": { - "outputs": {"duckdb": {"type": "duckdb", "path": str(dbt_data_file)}}, - "target": "duckdb", - } - } - db_profile_file = dbt_project_dir / "profiles.yml" - with open(db_profile_file, "w", encoding="utf-8") as f: - yaml.dump(dbt_profile_config, f) - context = Context(paths=dbt_project_dir) + context = Context(paths=project_dir) assert ( "Skipping audit 'relationships_full_model_cola__cola__ref_not_real_model_' because model 'not_real_model' is not a valid ref" in caplog.text @@ -165,3 +179,122 @@ def test_load_invalid_ref_audit_constraints( assert fqn in context.snapshots # The audit isn't loaded due to the invalid ref assert context.snapshots[fqn].model.audits == [] + + +@pytest.mark.slow +def test_load_microbatch_all_defined( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + project_dir, model_dir = create_empty_project() + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + microbatch_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + event_time='ds', + begin='2020-01-01', + batch_size='day', + lookback=2, + concurrent_batches=true + ) + }} + + SELECT 1 as cola, '2025-01-01' as ds + """ + microbatch_model_file = model_dir / "microbatch.sql" + with open(microbatch_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_contents) + + snapshot_fqn = '"local"."main"."microbatch"' + context = Context(paths=project_dir) + model = context.snapshots[snapshot_fqn].model + # Validate model-level attributes + assert model.start == datetime.datetime(2020, 1, 1, 0, 0) + assert model.interval_unit.is_day + # Validate model kind attributes + assert isinstance(model.kind, IncrementalByTimeRangeKind) + assert model.kind.lookback == 2 + assert model.kind.time_column == TimeColumn( + column=exp.to_column("ds", quoted=True), format="%Y-%m-%d" + ) + assert model.kind.batch_size == 1 + + +@pytest.mark.slow +def test_load_microbatch_all_defined_diff_values( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + project_dir, model_dir = create_empty_project() + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + microbatch_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + cron='@yearly', + event_time='blah', + begin='2022-01-01', + batch_size='year', + lookback=20, + concurrent_batches=false + ) + }} + + SELECT 1 as cola, '2022-01-01' as blah + """ + microbatch_model_file = model_dir / "microbatch.sql" + with open(microbatch_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_contents) + + snapshot_fqn = '"local"."main"."microbatch"' + context = Context(paths=project_dir) + model = context.snapshots[snapshot_fqn].model + # Validate model-level attributes + assert model.start == datetime.datetime(2022, 1, 1, 0, 0) + assert model.interval_unit.is_year + # Validate model kind attributes + assert isinstance(model.kind, IncrementalByTimeRangeKind) + assert model.kind.lookback == 20 + assert model.kind.time_column == TimeColumn( + column=exp.to_column("blah", quoted=True), format="%Y-%m-%d" + ) + assert model.kind.batch_size is None + + +@pytest.mark.slow +def test_load_microbatch_required_only( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + project_dir, model_dir = create_empty_project() + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + microbatch_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + begin='2021-01-01', + event_time='ds', + batch_size='hour', + ) + }} + + SELECT 1 as cola, '2021-01-01' as ds + """ + microbatch_model_file = model_dir / "microbatch.sql" + with open(microbatch_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_contents) + + snapshot_fqn = '"local"."main"."microbatch"' + context = Context(paths=project_dir) + model = context.snapshots[snapshot_fqn].model + # Validate model-level attributes + assert model.start == datetime.datetime(2021, 1, 1, 0, 0) + assert model.interval_unit.is_hour + # Validate model kind attributes + assert isinstance(model.kind, IncrementalByTimeRangeKind) + assert model.kind.lookback == 1 + assert model.kind.time_column == TimeColumn( + column=exp.to_column("ds", quoted=True), format="%Y-%m-%d" + ) + assert model.kind.batch_size is None