Skip to content

Commit 66dfd29

Browse files
committed
feat(linter): allow fixes to create files
1 parent 751c38d commit 66dfd29

File tree

5 files changed

+138
-8
lines changed

5 files changed

+138
-8
lines changed

sqlmesh/core/linter/rule.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from pathlib import Path
66

77
from sqlmesh.core.model import Model
@@ -49,12 +49,21 @@ class TextEdit:
4949
new_text: str
5050

5151

52+
@dataclass(frozen=True)
53+
class CreateFile:
54+
"""Create a new file with the provided text."""
55+
56+
path: Path
57+
text: str
58+
59+
5260
@dataclass(frozen=True)
5361
class Fix:
5462
"""A fix that can be applied to resolve a rule violation."""
5563

5664
title: str
5765
edits: t.List[TextEdit]
66+
create_files: t.List[CreateFile] = field(default_factory=list)
5867

5968

6069
class _Rule(abc.ABCMeta):

sqlmesh/core/linter/rules/builtin.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414
get_range_of_model_block,
1515
read_range_from_string,
1616
)
17-
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit, Position
17+
from sqlmesh.core.linter.rule import (
18+
Rule,
19+
RuleViolation,
20+
Range,
21+
Fix,
22+
TextEdit,
23+
Position,
24+
CreateFile,
25+
)
1826
from sqlmesh.core.linter.definition import RuleSet
1927
from sqlmesh.core.model import Model, SqlModel, ExternalModel
2028
from sqlmesh.utils.lineage import extract_references_from_query, ExternalModelReference
@@ -227,7 +235,16 @@ def create_fix(self, model_name: str) -> t.Optional[Fix]:
227235

228236
external_models_path = root / EXTERNAL_MODELS_YAML
229237
if not external_models_path.exists():
230-
return None
238+
return Fix(
239+
title="Add external model file",
240+
edits=[],
241+
create_files=[
242+
CreateFile(
243+
path=external_models_path,
244+
text=f"- name: '{model_name}'\n",
245+
)
246+
],
247+
)
231248

232249
# Figure out the position to insert the new external model at the end of the file, whether
233250
# needs new line or not.

