Skip to content

Commit a0348cd

Browse files
committed
feat(vscode): adding ability to update columns
[ci skip]
1 parent 95b1f6e commit a0348cd

File tree

5 files changed

+211
-8
lines changed

5 files changed

+211
-8
lines changed

requirements.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
Ability to click update columns on a model in external models
2+
3+
# Requirements
4+
5+
* [ ] For each model in external models file
6+
* [ ] Below each model in the yaml file, provide an LSP code action to update columns
7+
* [ ] The code action should be able to update the columns in the model
8+
9+
# Useful files
10+
11+
sqlmesh/lsp/main.py is the main file for the LSP server.
12+
sqlmesh/cli/main.py and create_external_model provides similar functionality for creating external models but for the CLI and only does the whole shebang.
13+
14+
# Implementation Plan
15+
16+
1. **Create a new code action handler for external models YAML files**
17+
- Add support for YAML files in the LSP context's `get_code_actions` method
18+
- Detect when cursor is on or near a model definition in external_models.yaml
19+
20+
2. **Add YAML file parsing to identify model positions**
21+
- Parse the YAML file to identify model definitions and their line positions
22+
- Track the name field and columns section for each model
23+
24+
3. **Implement the "Update Columns" code action**
25+
- When triggered, the action will:
26+
- Extract the model name from the YAML
27+
- Use the engine adapter to fetch current columns from the database
28+
- Generate the updated columns section in YAML format
29+
- Replace the existing columns section with the new one
30+
31+
4. **Integration points**
32+
- Extend `LSPContext.get_code_actions()` to handle YAML files
33+
- Add YAML parsing logic to identify model boundaries
34+
- Reuse existing column fetching logic from `schema_loader.py`
35+
- Use the adapter's `columns()` method to get schema information
36+
37+
5. **Code action workflow**
38+
- User opens external_models.yaml in editor
39+
- LSP server identifies model definitions in the file
40+
- When cursor is on a model, show "Update columns" code action
41+
- On activation, fetch latest schema and update the YAML
42+

sqlmesh/lsp/commands.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from sqlmesh.utils.pydantic import PydanticModel
2+
3+
EXTERNAL_MODEL_UPDATE_COLUMNS = 'sqlmesh.external_model_update_columns'
4+
5+
class ExternalModelUpdateColumnsRequest(PydanticModel):
6+
model_name: str

sqlmesh/lsp/context.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from sqlmesh.core.context import Context
44
import typing as t
55

6-
from sqlmesh.core.model.definition import SqlModel
6+
from sqlmesh.core.model.definition import SqlModel, ExternalModel
77
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
8+
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
89
from sqlmesh.lsp.custom import ModelForRendering
910
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
1011
from sqlmesh.lsp.uri import URI
1112
from lsprotocol import types
13+
from pathlib import Path
1214

1315

1416
@dataclass
@@ -228,6 +230,40 @@ def get_code_actions(
228230
edit=types.WorkspaceEdit(changes={params.text_document.uri: text_edits}),
229231
)
230232
code_actions.append(code_action)
233+
return code_actions
234+
235+
def get_code_lenses(self, uri: URI ) -> t.Optional[t.List[types.CodeLens]]:
236+
models_in_file = self.map.get(uri.to_path())
237+
if isinstance(models_in_file, ModelTarget):
238+
models = [self.context.get_model(model) for model in models_in_file.names]
239+
if any(isinstance(model, ExternalModel) for model in models):
240+
code_lenses = self._get_external_model_code_lenses(uri)
241+
if code_lenses:
242+
return code_lenses
243+
244+
return None
245+
246+
def _get_external_model_code_lenses(self, uri: URI) -> t.List[types.CodeLens]:
247+
from sqlmesh.lsp.reference import get_yaml_model_name_ranges
248+
249+
"""Get code actions for external models YAML files."""
250+
ranges = get_yaml_model_name_ranges(uri.to_path())
251+
252+
code_actions: t.List[types.CodeLens] = []
253+
for name, range in ranges.items():
254+
# Create a code action to update columns for external models
255+
command = types.Command(
256+
title="Update Columns",
257+
command=EXTERNAL_MODEL_UPDATE_COLUMNS,
258+
arguments=[
259+
name,
260+
],
261+
)
262+
code_action = types.CodeLens(
263+
range=range,
264+
command=command,
265+
)
266+
code_actions.append(code_action)
231267

