Skip to content

Commit 8c01e2f

Browse files
committed
cli: add --fix option to lint
1 parent 6724d96 commit 8c01e2f

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

sqlmesh/cli/main.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sqlmesh.core.context import Context
2424
from sqlmesh.utils import Verbosity
2525
from sqlmesh.utils.date import TimeLike
26-
from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError
26+
from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError, LinterError
2727

2828
logger = logging.getLogger(__name__)
2929

@@ -1201,15 +1201,34 @@ def environments(obj: Context) -> None:
12011201
multiple=True,
12021202
help="A model to lint. Multiple models can be linted. If no models are specified, every model will be linted.",
12031203
)
1204+
@click.option(
1205+
"--fix",
1206+
is_flag=True,
1207+
help="Automatically apply available fixes and fail if unfixable errors remain.",
1208+
)
12041209
@click.pass_obj
12051210
@error_handler
12061211
@cli_analytics
12071212
def lint(
12081213
obj: Context,
12091214
models: t.Iterator[str],
1215+
fix: bool = False,
12101216
) -> None:
12111217
"""Run the linter for the target model(s)."""
1212-
obj.lint_models(models)
1218+
if fix:
1219+
violations = obj.lint_models(models, raise_on_error=False)
1220+
if violations:
1221+
from sqlmesh.core.linter.helpers import apply_fixes
1222+
1223+
apply_fixes(violations)
1224+
violations = obj.lint_models(models, raise_on_error=False)
1225+
remaining_errors = [v for v in violations if v.violation_type == "error"]
1226+
if remaining_errors:
1227+
raise LinterError(
1228+
"Linter detected errors that could not be automatically fixed."
1229+
)
1230+
else:
1231+
obj.lint_models(models)
12131232

12141233

12151234
@cli.group(no_args_is_help=True)

sqlmesh/core/linter/helpers.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22

3-
from sqlmesh.core.linter.rule import Range, Position
3+
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
4+
from sqlmesh.core.linter.rule import Range, Position, TextEdit
45
from sqlmesh.utils.pydantic import PydanticModel
56
from sqlglot import tokenize, TokenType
67
import typing as t
@@ -210,3 +211,46 @@ def get_range_of_a_key_in_model_block(
210211
end=key_token.end,
211212
)
212213
return position.to_range(sql.splitlines())
214+
215+
216+
def apply_text_edits(path: Path, edits: t.Sequence[TextEdit]) -> None:
217+
"""Apply a sequence of TextEdits to a file."""
218+
if not edits:
219+
return
220+
221+
with open(path, "r", encoding="utf-8") as file:
222+
content = file.read()
223+
224+
lines = content.splitlines(keepends=True)
225+
offsets = [0]
226+
for line in lines:
227+
offsets.append(offsets[-1] + len(line))
228+
229+
def to_offset(pos: Position) -> int:
230+
line = min(pos.line, len(lines) - 1)
231+
char = min(pos.character, len(lines[line]))
232+
return offsets[line] + char
233+
234+
sorted_edits = sorted(
235+
edits, key=lambda e: (e.range.start.line, e.range.start.character), reverse=True
236+
)
237+
for edit in sorted_edits:
238+
start = to_offset(edit.range.start)
239+
end = to_offset(edit.range.end)
240+
content = content[:start] + edit.new_text + content[end:]
241+
242+
with open(path, "w", encoding="utf-8") as file:
243+
file.write(content)
244+
245+
246+
def apply_fixes(violations: t.Iterable[AnnotatedRuleViolation]) -> None:
247+
"""Apply fixes from the provided violations."""
248+
edits_by_path: dict[Path, list[TextEdit]] = {}
249+
250+
for violation in violations:
251+
for fix in violation.fixes:
252+
for edit in fix.edits:
253+
edits_by_path.setdefault(edit.path, []).append(edit)
254+
255+
for path, edits in edits_by_path.items():
256+
apply_text_edits(path, edits)

0 commit comments

Comments
 (0)