|
3 | 3 | from sqlglot import exp, parse_one |
4 | 4 |
|
5 | 5 | from sqlmesh.core import constants as c |
| 6 | +from sqlmesh.core.config.model import ModelDefaultsConfig |
6 | 7 | from sqlmesh.core.context import Context |
7 | 8 | from sqlmesh.core.audit import ( |
8 | 9 | ModelAudit, |
@@ -962,6 +963,117 @@ def test_multiple_audits_with_same_name(): |
962 | 963 | assert model.audits[1][1] == model.audits[2][1] |
963 | 964 |
|
964 | 965 |
|
| 966 | +def test_default_audits_included_when_no_model_audits(): |
| 967 | + expressions = parse(""" |
| 968 | + MODEL ( |
| 969 | + name test.basic_model |
| 970 | + ); |
| 971 | + SELECT 1 as id, 'test' as name; |
| 972 | + """) |
| 973 | + |
| 974 | + model_defaults = ModelDefaultsConfig( |
| 975 | + dialect="duckdb", audits=["not_null(columns := ['id'])", "unique_values(columns := ['id'])"] |
| 976 | + ) |
| 977 | + model = load_sql_based_model(expressions, defaults=model_defaults.dict()) |
| 978 | + |
| 979 | + assert len(model.audits) == 2 |
| 980 | + audit_names = [audit[0] for audit in model.audits] |
| 981 | + assert "not_null" in audit_names |
| 982 | + assert "unique_values" in audit_names |
| 983 | + |
| 984 | + # Verify arguments are preserved |
| 985 | + for audit_name, audit_args in model.audits: |
| 986 | + if audit_name == "not_null": |
| 987 | + assert "columns" in audit_args |
| 988 | + assert audit_args["columns"].expressions[0].this == "id" |
| 989 | + elif audit_name == "unique_values": |
| 990 | + assert "columns" in audit_args |
| 991 | + assert audit_args["columns"].expressions[0].this == "id" |
| 992 | + |
| 993 | + for audit_name, audit_args in model.audits_with_args: |
| 994 | + if audit_name == "not_null": |
| 995 | + assert "columns" in audit_args |
| 996 | + assert audit_args["columns"].expressions[0].this == "id" |
| 997 | + elif audit_name == "unique_values": |
| 998 | + assert "columns" in audit_args |
| 999 | + assert audit_args["columns"].expressions[0].this == "id" |
| 1000 | + |
| 1001 | + |
| 1002 | +def test_model_defaults_audits_with_same_name(): |
| 1003 | + expressions = parse( |
| 1004 | + """ |
| 1005 | + MODEL ( |
| 1006 | + name db.table, |
| 1007 | + dialect spark, |
| 1008 | + audits( |
| 1009 | + does_not_exceed_threshold(column := id, threshold := 1000), |
| 1010 | + does_not_exceed_threshold(column := price, threshold := 100), |
| 1011 | + unique_values(columns := ['id']) |
| 1012 | + ) |
| 1013 | + ); |
| 1014 | +
|
| 1015 | + SELECT id, price FROM tbl; |
| 1016 | +
|
| 1017 | + AUDIT ( |
| 1018 | + name does_not_exceed_threshold, |
| 1019 | + ); |
| 1020 | + SELECT * FROM @this_model |
| 1021 | + WHERE @column >= @threshold; |
| 1022 | + """ |
| 1023 | + ) |
| 1024 | + |
| 1025 | + model_defaults = ModelDefaultsConfig( |
| 1026 | + dialect="duckdb", |
| 1027 | + audits=[ |
| 1028 | + "does_not_exceed_threshold(column := price, threshold := 33)", |
| 1029 | + "does_not_exceed_threshold(column := id, threshold := 65)", |
| 1030 | + "not_null(columns := ['id'])", |
| 1031 | + ], |
| 1032 | + ) |
| 1033 | + model = load_sql_based_model(expressions, defaults=model_defaults.dict()) |
| 1034 | + assert len(model.audits) == 6 |
| 1035 | + assert len(model.audits_with_args) == 6 |
| 1036 | + assert len(model.audit_definitions) == 1 |
| 1037 | + |
| 1038 | + expected_audits = [ |
| 1039 | + ( |
| 1040 | + "does_not_exceed_threshold", |
| 1041 | + {"column": exp.column("price"), "threshold": exp.Literal.number(33)}, |
| 1042 | + ), |
| 1043 | + ( |
| 1044 | + "does_not_exceed_threshold", |
| 1045 | + {"column": exp.column("id"), "threshold": exp.Literal.number(65)}, |
| 1046 | + ), |
| 1047 | + ("not_null", {"columns": exp.convert(["id"])}), |
| 1048 | + ( |
| 1049 | + "does_not_exceed_threshold", |
| 1050 | + {"column": exp.column("id"), "threshold": exp.Literal.number(1000)}, |
| 1051 | + ), |
| 1052 | + ( |
| 1053 | + "does_not_exceed_threshold", |
| 1054 | + {"column": exp.column("price"), "threshold": exp.Literal.number(100)}, |
| 1055 | + ), |
| 1056 | + ("unique_values", {"columns": exp.convert(["id"])}), |
| 1057 | + ] |
| 1058 | + |
| 1059 | + for (actual_name, actual_args), (expected_name, expected_args) in zip( |
| 1060 | + model.audits, expected_audits |
| 1061 | + ): |
| 1062 | + # Validate the audit names are preserved |
| 1063 | + assert actual_name == expected_name |
| 1064 | + for key in expected_args: |
| 1065 | + # comparing sql representaion is easier |
| 1066 | + assert actual_args[key].sql() == expected_args[key].sql() |
| 1067 | + |
| 1068 | + # Validate audits with args as well along with their arguments |
| 1069 | + for (actual_audit, actual_args), (expected_name, expected_args) in zip( |
| 1070 | + model.audits_with_args, expected_audits |
| 1071 | + ): |
| 1072 | + assert actual_audit.name == expected_name |
| 1073 | + for key in expected_args: |
| 1074 | + assert actual_args[key].sql() == expected_args[key].sql() |
| 1075 | + |
| 1076 | + |
965 | 1077 | def test_audit_formatting_flag_serde(): |
966 | 1078 | expressions = parse( |
967 | 1079 | """ |
|
0 commit comments