Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlmesh/cli/project_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _gen_config(
rules:
- ambiguousorinvalidcolumn
- invalidselectstarexpansion
- noambiguousprojections
""",
ProjectTemplate.DBT: f"""# --- Virtual Data Environment Mode ---
# Enable Virtual Data Environments (VDE) for *development* environments.
Expand Down
4 changes: 3 additions & 1 deletion sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3130,7 +3130,9 @@ def lint_models(
found_error = False

model_list = (
list(self.get_model(model) for model in models) if models else self.models.values()
list(self.get_model(model, raise_if_missing=True) for model in models)
if models
else self.models.values()
)
all_violations = []
for model in model_list:
Expand Down
26 changes: 13 additions & 13 deletions sqlmesh/core/linter/definition.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations
import typing as t
from sqlmesh.core.config.linter import LinterConfig
from sqlmesh.core.model import Model
from sqlmesh.utils.errors import raise_config_error
from sqlmesh.core.console import LinterConsole, get_console

import operator as op
import typing as t
from collections.abc import Iterator, Iterable, Set, Mapping, Callable
from functools import reduce
from sqlmesh.core.model import Model
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix

from sqlmesh.core.config.linter import LinterConfig
from sqlmesh.core.console import LinterConsole, get_console
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix
from sqlmesh.core.model import Model
from sqlmesh.utils.errors import raise_config_error

if t.TYPE_CHECKING:
from sqlmesh.core.context import GenericContext
Expand Down Expand Up @@ -38,6 +38,12 @@ def __init__(
self.rules = rules
self.warn_rules = warn_rules

if overlapping := rules.intersection(warn_rules):
overlapping_rules = ", ".join(rule for rule in overlapping)
raise_config_error(
f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]"
)

@classmethod
def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
ignored_rules = select_rules(all_rules, config.ignored_rules)
Expand All @@ -46,12 +52,6 @@ def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
rules = select_rules(included_rules, config.rules)
warn_rules = select_rules(included_rules, config.warn_rules)

if overlapping := rules.intersection(warn_rules):
overlapping_rules = ", ".join(rule for rule in overlapping)
raise_config_error(
f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]"
)

return Linter(config.enabled, all_rules, rules, warn_rules)

def lint_model(
Expand Down
29 changes: 29 additions & 0 deletions sqlmesh/core/linter/rules/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,33 @@ def create_fix(self, model_name: str) -> t.Optional[Fix]:
)


class NoAmbiguousProjections(Rule):
"""All projections in a model must have unique & inferrable names or explicit aliases."""

def check_model(self, model: Model) -> t.Optional[RuleViolation]:
query = model.render_query()
if query is None:
return None

name_counts: t.Dict[str, int] = {}
projection_list = query.selects
for expression in projection_list:
alias = expression.output_name
if alias == "*":
continue

if not alias:
return self.violation(
f"Outer projection '{expression.sql(dialect=model.dialect)}' must have inferrable names or explicit aliases."
)

name_counts[alias] = name_counts.get(alias, 0) + 1

for name, count in name_counts.items():
if count > 1:
return self.violation(f"Found duplicate outer select name '{name}'")

return None


BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, (Rule,)))
37 changes: 15 additions & 22 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,12 +1417,20 @@ def columns_to_types(self) -> t.Optional[t.Dict[str, exp.DataType]]:

unknown = exp.DataType.build("unknown")

self._columns_to_types = {
columns_to_types = {}
for select in query.selects:
output_name = select.output_name

# If model validation is disabled, we cannot assume that projections
# will have inferrable output names or even that they will be unique
if not output_name or output_name in columns_to_types:
return None

# copy data type because it is used in the engine to build CTAS and other queries
# this can change the parent which will mess up the diffing algo
select.output_name: (select.type or unknown).copy()
for select in query.selects
}
columns_to_types[output_name] = (select.type or unknown).copy()

self._columns_to_types = columns_to_types

if "*" in self._columns_to_types:
return None
Expand Down Expand Up @@ -1473,22 +1481,6 @@ def validate_definition(self) -> None:
if not projection_list:
raise_config_error("Query missing select statements", self._path)

name_counts: t.Dict[str, int] = {}
for expression in projection_list:
alias = expression.output_name
if alias == "*":
continue
if not alias:
raise_config_error(
f"Outer projection '{expression.sql(dialect=self.dialect)}' must have inferrable names or explicit aliases.",
self._path,
)
name_counts[alias] = name_counts.get(alias, 0) + 1

for name, count in name_counts.items():
if count > 1:
raise_config_error(f"Found duplicate outer select name '{name}'", self._path)

