Skip to content

Commit b448d1c

Browse files
authored
feat: lint rule no missing external will create file (#5078)
1 parent 6724d96 commit b448d1c

File tree

5 files changed

+143
-9
lines changed

5 files changed

+143
-9
lines changed

sqlmesh/core/linter/rule.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from pathlib import Path
66

77
from sqlmesh.core.model import Model
@@ -49,12 +49,21 @@ class TextEdit:
4949
new_text: str
5050

5151

52+
@dataclass(frozen=True)
53+
class CreateFile:
54+
"""Create a new file with the provided text."""
55+
56+
path: Path
57+
text: str
58+
59+
5260
@dataclass(frozen=True)
5361
class Fix:
5462
"""A fix that can be applied to resolve a rule violation."""
5563

5664
title: str
57-
edits: t.List[TextEdit]
65+
edits: t.List[TextEdit] = field(default_factory=list)
66+
create_files: t.List[CreateFile] = field(default_factory=list)
5867

5968

6069
class _Rule(abc.ABCMeta):

sqlmesh/core/linter/rules/builtin.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414
get_range_of_model_block,
1515
read_range_from_string,
1616
)
17-
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit, Position
17+
from sqlmesh.core.linter.rule import (
18+
Rule,
19+
RuleViolation,
20+
Range,
21+
Fix,
22+
TextEdit,
23+
Position,
24+
CreateFile,
25+
)
1826
from sqlmesh.core.linter.definition import RuleSet
1927
from sqlmesh.core.model import Model, SqlModel, ExternalModel
2028
from sqlmesh.utils.lineage import extract_references_from_query, ExternalModelReference
@@ -227,7 +235,16 @@ def create_fix(self, model_name: str) -> t.Optional[Fix]:
227235

228236
external_models_path = root / EXTERNAL_MODELS_YAML
229237
if not external_models_path.exists():
230-
return None
238+
return Fix(
239+
title="Add external model file",
240+
edits=[],
241+
create_files=[
242+
CreateFile(
243+
path=external_models_path,
244+
text=f"- name: '{model_name}'\n",
245+
)
246+
],
247+
)
231248

232249
# Figure out the position to insert the new external model at the end of the file, whether
233250
# needs new line or not.

