diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 4edfea687a..24669807bb 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -164,6 +164,58 @@ def has_var(self, name: str) -> bool: return name in self.variables +class Config: + def __init__(self, config_dict: t.Dict[str, t.Any]) -> None: + self._config = config_dict + + def __call__(self, **kwargs: t.Any) -> str: + self._config.update(**kwargs) + return "" + + def set(self, name: str, value: t.Any) -> str: + self._config.update({name: value}) + return "" + + def _validate(self, name: str, validator: t.Callable, value: t.Optional[t.Any] = None) -> None: + try: + validator(value) + except Exception as e: + raise ConfigError(f"Config validation failed for '{name}': {e}") + + def require(self, name: str, validator: t.Optional[t.Callable] = None) -> t.Any: + if name not in self._config: + raise ConfigError(f"Missing required config: {name}") + + value = self._config[name] + + if validator is not None: + self._validate(name, validator, value) + + return value + + def get( + self, name: str, default: t.Any = None, validator: t.Optional[t.Callable] = None + ) -> t.Any: + value = self._config.get(name, default) + + if validator is not None and value is not None: + self._validate(name, validator, value) + + return value + + def persist_relation_docs(self) -> bool: + persist_docs = self.get("persist_docs", default={}) + if not isinstance(persist_docs, dict): + return False + return persist_docs.get("relation", False) + + def persist_column_docs(self) -> bool: + persist_docs = self.get("persist_docs", default={}) + if not isinstance(persist_docs, dict): + return False + return persist_docs.get("columns", False) + + def env_var(name: str, default: t.Optional[str] = None) -> t.Optional[str]: if name not in os.environ and default is None: raise ConfigError(f"Missing environment variable '{name}'") @@ -395,6 +447,8 @@ def create_builtin_globals( if variables is not None: builtin_globals["var"] = Var(variables) + builtin_globals["config"] = Config(jinja_globals.pop("config", {})) + deployability_index = ( jinja_globals.get("deployability_index") or DeployabilityIndex.all_deployable() ) diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 779160c27d..e81fcbe862 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -946,7 +946,7 @@ def test_schema_jinja(sushi_test_project: Project, assert_exp_eq): @pytest.mark.xdist_group("dbt_manifest") def test_config_jinja(sushi_test_project: Project): - hook = "{{ config(alias='bar') }} {{ config.alias }}" + hook = "{{ config(alias='bar') }} {{ config.get('alias') }}" model_config = ModelConfig( name="model", package_name="package", @@ -961,6 +961,120 @@ def test_config_jinja(sushi_test_project: Project): assert model.render_pre_statements()[0].sql() == '"bar"' +@pytest.mark.xdist_group("dbt_manifest") +def test_config_jinja_get_methods(sushi_test_project: Project): + model_config = ModelConfig( + name="model_conf", + package_name="package", + schema="sushi", + sql="""SELECT 1 AS one FROM foo""", + alias="model_alias", + **{ + "pre-hook": [ + "{{ config(materialized='incremental', unique_key='id') }}" + "{{ config.get('missed', 'a') + config.get('missed', default='b')}}", + "{{ config.set('alias', 'new_alias')}}", + "{{ config.get('package_name') + '_' + config.require('unique_key')}}", + "{{ config.get('alias') or 'default'}}", + ] + }, + **{"post-hook": "{{config.require('missing_key')}}"}, + ) + context = sushi_test_project.context + model = t.cast(SqlModel, model_config.to_sqlmesh(context)) + + assert model.render_pre_statements()[0].sql() == '"ab"' + assert model.render_pre_statements()[1].sql() == '"package_id"' + assert model.render_pre_statements()[2].sql() == '"new_alias"' + + with pytest.raises(ConfigError, match="Missing required config: missing_key"): + model.render_post_statements() + + # test get methods with operations + model_2_config = ModelConfig( + name="model_2", + package_name="package", + schema="sushi", + sql="""SELECT 1 AS one FROM foo""", + alias="mod", + materialized="table", + threads=8, + partition_by="date", + cluster_by=["user_id", "product_id"], + **{ + "pre-hook": [ + "{{ config.get('partition_by', default='none') }}", + "{{ config.get('cluster_by', default=[]) | length }}", + "{% if config.get('threads') > 4 %}high_threads{% else %}low_threads{% endif %}", + ] + }, + ) + model2 = t.cast(SqlModel, model_2_config.to_sqlmesh(context)) + + pre_statements2 = model2.render_pre_statements() + assert pre_statements2[0].sql() == "ARRAY('date')" + assert pre_statements2[1].sql() == "2" + assert pre_statements2[2].sql() == '"high_threads"' + + # test seting variable and conditional + model_invalid_timeout = ModelConfig( + name="invalid_timeout_test", + package_name="package", + schema="sushi", + sql="""SELECT 1 AS one FROM foo""", + alias="invalid_timeout_alias", + connection_timeout=44, + **{ + "pre-hook": [ + """ + {%- set value = config.require('connection_timeout') -%} + {%- set is_valid = value >= 10 and value <= 30 -%} + {%- if not is_valid -%} + {{ exceptions.raise_compiler_error("Validation failed for 'connection_timeout': Value must be between 10 and 30, got: " ~ value) }} + {%- endif -%} + {{ value }} + """, + ] + }, + ) + + model_invalid = t.cast(SqlModel, model_invalid_timeout.to_sqlmesh(context)) + with pytest.raises( + ConfigError, + match="Validation failed for 'connection_timeout': Value must be between 10 and 30, got: 44", + ): + model_invalid.render_pre_statements() + + # test persist_docs methods + model_config_persist = ModelConfig( + name="persist_docs_model", + package_name="package", + schema="sushi", + sql="""SELECT 1 AS one FROM foo""", + alias="persist_alias", + **{ + "pre-hook": [ + "{{ config(persist_docs={'relation': true, 'columns': true}) }}", + "{{ config.persist_relation_docs() }}", + "{{ config.persist_column_docs() }}", + "{{ config(persist_docs={'relation': false, 'columns': true}) }}", + "{{ config.persist_relation_docs() }}", + "{{ config.persist_column_docs() }}", + ] + }, + ) + model3 = t.cast(SqlModel, model_config_persist.to_sqlmesh(context)) + + pre_statements3 = model3.render_pre_statements() + + # it should filter out empty returns, so we get 4 statements + assert len(pre_statements3) == 4 + assert pre_statements3[0].sql() == "TRUE" + assert pre_statements3[1].sql() == "TRUE" + assert pre_statements3[2].sql() == "FALSE" + assert pre_statements3[3].sql() == "TRUE" + + @pytest.mark.xdist_group("dbt_manifest") def test_model_this(assert_exp_eq, sushi_test_project: Project): model_config = ModelConfig(