Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions sqlmesh/core/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,41 @@ def extract_func_call(
return func.lower(), kwargs


def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.Any:
"""Used for extracting function calls for signals or audits."""

if isinstance(func_calls, (exp.Tuple, exp.Array)):
return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions]
if isinstance(func_calls, exp.Paren):
return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)]
if isinstance(func_calls, exp.Expression):
return [extract_func_call(func_calls, allow_tuples=allow_tuples)]
if isinstance(func_calls, list):
function_calls = []
for entry in func_calls:
if isinstance(entry, dict):
args = entry
name = "" if allow_tuples else entry.pop("name")
elif isinstance(entry, (tuple, list)):
name, args = entry
else:
raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.")

function_calls.append(
(
name.lower(),
{
key: parse_one(value) if isinstance(value, str) else value
for key, value in args.items()
},
)
)

return function_calls

return func_calls or []


def is_meta_expression(v: t.Any) -> bool:
return isinstance(v, (Audit, Metric, Model))

Expand Down
1 change: 0 additions & 1 deletion sqlmesh/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,6 @@ def _load_sql_models(
macros=macros,
jinja_macros=jinja_macros,
audit_definitions=audits,
default_audits=self.config.model_defaults.audits,
module_path=self.config_path,
dialect=self.config.model_defaults.dialect,
time_column_format=self.config.time_column_format,
Expand Down
16 changes: 6 additions & 10 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
sorted_python_env_payloads,
validate_extra_and_required_fields,
)
from sqlmesh.core.model.meta import ModelMeta, FunctionCall
from sqlmesh.core.model.meta import ModelMeta
from sqlmesh.core.model.kind import (
ModelKindName,
SeedKind,
Expand Down Expand Up @@ -2038,7 +2038,6 @@ def load_sql_based_model(
macros: t.Optional[MacroRegistry] = None,
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
audits: t.Optional[t.Dict[str, ModelAudit]] = None,
default_audits: t.Optional[t.List[FunctionCall]] = None,
python_env: t.Optional[t.Dict[str, Executable]] = None,
dialect: t.Optional[str] = None,
physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None,
Expand Down Expand Up @@ -2211,7 +2210,6 @@ def load_sql_based_model(
physical_schema_mapping=physical_schema_mapping,
default_catalog=default_catalog,
variables=variables,
default_audits=default_audits,
inline_audits=inline_audits,
blueprint_variables=blueprint_variables,
**meta_fields,
Expand Down Expand Up @@ -2431,7 +2429,6 @@ def _create_model(
physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None,
python_env: t.Optional[t.Dict[str, Executable]] = None,
audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None,
default_audits: t.Optional[t.List[FunctionCall]] = None,
inline_audits: t.Optional[t.Dict[str, ModelAudit]] = None,
module_path: Path = Path(),
macros: t.Optional[MacroRegistry] = None,
Expand Down Expand Up @@ -2541,6 +2538,10 @@ def _create_model(
for jinja_macro in jinja_macros.root_macros.values():
used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1])

# Merge model-specific audits with default audits
if default_audits := defaults.pop("audits", None):
kwargs["audits"] = default_audits + d.extract_function_calls(kwargs.pop("audits", []))

model = klass(
name=name,
**{
Expand All @@ -2558,12 +2559,7 @@ def _create_model(
**(inline_audits or {}),
}

# TODO: default_audits needs to be merged with model.audits; the former's arguments
# are silently dropped today because we add them in audit_definitions. We also need
# to check for duplicates when we implement this merging logic.
used_audits: t.Set[str] = set()
used_audits.update(audit_name for audit_name, _ in default_audits or [])
used_audits.update(audit_name for audit_name, _ in model.audits)
used_audits: t.Set[str] = {audit_name for audit_name, _ in model.audits}

audit_definitions = {
audit_name: audit_definitions[audit_name]
Expand Down
34 changes: 2 additions & 32 deletions sqlmesh/core/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from sqlmesh.core import dialect as d
from sqlmesh.core.config.linter import LinterConfig
from sqlmesh.core.dialect import normalize_model_name, extract_func_call
from sqlmesh.core.dialect import normalize_model_name
from sqlmesh.core.model.common import (
bool_validator,
default_catalog_validator,
Expand Down Expand Up @@ -94,37 +94,7 @@ class ModelMeta(_Node):
def _func_call_validator(cls, v: t.Any, field: t.Any) -> t.Any:
is_signal = getattr(field, "name" if hasattr(field, "name") else "field_name") == "signals"

if isinstance(v, (exp.Tuple, exp.Array)):
return [extract_func_call(i, allow_tuples=is_signal) for i in v.expressions]
if isinstance(v, exp.Paren):
return [extract_func_call(v.this, allow_tuples=is_signal)]
if isinstance(v, exp.Expression):
return [extract_func_call(v, allow_tuples=is_signal)]
if isinstance(v, list):
audits = []

for entry in v:
if isinstance(entry, dict):
args = entry
name = "" if is_signal else entry.pop("name")
elif isinstance(entry, (tuple, list)):
name, args = entry
else:
raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.")

audits.append(
(
name.lower(),
{
key: d.parse_one(value) if isinstance(value, str) else value
for key, value in args.items()
},
)
)

return audits

return v or []
return d.extract_function_calls(v, allow_tuples=is_signal)

@field_validator("tags", mode="before")
def _value_or_tuple_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:
Expand Down
112 changes: 112 additions & 0 deletions tests/core/test_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sqlglot import exp, parse_one

from sqlmesh.core import constants as c
from sqlmesh.core.config.model import ModelDefaultsConfig
from sqlmesh.core.context import Context
from sqlmesh.core.audit import (
ModelAudit,
Expand Down Expand Up @@ -962,6 +963,117 @@ def test_multiple_audits_with_same_name():
assert model.audits[1][1] == model.audits[2][1]


def test_default_audits_included_when_no_model_audits():
expressions = parse("""
MODEL (
name test.basic_model
);
SELECT 1 as id, 'test' as name;
""")

model_defaults = ModelDefaultsConfig(
dialect="duckdb", audits=["not_null(columns := ['id'])", "unique_values(columns := ['id'])"]
)
model = load_sql_based_model(expressions, defaults=model_defaults.dict())

assert len(model.audits) == 2
audit_names = [audit[0] for audit in model.audits]
assert "not_null" in audit_names
assert "unique_values" in audit_names

# Verify arguments are preserved
for audit_name, audit_args in model.audits:
if audit_name == "not_null":
assert "columns" in audit_args
assert audit_args["columns"].expressions[0].this == "id"
elif audit_name == "unique_values":
assert "columns" in audit_args
assert audit_args["columns"].expressions[0].this == "id"

for audit_name, audit_args in model.audits_with_args:
if audit_name == "not_null":
assert "columns" in audit_args
assert audit_args["columns"].expressions[0].this == "id"
elif audit_name == "unique_values":
assert "columns" in audit_args
assert audit_args["columns"].expressions[0].this == "id"


def test_model_defaults_audits_with_same_name():
expressions = parse(
"""
MODEL (
name db.table,
dialect spark,
audits(
does_not_exceed_threshold(column := id, threshold := 1000),
does_not_exceed_threshold(column := price, threshold := 100),
unique_values(columns := ['id'])
)
);

SELECT id, price FROM tbl;

AUDIT (
name does_not_exceed_threshold,
);
SELECT * FROM @this_model
WHERE @column >= @threshold;
"""
)

model_defaults = ModelDefaultsConfig(
dialect="duckdb",
audits=[
"does_not_exceed_threshold(column := price, threshold := 33)",
"does_not_exceed_threshold(column := id, threshold := 65)",
"not_null(columns := ['id'])",
],
)
model = load_sql_based_model(expressions, defaults=model_defaults.dict())
assert len(model.audits) == 6
assert len(model.audits_with_args) == 6
assert len(model.audit_definitions) == 1

expected_audits = [
(
"does_not_exceed_threshold",
{"column": exp.column("price"), "threshold": exp.Literal.number(33)},
),
(
"does_not_exceed_threshold",
{"column": exp.column("id"), "threshold": exp.Literal.number(65)},
),
("not_null", {"columns": exp.convert(["id"])}),
(
"does_not_exceed_threshold",
{"column": exp.column("id"), "threshold": exp.Literal.number(1000)},
),
(
"does_not_exceed_threshold",
{"column": exp.column("price"), "threshold": exp.Literal.number(100)},
),
("unique_values", {"columns": exp.convert(["id"])}),
]

for (actual_name, actual_args), (expected_name, expected_args) in zip(
model.audits, expected_audits
):
# Validate the audit names are preserved
assert actual_name == expected_name
for key in expected_args:
# comparing sql representaion is easier
assert actual_args[key].sql() == expected_args[key].sql()

# Validate audits with args as well along with their arguments
for (actual_audit, actual_args), (expected_name, expected_args) in zip(
model.audits_with_args, expected_audits
):
assert actual_audit.name == expected_name
for key in expected_args:
assert actual_args[key].sql() == expected_args[key].sql()


def test_audit_formatting_flag_serde():
expressions = parse(
"""
Expand Down
Loading