Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sqlmesh.core.context import Context
from sqlmesh.utils import Verbosity
from sqlmesh.utils.date import TimeLike
from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError
from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError, LinterError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1201,15 +1201,34 @@ def environments(obj: Context) -> None:
multiple=True,
help="A model to lint. Multiple models can be linted. If no models are specified, every model will be linted.",
)
@click.option(
"--fix",
is_flag=True,
help="Automatically apply available fixes and fail if unfixable errors remain.",
)
@click.pass_obj
@error_handler
@cli_analytics
def lint(
obj: Context,
models: t.Iterator[str],
fix: bool = False,
) -> None:
"""Run the linter for the target model(s)."""
obj.lint_models(models)
if fix:
violations = obj.lint_models(models, raise_on_error=False)
if violations:
from sqlmesh.core.linter.helpers import apply_fixes

apply_fixes(violations)
violations = obj.lint_models(models, raise_on_error=False)
remaining_errors = [v for v in violations if v.violation_type == "error"]
if remaining_errors:
raise LinterError(
"Linter detected errors that could not be automatically fixed."
)
else:
obj.lint_models(models)


@cli.group(no_args_is_help=True)
Expand Down
46 changes: 45 additions & 1 deletion sqlmesh/core/linter/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

from sqlmesh.core.linter.rule import Range, Position
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
from sqlmesh.core.linter.rule import Range, Position, TextEdit
from sqlmesh.utils.pydantic import PydanticModel
from sqlglot import tokenize, TokenType
import typing as t
Expand Down Expand Up @@ -210,3 +211,46 @@ def get_range_of_a_key_in_model_block(
end=key_token.end,
)
return position.to_range(sql.splitlines())


def apply_text_edits(path: Path, edits: t.Sequence[TextEdit]) -> None:
"""Apply a sequence of TextEdits to a file."""
if not edits:
return

with open(path, "r", encoding="utf-8") as file:
content = file.read()

lines = content.splitlines(keepends=True)
offsets = [0]
for line in lines:
offsets.append(offsets[-1] + len(line))

def to_offset(pos: Position) -> int:
line = min(pos.line, len(lines) - 1)
char = min(pos.character, len(lines[line]))
return offsets[line] + char

sorted_edits = sorted(
edits, key=lambda e: (e.range.start.line, e.range.start.character), reverse=True
)
for edit in sorted_edits:
start = to_offset(edit.range.start)
end = to_offset(edit.range.end)
content = content[:start] + edit.new_text + content[end:]

with open(path, "w", encoding="utf-8") as file:
file.write(content)


def apply_fixes(violations: t.Iterable[AnnotatedRuleViolation]) -> None:
"""Apply fixes from the provided violations."""
edits_by_path: dict[Path, list[TextEdit]] = {}

for violation in violations:
for fix in violation.fixes:
for edit in fix.edits:
edits_by_path.setdefault(edit.path, []).append(edit)

for path, edits in edits_by_path.items():
apply_text_edits(path, edits)