Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions sqlmesh/dbt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def _validate_meta(cls, v: t.Dict[str, t.Union[str, t.Any]]) -> t.Dict[str, t.An
def config_attribute_dict(self) -> AttributeDict[str, t.Any]:
return AttributeDict(self.dict(exclude=EXCLUDED_CONFIG_ATTRIBUTE_KEYS))

def _get_field_value(self, field: str) -> t.Optional[t.Any]:
field_val = getattr(self, field, None)
return field_val if field_val is not None else self.meta.get(field, None)

def replace(self, other: T) -> None:
"""
Replace the contents of this instance with the passed in instance.
Expand All @@ -152,9 +156,7 @@ def sqlmesh_config_kwargs(self) -> t.Dict[str, t.Any]:
"""
kwargs = {}
for field in self.sqlmesh_config_fields:
field_val = getattr(self, field, None)
if field_val is None:
field_val = self.meta.get(field, None)
field_val = self._get_field_value(field)
if field_val is not None:
kwargs[field] = field_val
return kwargs
Expand Down
130 changes: 102 additions & 28 deletions sqlmesh/dbt/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
import typing as t

from sqlglot import exp
Expand Down Expand Up @@ -34,7 +35,7 @@
from sqlmesh.dbt.context import DbtContext


INCREMENTAL_BY_TIME_STRATEGIES = set(["delete+insert", "insert_overwrite"])
INCREMENTAL_BY_TIME_STRATEGIES = set(["delete+insert", "insert_overwrite", "microbatch"])
INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES = set(["merge"])


Expand Down Expand Up @@ -73,9 +74,7 @@ class ModelConfig(BaseModelConfig):
time_column: t.Optional[str] = None
cron: t.Optional[str] = None
interval_unit: t.Optional[str] = None
batch_size: t.Optional[int] = None
batch_concurrency: t.Optional[int] = None
lookback: t.Optional[int] = None
forward_only: bool = True
disable_restatement: t.Optional[bool] = None
allow_partials: t.Optional[bool] = None
Expand All @@ -100,6 +99,15 @@ class ModelConfig(BaseModelConfig):
target_schema: t.Optional[str] = None
check_cols: t.Optional[t.Union[t.List[str], str]] = None

# Microbatch Fields
event_time: t.Optional[str] = None
begin: t.Optional[datetime.datetime] = None
concurrent_batches: t.Optional[bool] = None

# Shared SQLMesh and DBT configuration fields
batch_size: t.Optional[t.Union[int, str]] = None
lookback: t.Optional[int] = None

# redshift
bind: t.Optional[bool] = None

Expand Down Expand Up @@ -220,6 +228,17 @@ def snapshot_strategy(self) -> t.Optional[SnapshotStrategy]:
def table_schema(self) -> str:
return self.target_schema or super().table_schema

def _get_overlapping_field_value(
self, context: DbtContext, dbt_field_name: str, sqlmesh_field_name: str
) -> t.Optional[t.Any]:
dbt_field = self._get_field_value(dbt_field_name)
sqlmesh_field = getattr(self, sqlmesh_field_name, None)
if dbt_field is not None and sqlmesh_field is not None:
get_console().log_warning(
f"Both '{dbt_field_name}' and '{sqlmesh_field_name}' are set for model '{self.canonical_name(context)}'. '{sqlmesh_field_name}' will be used."
)
return sqlmesh_field if sqlmesh_field is not None else dbt_field

def model_kind(self, context: DbtContext) -> ModelKind:
"""
Get the sqlmesh ModelKind
Expand Down Expand Up @@ -256,27 +275,44 @@ def model_kind(self, context: DbtContext) -> ModelKind:

incremental_kind_kwargs["on_destructive_change"] = on_destructive_change
incremental_kind_kwargs["on_additive_change"] = on_additive_change
for field in ("forward_only", "auto_restatement_cron"):
field_val = getattr(self, field, None)
if field_val is None:
field_val = self.meta.get(field, None)
if field_val is not None:
incremental_kind_kwargs[field] = field_val
Comment on lines -259 to -264
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This used to add forward_only to the incremental_kind_kwargs when it should really be part of incremental_by_kind_kwargs. This PR fixes that.

auto_restatement_cron_value = self._get_field_value("auto_restatement_cron")
if auto_restatement_cron_value is not None:
incremental_kind_kwargs["auto_restatement_cron"] = auto_restatement_cron_value

if materialization == Materialization.TABLE:
return FullKind()
if materialization == Materialization.VIEW:
return ViewKind()
if materialization == Materialization.INCREMENTAL:
incremental_by_kind_kwargs: t.Dict[str, t.Any] = {"dialect": self.dialect(context)}
forward_only_value = self._get_field_value("forward_only")
if forward_only_value is not None:
incremental_kind_kwargs["forward_only"] = forward_only_value

is_incremental_by_time_range = self.time_column or (
self.incremental_strategy and self.incremental_strategy == "microbatch"
)
# Get shared incremental by kwargs
for field in ("batch_size", "batch_concurrency", "lookback"):
field_val = getattr(self, field, None)
if field_val is None:
field_val = self.meta.get(field, None)
field_val = self._get_field_value(field)
if field_val is not None:
# Check if `batch_size` is representing an interval unit and if so that will be handled at the model level
if field == "batch_size" and isinstance(field_val, str):
continue
incremental_by_kind_kwargs[field] = field_val

if self.time_column:
disable_restatement = self.disable_restatement
if disable_restatement is None:
if is_incremental_by_time_range:
disable_restatement = False
else:
disable_restatement = (
not self.full_refresh if self.full_refresh is not None else False
)
incremental_by_kind_kwargs["disable_restatement"] = disable_restatement
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable_restatement is part of IncrementBy so it should be part of incremental_by_kind_kwargs and this PR adds it to it.


# Incremental by time range which includes microbatch
if is_incremental_by_time_range:
strategy = self.incremental_strategy or target.default_incremental_strategy(
IncrementalByTimeRangeKind
)
Expand All @@ -287,22 +323,37 @@ def model_kind(self, context: DbtContext) -> ModelKind:
f"Supported strategies include {collection_to_str(INCREMENTAL_BY_TIME_STRATEGIES)}."
)

if strategy == "microbatch":
time_column = self._get_overlapping_field_value(
context, "event_time", "time_column"
)
if not time_column:
raise ConfigError(
f"{self.canonical_name(context)}: 'event_time' is required for microbatch incremental strategy."
)
concurrent_batches = self._get_field_value("concurrent_batches")
if concurrent_batches is True:
if incremental_by_kind_kwargs.get("batch_size"):
get_console().log_warning(
f"'concurrent_batches' is set to True and 'batch_size' are defined in '{self.canonical_name(context)}'. The batch size will be set to the value of `batch_size`."
)
incremental_by_kind_kwargs["batch_size"] = incremental_by_kind_kwargs.get(
"batch_size", 1
)
else:
if not self.time_column:
raise ConfigError(
f"{self.canonical_name(context)}: 'time_column' is required for incremental by time range models not defined using microbatch."
)
time_column = self.time_column

return IncrementalByTimeRangeKind(
time_column=self.time_column,
disable_restatement=(
self.disable_restatement if self.disable_restatement is not None else False
),
time_column=time_column,
auto_restatement_intervals=self.auto_restatement_intervals,
**incremental_kind_kwargs,
**incremental_by_kind_kwargs,
)

disable_restatement = self.disable_restatement
if disable_restatement is None:
disable_restatement = (
not self.full_refresh if self.full_refresh is not None else False
)

if self.unique_key:
strategy = self.incremental_strategy or target.default_incremental_strategy(
IncrementalByUniqueKeyKind
Expand All @@ -315,11 +366,11 @@ def model_kind(self, context: DbtContext) -> ModelKind:
f"Unique key is not compatible with '{strategy}' incremental strategy in model '{self.canonical_name(context)}'. "
f"Supported strategies include {collection_to_str(INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES)}. Falling back to 'merge' strategy."
)
strategy = "merge"

merge_filter = None
if self.incremental_predicates:
dialect = self.dialect(context)
incremental_kind_kwargs["merge_filter"] = exp.and_(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge_filter is not part of the incremental kind model so this PR fixes that.

merge_filter = exp.and_(
*[
d.parse_one(predicate, dialect=dialect)
for predicate in self.incremental_predicates
Expand All @@ -329,7 +380,7 @@ def model_kind(self, context: DbtContext) -> ModelKind:

return IncrementalByUniqueKeyKind(
unique_key=self.unique_key,
disable_restatement=disable_restatement,
merge_filter=merge_filter,
**incremental_kind_kwargs,
**incremental_by_kind_kwargs,
)
Expand All @@ -339,7 +390,7 @@ def model_kind(self, context: DbtContext) -> ModelKind:
)
return IncrementalUnmanagedKind(
insert_overwrite=strategy in INCREMENTAL_BY_TIME_STRATEGIES,
disable_restatement=disable_restatement,
disable_restatement=incremental_by_kind_kwargs["disable_restatement"],
**incremental_kind_kwargs,
)
if materialization == Materialization.EPHEMERAL:
Expand Down Expand Up @@ -438,6 +489,9 @@ def sqlmesh_config_fields(self) -> t.Set[str]:
"interval_unit",
"allow_partials",
"physical_version",
"start",
# In microbatch models `begin` is the same as `start`
"begin",
}

def to_sqlmesh(
Expand Down Expand Up @@ -583,12 +637,32 @@ def to_sqlmesh(
# Set allow_partials to True for dbt models to preserve the original semantics.
allow_partials = True

if kind.is_incremental:
if self.batch_size and isinstance(self.batch_size, str):
if "interval_unit" in model_kwargs:
get_console().log_warning(
f"Both 'interval_unit' and 'batch_size' are set for model '{self.canonical_name(context)}'. 'interval_unit' will be used."
)
else:
model_kwargs["interval_unit"] = self.batch_size
self.batch_size = None
if begin := model_kwargs.pop("begin", None):
if "start" in model_kwargs:
get_console().log_warning(
f"Both 'begin' and 'start' are set for model '{self.canonical_name(context)}'. 'start' will be used."
)
else:
model_kwargs["start"] = begin

model_kwargs["start"] = model_kwargs.get(
"start", context.sqlmesh_config.model_defaults.start
)

model = create_sql_model(
self.canonical_name(context),
query,
dialect=model_dialect,
kind=kind,
start=self.start or context.sqlmesh_config.model_defaults.start,
audit_definitions=audit_definitions,
# This ensures that we bypass query rendering that would otherwise be required to extract additional
# dependencies from the model's SQL.
Expand Down
Loading