Skip to content

Commit 55f59ce

Browse files
pr feedback 2
1 parent 4497de2 commit 55f59ce

File tree

3 files changed

+34
-21
lines changed

3 files changed

+34
-21
lines changed

sqlmesh/core/config/model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,3 @@ def _audits_validator(cls, v: t.Any) -> t.Any:
7878
return [extract_func_call(parse_one(audit)) for audit in v]
7979

8080
return v
81-
82-
@field_validator("pre_statements", "post_statements", "on_virtual_update", mode="before")
83-
def _statements_validator(cls, v: t.Any) -> t.Any:
84-
if isinstance(v, list):
85-
return [parse_one(stmt) if isinstance(stmt, str) else stmt for stmt in v]
86-
return v

sqlmesh/core/model/definition.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2474,21 +2474,21 @@ def _create_model(
24742474

24752475
# Merge default pre_statements with model-specific pre_statements
24762476
if "pre_statements" in defaults:
2477-
kwargs["pre_statements"] = list(defaults["pre_statements"]) + list(
2478-
kwargs.get("pre_statements", [])
2479-
)
2477+
kwargs["pre_statements"] = [
2478+
exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["pre_statements"]
2479+
] + kwargs.get("pre_statements", [])
24802480

24812481
# Merge default post_statements with model-specific post_statements
24822482
if "post_statements" in defaults:
2483-
kwargs["post_statements"] = list(defaults["post_statements"]) + list(
2484-
kwargs.get("post_statements", [])
2485-
)
2483+
kwargs["post_statements"] = [
2484+
exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["post_statements"]
2485+
] + kwargs.get("post_statements", [])
24862486

24872487
# Merge default on_virtual_update with model-specific on_virtual_update
24882488
if "on_virtual_update" in defaults:
2489-
kwargs["on_virtual_update"] = list(defaults["on_virtual_update"]) + list(
2490-
kwargs.get("on_virtual_update", [])
2491-
)
2489+
kwargs["on_virtual_update"] = [
2490+
exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults["on_virtual_update"]
2491+
] + kwargs.get("on_virtual_update", [])
24922492

24932493
if "pre_statements" in kwargs:
24942494
statements.extend(kwargs["pre_statements"])

tests/core/test_config.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -702,18 +702,37 @@ def test_load_model_defaults_statements(tmp_path):
702702

703703
assert config.model_defaults.pre_statements is not None
704704
assert len(config.model_defaults.pre_statements) == 2
705-
assert isinstance(config.model_defaults.pre_statements[0], exp.Set)
706-
assert isinstance(config.model_defaults.pre_statements[1], exp.Create)
705+
assert isinstance(exp.maybe_parse(config.model_defaults.pre_statements[0]), exp.Set)
706+
assert isinstance(exp.maybe_parse(config.model_defaults.pre_statements[1]), exp.Create)
707707

708708
assert config.model_defaults.post_statements is not None
709709
assert len(config.model_defaults.post_statements) == 3
710-
assert isinstance(config.model_defaults.post_statements[0], exp.Drop)
711-
assert isinstance(config.model_defaults.post_statements[1], exp.Analyze)
712-
assert isinstance(config.model_defaults.post_statements[2], exp.Set)
710+
assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[0]), exp.Drop)
711+
assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[1]), exp.Analyze)
712+
assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[2]), exp.Set)
713713

714714
assert config.model_defaults.on_virtual_update is not None
715715
assert len(config.model_defaults.on_virtual_update) == 1
716-
assert isinstance(config.model_defaults.on_virtual_update[0], exp.Update)
716+
assert isinstance(exp.maybe_parse(config.model_defaults.on_virtual_update[0]), exp.Update)
717+
718+
719+
def test_load_model_defaults_validation_statements(tmp_path):
720+
config_path = tmp_path / "config_model_defaults_statements_wrong.yaml"
721+
with open(config_path, "w", encoding="utf-8") as fd:
722+
fd.write(
723+
"""
724+
model_defaults:
725+
dialect: duckdb
726+
pre_statements:
727+
- 313
728+
"""
729+
)
730+
731+
with pytest.raises(TypeError, match=r"expected str instance, int found"):
732+
config = load_config_from_paths(
733+
Config,
734+
project_paths=[config_path],
735+
)
717736

718737

719738
def test_scheduler_config(tmp_path_factory):

0 commit comments

Comments
 (0)