diff --git a/sqlmesh/core/linter/rule.py b/sqlmesh/core/linter/rule.py index ec942928e7..8dd1a2ebbd 100644 --- a/sqlmesh/core/linter/rule.py +++ b/sqlmesh/core/linter/rule.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from sqlmesh.core.model import Model @@ -49,12 +49,21 @@ class TextEdit: new_text: str +@dataclass(frozen=True) +class CreateFile: + """Create a new file with the provided text.""" + + path: Path + text: str + + @dataclass(frozen=True) class Fix: """A fix that can be applied to resolve a rule violation.""" title: str - edits: t.List[TextEdit] + edits: t.List[TextEdit] = field(default_factory=list) + create_files: t.List[CreateFile] = field(default_factory=list) class _Rule(abc.ABCMeta): diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py index 1a96a4fcec..a793f79434 100644 --- a/sqlmesh/core/linter/rules/builtin.py +++ b/sqlmesh/core/linter/rules/builtin.py @@ -14,7 +14,15 @@ get_range_of_model_block, read_range_from_string, ) -from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit, Position +from sqlmesh.core.linter.rule import ( + Rule, + RuleViolation, + Range, + Fix, + TextEdit, + Position, + CreateFile, +) from sqlmesh.core.linter.definition import RuleSet from sqlmesh.core.model import Model, SqlModel, ExternalModel from sqlmesh.utils.lineage import extract_references_from_query, ExternalModelReference @@ -227,7 +235,16 @@ def create_fix(self, model_name: str) -> t.Optional[Fix]: external_models_path = root / EXTERNAL_MODELS_YAML if not external_models_path.exists(): - return None + return Fix( + title="Add external model file", + edits=[], + create_files=[ + CreateFile( + path=external_models_path, + text=f"- name: '{model_name}'\n", + ) + ], + ) # Figure out the position to insert the new external model at the end of the file, whether # needs new line or not. diff --git a/sqlmesh/lsp/context.py b/sqlmesh/lsp/context.py index 52b33453b2..50265ec306 100644 --- a/sqlmesh/lsp/context.py +++ b/sqlmesh/lsp/context.py @@ -273,13 +273,41 @@ def get_code_actions( if found_violation is not None and found_violation.fixes: # Create code actions for each fix for fix in found_violation.fixes: - # Convert our Fix to LSP TextEdits changes: t.Dict[str, t.List[types.TextEdit]] = {} + document_changes: t.List[ + t.Union[ + types.TextDocumentEdit, + types.CreateFile, + types.RenameFile, + types.DeleteFile, + ] + ] = [] + + for create in fix.create_files: + create_uri = URI.from_path(create.path).value + document_changes.append(types.CreateFile(uri=create_uri)) + document_changes.append( + types.TextDocumentEdit( + text_document=types.OptionalVersionedTextDocumentIdentifier( + uri=create_uri, + version=None, + ), + edits=[ + types.TextEdit( + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ), + new_text=create.text, + ) + ], + ) + ) + for edit in fix.edits: uri_key = URI.from_path(edit.path).value if uri_key not in changes: changes[uri_key] = [] - # Create a TextEdit for the LSP changes[uri_key].append( types.TextEdit( range=types.Range( @@ -296,12 +324,15 @@ def get_code_actions( ) ) - # Create the code action + workspace_edit = types.WorkspaceEdit( + changes=changes if changes else None, + document_changes=document_changes if document_changes else None, + ) code_action = types.CodeAction( title=fix.title, kind=types.CodeActionKind.QuickFix, diagnostics=[diagnostic], - edit=types.WorkspaceEdit(changes=changes), + edit=workspace_edit, ) code_actions.append(code_action) diff --git a/tests/core/linter/test_builtin.py b/tests/core/linter/test_builtin.py index a5a73fcf87..1a19d036b5 100644 --- a/tests/core/linter/test_builtin.py +++ b/tests/core/linter/test_builtin.py @@ -51,7 +51,13 @@ def test_no_missing_external_models(tmp_path, copy_to_temp_path) -> None: lint.violation_msg == """Model '"memory"."sushi"."customers"' depends on unregistered external model '"memory"."raw"."demographics"'. Please register it in the external models file. This can be done by running 'sqlmesh create_external_models'.""" ) - assert len(lint.fixes) == 0 + assert len(lint.fixes) == 1 + fix = lint.fixes[0] + assert len(fix.edits) == 0 + assert len(fix.create_files) == 1 + create = fix.create_files[0] + assert create.path == sushi_path / "external_models.yaml" + assert create.text == '- name: \'"memory"."raw"."demographics"\'\n' def test_no_missing_external_models_with_existing_file_ending_in_newline( diff --git a/tests/lsp/test_code_actions.py b/tests/lsp/test_code_actions.py index b2f30feb47..509f49f9b1 100644 --- a/tests/lsp/test_code_actions.py +++ b/tests/lsp/test_code_actions.py @@ -1,4 +1,5 @@ import typing as t +import os from lsprotocol import types from sqlmesh.core.context import Context from sqlmesh.lsp.context import LSPContext @@ -109,3 +110,73 @@ def test_code_actions_with_linting(copy_to_temp_path: t.Callable): URI.from_path(sushi_path / "models" / "latest_order.sql").value ] assert len(text_edits) > 0 + + +def test_code_actions_create_file(copy_to_temp_path: t.Callable) -> None: + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Remove external models file and enable linter + os.remove(sushi_path / "external_models.yaml") + config_path = sushi_path / "config.py" + with config_path.open("r") as f: + content = f.read() + + before = """ linter=LinterConfig( + enabled=False, + rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ], + ),""" + after = """linter=LinterConfig(enabled=True, rules=["nomissingexternalmodels"]),""" + content = content.replace(before, after) + with config_path.open("w") as f: + f.write(content) + + context = Context(paths=[str(sushi_path)]) + lsp_context = LSPContext(context) + + uri = URI.from_path(sushi_path / "models" / "customers.sql") + violations = lsp_context.lint_model(uri) + + diagnostics = [] + for violation in violations: + if violation.violation_range: + diagnostics.append( + types.Diagnostic( + range=types.Range( + start=types.Position( + line=violation.violation_range.start.line, + character=violation.violation_range.start.character, + ), + end=types.Position( + line=violation.violation_range.end.line, + character=violation.violation_range.end.character, + ), + ), + message=violation.violation_msg, + severity=types.DiagnosticSeverity.Warning, + ) + ) + + params = types.CodeActionParams( + text_document=types.TextDocumentIdentifier(uri=uri.value), + range=types.Range( + start=types.Position(line=0, character=0), end=types.Position(line=1, character=0) + ), + context=types.CodeActionContext(diagnostics=diagnostics), + ) + + actions = lsp_context.get_code_actions(uri, params) + assert actions is not None and len(actions) > 0 + action = next(a for a in actions if isinstance(a, types.CodeAction)) + assert action.edit is not None + assert action.edit.document_changes is not None + create_file = [c for c in action.edit.document_changes if isinstance(c, types.CreateFile)] + assert create_file, "Expected a CreateFile operation" + assert create_file[0].uri == URI.from_path(sushi_path / "external_models.yaml").value