diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index f47283d06e..3cc9b301bc 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -172,6 +172,22 @@ def _validate_check_cols(cls, v: t.Union[str, t.List[str]]) -> t.Union[str, t.Li return "*" return ensure_list(v) + @field_validator("updated_at", mode="before") + @classmethod + def _validate_updated_at(cls, v: t.Optional[str]) -> t.Optional[str]: + """ + Extract column name if updated_at contains a cast. + + SCDType2ByTimeKind and SCDType2ByColumnKind expect a column, and the casting is done later. + """ + if v is None: + return None + parsed = d.parse_one(v) + if isinstance(parsed, exp.Cast) and isinstance(parsed.this, exp.Column): + return parsed.this.name + + return v + @field_validator("sql", mode="before") @classmethod def _validate_sql(cls, v: t.Union[str, SqlStr]) -> SqlStr: diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index a33e3ed843..141c160e7e 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -652,6 +652,23 @@ def test_model_kind(): == ManagedKind() ) + assert ModelConfig( + materialized=Materialization.SNAPSHOT, + unique_key=["id"], + updated_at="updated_at::timestamp", + strategy="timestamp", + dialect="redshift", + ).model_kind(context) == SCDType2ByTimeKind( + unique_key=["id"], + valid_from_name="dbt_valid_from", + valid_to_name="dbt_valid_to", + updated_at_as_valid_from=True, + updated_at_name="updated_at", + dialect="redshift", + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.ALLOW, + ) + def test_model_kind_snapshot_bigquery(): context = DbtContext()