Skip to content

Commit a392ee4

Browse files
refactor extraction block in an util function
1 parent 0b9576c commit a392ee4

File tree

3 files changed

+39
-47
lines changed

3 files changed

+39
-47
lines changed

sqlmesh/core/dialect.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,41 @@ def extract_func_call(
14081408
return func.lower(), kwargs
14091409

14101410

1411+
def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.Any:
1412+
"""Used for extracting function calls for signals or audits."""
1413+
1414+
if isinstance(func_calls, (exp.Tuple, exp.Array)):
1415+
return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions]
1416+
if isinstance(func_calls, exp.Paren):
1417+
return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)]
1418+
if isinstance(func_calls, exp.Expression):
1419+
return [extract_func_call(func_calls, allow_tuples=allow_tuples)]
1420+
if isinstance(func_calls, list):
1421+
function_calls = []
1422+
for entry in func_calls:
1423+
if isinstance(entry, dict):
1424+
args = entry
1425+
name = "" if allow_tuples else entry.pop("name")
1426+
elif isinstance(entry, (tuple, list)):
1427+
name, args = entry
1428+
else:
1429+
raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.")
1430+
1431+
function_calls.append(
1432+
(
1433+
name.lower(),
1434+
{
1435+
key: parse_one(value) if isinstance(value, str) else value
1436+
for key, value in args.items()
1437+
},
1438+
)
1439+
)
1440+
1441+
return function_calls
1442+
1443+
return func_calls or []
1444+
1445+
14111446
def is_meta_expression(v: t.Any) -> bool:
14121447
return isinstance(v, (Audit, Metric, Model))
14131448

sqlmesh/core/model/definition.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2539,21 +2539,8 @@ def _create_model(
25392539
used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1])
25402540

25412541
# Merge model-specific audits with default audits
2542-
model_audits = kwargs.pop("audits", [])
2543-
default_audits = defaults.pop("audits", [])
2544-
2545-
if isinstance(model_audits, (exp.Tuple, exp.Array)):
2546-
model_audits_list = [d.extract_func_call(i) for i in model_audits.expressions]
2547-
elif isinstance(model_audits, exp.Paren):
2548-
model_audits_list = [d.extract_func_call(model_audits.this)]
2549-
elif isinstance(model_audits, exp.Expression):
2550-
model_audits_list = [d.extract_func_call(model_audits)]
2551-
elif isinstance(model_audits, list):
2552-
model_audits_list = model_audits
2553-
else:
2554-
model_audits_list = []
2555-
merged_audits = default_audits + model_audits_list
2556-
kwargs["audits"] = merged_audits
2542+
if default_audits := defaults.pop("audits", None):
2543+
kwargs["audits"] = default_audits + d.extract_function_calls(kwargs.pop("audits", []))
25572544

25582545
model = klass(
25592546
name=name,

sqlmesh/core/model/meta.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from sqlmesh.core import dialect as d
1313
from sqlmesh.core.config.linter import LinterConfig
14-
from sqlmesh.core.dialect import normalize_model_name, extract_func_call
14+
from sqlmesh.core.dialect import normalize_model_name
1515
from sqlmesh.core.model.common import (
1616
bool_validator,
1717
default_catalog_validator,
@@ -94,37 +94,7 @@ class ModelMeta(_Node):
9494
def _func_call_validator(cls, v: t.Any, field: t.Any) -> t.Any:
9595
is_signal = getattr(field, "name" if hasattr(field, "name") else "field_name") == "signals"
9696

97-
if isinstance(v, (exp.Tuple, exp.Array)):
98-
return [extract_func_call(i, allow_tuples=is_signal) for i in v.expressions]
99-
if isinstance(v, exp.Paren):
100-
return [extract_func_call(v.this, allow_tuples=is_signal)]
101-
if isinstance(v, exp.Expression):
102-
return [extract_func_call(v, allow_tuples=is_signal)]
103-
if isinstance(v, list):
104-
audits = []
105-
106-
for entry in v:
107-
if isinstance(entry, dict):
108-
args = entry
109-
name = "" if is_signal else entry.pop("name")
110-
elif isinstance(entry, (tuple, list)):
111-
name, args = entry
112-
else:
113-
raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.")
114-
115-
audits.append(
116-
(
117-
name.lower(),
118-
{
119-
key: d.parse_one(value) if isinstance(value, str) else value
120-
for key, value in args.items()
121-
},
122-
)
123-
)
124-
125-
return audits
126-
127-
return v or []
97+
return d.extract_function_calls(v, allow_tuples=is_signal)
12898

12999
@field_validator("tags", mode="before")
130100
def _value_or_tuple_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:

0 commit comments

Comments
 (0)