Skip to content

Commit 9a451d0

Browse files
Fix!: Add default audits in the model properly with their args
1 parent 49b5574 commit 9a451d0

File tree

6 files changed

+400
-22
lines changed

6 files changed

+400
-22
lines changed

sqlmesh/core/loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,6 @@ def _load_sql_models(
594594
macros=macros,
595595
jinja_macros=jinja_macros,
596596
audit_definitions=audits,
597-
default_audits=self.config.model_defaults.audits,
598597
module_path=self.config_path,
599598
dialect=self.config.model_defaults.dialect,
600599
time_column_format=self.config.time_column_format,

sqlmesh/core/model/definition.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
sorted_python_env_payloads,
3333
validate_extra_and_required_fields,
3434
)
35-
from sqlmesh.core.model.meta import ModelMeta, FunctionCall
35+
from sqlmesh.core.model.meta import ModelMeta
3636
from sqlmesh.core.model.kind import (
3737
ModelKindName,
3838
SeedKind,
@@ -2038,7 +2038,6 @@ def load_sql_based_model(
20382038
macros: t.Optional[MacroRegistry] = None,
20392039
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
20402040
audits: t.Optional[t.Dict[str, ModelAudit]] = None,
2041-
default_audits: t.Optional[t.List[FunctionCall]] = None,
20422041
python_env: t.Optional[t.Dict[str, Executable]] = None,
20432042
dialect: t.Optional[str] = None,
20442043
physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None,
@@ -2211,7 +2210,6 @@ def load_sql_based_model(
22112210
physical_schema_mapping=physical_schema_mapping,
22122211
default_catalog=default_catalog,
22132212
variables=variables,
2214-
default_audits=default_audits,
22152213
inline_audits=inline_audits,
22162214
blueprint_variables=blueprint_variables,
22172215
**meta_fields,
@@ -2431,7 +2429,6 @@ def _create_model(
24312429
physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None,
24322430
python_env: t.Optional[t.Dict[str, Executable]] = None,
24332431
audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None,
2434-
default_audits: t.Optional[t.List[FunctionCall]] = None,
24352432
inline_audits: t.Optional[t.Dict[str, ModelAudit]] = None,
24362433
module_path: Path = Path(),
24372434
macros: t.Optional[MacroRegistry] = None,
@@ -2541,6 +2538,8 @@ def _create_model(
25412538
for jinja_macro in jinja_macros.root_macros.values():
25422539
used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1])
25432540

2541+
default_audits = defaults.get("audits", None) if kwargs.get("audits") else None
2542+
25442543
model = klass(
25452544
name=name,
25462545
**{
@@ -2558,12 +2557,10 @@ def _create_model(
25582557
**(inline_audits or {}),
25592558
}
25602559

2561-
# TODO: default_audits needs to be merged with model.audits; the former's arguments
2562-
# are silently dropped today because we add them in audit_definitions. We also need
2563-
# to check for duplicates when we implement this merging logic.
2564-
used_audits: t.Set[str] = set()
2565-
used_audits.update(audit_name for audit_name, _ in default_audits or [])
2566-
used_audits.update(audit_name for audit_name, _ in model.audits)
2560+
if default_audits:
2561+
model = model.copy(update={"audits": default_audits + model.audits})
2562+
2563+
used_audits: t.Set[str] = {audit_name for audit_name, _ in model.audits}
25672564

25682565
audit_definitions = {
25692566
audit_name: audit_definitions[audit_name]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Include the model defaults audits along with their args in the model."""
2+
3+
4+
def migrate(state_sync, **kwargs): # type: ignore
5+
pass

tests/core/test_audit.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from sqlglot import exp, parse_one
44

55
from sqlmesh.core import constants as c
6+
from sqlmesh.core.config.model import ModelDefaultsConfig
67
from sqlmesh.core.context import Context
78
from sqlmesh.core.audit import (
89
ModelAudit,
@@ -962,6 +963,117 @@ def test_multiple_audits_with_same_name():
962963
assert model.audits[1][1] == model.audits[2][1]
963964

964965

966+
def test_default_audits_included_when_no_model_audits():
967+
expressions = parse("""
968+
MODEL (
969+
name test.basic_model
970+
);
971+
SELECT 1 as id, 'test' as name;
972+
""")
973+
974+
model_defaults = ModelDefaultsConfig(
975+
dialect="duckdb", audits=["not_null(columns := ['id'])", "unique_values(columns := ['id'])"]
976+
)
977+
model = load_sql_based_model(expressions, defaults=model_defaults.dict())
978+
979+
assert len(model.audits) == 2
980+
audit_names = [audit[0] for audit in model.audits]
981+
assert "not_null" in audit_names
982+
assert "unique_values" in audit_names
983+
984+
# Verify arguments are preserved
985+
for audit_name, audit_args in model.audits:
986+
if audit_name == "not_null":
987+
assert "columns" in audit_args
988+
assert audit_args["columns"].expressions[0].this == "id"
989+
elif audit_name == "unique_values":
990+
assert "columns" in audit_args
991+
assert audit_args["columns"].expressions[0].this == "id"
992+
993+
for audit_name, audit_args in model.audits_with_args:
994+
if audit_name == "not_null":
995+
assert "columns" in audit_args
996+
assert audit_args["columns"].expressions[0].this == "id"
997+
elif audit_name == "unique_values":
998+
assert "columns" in audit_args
999+
assert audit_args["columns"].expressions[0].this == "id"
1000+
1001+
1002+
def test_model_defaults_audits_with_same_name():
1003+
expressions = parse(
1004+
"""
1005+
MODEL (
1006+
name db.table,
1007+
dialect spark,
1008+
audits(
1009+
does_not_exceed_threshold(column := id, threshold := 1000),
1010+
does_not_exceed_threshold(column := price, threshold := 100),
1011+
unique_values(columns := ['id'])
1012+
)
1013+
);
1014+
1015+
SELECT id, price FROM tbl;
1016+
1017+
AUDIT (
1018+
name does_not_exceed_threshold,
1019+
);
1020+
SELECT * FROM @this_model
1021+
WHERE @column >= @threshold;
1022+
"""
1023+
)
1024+
1025+
model_defaults = ModelDefaultsConfig(
1026+
dialect="duckdb",
1027+
audits=[
1028+
"does_not_exceed_threshold(column := price, threshold := 33)",
1029+
"does_not_exceed_threshold(column := id, threshold := 65)",
1030+
"not_null(columns := ['id'])",
1031+
],
1032+
)
1033+
model = load_sql_based_model(expressions, defaults=model_defaults.dict())
1034+
assert len(model.audits) == 6
1035+
assert len(model.audits_with_args) == 6
1036+
assert len(model.audit_definitions) == 1
1037+
1038+
expected_audits = [
1039+
(
1040+
"does_not_exceed_threshold",
1041+
{"column": exp.column("price"), "threshold": exp.Literal.number(33)},
1042+
),
1043+
(
1044+
"does_not_exceed_threshold",
1045+
{"column": exp.column("id"), "threshold": exp.Literal.number(65)},
1046+
),
1047+
("not_null", {"columns": exp.convert(["id"])}),
1048+
(
1049+
"does_not_exceed_threshold",
1050+
{"column": exp.column("id"), "threshold": exp.Literal.number(1000)},
1051+
),
1052+
(
1053+
"does_not_exceed_threshold",
1054+
{"column": exp.column("price"), "threshold": exp.Literal.number(100)},
1055+
),
1056+
("unique_values", {"columns": exp.convert(["id"])}),
1057+
]
1058+
1059+
for (actual_name, actual_args), (expected_name, expected_args) in zip(
1060+
model.audits, expected_audits
1061+
):
1062+
# Validate the audit names are preserved
1063+
assert actual_name == expected_name
1064+
for key in expected_args:
1065+
# comparing sql representaion is easier
1066+
assert actual_args[key].sql() == expected_args[key].sql()
1067+
1068+
# Validate audits with args as well along with their arguments
1069+
for (actual_audit, actual_args), (expected_name, expected_args) in zip(
1070+
model.audits_with_args, expected_audits
1071+
):
1072+
assert actual_audit.name == expected_name
1073+
for key in expected_args:
1074+
assert actual_args[key].sql() == expected_args[key].sql()
1075+
1076+
9651077
def test_audit_formatting_flag_serde():
9661078
expressions = parse(
9671079
"""

0 commit comments

Comments
 (0)