Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions sqlmesh/core/linter/rule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import abc
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path

from sqlmesh.core.model import Model
Expand Down Expand Up @@ -49,12 +49,21 @@ class TextEdit:
new_text: str


@dataclass(frozen=True)
class CreateFile:
"""Create a new file with the provided text."""

path: Path
text: str


@dataclass(frozen=True)
class Fix:
"""A fix that can be applied to resolve a rule violation."""

title: str
edits: t.List[TextEdit]
edits: t.List[TextEdit] = field(default_factory=list)
create_files: t.List[CreateFile] = field(default_factory=list)


class _Rule(abc.ABCMeta):
Expand Down
21 changes: 19 additions & 2 deletions sqlmesh/core/linter/rules/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
get_range_of_model_block,
read_range_from_string,
)
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit, Position
from sqlmesh.core.linter.rule import (
Rule,
RuleViolation,
Range,
Fix,
TextEdit,
Position,
CreateFile,
)
from sqlmesh.core.linter.definition import RuleSet
from sqlmesh.core.model import Model, SqlModel, ExternalModel
from sqlmesh.utils.lineage import extract_references_from_query, ExternalModelReference
Expand Down Expand Up @@ -227,7 +235,16 @@ def create_fix(self, model_name: str) -> t.Optional[Fix]:

external_models_path = root / EXTERNAL_MODELS_YAML
if not external_models_path.exists():
return None
return Fix(
title="Add external model file",
edits=[],
create_files=[
CreateFile(
path=external_models_path,
text=f"- name: '{model_name}'\n",
)
],
)

# Figure out the position to insert the new external model at the end of the file, whether
# needs new line or not.
Expand Down
39 changes: 35 additions & 4 deletions sqlmesh/lsp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,41 @@ def get_code_actions(
if found_violation is not None and found_violation.fixes:
# Create code actions for each fix
for fix in found_violation.fixes:
# Convert our Fix to LSP TextEdits
changes: t.Dict[str, t.List[types.TextEdit]] = {}
document_changes: t.List[
t.Union[
types.TextDocumentEdit,
types.CreateFile,
types.RenameFile,
types.DeleteFile,
]
] = []

for create in fix.create_files:
create_uri = URI.from_path(create.path).value
document_changes.append(types.CreateFile(uri=create_uri))
document_changes.append(
types.TextDocumentEdit(
text_document=types.OptionalVersionedTextDocumentIdentifier(
uri=create_uri,
version=None,
),
edits=[
types.TextEdit(
range=types.Range(
start=types.Position(line=0, character=0),
end=types.Position(line=0, character=0),
),
new_text=create.text,
)
],
)
)

for edit in fix.edits:
uri_key = URI.from_path(edit.path).value
if uri_key not in changes:
changes[uri_key] = []
# Create a TextEdit for the LSP
changes[uri_key].append(
types.TextEdit(
range=types.Range(
Expand All @@ -296,12 +324,15 @@ def get_code_actions(
)
)

# Create the code action
workspace_edit = types.WorkspaceEdit(
changes=changes if changes else None,
document_changes=document_changes if document_changes else None,
)
code_action = types.CodeAction(
title=fix.title,
kind=types.CodeActionKind.QuickFix,
diagnostics=[diagnostic],
edit=types.WorkspaceEdit(changes=changes),
edit=workspace_edit,
)
code_actions.append(code_action)

Expand Down
8 changes: 7 additions & 1 deletion tests/core/linter/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ def test_no_missing_external_models(tmp_path, copy_to_temp_path) -> None:
lint.violation_msg
== """Model '"memory"."sushi"."customers"' depends on unregistered external model '"memory"."raw"."demographics"'. Please register it in the external models file. This can be done by running 'sqlmesh create_external_models'."""
)
assert len(lint.fixes) == 0
assert len(lint.fixes) == 1
fix = lint.fixes[0]
assert len(fix.edits) == 0
assert len(fix.create_files) == 1
create = fix.create_files[0]
assert create.path == sushi_path / "external_models.yaml"
assert create.text == '- name: \'"memory"."raw"."demographics"\'\n'


def test_no_missing_external_models_with_existing_file_ending_in_newline(
Expand Down
71 changes: 71 additions & 0 deletions tests/lsp/test_code_actions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing as t
import os
from lsprotocol import types
from sqlmesh.core.context import Context
from sqlmesh.lsp.context import LSPContext
Expand Down Expand Up @@ -109,3 +110,73 @@ def test_code_actions_with_linting(copy_to_temp_path: t.Callable):
URI.from_path(sushi_path / "models" / "latest_order.sql").value
]
assert len(text_edits) > 0


def test_code_actions_create_file(copy_to_temp_path: t.Callable) -> None:
sushi_paths = copy_to_temp_path("examples/sushi")
sushi_path = sushi_paths[0]

# Remove external models file and enable linter
os.remove(sushi_path / "external_models.yaml")
config_path = sushi_path / "config.py"
with config_path.open("r") as f:
content = f.read()

before = """ linter=LinterConfig(
enabled=False,
rules=[
"ambiguousorinvalidcolumn",
"invalidselectstarexpansion",
"noselectstar",
"nomissingaudits",
"nomissingowner",
"nomissingexternalmodels",
],
),"""
after = """linter=LinterConfig(enabled=True, rules=["nomissingexternalmodels"]),"""
content = content.replace(before, after)
with config_path.open("w") as f:
f.write(content)

context = Context(paths=[str(sushi_path)])
lsp_context = LSPContext(context)

uri = URI.from_path(sushi_path / "models" / "customers.sql")
violations = lsp_context.lint_model(uri)

diagnostics = []
for violation in violations:
if violation.violation_range:
diagnostics.append(
types.Diagnostic(
range=types.Range(
start=types.Position(
line=violation.violation_range.start.line,
character=violation.violation_range.start.character,
),
end=types.Position(
line=violation.violation_range.end.line,
character=violation.violation_range.end.character,
),
),
message=violation.violation_msg,
severity=types.DiagnosticSeverity.Warning,
)
)

params = types.CodeActionParams(
text_document=types.TextDocumentIdentifier(uri=uri.value),
range=types.Range(
start=types.Position(line=0, character=0), end=types.Position(line=1, character=0)
),
context=types.CodeActionContext(diagnostics=diagnostics),
)

actions = lsp_context.get_code_actions(uri, params)
assert actions is not None and len(actions) > 0
action = next(a for a in actions if isinstance(a, types.CodeAction))
assert action.edit is not None
assert action.edit.document_changes is not None
create_file = [c for c in action.edit.document_changes if isinstance(c, types.CreateFile)]
assert create_file, "Expected a CreateFile operation"
assert create_file[0].uri == URI.from_path(sushi_path / "external_models.yaml").value