Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions sqlmesh/core/linter/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from sqlmesh.core.linter.rule import Position, Range
from sqlmesh.utils.pydantic import PydanticModel
from sqlglot import tokenize, TokenType
import typing as t


Expand Down Expand Up @@ -113,3 +114,41 @@ def read_range_from_file(file: Path, text_range: Range) -> str:
result.append(line[start_char:end_char])

return "".join(result)


def get_range_of_model_block(
sql: str,
dialect: str,
) -> t.Optional[Range]:
"""
Get the range of the model block in an SQL file.
"""
tokens = tokenize(sql, dialect=dialect)

# Find start of the model block
start = next(
(t for t in tokens if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"),
None,
)
end = next((t for t in tokens if t.token_type is TokenType.SEMICOLON), None)

if start is None or end is None:
return None

start_position = TokenPositionDetails(
line=start.line,
col=start.col,
start=start.start,
end=start.end,
)
end_position = TokenPositionDetails(
line=end.line,
col=end.col,
start=end.start,
end=end.end,
)

splitlines = sql.splitlines()
return Range(
start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end
)
18 changes: 16 additions & 2 deletions sqlmesh/core/linter/rules/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlglot.expressions import Star
from sqlglot.helper import subclasses

from sqlmesh.core.linter.helpers import TokenPositionDetails
from sqlmesh.core.linter.helpers import TokenPositionDetails, get_range_of_model_block
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit
from sqlmesh.core.linter.definition import RuleSet
from sqlmesh.core.model import Model, SqlModel
Expand Down Expand Up @@ -93,7 +93,21 @@ class NoMissingAudits(Rule):
"""Model `audits` must be configured to test data quality."""

def check_model(self, model: Model) -> t.Optional[RuleViolation]:
return self.violation() if not model.audits and not model.kind.is_symbolic else None
if model.audits or model.kind.is_symbolic:
return None
if model._path is None or not str(model._path).endswith(".sql"):
return self.violation()

try:
with open(model._path, "r", encoding="utf-8") as file:
content = file.read()

range = get_range_of_model_block(content, model.dialect)
if range:
return self.violation(violation_range=range)
return self.violation()
except Exception:
return self.violation()


BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, (Rule,)))
36 changes: 36 additions & 0 deletions tests/core/linter/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from sqlmesh import Context
from sqlmesh.core.linter.helpers import read_range_from_file, get_range_of_model_block
from sqlmesh.core.model import SqlModel


def test_get_position_of_model_block():
context = Context(paths=["examples/sushi"])

sql_models = [
model
for model in context.models.values()
if isinstance(model, SqlModel)
and model._path is not None
and str(model._path).endswith(".sql")
]
assert len(sql_models) > 0

for model in sql_models:
dialect = model.dialect
assert dialect is not None

path = model._path
assert path is not None

with open(path, "r", encoding="utf-8") as file:
content = file.read()

as_lines = content.splitlines()

range = get_range_of_model_block(content, dialect)
assert range is not None

# Check that the range starts with MODEL and ends with ;
read_range = read_range_from_file(path, range)
assert read_range.startswith("MODEL")
assert read_range.endswith(";")