232268
return code_actions if code_actions else None
233269

sqlmesh/lsp/main.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ApiResponseGetLineage,
2424
ApiResponseGetModels,
2525
)
26+
from sqlmesh.lsp.commands import ExternalModelUpdateColumnsRequest, EXTERNAL_MODEL_UPDATE_COLUMNS
2627
from sqlmesh.lsp.completions import get_sql_completions
2728
from sqlmesh.lsp.context import (
2829
LSPContext,
@@ -63,6 +64,7 @@
6364
from web.server.api.endpoints.models import get_models
6465
from typing import Union
6566
from dataclasses import dataclass, field
67+
from sqlmesh.utils import yaml
6668

6769

6870
@dataclass
@@ -313,6 +315,25 @@ def function_call(ls: LanguageServer, params: t.Any) -> t.Dict[str, t.Any]:
313315

314316
self.server.feature(name)(create_function_call(method))
315317

318+
@self.server.command(EXTERNAL_MODEL_UPDATE_COLUMNS)
319+
def execute_command(ls: LanguageServer, raw: t.Any) -> None:
320+
try:
321+
if not isinstance(raw, list) or not len(list) == 1 or not isinstance(raw[0], str):
322+
raise ValueError("Invalid command parameters")
323+
324+
request = ExternalModelUpdateColumnsRequest.model_validate(raw[0])
325+
context = self._context_get_or_load()
326+
if not isinstance(context, LSPContext):
327+
raise ValueError("Context is not loaded or invalid")
328+
context.context.update_external_model_columns(
329+
model_name=request.model_name,
330+
)
331+
332+
ls.show_message(f"Executing command to update external model columns {raw}", types.MessageType.Info)
333+
except Exception as e:
334+
ls.show_message(f"Error executing command: {e}", types.MessageType.Error)
335+
return None
336+
316337
@self.server.feature(types.INITIALIZE)
317338
def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
318339
"""Initialize the server when the client connects."""
@@ -688,6 +709,17 @@ def code_action(
688709
ls.log_trace(f"Error getting code actions: {e}")
689710
return None
690711

712+
@self.server.feature(types.TEXT_DOCUMENT_CODE_LENS)
713+
def code_lens(ls: LanguageServer, params: types.CodeLensParams) -> t.List[types.CodeLens]:
714+
try:
715+
uri = URI(params.text_document.uri)
716+
context = self._context_get_or_load(uri)
717+
code_lenses = context.get_code_lenses(uri)
718+
return code_lenses
719+
except Exception as e:
720+
ls.log_trace(f"Error getting code lenses: {e}")
721+
return []
722+
691723
@self.server.feature(
692724
types.TEXT_DOCUMENT_COMPLETION,
693725
types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros
@@ -902,6 +934,68 @@ def _uri_to_path(uri: str) -> Path:
902934
"""Convert a URI to a path."""
903935
return URI(uri).to_path()
904936

937+
def _update_external_model_columns(self, ls: LanguageServer, arguments: t.List[t.Any]) -> t.Any:
938+
"""Update the columns for an external model in the YAML file."""
939+
if len(arguments) != 3:
940+
ls.show_message("Invalid arguments for update columns command", types.MessageType.Error)
941+
return None
942+
943+
uri_str, model_name, model_idx = arguments
944+
uri = URI(uri_str)
945+
path = uri.to_path()
946+
947+
try:
948+
# Get the context
949+
context = self._context_get_or_load(uri)
950+
951+
# Read the YAML file
952+
models = yaml.load(path)
953+
if not isinstance(models, list) or model_idx >= len(models):
954+
ls.show_message("Invalid model index", types.MessageType.Error)
955+
return None
956+
957+
model = models[model_idx]
958+
if model.get("name") != model_name:
959+
ls.show_message("Model name mismatch", types.MessageType.Error)
960+
return None
961+
962+
# Get the adapter and fetch columns
963+
adapter = context.context.engine_adapter
964+
965+
# Get columns for the model
966+
try:
967+
columns = adapter.columns(model_name, include_pseudo_columns=True)
968+
except Exception as e:
969+
ls.show_message(
970+
f"Unable to fetch columns for {model_name}: {e}", types.MessageType.Error
971+
)
972+
return None
973+
974+
if not columns:
975+
ls.show_message(f"No columns found for {model_name}", types.MessageType.Warning)
976+
return None
977+
978+
# Update the model's columns
979+
dialect = context.context.config.model_defaults.dialect
980+
model["columns"] = {
981+
col_name: dtype.sql(dialect=dialect) for col_name, dtype in columns.items()
982+
}
983+
984+
# Write back to the file
985+
with open(path, "w", encoding="utf-8") as f:
986+
yaml.dump(models, f)
987+
988+
ls.show_message(f"Updated columns for {model_name}", types.MessageType.Info)
989+
990+
# Reload the context to pick up the changes
991+
self._reload_context_and_publish_diagnostics(ls, uri, uri_str)
992+
993+
except Exception as e:
994+
ls.show_message(f"Error updating columns: {e}", types.MessageType.Error)
995+
ls.log_trace(f"Error updating columns: {e}")
996+
997+
return None
998+
905999
def start(self) -> None:
9061000
"""Start the server with I/O transport."""
9071001
logging.basicConfig(level=logging.DEBUG)

sqlmesh/lsp/reference.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,34 @@ def _process_column_references(
877877
return references
878878

879879

880+
def get_yaml_model_name_ranges(path: Path) -> t.Dict[str, Range]:
881+
"""
882+
Get a mapping of model names to their ranges in an external models YAML file.
883+
884+
Args:
885+
path: The path to the YAML file.
886+
Returns:
887+
A dictionary mapping model names to their ranges in the YAML file.
888+
"""
889+
yaml = YAML()
890+
model_ranges: t.Dict[str, Range] = {}
891+
with path.open("r", encoding="utf-8") as f:
892+
data = yaml.load(f)
893+
894+
if not isinstance(data, list):
895+
return model_ranges
896+
897+
for item in data:
898+
if isinstance(item, dict) and "name" in item:
899+
# Get size of block by taking the earliest line/col in the items block and the last line/col of the block
900+
position_data = item.lc.data["name"] # type: ignore
901+
start = Position(line=position_data[2], character=position_data[3])
902+
end = Position(line=position_data[2], character=position_data[3] + len(item["name"]))
903+
model_ranges[item["name"]] = Range(start=start, end=end)
904+
905+
return model_ranges
906+
907+
880908
def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
881909
"""
882910
Find the range of a specific model block in a YAML file.
@@ -895,11 +923,8 @@ def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
895923
if not isinstance(data, list):
896924
return None
897925

898-
for item in data:
899-
if isinstance(item, dict) and item.get("name") == model_name:
900-
# Get size of block by taking the earliest line/col in the items block and the last line/col of the block
901-
position_data = item.lc.data["name"] # type: ignore
902-
start = Position(line=position_data[2], character=position_data[3])
903-
end = Position(line=position_data[2], character=position_data[3] + len(item["name"]))
904-
return Range(start=start, end=end)
926+
# Get all model ranges in the YAML file
927+
model_ranges = get_yaml_model_name_ranges(path)
928+
if model_name in model_ranges:
929+
return model_ranges[model_name]
905930
return None

0 commit comments

Comments
 (0)