Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions sqlmesh/dbt/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,58 @@ def has_var(self, name: str) -> bool:
return name in self.variables


class Config:
def __init__(self, config_dict: t.Dict[str, t.Any]) -> None:
self._config = config_dict

def __call__(self, **kwargs: t.Any) -> str:
self._config.update(**kwargs)
return ""

def set(self, name: str, value: t.Any) -> str:
self._config.update({name: value})
return ""

def _validate(self, name: str, validator: t.Callable, value: t.Optional[t.Any] = None) -> None:
try:
validator(value)
except Exception as e:
raise ConfigError(f"Config validation failed for '{name}': {e}")

def require(self, name: str, validator: t.Optional[t.Callable] = None) -> t.Any:
if name not in self._config:
raise ConfigError(f"Missing required config: {name}")

value = self._config[name]

if validator is not None:
self._validate(name, validator, value)

return value

def get(
self, name: str, default: t.Any = None, validator: t.Optional[t.Callable] = None
) -> t.Any:
value = self._config.get(name, default)

if validator is not None and value is not None:
self._validate(name, validator, value)

return value

def persist_relation_docs(self) -> bool:
persist_docs = self.get("persist_docs", default={})
if not isinstance(persist_docs, dict):
return False
return persist_docs.get("relation", False)

def persist_column_docs(self) -> bool:
persist_docs = self.get("persist_docs", default={})
if not isinstance(persist_docs, dict):
return False
return persist_docs.get("columns", False)


def env_var(name: str, default: t.Optional[str] = None) -> t.Optional[str]:
if name not in os.environ and default is None:
raise ConfigError(f"Missing environment variable '{name}'")
Expand Down Expand Up @@ -395,6 +447,8 @@ def create_builtin_globals(
if variables is not None:
builtin_globals["var"] = Var(variables)

builtin_globals["config"] = Config(jinja_globals.pop("config", {}))

deployability_index = (
jinja_globals.get("deployability_index") or DeployabilityIndex.all_deployable()
)
Expand Down
116 changes: 115 additions & 1 deletion tests/dbt/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def test_schema_jinja(sushi_test_project: Project, assert_exp_eq):

@pytest.mark.xdist_group("dbt_manifest")
def test_config_jinja(sushi_test_project: Project):
hook = "{{ config(alias='bar') }} {{ config.alias }}"
hook = "{{ config(alias='bar') }} {{ config.get('alias') }}"
model_config = ModelConfig(
name="model",
package_name="package",
Expand All @@ -961,6 +961,120 @@ def test_config_jinja(sushi_test_project: Project):
assert model.render_pre_statements()[0].sql() == '"bar"'


@pytest.mark.xdist_group("dbt_manifest")
def test_config_jinja_get_methods(sushi_test_project: Project):
model_config = ModelConfig(
name="model_conf",
package_name="package",
schema="sushi",
sql="""SELECT 1 AS one FROM foo""",
alias="model_alias",
**{
"pre-hook": [
"{{ config(materialized='incremental', unique_key='id') }}"
"{{ config.get('missed', 'a') + config.get('missed', default='b')}}",
"{{ config.set('alias', 'new_alias')}}",
"{{ config.get('package_name') + '_' + config.require('unique_key')}}",
"{{ config.get('alias') or 'default'}}",
]
},
**{"post-hook": "{{config.require('missing_key')}}"},
)
context = sushi_test_project.context
model = t.cast(SqlModel, model_config.to_sqlmesh(context))

assert model.render_pre_statements()[0].sql() == '"ab"'
assert model.render_pre_statements()[1].sql() == '"package_id"'
assert model.render_pre_statements()[2].sql() == '"new_alias"'

with pytest.raises(ConfigError, match="Missing required config: missing_key"):
model.render_post_statements()

# test get methods with operations
model_2_config = ModelConfig(
name="model_2",
package_name="package",
schema="sushi",
sql="""SELECT 1 AS one FROM foo""",
alias="mod",
materialized="table",
threads=8,
partition_by="date",
cluster_by=["user_id", "product_id"],
**{
"pre-hook": [
"{{ config.get('partition_by', default='none') }}",
"{{ config.get('cluster_by', default=[]) | length }}",
"{% if config.get('threads') > 4 %}high_threads{% else %}low_threads{% endif %}",
]
},
)
model2 = t.cast(SqlModel, model_2_config.to_sqlmesh(context))

pre_statements2 = model2.render_pre_statements()
assert pre_statements2[0].sql() == "ARRAY('date')"
assert pre_statements2[1].sql() == "2"
assert pre_statements2[2].sql() == '"high_threads"'

# test seting variable and conditional
model_invalid_timeout = ModelConfig(
name="invalid_timeout_test",
package_name="package",
schema="sushi",
sql="""SELECT 1 AS one FROM foo""",
alias="invalid_timeout_alias",
connection_timeout=44,
**{
"pre-hook": [
"""
{%- set value = config.require('connection_timeout') -%}
{%- set is_valid = value >= 10 and value <= 30 -%}
{%- if not is_valid -%}
{{ exceptions.raise_compiler_error("Validation failed for 'connection_timeout': Value must be between 10 and 30, got: " ~ value) }}
{%- endif -%}
{{ value }}
""",
]
},
)

model_invalid = t.cast(SqlModel, model_invalid_timeout.to_sqlmesh(context))
with pytest.raises(
ConfigError,
match="Validation failed for 'connection_timeout': Value must be between 10 and 30, got: 44",
):
model_invalid.render_pre_statements()

# test persist_docs methods
model_config_persist = ModelConfig(
name="persist_docs_model",
package_name="package",
schema="sushi",
sql="""SELECT 1 AS one FROM foo""",
alias="persist_alias",
**{
"pre-hook": [
"{{ config(persist_docs={'relation': true, 'columns': true}) }}",
"{{ config.persist_relation_docs() }}",
"{{ config.persist_column_docs() }}",
"{{ config(persist_docs={'relation': false, 'columns': true}) }}",
"{{ config.persist_relation_docs() }}",
"{{ config.persist_column_docs() }}",
]
},
)
model3 = t.cast(SqlModel, model_config_persist.to_sqlmesh(context))

pre_statements3 = model3.render_pre_statements()

# it should filter out empty returns, so we get 4 statements
assert len(pre_statements3) == 4
assert pre_statements3[0].sql() == "TRUE"
assert pre_statements3[1].sql() == "TRUE"
assert pre_statements3[2].sql() == "FALSE"
assert pre_statements3[3].sql() == "TRUE"


@pytest.mark.xdist_group("dbt_manifest")
def test_model_this(assert_exp_eq, sushi_test_project: Project):
model_config = ModelConfig(
Expand Down