diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py index 6d0d796c97..59fd478f78 100644 --- a/sqlmesh/core/linter/helpers.py +++ b/sqlmesh/core/linter/helpers.py @@ -2,6 +2,7 @@ from sqlmesh.core.linter.rule import Position, Range from sqlmesh.utils.pydantic import PydanticModel +from sqlglot import tokenize, TokenType import typing as t @@ -113,3 +114,41 @@ def read_range_from_file(file: Path, text_range: Range) -> str: result.append(line[start_char:end_char]) return "".join(result) + + +def get_range_of_model_block( + sql: str, + dialect: str, +) -> t.Optional[Range]: + """ + Get the range of the model block in an SQL file. + """ + tokens = tokenize(sql, dialect=dialect) + + # Find start of the model block + start = next( + (t for t in tokens if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"), + None, + ) + end = next((t for t in tokens if t.token_type is TokenType.SEMICOLON), None) + + if start is None or end is None: + return None + + start_position = TokenPositionDetails( + line=start.line, + col=start.col, + start=start.start, + end=start.end, + ) + end_position = TokenPositionDetails( + line=end.line, + col=end.col, + start=end.start, + end=end.end, + ) + + splitlines = sql.splitlines() + return Range( + start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end + ) diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py index 02c2bb628e..8bf6e33720 100644 --- a/sqlmesh/core/linter/rules/builtin.py +++ b/sqlmesh/core/linter/rules/builtin.py @@ -7,7 +7,7 @@ from sqlglot.expressions import Star from sqlglot.helper import subclasses -from sqlmesh.core.linter.helpers import TokenPositionDetails +from sqlmesh.core.linter.helpers import TokenPositionDetails, get_range_of_model_block from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit from sqlmesh.core.linter.definition import RuleSet from sqlmesh.core.model import Model, SqlModel @@ -93,7 +93,21 @@ class NoMissingAudits(Rule): """Model `audits` must be configured to test data quality.""" def check_model(self, model: Model) -> t.Optional[RuleViolation]: - return self.violation() if not model.audits and not model.kind.is_symbolic else None + if model.audits or model.kind.is_symbolic: + return None + if model._path is None or not str(model._path).endswith(".sql"): + return self.violation() + + try: + with open(model._path, "r", encoding="utf-8") as file: + content = file.read() + + range = get_range_of_model_block(content, model.dialect) + if range: + return self.violation(violation_range=range) + return self.violation() + except Exception: + return self.violation() BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, (Rule,))) diff --git a/tests/core/linter/test_helpers.py b/tests/core/linter/test_helpers.py new file mode 100644 index 0000000000..be6ebf2c27 --- /dev/null +++ b/tests/core/linter/test_helpers.py @@ -0,0 +1,36 @@ +from sqlmesh import Context +from sqlmesh.core.linter.helpers import read_range_from_file, get_range_of_model_block +from sqlmesh.core.model import SqlModel + + +def test_get_position_of_model_block(): + context = Context(paths=["examples/sushi"]) + + sql_models = [ + model + for model in context.models.values() + if isinstance(model, SqlModel) + and model._path is not None + and str(model._path).endswith(".sql") + ] + assert len(sql_models) > 0 + + for model in sql_models: + dialect = model.dialect + assert dialect is not None + + path = model._path + assert path is not None + + with open(path, "r", encoding="utf-8") as file: + content = file.read() + + as_lines = content.splitlines() + + range = get_range_of_model_block(content, dialect) + assert range is not None + + # Check that the range starts with MODEL and ends with ; + read_range = read_range_from_file(path, range) + assert read_range.startswith("MODEL") + assert read_range.endswith(";")