Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlmesh/core/test/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def _execute_model(self) -> pd.DataFrame:
with self._concurrent_render_context():
variables = self.body.get("vars", {}).copy()
time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables}
df = next(self.model.render(context=self.context, **time_kwargs, **variables))
df = next(self.model.render(context=self.context, variables=variables, **time_kwargs))

assert not isinstance(df, exp.Expression)
return df if isinstance(df, pd.DataFrame) else df.toPandas()
Expand Down
100 changes: 100 additions & 0 deletions tests/core/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3227,3 +3227,103 @@ def execute(
test.context.resolve_table("silver.sushi.bar")

_check_successful_or_raise(test.run())


def test_python_model_test_variables_override(tmp_path: Path) -> None:
py_model = tmp_path / "models" / "test_var_model.py"
py_model.parent.mkdir(parents=True, exist_ok=True)
py_model.write_text(
"""
import pandas as pd # noqa: TID253
from sqlmesh import model, ExecutionContext
import typing as t

@model(
name="test_var_model",
columns={"id": "int", "flag_value": "boolean", "var_value": "varchar"},
)
def execute(context: ExecutionContext, **kwargs: t.Any) -> pd.DataFrame:
my_flag = context.var("my_flag")
other_var = context.var("other_var")

return pd.DataFrame([{
"id": 1 if my_flag else 2,
"flag_value": my_flag,
"var_value": other_var,
}])"""
)

config = Config(
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
variables={"my_flag": False, "other_var": "default_value"},
)
context = Context(config=config, paths=tmp_path)

python_model = context.models['"test_var_model"']

# Test when Flag is True
# Overriding the config default flag_value to True
# AND the var_value to use test one
test_flag_true = _create_test(
body=load_yaml("""
test_flag_true:
model: test_var_model
vars:
my_flag: true
other_var: "test_value"
outputs:
query:
rows:
- id: 1
flag_value: true
var_value: "test_value"
"""),
test_name="test_flag_true",
model=python_model,
context=context,
)

_check_successful_or_raise(test_flag_true.run())

# Test when Flag is False
# Overriding the config default flag_value to False
# AND the var_value to use test one (since the above would be false for both)
test_flag_false = _create_test(
body=load_yaml("""
test_flag_false:
model: test_var_model
vars:
my_flag: false
other_var: "another_test_value"
outputs:
query:
rows:
- id: 2
flag_value: false
var_value: "another_test_value"
"""),
test_name="test_flag_false",
model=python_model,
context=context,
)

_check_successful_or_raise(test_flag_false.run())

# Test with no vars specified
# (should use config defaults for both flag and var_value)
test_default_vars = _create_test(
body=load_yaml("""
test_default_vars:
model: test_var_model
outputs:
query:
rows:
- id: 2
flag_value: false
var_value: "default_value"
"""),
test_name="test_default_vars",
model=python_model,
context=context,
)
_check_successful_or_raise(test_default_vars.run())