if self.depends_on_self and not self.annotated:
raise_config_error(
"Self-referencing models require inferrable column types. There are three options available to mitigate this issue: add explicit types to all projections in the outermost SELECT statement, leverage external models (https://sqlmesh.readthedocs.io/en/stable/concepts/models/external_models/), or use the `columns` model attribute (https://sqlmesh.readthedocs.io/en/stable/concepts/models/overview/#columns).",
Expand Down Expand Up @@ -1846,8 +1838,9 @@ def validate_definition(self) -> None:
super().validate_definition()

if self.kind and not self.kind.supports_python_models:
raise SQLMeshError(
f"Cannot create Python model '{self.name}' as the '{self.kind.name}' kind doesn't support Python models"
raise_config_error(
f"Cannot create Python model '{self.name}' as the '{self.kind.name}' kind doesn't support Python models",
self._path,
)

def render(
Expand Down
3 changes: 3 additions & 0 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,7 @@ def test_dlt_filesystem_pipeline(tmp_path):
" rules:\n"
" - ambiguousorinvalidcolumn\n"
" - invalidselectstarexpansion\n"
" - noambiguousprojections\n"
)

with open(config_path) as file:
Expand Down Expand Up @@ -1048,6 +1049,7 @@ def test_dlt_pipeline(runner, tmp_path):
rules:
- ambiguousorinvalidcolumn
- invalidselectstarexpansion
- noambiguousprojections
"""

with open(tmp_path / "config.yaml") as file:
Expand Down Expand Up @@ -1990,6 +1992,7 @@ def test_init_project_engine_configs(tmp_path):
rules:
- ambiguousorinvalidcolumn
- invalidselectstarexpansion
- noambiguousprojections
"""

with open(tmp_path / "config.yaml") as file:
Expand Down
49 changes: 49 additions & 0 deletions tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,38 @@ def test_raw_code_handling(sushi_test_dbt_context: Context):
)


@pytest.mark.slow
def test_dbt_models_are_not_validated(sushi_test_dbt_context: Context):
model = sushi_test_dbt_context.models['"memory"."sushi"."non_validated_model"']

assert model.render_query_or_raise().sql(comments=False) == 'SELECT 1 AS "c", 2 AS "c"'
assert sushi_test_dbt_context.fetchdf(
'SELECT * FROM "memory"."sushi"."non_validated_model"'
).to_dict() == {"c": {0: 1}, "c_1": {0: 2}}

# Write a new incremental model file that should fail validation
models_dir = sushi_test_dbt_context.path / "models"
incremental_model_path = models_dir / "invalid_incremental.sql"
incremental_model_content = """{{
config(
materialized='incremental',
incremental_strategy='delete+insert',
)
}}

SELECT
1 AS c"""

incremental_model_path.write_text(incremental_model_content)

# Reload the context - this should raise a validation error for the incremental model
with pytest.raises(
ConfigError,
match="Unmanaged incremental models with insert / overwrite enabled must specify the partitioned_by field",
):
Context(paths=sushi_test_dbt_context.path, config="test_config")


def test_catalog_name_needs_to_be_quoted():
config = Config(
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
Expand Down Expand Up @@ -3085,3 +3117,20 @@ def test_plan_no_start_configured():
match=r"Model '.*xvg.*': Start date / time .* can't be greater than end date / time .*\.\nSet the `start` attribute in your project config model defaults to avoid this issue",
):
context.plan("dev", execution_time="1999-01-05")


def test_lint_model_projections(tmp_path: Path):
init_example_project(tmp_path, engine_type="duckdb", dialect="duckdb")

context = Context(paths=tmp_path)
context.upsert_model(
load_sql_based_model(
parse("""MODEL(name sqlmesh_example.m); SELECT 1 AS x, 2 AS x"""),
default_catalog="db",
)
)

config_err = "Linter detected errors in the code. Please fix them before proceeding."

with pytest.raises(LinterError, match=config_err):
prod_plan = context.plan(no_prompts=True, auto_apply=True)
74 changes: 46 additions & 28 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
ModelDefaultsConfig,
LinterConfig,
)
from sqlmesh.core import constants as c
from sqlmesh.core.context import Context, ExecutionContext
from sqlmesh.core.dialect import parse
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter
from sqlmesh.core.macros import MacroEvaluator, macro
from sqlmesh.core import constants as c
from sqlmesh.core.model import (
CustomKind,
PythonModel,
Expand Down Expand Up @@ -198,29 +198,64 @@ def test_model_multiple_select_statements():
load_sql_based_model(expressions)


@pytest.mark.parametrize(
"query, error",
[
("y::int, x::int AS y", "duplicate"),
("* FROM db.table", "require inferrable column types"),
],
)
def test_model_validation(query, error):
def test_model_validation(tmp_path):
expressions = d.parse(
f"""
MODEL (
name db.table,
kind FULL,
);

SELECT {query}
SELECT
y::int,
x::int AS y
FROM db.ext
"""
)

ctx = Context(
config=Config(linter=LinterConfig(enabled=True, rules=["noambiguousprojections"])),
paths=tmp_path,
)
ctx.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))

errors = ctx.lint_models(["db.table"], raise_on_error=False)
assert errors, "Expected NoAmbiguousProjections violation"
assert errors[0].violation_msg == "Found duplicate outer select name 'y'"

expressions = d.parse(
"""
MODEL (
name db.table,
kind FULL,
);

SELECT a, a UNION SELECT c, c
"""
)

ctx.upsert_model(load_sql_based_model(expressions, default_catalog="memory"))

errors = ctx.lint_models(["db.table"], raise_on_error=False)
assert errors, "Expected NoAmbiguousProjections violation"
assert errors[0].violation_msg == "Found duplicate outer select name 'a'"

expressions = d.parse(
f"""
MODEL (
name db.table,
kind FULL,
);

SELECT * FROM db.table
"""
)

model = load_sql_based_model(expressions)
with pytest.raises(ConfigError) as ex:
model.validate_definition()
assert error in str(ex.value)

assert "require inferrable column types" in str(ex.value)


def test_model_union_query(sushi_context, assert_exp_eq):
Expand Down Expand Up @@ -405,23 +440,6 @@ def get_date(evaluator):
)


def test_model_validation_union_query():
expressions = d.parse(
"""
MODEL (
name db.table,
kind FULL,
);

SELECT a, a UNION SELECT c, c
"""
)

model = load_sql_based_model(expressions)
with pytest.raises(ConfigError, match=r"Found duplicate outer select name 'a'"):
model.validate_definition()


@use_terminal_console
def test_model_qualification(tmp_path: Path):
with patch.object(get_console(), "log_warning") as mock_logger:
Expand Down
5 changes: 5 additions & 0 deletions tests/fixtures/dbt/sushi_test/models/non_validated_model.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{{ config(materialized='table') }}

SELECT
1 AS c,
2 AS c,