From 66dfd2943b02b334b62b19a77bdbf75525b83707 Mon Sep 17 00:00:00 2001 From: Ben <9087625+benfdking@users.noreply.github.com> Date: Thu, 31 Jul 2025 11:01:37 +0100 Subject: [PATCH 1/2] feat(linter): allow fixes to create files --- sqlmesh/core/linter/rule.py | 11 ++++- sqlmesh/core/linter/rules/builtin.py | 21 ++++++++- sqlmesh/lsp/context.py | 37 +++++++++++++-- tests/core/linter/test_builtin.py | 8 +++- tests/lsp/test_code_actions.py | 69 ++++++++++++++++++++++++++++ 5 files changed, 138 insertions(+), 8 deletions(-) diff --git a/sqlmesh/core/linter/rule.py b/sqlmesh/core/linter/rule.py index ec942928e7..452ee97347 100644 --- a/sqlmesh/core/linter/rule.py +++ b/sqlmesh/core/linter/rule.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from sqlmesh.core.model import Model @@ -49,12 +49,21 @@ class TextEdit: new_text: str +@dataclass(frozen=True) +class CreateFile: + """Create a new file with the provided text.""" + + path: Path + text: str + + @dataclass(frozen=True) class Fix: """A fix that can be applied to resolve a rule violation.""" title: str edits: t.List[TextEdit] + create_files: t.List[CreateFile] = field(default_factory=list) class _Rule(abc.ABCMeta): diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py index 1a96a4fcec..a793f79434 100644 --- a/sqlmesh/core/linter/rules/builtin.py +++ b/sqlmesh/core/linter/rules/builtin.py @@ -14,7 +14,15 @@ get_range_of_model_block, read_range_from_string, ) -from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit, Position +from sqlmesh.core.linter.rule import ( + Rule, + RuleViolation, + Range, + Fix, + TextEdit, + Position, + CreateFile, +) from sqlmesh.core.linter.definition import RuleSet from sqlmesh.core.model import Model, SqlModel, ExternalModel from sqlmesh.utils.lineage import extract_references_from_query, ExternalModelReference @@ -227,7 +235,16 @@ def create_fix(self, model_name: str) -> t.Optional[Fix]: external_models_path = root / EXTERNAL_MODELS_YAML if not external_models_path.exists(): - return None + return Fix( + title="Add external model file", + edits=[], + create_files=[ + CreateFile( + path=external_models_path, + text=f"- name: '{model_name}'\n", + ) + ], + ) # Figure out the position to insert the new external model at the end of the file, whether # needs new line or not. diff --git a/sqlmesh/lsp/context.py b/sqlmesh/lsp/context.py index 52b33453b2..5b72a3e871 100644 --- a/sqlmesh/lsp/context.py +++ b/sqlmesh/lsp/context.py @@ -273,13 +273,39 @@ def get_code_actions( if found_violation is not None and found_violation.fixes: # Create code actions for each fix for fix in found_violation.fixes: - # Convert our Fix to LSP TextEdits changes: t.Dict[str, t.List[types.TextEdit]] = {} + document_changes: t.List[ + t.Union[ + types.TextDocumentEdit, + types.CreateFile, + ] + ] = [] + + for create in getattr(fix, "create_files", []): + create_uri = URI.from_path(create.path).value + document_changes.append(types.CreateFile(uri=create_uri)) + document_changes.append( + types.TextDocumentEdit( + text_document=types.OptionalVersionedTextDocumentIdentifier( + uri=create_uri, + version=None, + ), + edits=[ + types.TextEdit( + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ), + new_text=create.text, + ) + ], + ) + ) + for edit in fix.edits: uri_key = URI.from_path(edit.path).value if uri_key not in changes: changes[uri_key] = [] - # Create a TextEdit for the LSP changes[uri_key].append( types.TextEdit( range=types.Range( @@ -296,12 +322,15 @@ def get_code_actions( ) ) - # Create the code action + workspace_edit = types.WorkspaceEdit( + changes=changes if changes else None, + document_changes=document_changes if document_changes else None, + ) code_action = types.CodeAction( title=fix.title, kind=types.CodeActionKind.QuickFix, diagnostics=[diagnostic], - edit=types.WorkspaceEdit(changes=changes), + edit=workspace_edit, ) code_actions.append(code_action) diff --git a/tests/core/linter/test_builtin.py b/tests/core/linter/test_builtin.py index a5a73fcf87..1b9562c935 100644 --- a/tests/core/linter/test_builtin.py +++ b/tests/core/linter/test_builtin.py @@ -51,7 +51,13 @@ def test_no_missing_external_models(tmp_path, copy_to_temp_path) -> None: lint.violation_msg == """Model '"memory"."sushi"."customers"' depends on unregistered external model '"memory"."raw"."demographics"'. Please register it in the external models file. This can be done by running 'sqlmesh create_external_models'.""" ) - assert len(lint.fixes) == 0 + assert len(lint.fixes) == 1 + fix = lint.fixes[0] + assert len(fix.edits) == 0 + assert len(fix.create_files) == 1 + create = fix.create_files[0] + assert create.path == sushi_path / "external_models.yaml" + assert create.text == "- name: '\"memory\".\"raw\".\"demographics\"'\n" def test_no_missing_external_models_with_existing_file_ending_in_newline( diff --git a/tests/lsp/test_code_actions.py b/tests/lsp/test_code_actions.py index b2f30feb47..645d9e3a5e 100644 --- a/tests/lsp/test_code_actions.py +++ b/tests/lsp/test_code_actions.py @@ -1,4 +1,5 @@ import typing as t +import os from lsprotocol import types from sqlmesh.core.context import Context from sqlmesh.lsp.context import LSPContext @@ -109,3 +110,71 @@ def test_code_actions_with_linting(copy_to_temp_path: t.Callable): URI.from_path(sushi_path / "models" / "latest_order.sql").value ] assert len(text_edits) > 0 + + +def test_code_actions_create_file(copy_to_temp_path: t.Callable) -> None: + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Remove external models file and enable linter + os.remove(sushi_path / "external_models.yaml") + config_path = sushi_path / "config.py" + with config_path.open("r") as f: + content = f.read() + + before = """ linter=LinterConfig( + enabled=False, + rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ], + ),""" + after = """linter=LinterConfig(enabled=True, rules=["nomissingexternalmodels"]),""" + content = content.replace(before, after) + with config_path.open("w") as f: + f.write(content) + + context = Context(paths=[str(sushi_path)]) + lsp_context = LSPContext(context) + + uri = URI.from_path(sushi_path / "models" / "customers.sql") + violations = lsp_context.lint_model(uri) + + diagnostics = [] + for violation in violations: + if violation.violation_range: + diagnostics.append( + types.Diagnostic( + range=types.Range( + start=types.Position( + line=violation.violation_range.start.line, + character=violation.violation_range.start.character, + ), + end=types.Position( + line=violation.violation_range.end.line, + character=violation.violation_range.end.character, + ), + ), + message=violation.violation_msg, + severity=types.DiagnosticSeverity.Warning, + ) + ) + + params = types.CodeActionParams( + text_document=types.TextDocumentIdentifier(uri=uri.value), + range=types.Range(start=types.Position(line=0, character=0), end=types.Position(line=1, character=0)), + context=types.CodeActionContext(diagnostics=diagnostics), + ) + + actions = lsp_context.get_code_actions(uri, params) + assert actions is not None and len(actions) > 0 + action = next(a for a in actions if isinstance(a, types.CodeAction)) + assert action.edit is not None + assert action.edit.document_changes is not None + create_file = [c for c in action.edit.document_changes if isinstance(c, types.CreateFile)] + assert create_file, "Expected a CreateFile operation" + assert create_file[0].uri == URI.from_path(sushi_path / "external_models.yaml").value From d23c05789e95f3fb60de8bd8790913e33642e315 Mon Sep 17 00:00:00 2001 From: Ben King <9087625+benfdking@users.noreply.github.com> Date: Thu, 31 Jul 2025 11:22:58 +0100 Subject: [PATCH 2/2] feat(lsp): no missing external will create file - before the fix would only apply if there was a file already there - now it also has the quick fix if the external models file is not present --- sqlmesh/core/linter/rule.py | 2 +- sqlmesh/lsp/context.py | 4 +++- tests/core/linter/test_builtin.py | 2 +- tests/lsp/test_code_actions.py | 4 +++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sqlmesh/core/linter/rule.py b/sqlmesh/core/linter/rule.py index 452ee97347..8dd1a2ebbd 100644 --- a/sqlmesh/core/linter/rule.py +++ b/sqlmesh/core/linter/rule.py @@ -62,7 +62,7 @@ class Fix: """A fix that can be applied to resolve a rule violation.""" title: str - edits: t.List[TextEdit] + edits: t.List[TextEdit] = field(default_factory=list) create_files: t.List[CreateFile] = field(default_factory=list) diff --git a/sqlmesh/lsp/context.py b/sqlmesh/lsp/context.py index 5b72a3e871..50265ec306 100644 --- a/sqlmesh/lsp/context.py +++ b/sqlmesh/lsp/context.py @@ -278,10 +278,12 @@ def get_code_actions( t.Union[ types.TextDocumentEdit, types.CreateFile, + types.RenameFile, + types.DeleteFile, ] ] = [] - for create in getattr(fix, "create_files", []): + for create in fix.create_files: create_uri = URI.from_path(create.path).value document_changes.append(types.CreateFile(uri=create_uri)) document_changes.append( diff --git a/tests/core/linter/test_builtin.py b/tests/core/linter/test_builtin.py index 1b9562c935..1a19d036b5 100644 --- a/tests/core/linter/test_builtin.py +++ b/tests/core/linter/test_builtin.py @@ -57,7 +57,7 @@ def test_no_missing_external_models(tmp_path, copy_to_temp_path) -> None: assert len(fix.create_files) == 1 create = fix.create_files[0] assert create.path == sushi_path / "external_models.yaml" - assert create.text == "- name: '\"memory\".\"raw\".\"demographics\"'\n" + assert create.text == '- name: \'"memory"."raw"."demographics"\'\n' def test_no_missing_external_models_with_existing_file_ending_in_newline( diff --git a/tests/lsp/test_code_actions.py b/tests/lsp/test_code_actions.py index 645d9e3a5e..509f49f9b1 100644 --- a/tests/lsp/test_code_actions.py +++ b/tests/lsp/test_code_actions.py @@ -166,7 +166,9 @@ def test_code_actions_create_file(copy_to_temp_path: t.Callable) -> None: params = types.CodeActionParams( text_document=types.TextDocumentIdentifier(uri=uri.value), - range=types.Range(start=types.Position(line=0, character=0), end=types.Position(line=1, character=0)), + range=types.Range( + start=types.Position(line=0, character=0), end=types.Position(line=1, character=0) + ), context=types.CodeActionContext(diagnostics=diagnostics), )