Skip to content

Commit acc8d37

Browse files
committed
Address PR comments
1 parent db464f3 commit acc8d37

File tree

4 files changed

+77
-66
lines changed

4 files changed

+77
-66
lines changed

sqlmesh/core/context.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from sqlmesh.core.linter.rules import BUILTIN_RULES
8282
from sqlmesh.core.macros import ExecutableOrMacro, macro
8383
from sqlmesh.core.metric import Metric, rewrite
84-
from sqlmesh.core.model import Model, SqlModel, update_model_schemas
84+
from sqlmesh.core.model import Model, update_model_schemas
8585
from sqlmesh.core.config.model import ModelDefaultsConfig
8686
from sqlmesh.core.notification_target import (
8787
NotificationEvent,
@@ -681,13 +681,10 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
681681
cache_dir=self.cache_dir,
682682
)
683683

684-
# The model definition can be validated correctly only after the schema is set.
685-
for model in self.models.values():
686-
if isinstance(model, SqlModel):
687-
# Validates only the model metadata; SQL validations are handled by the linter
688-
super(SqlModel, model).validate_definition() # type: ignore
689-
else:
690-
model.validate_definition()
684+
models = self.models.values()
685+
for model in models:
686+
# The model definition can be validated correctly only after the schema is set.
687+
model.validate_definition()
691688

692689
duplicates = set(self._models) & set(self._standalone_audits)
693690
if duplicates:
@@ -3134,7 +3131,9 @@ def lint_models(
31343131
found_error = False
31353132

31363133
model_list = (
3137-
list(self.get_model(model) for model in models) if models else self.models.values()
3134+
list(self.get_model(model, raise_if_missing=True) for model in models)
3135+
if models
3136+
else self.models.values()
31383137
)
31393138
all_violations = []
31403139
for model in model_list:

sqlmesh/core/linter/rules/builtin.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
)
2626
from sqlmesh.core.linter.definition import RuleSet
2727
from sqlmesh.core.model import Model, SqlModel, ExternalModel
28-
from sqlmesh.utils.errors import ConfigError
2928
from sqlmesh.utils.lineage import extract_references_from_query, ExternalModelReference
3029

3130

@@ -275,20 +274,31 @@ def create_fix(self, model_name: str) -> t.Optional[Fix]:
275274
)
276275

277276

278-
class ValidateModelDefinition(Rule):
279-
"""
280-
Checks whether a model satisfies certain properties, such as (but not limited to):
281-
282-
- If SQL-based, it contains at least one projection & projection names are unique
283-
- Its kind is configured correctly (e.g., the VIEW kind is not supported for Python models)
284-
- Other metadata properties are well-formed (e.g., incremental-by-time models require a time column)
285-
"""
277+
class NoAmbiguousProjections(Rule):
278+
"""All projections in a model must have unique, inferrable names or explicit aliases."""
286279

287280
def check_model(self, model: Model) -> t.Optional[RuleViolation]:
288-
try:
289-
model.validate_definition()
290-
except ConfigError as ex:
291-
return self.violation(str(ex))
281+
query = model.render_query()
282+
if query is None:
283+
return None
284+
285+
name_counts: t.Dict[str, int] = {}
286+
projection_list = query.selects
287+
for expression in projection_list:
288+
alias = expression.output_name
289+
if alias == "*":
290+
continue
291+
292+
if not alias:
293+
return self.violation(
294+
f"Outer projection '{expression.sql(dialect=model.dialect)}' must have inferrable names or explicit aliases."
295+
)
296+
297+
name_counts[alias] = name_counts.get(alias, 0) + 1
298+
299+
for name, count in name_counts.items():
300+
if count > 1:
301+
return self.violation(f"Found duplicate outer select name '{name}'")
292302

293303
return None
294304

sqlmesh/core/model/definition.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,22 +1481,6 @@ def validate_definition(self) -> None:
14811481
if not projection_list:
14821482
raise_config_error("Query missing select statements", self._path)
14831483