sqlmesh/lsp/context.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,41 @@ def get_code_actions(
273273
if found_violation is not None and found_violation.fixes:
274274
# Create code actions for each fix
275275
for fix in found_violation.fixes:
276-
# Convert our Fix to LSP TextEdits
277276
changes: t.Dict[str, t.List[types.TextEdit]] = {}
277+
document_changes: t.List[
278+
t.Union[
279+
types.TextDocumentEdit,
280+
types.CreateFile,
281+
types.RenameFile,
282+
types.DeleteFile,
283+
]
284+
] = []
285+
286+
for create in fix.create_files:
287+
create_uri = URI.from_path(create.path).value
288+
document_changes.append(types.CreateFile(uri=create_uri))
289+
document_changes.append(
290+
types.TextDocumentEdit(
291+
text_document=types.OptionalVersionedTextDocumentIdentifier(
292+
uri=create_uri,
293+
version=None,
294+
),
295+
edits=[
296+
types.TextEdit(
297+
range=types.Range(
298+
start=types.Position(line=0, character=0),
299+
end=types.Position(line=0, character=0),
300+
),
301+
new_text=create.text,
302+
)
303+
],
304+
)
305+
)
306+
278307
for edit in fix.edits:
279308
uri_key = URI.from_path(edit.path).value
280309
if uri_key not in changes:
281310
changes[uri_key] = []
282-
# Create a TextEdit for the LSP
283311
changes[uri_key].append(
284312
types.TextEdit(
285313
range=types.Range(
@@ -296,12 +324,15 @@ def get_code_actions(
296324
)
297325
)
298326

299-
# Create the code action
327+
workspace_edit = types.WorkspaceEdit(
328+
changes=changes if changes else None,
329+
document_changes=document_changes if document_changes else None,
330+
)
300331
code_action = types.CodeAction(
301332
title=fix.title,
302333
kind=types.CodeActionKind.QuickFix,
303334
diagnostics=[diagnostic],
304-
edit=types.WorkspaceEdit(changes=changes),
335+
edit=workspace_edit,
305336
)
306337
code_actions.append(code_action)
307338

tests/core/linter/test_builtin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,13 @@ def test_no_missing_external_models(tmp_path, copy_to_temp_path) -> None:
5151
lint.violation_msg
5252
== """Model '"memory"."sushi"."customers"' depends on unregistered external model '"memory"."raw"."demographics"'. Please register it in the external models file. This can be done by running 'sqlmesh create_external_models'."""
5353
)
54-
assert len(lint.fixes) == 0
54+
assert len(lint.fixes) == 1
55+
fix = lint.fixes[0]
56+
assert len(fix.edits) == 0
57+
assert len(fix.create_files) == 1
58+
create = fix.create_files[0]
59+
assert create.path == sushi_path / "external_models.yaml"
60+
assert create.text == '- name: \'"memory"."raw"."demographics"\'\n'
5561

5662

5763
def test_no_missing_external_models_with_existing_file_ending_in_newline(

tests/lsp/test_code_actions.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import typing as t
2+
import os
23
from lsprotocol import types
34
from sqlmesh.core.context import Context
45
from sqlmesh.lsp.context import LSPContext
@@ -109,3 +110,73 @@ def test_code_actions_with_linting(copy_to_temp_path: t.Callable):
109110
URI.from_path(sushi_path / "models" / "latest_order.sql").value
110111
]
111112
assert len(text_edits) > 0
113+
114+
115+
def test_code_actions_create_file(copy_to_temp_path: t.Callable) -> None:
116+
sushi_paths = copy_to_temp_path("examples/sushi")
117+
sushi_path = sushi_paths[0]
118+
119+
# Remove external models file and enable linter
120+
os.remove(sushi_path / "external_models.yaml")
121+
config_path = sushi_path / "config.py"
122+
with config_path.open("r") as f:
123+
content = f.read()
124+
125+
before = """ linter=LinterConfig(
126+
enabled=False,
127+
rules=[
128+
"ambiguousorinvalidcolumn",
129+
"invalidselectstarexpansion",
130+
"noselectstar",
131+
"nomissingaudits",
132+
"nomissingowner",
133+
"nomissingexternalmodels",
134+
],
135+
),"""
136+
after = """linter=LinterConfig(enabled=True, rules=["nomissingexternalmodels"]),"""
137+
content = content.replace(before, after)
138+
with config_path.open("w") as f:
139+
f.write(content)
140+
141+
context = Context(paths=[str(sushi_path)])
142+
lsp_context = LSPContext(context)
143+
144+
uri = URI.from_path(sushi_path / "models" / "customers.sql")
145+
violations = lsp_context.lint_model(uri)
146+
147+
diagnostics = []
148+
for violation in violations:
149+
if violation.violation_range:
150+
diagnostics.append(
151+
types.Diagnostic(
152+
range=types.Range(
153+
start=types.Position(
154+
line=violation.violation_range.start.line,
155+
character=violation.violation_range.start.character,
156+
),
157+
end=types.Position(
158+
line=violation.violation_range.end.line,
159+
character=violation.violation_range.end.character,
160+
),
161+
),
162+
message=violation.violation_msg,
163+
severity=types.DiagnosticSeverity.Warning,
164+
)
165+
)
166+
167+
params = types.CodeActionParams(
168+
text_document=types.TextDocumentIdentifier(uri=uri.value),
169+
range=types.Range(
170+
start=types.Position(line=0, character=0), end=types.Position(line=1, character=0)
171+
),
172+
context=types.CodeActionContext(diagnostics=diagnostics),
173+
)
174+
175+
actions = lsp_context.get_code_actions(uri, params)
176+
assert actions is not None and len(actions) > 0
177+
action = next(a for a in actions if isinstance(a, types.CodeAction))
178+
assert action.edit is not None
179+
assert action.edit.document_changes is not None
180+
create_file = [c for c in action.edit.document_changes if isinstance(c, types.CreateFile)]
181+
assert create_file, "Expected a CreateFile operation"
182+
assert create_file[0].uri == URI.from_path(sushi_path / "external_models.yaml").value

0 commit comments

Comments
 (0)