diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 8982efc9f8..677e5c599e 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -23,7 +23,7 @@ from sqlmesh.core.context import Context from sqlmesh.utils import Verbosity from sqlmesh.utils.date import TimeLike -from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError +from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError, LinterError logger = logging.getLogger(__name__) @@ -1201,15 +1201,34 @@ def environments(obj: Context) -> None: multiple=True, help="A model to lint. Multiple models can be linted. If no models are specified, every model will be linted.", ) +@click.option( + "--fix", + is_flag=True, + help="Automatically apply available fixes and fail if unfixable errors remain.", +) @click.pass_obj @error_handler @cli_analytics def lint( obj: Context, models: t.Iterator[str], + fix: bool = False, ) -> None: """Run the linter for the target model(s).""" - obj.lint_models(models) + if fix: + violations = obj.lint_models(models, raise_on_error=False) + if violations: + from sqlmesh.core.linter.helpers import apply_fixes + + apply_fixes(violations) + violations = obj.lint_models(models, raise_on_error=False) + remaining_errors = [v for v in violations if v.violation_type == "error"] + if remaining_errors: + raise LinterError( + "Linter detected errors that could not be automatically fixed." + ) + else: + obj.lint_models(models) @cli.group(no_args_is_help=True) diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py index 3f6e96765f..ffed7193ab 100644 --- a/sqlmesh/core/linter/helpers.py +++ b/sqlmesh/core/linter/helpers.py @@ -1,6 +1,7 @@ from pathlib import Path -from sqlmesh.core.linter.rule import Range, Position +from sqlmesh.core.linter.definition import AnnotatedRuleViolation +from sqlmesh.core.linter.rule import Range, Position, TextEdit from sqlmesh.utils.pydantic import PydanticModel from sqlglot import tokenize, TokenType import typing as t @@ -210,3 +211,46 @@ def get_range_of_a_key_in_model_block( end=key_token.end, ) return position.to_range(sql.splitlines()) + + +def apply_text_edits(path: Path, edits: t.Sequence[TextEdit]) -> None: + """Apply a sequence of TextEdits to a file.""" + if not edits: + return + + with open(path, "r", encoding="utf-8") as file: + content = file.read() + + lines = content.splitlines(keepends=True) + offsets = [0] + for line in lines: + offsets.append(offsets[-1] + len(line)) + + def to_offset(pos: Position) -> int: + line = min(pos.line, len(lines) - 1) + char = min(pos.character, len(lines[line])) + return offsets[line] + char + + sorted_edits = sorted( + edits, key=lambda e: (e.range.start.line, e.range.start.character), reverse=True + ) + for edit in sorted_edits: + start = to_offset(edit.range.start) + end = to_offset(edit.range.end) + content = content[:start] + edit.new_text + content[end:] + + with open(path, "w", encoding="utf-8") as file: + file.write(content) + + +def apply_fixes(violations: t.Iterable[AnnotatedRuleViolation]) -> None: + """Apply fixes from the provided violations.""" + edits_by_path: dict[Path, list[TextEdit]] = {} + + for violation in violations: + for fix in violation.fixes: + for edit in fix.edits: + edits_by_path.setdefault(edit.path, []).append(edit) + + for path, edits in edits_by_path.items(): + apply_text_edits(path, edits)