sqlmesh/lsp/context.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,39 @@ def get_code_actions(
273273
if found_violation is not None and found_violation.fixes:
274274
# Create code actions for each fix
275275
for fix in found_violation.fixes:
276-
# Convert our Fix to LSP TextEdits
277276
changes: t.Dict[str, t.List[types.TextEdit]] = {}
277+
document_changes: t.List[
278+
t.Union[
279+
types.TextDocumentEdit,
280+
types.CreateFile,
281+
]
282+
] = []
283+
284+
for create in getattr(fix, "create_files", []):
285+
create_uri = URI.from_path(create.path).value
286+
document_changes.append(types.CreateFile(uri=create_uri))
287+
document_changes.append(
288+
types.TextDocumentEdit(
289+
text_document=types.OptionalVersionedTextDocumentIdentifier(
290+
uri=create_uri,
291+
version=None,
292+
),
293+
edits=[
294+
types.TextEdit(
295+
range=types.Range(
296+
start=types.Position(line=0, character=0),
297+
end=types.Position(line=0, character=0),
298+
),
299+
new_text=create.text,
300+
)
301+
],
302+
)
303+
)
304+
278305
for edit in fix.edits:
279306
uri_key = URI.from_path(edit.path).value
280307
if uri_key not in changes:
281308
changes[uri_key] = []
282-
# Create a TextEdit for the LSP
283309
changes[uri_key].append(
284310
types.TextEdit(
285311
range=types.Range(
@@ -296,12 +322,15 @@ def get_code_actions(
296322
)
297323
)
298324

299-
# Create the code action
325+
workspace_edit = types.WorkspaceEdit(
326+
changes=changes if changes else None,
327+
document_changes=document_changes if document_changes else None,
328+
)
300329
code_action = types.CodeAction(
301330
title=fix.title,
302331
kind=types.CodeActionKind.QuickFix,
303332
diagnostics=[diagnostic],
304-
edit=types.WorkspaceEdit(changes=changes),
333+
edit=workspace_edit,
305334
)
306335
code_actions.append(code_action)
307336

tests/core/linter/test_builtin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,13 @@ def test_no_missing_external_models(tmp_path, copy_to_temp_path) -> None:
5151
lint.violation_msg
5252
== """Model '"memory"."sushi"."customers"' depends on unregistered external model '"memory"."raw"."demographics"'. Please register it in the external models file. This can be done by running 'sqlmesh create_external_models'."""
5353
)
54-
assert len(lint.fixes) == 0
54+
assert len(lint.fixes) == 1
55+
fix = lint.fixes[0]
56+
assert len(fix.edits) == 0
57+
assert len(fix.create_files) == 1
58+
create = fix.create_files[0]
59+
assert create.path == sushi_path / "external_models.yaml"
60+
assert create.text == "- name: '\"memory\".\"raw\".\"demographics\"'\n"
5561

5662

5763
def test_no_missing_external_models_with_existing_file_ending_in_newline(

tests/lsp/test_code_actions.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import typing as t
2+
import os
23
from lsprotocol import types
34
from sqlmesh.core.context import Context
45
from sqlmesh.lsp.context import LSPContext
@@ -109,3 +110,71 @@ def test_code_actions_with_linting(copy_to_temp_path: t.Callable):
109110
URI.from_path(sushi_path / "models" / "latest_order.sql").value
110111
]
111112
assert len(text_edits) > 0
113+
114+
115+
def test_code_actions_create_file(copy_to_temp_path: t.Callable) -> None:
116+
sushi_paths = copy_to_temp_path("examples/sushi")
117+
sushi_path = sushi_paths[0]
118+
119+
# Remove external models file and enable linter
120+
os.remove(sushi_path / "external_models.yaml")
121+
config_path = sushi_path / "config.py"
122+
with config_path.open("r") as f:
123+
content = f.read()
124+
125+
before = """ linter=LinterConfig(
126+
enabled=False,
127+
rules=[
128+
"ambiguousorinvalidcolumn",
129+
"invalidselectstarexpansion",
130+
"noselectstar",
131+
"nomissingaudits",
132+
"nomissingowner",
133+
"nomissingexternalmodels",
134+
],
135+
),"""
136+
after = """linter=LinterConfig(enabled=True, rules=["nomissingexternalmodels"]),"""
137+
content = content.replace(before, after)
138+
with config_path.open("w") as f:
139+
f.write(content)
140+
141+
context = Context(paths=[str(sushi_path)])
142+
lsp_context = LSPContext(context)
143+
144+
uri = URI.from_path(sushi_path / "models" / "customers.sql")
145+
violations = lsp_context.lint_model(uri)
146+
147+
diagnostics = []
148+
for violation in violations:
149+
if violation.violation_range:
150+
diagnostics.append(
151+
types.Diagnostic(
152+
range=types.Range(
153+
start=types.Position(
154+
line=violation.violation_range.start.line,
155+
character=violation.violation_range.start.character,
156+
),
157+
end=types.Position(
158+
line=violation.violation_range.end.line,
159+
character=violation.violation_range.end.character,
160+
),
161+
),
162+
message=violation.violation_msg,
163+
severity=types.DiagnosticSeverity.Warning,
164+
)
165+
)
166+
167+
params = types.CodeActionParams(
168+
text_document=types.TextDocumentIdentifier(uri=uri.value),
169+
range=types.Range(start=types.Position(line=0, character=0), end=types.Position(line=1, character=0)),
170+
context=types.CodeActionContext(diagnostics=diagnostics),
171+
)
172+
173+
actions = lsp_context.get_code_actions(uri, params)
174+
assert actions is not None and len(actions) > 0
175+
action = next(a for a in actions if isinstance(a, types.CodeAction))
176+
assert action.edit is not None
177+
assert action.edit.document_changes is not None
178+
create_file = [c for c in action.edit.document_changes if isinstance(c, types.CreateFile)]
179+
assert create_file, "Expected a CreateFile operation"
180+
assert create_file[0].uri == URI.from_path(sushi_path / "external_models.yaml").value

0 commit comments

Comments
 (0)