diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 37568f2b27..568d9f5f73 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -1408,6 +1408,41 @@ def extract_func_call( return func.lower(), kwargs +def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.Any: + """Used for extracting function calls for signals or audits.""" + + if isinstance(func_calls, (exp.Tuple, exp.Array)): + return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions] + if isinstance(func_calls, exp.Paren): + return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)] + if isinstance(func_calls, exp.Expression): + return [extract_func_call(func_calls, allow_tuples=allow_tuples)] + if isinstance(func_calls, list): + function_calls = [] + for entry in func_calls: + if isinstance(entry, dict): + args = entry + name = "" if allow_tuples else entry.pop("name") + elif isinstance(entry, (tuple, list)): + name, args = entry + else: + raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.") + + function_calls.append( + ( + name.lower(), + { + key: parse_one(value) if isinstance(value, str) else value + for key, value in args.items() + }, + ) + ) + + return function_calls + + return func_calls or [] + + def is_meta_expression(v: t.Any) -> bool: return isinstance(v, (Audit, Metric, Model)) diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index f4c9147d1b..b593da1ad0 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -594,7 +594,6 @@ def _load_sql_models( macros=macros, jinja_macros=jinja_macros, audit_definitions=audits, - default_audits=self.config.model_defaults.audits, module_path=self.config_path, dialect=self.config.model_defaults.dialect, time_column_format=self.config.time_column_format, diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 4485875df8..954b03eff8 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -32,7 +32,7 @@ sorted_python_env_payloads, validate_extra_and_required_fields, ) -from sqlmesh.core.model.meta import ModelMeta, FunctionCall +from sqlmesh.core.model.meta import ModelMeta from sqlmesh.core.model.kind import ( ModelKindName, SeedKind, @@ -2038,7 +2038,6 @@ def load_sql_based_model( macros: t.Optional[MacroRegistry] = None, jinja_macros: t.Optional[JinjaMacroRegistry] = None, audits: t.Optional[t.Dict[str, ModelAudit]] = None, - default_audits: t.Optional[t.List[FunctionCall]] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, dialect: t.Optional[str] = None, physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, @@ -2211,7 +2210,6 @@ def load_sql_based_model( physical_schema_mapping=physical_schema_mapping, default_catalog=default_catalog, variables=variables, - default_audits=default_audits, inline_audits=inline_audits, blueprint_variables=blueprint_variables, **meta_fields, @@ -2431,7 +2429,6 @@ def _create_model( physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, - default_audits: t.Optional[t.List[FunctionCall]] = None, inline_audits: t.Optional[t.Dict[str, ModelAudit]] = None, module_path: Path = Path(), macros: t.Optional[MacroRegistry] = None, @@ -2541,6 +2538,10 @@ def _create_model( for jinja_macro in jinja_macros.root_macros.values(): used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1]) + # Merge model-specific audits with default audits + if default_audits := defaults.pop("audits", None): + kwargs["audits"] = default_audits + d.extract_function_calls(kwargs.pop("audits", [])) + model = klass( name=name, **{ @@ -2558,12 +2559,7 @@ def _create_model( **(inline_audits or {}), } - # TODO: default_audits needs to be merged with model.audits; the former's arguments - # are silently dropped today because we add them in audit_definitions. We also need - # to check for duplicates when we implement this merging logic. - used_audits: t.Set[str] = set() - used_audits.update(audit_name for audit_name, _ in default_audits or []) - used_audits.update(audit_name for audit_name, _ in model.audits) + used_audits: t.Set[str] = {audit_name for audit_name, _ in model.audits} audit_definitions = { audit_name: audit_definitions[audit_name] diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index 585bb15a6c..b5371ab811 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -11,7 +11,7 @@ from sqlmesh.core import dialect as d from sqlmesh.core.config.linter import LinterConfig -from sqlmesh.core.dialect import normalize_model_name, extract_func_call +from sqlmesh.core.dialect import normalize_model_name from sqlmesh.core.model.common import ( bool_validator, default_catalog_validator, @@ -94,37 +94,7 @@ class ModelMeta(_Node): def _func_call_validator(cls, v: t.Any, field: t.Any) -> t.Any: is_signal = getattr(field, "name" if hasattr(field, "name") else "field_name") == "signals" - if isinstance(v, (exp.Tuple, exp.Array)): - return [extract_func_call(i, allow_tuples=is_signal) for i in v.expressions] - if isinstance(v, exp.Paren): - return [extract_func_call(v.this, allow_tuples=is_signal)] - if isinstance(v, exp.Expression): - return [extract_func_call(v, allow_tuples=is_signal)] - if isinstance(v, list): - audits = [] - - for entry in v: - if isinstance(entry, dict): - args = entry - name = "" if is_signal else entry.pop("name") - elif isinstance(entry, (tuple, list)): - name, args = entry - else: - raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.") - - audits.append( - ( - name.lower(), - { - key: d.parse_one(value) if isinstance(value, str) else value - for key, value in args.items() - }, - ) - ) - - return audits - - return v or [] + return d.extract_function_calls(v, allow_tuples=is_signal) @field_validator("tags", mode="before") def _value_or_tuple_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any: diff --git a/tests/core/test_audit.py b/tests/core/test_audit.py index da8145ba87..81335e5f1a 100644 --- a/tests/core/test_audit.py +++ b/tests/core/test_audit.py @@ -3,6 +3,7 @@ from sqlglot import exp, parse_one from sqlmesh.core import constants as c +from sqlmesh.core.config.model import ModelDefaultsConfig from sqlmesh.core.context import Context from sqlmesh.core.audit import ( ModelAudit, @@ -962,6 +963,117 @@ def test_multiple_audits_with_same_name(): assert model.audits[1][1] == model.audits[2][1] +def test_default_audits_included_when_no_model_audits(): + expressions = parse(""" + MODEL ( + name test.basic_model + ); + SELECT 1 as id, 'test' as name; + """) + + model_defaults = ModelDefaultsConfig( + dialect="duckdb", audits=["not_null(columns := ['id'])", "unique_values(columns := ['id'])"] + ) + model = load_sql_based_model(expressions, defaults=model_defaults.dict()) + + assert len(model.audits) == 2 + audit_names = [audit[0] for audit in model.audits] + assert "not_null" in audit_names + assert "unique_values" in audit_names + + # Verify arguments are preserved + for audit_name, audit_args in model.audits: + if audit_name == "not_null": + assert "columns" in audit_args + assert audit_args["columns"].expressions[0].this == "id" + elif audit_name == "unique_values": + assert "columns" in audit_args + assert audit_args["columns"].expressions[0].this == "id" + + for audit_name, audit_args in model.audits_with_args: + if audit_name == "not_null": + assert "columns" in audit_args + assert audit_args["columns"].expressions[0].this == "id" + elif audit_name == "unique_values": + assert "columns" in audit_args + assert audit_args["columns"].expressions[0].this == "id" + + +def test_model_defaults_audits_with_same_name(): + expressions = parse( + """ + MODEL ( + name db.table, + dialect spark, + audits( + does_not_exceed_threshold(column := id, threshold := 1000), + does_not_exceed_threshold(column := price, threshold := 100), + unique_values(columns := ['id']) + ) + ); + + SELECT id, price FROM tbl; + + AUDIT ( + name does_not_exceed_threshold, + ); + SELECT * FROM @this_model + WHERE @column >= @threshold; + """ + ) + + model_defaults = ModelDefaultsConfig( + dialect="duckdb", + audits=[ + "does_not_exceed_threshold(column := price, threshold := 33)", + "does_not_exceed_threshold(column := id, threshold := 65)", + "not_null(columns := ['id'])", + ], + ) + model = load_sql_based_model(expressions, defaults=model_defaults.dict()) + assert len(model.audits) == 6 + assert len(model.audits_with_args) == 6 + assert len(model.audit_definitions) == 1 + + expected_audits = [ + ( + "does_not_exceed_threshold", + {"column": exp.column("price"), "threshold": exp.Literal.number(33)}, + ), + ( + "does_not_exceed_threshold", + {"column": exp.column("id"), "threshold": exp.Literal.number(65)}, + ), + ("not_null", {"columns": exp.convert(["id"])}), + ( + "does_not_exceed_threshold", + {"column": exp.column("id"), "threshold": exp.Literal.number(1000)}, + ), + ( + "does_not_exceed_threshold", + {"column": exp.column("price"), "threshold": exp.Literal.number(100)}, + ), + ("unique_values", {"columns": exp.convert(["id"])}), + ] + + for (actual_name, actual_args), (expected_name, expected_args) in zip( + model.audits, expected_audits + ): + # Validate the audit names are preserved + assert actual_name == expected_name + for key in expected_args: + # comparing sql representaion is easier + assert actual_args[key].sql() == expected_args[key].sql() + + # Validate audits with args as well along with their arguments + for (actual_audit, actual_args), (expected_name, expected_args) in zip( + model.audits_with_args, expected_audits + ): + assert actual_audit.name == expected_name + for key in expected_args: + assert actual_args[key].sql() == expected_args[key].sql() + + def test_audit_formatting_flag_serde(): expressions = parse( """ diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index d7edd8d131..d15e097875 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -7242,3 +7242,265 @@ def test_physical_table_naming_strategy_hash_md5(copy_to_temp_path: t.Callable): s.table_naming_convention == TableNamingConvention.HASH_MD5 for s in prod_env_snapshots.values() ) + + +@pytest.mark.slow +def test_default_audits_applied_in_plan(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir(exist_ok=True) + + # Create a model with data that will pass the audits + create_temp_file( + tmp_path, + models_dir / "orders.sql", + dedent(""" + MODEL ( + name test.orders, + kind FULL + ); + + SELECT + 1 AS order_id, + 'customer_1' AS customer_id, + 100.50 AS amount, + '2024-01-01'::DATE AS order_date + UNION ALL + SELECT + 2 AS order_id, + 'customer_2' AS customer_id, + 200.75 AS amount, + '2024-01-02'::DATE AS order_date + """), + ) + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="duckdb", + audits=[ + "not_null(columns := [order_id, customer_id])", + "unique_values(columns := [order_id])", + ], + ) + ) + + context = Context(paths=tmp_path, config=config) + + # Create and apply plan, here audits should pass + plan = context.plan("prod", no_prompts=True) + context.apply(plan) + + # Verify model has the default audits + model = context.get_model("test.orders") + assert len(model.audits) == 2 + + audit_names = [audit[0] for audit in model.audits] + assert "not_null" in audit_names + assert "unique_values" in audit_names + + # Verify audit arguments are preserved + for audit_name, audit_args in model.audits: + if audit_name == "not_null": + assert "columns" in audit_args + columns = [col.name for col in audit_args["columns"].expressions] + assert "order_id" in columns + assert "customer_id" in columns + elif audit_name == "unique_values": + assert "columns" in audit_args + columns = [col.name for col in audit_args["columns"].expressions] + assert "order_id" in columns + + +@pytest.mark.slow +def test_default_audits_fail_on_bad_data(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir(exist_ok=True) + + # Create a model with data that violates NOT NULL constraint + create_temp_file( + tmp_path, + models_dir / "bad_orders.sql", + dedent(""" + MODEL ( + name test.bad_orders, + kind FULL + ); + + SELECT + 1 AS order_id, + NULL AS customer_id, -- This violates NOT NULL + 100.50 AS amount, + '2024-01-01'::DATE AS order_date + UNION ALL + SELECT + 2 AS order_id, + 'customer_2' AS customer_id, + 200.75 AS amount, + '2024-01-02'::DATE AS order_date + """), + ) + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="duckdb", audits=["not_null(columns := [customer_id])"] + ) + ) + + context = Context(paths=tmp_path, config=config) + + # Plan should fail due to audit failure + with pytest.raises(PlanError): + context.plan("prod", no_prompts=True, auto_apply=True) + + +@pytest.mark.slow +def test_default_audits_with_model_specific_audits(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir(exist_ok=True) + audits_dir = tmp_path / "audits" + audits_dir.mkdir(exist_ok=True) + + create_temp_file( + tmp_path, + audits_dir / "range_check.sql", + dedent(""" + AUDIT ( + name range_check + ); + + SELECT * FROM @this_model + WHERE @column < @min_value OR @column > @max_value + """), + ) + + # Create a model with its own audits in addition to defaults + create_temp_file( + tmp_path, + models_dir / "products.sql", + dedent(""" + MODEL ( + name test.products, + kind FULL, + audits ( + range_check(column := price, min_value := 0, max_value := 10000) + ) + ); + + SELECT + 1 AS product_id, + 'Widget' AS product_name, + 99.99 AS price + UNION ALL + SELECT + 2 AS product_id, + 'Gadget' AS product_name, + 149.99 AS price + """), + ) + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="duckdb", + audits=[ + "not_null(columns := [product_id, product_name])", + "unique_values(columns := [product_id])", + ], + ) + ) + + context = Context(paths=tmp_path, config=config) + + # Create and apply plan + plan = context.plan("prod", no_prompts=True) + context.apply(plan) + + # Verify model has both default and model-specific audits + model = context.get_model("test.products") + assert len(model.audits) == 3 + + audit_names = [audit[0] for audit in model.audits] + assert "not_null" in audit_names + assert "unique_values" in audit_names + assert "range_check" in audit_names + + # Verify audit execution order, default audits first then model-specific + assert model.audits[0][0] == "not_null" + assert model.audits[1][0] == "unique_values" + assert model.audits[2][0] == "range_check" + + +@pytest.mark.slow +def test_default_audits_with_custom_audit_definitions(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir(exist_ok=True) + audits_dir = tmp_path / "audits" + audits_dir.mkdir(exist_ok=True) + + # Create custom audit definition + create_temp_file( + tmp_path, + audits_dir / "positive_amount.sql", + dedent(""" + AUDIT ( + name positive_amount + ); + + SELECT * FROM @this_model + WHERE @column <= 0 + """), + ) + + # Create a model + create_temp_file( + tmp_path, + models_dir / "transactions.sql", + dedent(""" + MODEL ( + name test.transactions, + kind FULL + ); + + SELECT + 1 AS transaction_id, + 'TXN001' AS transaction_code, + 250.00 AS amount, + '2024-01-01'::DATE AS transaction_date + UNION ALL + SELECT + 2 AS transaction_id, + 'TXN002' AS transaction_code, + 150.00 AS amount, + '2024-01-02'::DATE AS transaction_date + """), + ) + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="duckdb", + audits=[ + "not_null(columns := [transaction_id, transaction_code])", + "unique_values(columns := [transaction_id])", + "positive_amount(column := amount)", + ], + ) + ) + + context = Context(paths=tmp_path, config=config) + + # Create and apply plan + plan = context.plan("prod", no_prompts=True) + context.apply(plan) + + # Verify model has all default audits including custom + model = context.get_model("test.transactions") + assert len(model.audits) == 3 + + audit_names = [audit[0] for audit in model.audits] + assert "not_null" in audit_names + assert "unique_values" in audit_names + assert "positive_amount" in audit_names + + # Verify custom audit arguments + for audit_name, audit_args in model.audits: + if audit_name == "positive_amount": + assert "column" in audit_args + assert audit_args["column"].name == "amount" diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 3cadbae9ca..0be1702fa1 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1672,13 +1672,13 @@ def test_enable_audits_from_model_defaults(): model = load_sql_based_model( expressions, path=Path("./examples/sushi/models/test_model.sql"), - default_audits=model_defaults.audits, + defaults=model_defaults.dict(), ) - assert len(model.audits) == 0 + assert len(model.audits) == 1 config = Config(model_defaults=model_defaults) - assert config.model_defaults.audits[0] == ("assert_positive_order_ids", {}) + assert config.model_defaults.audits[0] == ("assert_positive_order_ids", {}) == model.audits[0] audits_with_args = model.audits_with_args assert len(audits_with_args) == 1 @@ -7253,23 +7253,26 @@ def max_value(evaluator: MacroEvaluator) -> int: "assert_max_value": load_audit(audit_expression, dialect="duckdb"), "assert_not_zero": load_audit(not_zero_audit, dialect="duckdb"), } - config = Config( - model_defaults=ModelDefaultsConfig(dialect="duckdb", audits=["assert_not_zero"]) - ) + model_defaults = ModelDefaultsConfig(dialect="duckdb", audits=["assert_not_zero"]) + model = load_sql_based_model( model_expression, - audits=audits, - default_audits=config.model_defaults.audits, + defaults=model_defaults.dict(), audit_definitions=audits, ) - assert len(model.audits) == 2 + assert len(model.audits) == 3 audits_with_args = model.audits_with_args assert len(audits_with_args) == 3 assert len(model.python_env) == 3 - assert config.model_defaults.audits == [("assert_not_zero", {})] - assert model.audits == [("assert_max_value", {}), ("assert_positive_ids", {})] + assert model.audits == [ + ("assert_not_zero", {}), + ("assert_max_value", {}), + ("assert_positive_ids", {}), + ] assert isinstance(audits_with_args[0][0], ModelAudit) + assert isinstance(audits_with_args[1][0], ModelAudit) + assert isinstance(audits_with_args[2][0], ModelAudit) assert isinstance(model.python_env["min_value"], Executable) assert isinstance(model.python_env["max_value"], Executable) assert isinstance(model.python_env["zero_value"], Executable)