Skip to content

Commit ae25572

Browse files
committed
feat: add fix option to lint command
1 parent 9eba5c1 commit ae25572

File tree

6 files changed

+100
-2
lines changed

6 files changed

+100
-2
lines changed

docs/examples/sqlmesh_cli_crash_course.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,8 @@ This is a great way to catch SQL issues before wasting runtime in your data ware
675675

676676
```bash
677677
sqlmesh lint
678+
# or apply fixes automatically
679+
sqlmesh lint --fix
678680
```
679681

680682
=== "Tobiko Cloud"

docs/guides/linter.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ Place a rule's code in the project's `linter/` directory. SQLMesh will load all
100100

101101
If the rule is specified in the project's [configuration file](#applying-linting-rules), SQLMesh will run it when:
102102
- A plan is created during `sqlmesh plan`
103-
- The command `sqlmesh lint` is ran
103+
- The command `sqlmesh lint` is ran. Add `--fix` to automatically apply available fixes and fail if errors remain.
104104

105105
SQLMesh will error if a model violates the rule, informing you which model(s) violated the rule. In this example, `full_model.sql` violated the `NoMissingOwner` rule, essentially halting execution:
106106

docs/reference/cli.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ Usage: sqlmesh lint [OPTIONS]
636636
637637
Options:
638638
--model TEXT A model to lint. Multiple models can be linted. If no models are specified, every model will be linted.
639+
--fix Apply fixes for lint errors. Fails if errors remain after fixes are applied.
639640
--help Show this message and exit.
640641
641642
```

sqlmesh/cli/main.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1201,15 +1201,21 @@ def environments(obj: Context) -> None:
12011201
multiple=True,
12021202
help="A model to lint. Multiple models can be linted. If no models are specified, every model will be linted.",
12031203
)
1204+
@click.option(
1205+
"--fix",
1206+
is_flag=True,
1207+
help="Apply fixes for lint errors. Fails if errors remain after fixes are applied.",
1208+
)
12041209
@click.pass_obj
12051210
@error_handler
12061211
@cli_analytics
12071212
def lint(
12081213
obj: Context,
12091214
models: t.Iterator[str],
1215+
fix: bool,
12101216
) -> None:
12111217
"""Run the linter for the target model(s)."""
1212-
obj.lint_models(models)
1218+
obj.lint_models(models, fix=fix)
12131219

12141220

12151221
@cli.group(no_args_is_help=True)

sqlmesh/core/context.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements
7979
from sqlmesh.core.loader import Loader
8080
from sqlmesh.core.linter.definition import AnnotatedRuleViolation, Linter
81+
from sqlmesh.core.linter.rule import TextEdit, Position
8182
from sqlmesh.core.linter.rules import BUILTIN_RULES
8283
from sqlmesh.core.macros import ExecutableOrMacro, macro
8384
from sqlmesh.core.metric import Metric, rewrite
@@ -3099,6 +3100,7 @@ def lint_models(
30993100
self,
31003101
models: t.Optional[t.Iterable[t.Union[str, Model]]] = None,
31013102
raise_on_error: bool = True,
3103+
fix: bool = False,
31023104
) -> t.List[AnnotatedRuleViolation]:
31033105
found_error = False
31043106

@@ -3116,13 +3118,45 @@ def lint_models(
31163118
found_error = True
31173119
all_violations.extend(violations)
31183120

3121+
if fix:
3122+
self._apply_fixes(all_violations)
3123+
self.refresh()
3124+
return self.lint_models(models, raise_on_error=raise_on_error, fix=False)
3125+
31193126
if raise_on_error and found_error:
31203127
raise LinterError(
31213128
"Linter detected errors in the code. Please fix them before proceeding."
31223129
)
31233130

31243131
return all_violations
31253132

3133+
def _apply_fixes(self, violations: t.List[AnnotatedRuleViolation]) -> None:
3134+
edits_by_file: t.Dict[Path, t.List[TextEdit]] = {}
3135+
for violation in violations:
3136+
for fix in violation.fixes:
3137+
for create in fix.create_files:
3138+
create.path.parent.mkdir(parents=True, exist_ok=True)
3139+
create.path.write_text(create.text, encoding="utf-8")
3140+
for edit in fix.edits:
3141+
edits_by_file.setdefault(edit.path, []).append(edit)
3142+
3143+
for path, edits in edits_by_file.items():
3144+
content = path.read_text(encoding="utf-8")
3145+
lines = content.splitlines(keepends=True)
3146+
3147+
def _offset(pos: Position) -> int:
3148+
return sum(len(lines[i]) for i in range(pos.line)) + pos.character
3149+
3150+
for edit in sorted(
3151+
edits, key=lambda e: (e.range.start.line, e.range.start.character), reverse=True
3152+
):
3153+
start = _offset(edit.range.start)
3154+
end = _offset(edit.range.end)
3155+
content = content[:start] + edit.new_text + content[end:]
3156+
lines = content.splitlines(keepends=True)
3157+
3158+
path.write_text(content, encoding="utf-8")
3159+
31263160
def load_model_tests(
31273161
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
31283162
) -> t.List[ModelTestMetadata]:

tests/cli/test_cli.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,61 @@ def test_lint(runner, tmp_path):
13281328
assert result.exit_code == 1
13291329

13301330

1331+
def test_lint_fix(runner, tmp_path):
1332+
create_example_project(tmp_path)
1333+
1334+
with open(tmp_path / "config.yaml", "a", encoding="utf-8") as f:
1335+
f.write(
1336+
"""linter:
1337+
enabled: True
1338+
rules: ["noselectstar"]
1339+
"""
1340+
)
1341+
1342+
model_path = tmp_path / "models" / "incremental_model.sql"
1343+
with open(model_path, "r", encoding="utf-8") as f:
1344+
content = f.read()
1345+
content = content.replace(
1346+
"SELECT\n id,\n item_id,\n event_date,\nFROM",
1347+
"SELECT *\nFROM",
1348+
)
1349+
with open(model_path, "w", encoding="utf-8") as f:
1350+
f.write(content)
1351+
1352+
result = runner.invoke(cli, ["--paths", tmp_path, "lint", "--fix"])
1353+
assert result.exit_code == 0
1354+
with open(model_path, "r", encoding="utf-8") as f:
1355+
assert "SELECT *" not in f.read()
1356+
1357+
1358+
def test_lint_fix_unfixable_error(runner, tmp_path):
1359+
create_example_project(tmp_path)
1360+
1361+
with open(tmp_path / "config.yaml", "a", encoding="utf-8") as f:
1362+
f.write(
1363+
"""linter:
1364+
enabled: True
1365+
rules: ["noselectstar", "nomissingaudits"]
1366+
"""
1367+
)
1368+
1369+
model_path = tmp_path / "models" / "incremental_model.sql"
1370+
with open(model_path, "r", encoding="utf-8") as f:
1371+
content = f.read()
1372+
content = content.replace(
1373+
"SELECT\n id,\n item_id,\n event_date,\nFROM",
1374+
"SELECT *\nFROM",
1375+
)
1376+
with open(model_path, "w", encoding="utf-8") as f:
1377+
f.write(content)
1378+
1379+
result = runner.invoke(cli, ["--paths", tmp_path, "lint", "--fix"])
1380+
assert result.exit_code == 1
1381+
assert "nomissingaudits" in result.output
1382+
with open(model_path, "r", encoding="utf-8") as f:
1383+
assert "SELECT *" not in f.read()
1384+
1385+
13311386
def test_state_export(runner: CliRunner, tmp_path: Path) -> None:
13321387
create_example_project(tmp_path)
13331388

0 commit comments

Comments
 (0)