Skip to content

Commit 78949d4

Browse files
pr feedback 2
1 parent aaeb854 commit 78949d4

File tree

3 files changed

+34
-21
lines changed

3 files changed

+34
-21
lines changed

sqlmesh/core/config/model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,3 @@ def _audits_validator(cls, v: t.Any) -> t.Any:
7878
return [extract_func_call(parse_one(audit)) for audit in v]
7979

8080
return v
81-
82-
@field_validator("pre_statements", "post_statements", "on_virtual_update", mode="before")
83-
def _statements_validator(cls, v: t.Any) -> t.Any:
84-
if isinstance(v, list):
85-
return [parse_one(stmt) if isinstance(stmt, str) else stmt for stmt in v]
86-
return v

sqlmesh/core/model/definition.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2466,21 +2466,21 @@ def _create_model(
24662466

24672467
# Merge default pre_statements with model-specific pre_statements
24682468
if "pre_statements" in defaults:
2469-
kwargs["pre_statements"] = list(defaults["pre_statements"]) + list(
2470-
kwargs.get("pre_statements", [])
2471-
)
2469+
kwargs["pre_statements"] = [
2470+
exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["pre_statements"]
2471+
] + kwargs.get("pre_statements", [])
24722472

24732473
# Merge default post_statements with model-specific post_statements
24742474
if "post_statements" in defaults:
2475-
kwargs["post_statements"] = list(defaults["post_statements"]) + list(
2476-
kwargs.get("post_statements", [])
2477-
)
2475+
kwargs["post_statements"] = [
2476+
exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["post_statements"]
2477+
] + kwargs.get("post_statements", [])
24782478

24792479
# Merge default on_virtual_update with model-specific on_virtual_update
24802480
if "on_virtual_update" in defaults:
2481-
kwargs["on_virtual_update"] = list(defaults["on_virtual_update"]) + list(
2482-
kwargs.get("on_virtual_update", [])
2483-
)
2481+
kwargs["on_virtual_update"] = [
2482+
exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["on_virtual_update"]
2483+
] + kwargs.get("on_virtual_update", [])
24842484

24852485
if "pre_statements" in kwargs:
24862486
statements.extend(kwargs["pre_statements"])

tests/core/test_config.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -702,18 +702,37 @@ def test_load_model_defaults_statements(tmp_path):
702702

703703
assert config.model_defaults.pre_statements is not None
704704
assert len(config.model_defaults.pre_statements) == 2
705-
assert isinstance(config.model_defaults.pre_statements[0], exp.Set)
706-
assert isinstance(config.model_defaults.pre_statements[1], exp.Create)
705+
assert isinstance(exp.maybe_parse(config.model_defaults.pre_statements[0]), exp.Set)
706+
assert isinstance(exp.maybe_parse(config.model_defaults.pre_statements[1]), exp.Create)
707707

708708
assert config.model_defaults.post_statements is not None
709709
assert len(config.model_defaults.post_statements) == 3
710-
assert isinstance(config.model_defaults.post_statements[0], exp.Drop)
711-
assert isinstance(config.model_defaults.post_statements[1], exp.Analyze)
712-
assert isinstance(config.model_defaults.post_statements[2], exp.Set)
710+
assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[0]), exp.Drop)
711+
assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[1]), exp.Analyze)
712+
assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[2]), exp.Set)
713713

714714
assert config.model_defaults.on_virtual_update is not None
715715
assert len(config.model_defaults.on_virtual_update) == 1
716-
assert isinstance(config.model_defaults.on_virtual_update[0], exp.Update)
716+
assert isinstance(exp.maybe_parse(config.model_defaults.on_virtual_update[0]), exp.Update)
717+
718+
719+
def test_load_model_defaults_validation_statements(tmp_path):
720+
config_path = tmp_path / "config_model_defaults_statements_wrong.yaml"
721+
with open(config_path, "w", encoding="utf-8") as fd:
722+
fd.write(
723+
"""
724+
model_defaults:
725+
dialect: duckdb
726+
pre_statements:
727+
- 313
728+
"""
729+
)
730+
731+
with pytest.raises(TypeError, match=r"expected str instance, int found"):
732+
config = load_config_from_paths(
733+
Config,
734+
project_paths=[config_path],
735+
)
717736

718737

719738
def test_scheduler_config(tmp_path_factory):

0 commit comments

Comments
 (0)