1484-
name_counts: t.Dict[str, int] = {}
1485-
for expression in projection_list:
1486-
alias = expression.output_name
1487-
if alias == "*":
1488-
continue
1489-
if not alias:
1490-
raise_config_error(
1491-
f"Outer projection '{expression.sql(dialect=self.dialect)}' must have inferrable names or explicit aliases.",
1492-
self._path,
1493-
)
1494-
name_counts[alias] = name_counts.get(alias, 0) + 1
1495-
1496-
for name, count in name_counts.items():
1497-
if count > 1:
1498-
raise_config_error(f"Found duplicate outer select name '{name}'", self._path)
1499-
15001484
if self.depends_on_self and not self.annotated:
15011485
raise_config_error(
15021486
"Self-referencing models require inferrable column types. There are three options available to mitigate this issue: add explicit types to all projections in the outermost SELECT statement, leverage external models (https://sqlmesh.readthedocs.io/en/stable/concepts/models/external_models/), or use the `columns` model attribute (https://sqlmesh.readthedocs.io/en/stable/concepts/models/overview/#columns).",

tests/core/test_model.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131
ModelDefaultsConfig,
3232
LinterConfig,
3333
)
34+
from sqlmesh.core import constants as c
3435
from sqlmesh.core.context import Context, ExecutionContext
3536
from sqlmesh.core.dialect import parse
3637
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
3738
from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter
3839
from sqlmesh.core.macros import MacroEvaluator, macro
39-
from sqlmesh.core import constants as c
4040
from sqlmesh.core.model import (
4141
CustomKind,
4242
PythonModel,
@@ -198,29 +198,64 @@ def test_model_multiple_select_statements():
198198
load_sql_based_model(expressions)
199199

200200

201-
@pytest.mark.parametrize(
202-
"query, error",
203-
[
204-
("y::int, x::int AS y", "duplicate"),
205-
("* FROM db.table", "require inferrable column types"),
206-
],
207-
)
208-
def test_model_validation(query, error):
201+
def test_model_validation(tmp_path):
209202
expressions = d.parse(
210203
f"""
211204
MODEL (
212205
name db.table,
213206
kind FULL,
214207
);
215208
216-
SELECT {query}
209+
SELECT
210+
y::int,
211+
x::int AS y
212+
FROM db.ext
213+
"""
214+
)
215+
216+
ctx = Context(
217+
config=Config(linter=LinterConfig(enabled=True, rules=["noambiguousprojections"])),
218+
paths=tmp_path,
219+
)
220+
ctx.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
221+
222+
errors = ctx.lint_models(["db.table"], raise_on_error=False)
223+
assert errors, "Expected NoAmbiguousProjections violation"
224+
assert errors[0].violation_msg == "Found duplicate outer select name 'y'"
225+
226+
expressions = d.parse(
227+
"""
228+
MODEL (
229+
name db.table,
230+
kind FULL,
231+
);
232+
233+
SELECT a, a UNION SELECT c, c
234+
"""
235+
)
236+
237+
ctx.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))
238+
239+
errors = ctx.lint_models(["db.table"], raise_on_error=False)
240+
assert errors, "Expected NoAmbiguousProjections violation"
241+
assert errors[0].violation_msg == "Found duplicate outer select name 'a'"
242+
243+
expressions = d.parse(
244+
f"""
245+
MODEL (
246+
name db.table,
247+
kind FULL,
248+
);
249+
250+
SELECT * FROM db.table
217251
"""
218252
)
219253

220254
model = load_sql_based_model(expressions)
221255
with pytest.raises(ConfigError) as ex:
222256
model.validate_definition()
223-
assert error in str(ex.value)
257+
258+
assert "require inferrable column types" in str(ex.value)
224259

225260

226261
def test_model_union_query(sushi_context, assert_exp_eq):
@@ -405,23 +440,6 @@ def get_date(evaluator):
405440
)
406441

407442

408-
def test_model_validation_union_query():
409-
expressions = d.parse(
410-
"""
411-
MODEL (
412-
name db.table,
413-
kind FULL,
414-
);
415-
416-
SELECT a, a UNION SELECT c, c
417-
"""
418-
)
419-
420-
model = load_sql_based_model(expressions)
421-
with pytest.raises(ConfigError, match=r"Found duplicate outer select name 'a'"):
422-
model.validate_definition()
423-
424-
425443
@use_terminal_console
426444
def test_model_qualification(tmp_path: Path):
427445
with patch.object(get_console(), "log_warning") as mock_logger:

0 commit comments

Comments
 (0)