Skip to content

Commit 42a1fa5

Browse files
committed
feat(vscode): adding ability to update columns
1 parent 95b1f6e commit 42a1fa5

File tree

8 files changed

+375
-22
lines changed

8 files changed

+375
-22
lines changed

pnpm-lock.yaml

Lines changed: 61 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlmesh/core/schema_loader.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,28 +57,17 @@ def create_external_models_file(
5757
external_model_fqns -= existing_model_fqns
5858

5959
with ThreadPoolExecutor(max_workers=max_workers) as pool:
60-
61-
def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
62-
try:
63-
return adapter.columns(table, include_pseudo_columns=True)
64-
except Exception as e:
65-
msg = f"Unable to get schema for '{table}': '{e}'."
66-
if strict:
67-
raise SQLMeshError(msg) from e
68-
get_console().log_warning(msg)
69-
return None
70-
7160
gateway_part = {"gateway": gateway} if gateway else {}
7261

7362
schemas = [
7463
{
7564
"name": exp.to_table(table).sql(dialect=dialect),
76-
"columns": {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()},
65+
"columns": columns,
7766
**gateway_part,
7867
}
7968
for table, columns in sorted(
8069
pool.map(
81-
lambda table: (table, _get_columns(table)),
70+
lambda table: (table, get_columns(adapter, dialect, table, strict)),
8271
external_model_fqns,
8372
)
8473
)
@@ -94,3 +83,20 @@ def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
9483

9584
with open(path, "w", encoding="utf-8") as file:
9685
yaml.dump(entries_to_keep + schemas, file)
86+
87+
88+
def get_columns(
89+
adapter: EngineAdapter, dialect: DialectType, table: str, strict: bool
90+
) -> t.Optional[t.Dict[str, t.Any]]:
91+
"""
92+
Return the column and their types in a dictionary
93+
"""
94+
try:
95+
columns = adapter.columns(table, include_pseudo_columns=True)
96+
return {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()}
97+
except Exception as e:
98+
msg = f"Unable to get schema for '{table}': '{e}'."
99+
if strict:
100+
raise SQLMeshError(msg) from e
101+
get_console().log_warning(msg)
102+
return None

sqlmesh/lsp/commands.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from sqlmesh.utils.pydantic import PydanticModel
2+
3+
EXTERNAL_MODEL_UPDATE_COLUMNS = "sqlmesh.external_model_update_columns"
4+
5+
6+
class ExternalModelUpdateColumnsRequest(PydanticModel):
7+
model_name: str

sqlmesh/lsp/context.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
from dataclasses import dataclass
22
from pathlib import Path
3+
4+
from pygls.server import LanguageServer
5+
36
from sqlmesh.core.context import Context
47
import typing as t
58

6-
from sqlmesh.core.model.definition import SqlModel
9+
from sqlmesh.core.model.definition import SqlModel, ExternalModel
710
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
11+
from sqlmesh.core.schema_loader import get_columns
12+
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
813
from sqlmesh.lsp.custom import ModelForRendering
914
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
1015
from sqlmesh.lsp.uri import URI
1116
from lsprotocol import types
17+
from pathlib import Path
18+
19+
from sqlmesh.utils import yaml
1220

1321

1422
@dataclass
@@ -228,8 +236,42 @@ def get_code_actions(
228236
edit=types.WorkspaceEdit(changes={params.text_document.uri: text_edits}),
229237
)
230238
code_actions.append(code_action)
239+
return code_actions
240+
241+
def get_code_lenses(self, uri: URI) -> t.Optional[t.List[types.CodeLens]]:
242+
models_in_file = self.map.get(uri.to_path())
243+
if isinstance(models_in_file, ModelTarget):
244+
models = [self.context.get_model(model) for model in models_in_file.names]
245+
if any(isinstance(model, ExternalModel) for model in models):
246+
code_lenses = self._get_external_model_code_lenses(uri)
247+
if code_lenses:
248+
return code_lenses
249+
250+
return None
251+
252+
def _get_external_model_code_lenses(self, uri: URI) -> t.List[types.CodeLens]:
253+
from sqlmesh.lsp.reference import get_yaml_model_name_ranges
254+
255+
"""Get code actions for external models YAML files."""
256+
ranges = get_yaml_model_name_ranges(uri.to_path())
257+
258+
code_actions: t.List[types.CodeLens] = []
259+
for name, range in ranges.items():
260+
# Create a code action to update columns for external models
261+
command = types.Command(
262+
title="Update Columns",
263+
command=EXTERNAL_MODEL_UPDATE_COLUMNS,
264+
arguments=[
265+
name,
266+
],
267+
)
268+
code_action = types.CodeLens(
269+
range=range,
270+
command=command,
271+
)
272+
code_actions.append(code_action)
231273

232-
return code_actions if code_actions else None
274+
return code_actions if code_actions else []
233275

234276
def list_of_models_for_rendering(self) -> t.List[ModelForRendering]:
235277
"""Get a list of models for rendering.
@@ -332,3 +374,73 @@ def diagnostic_to_lsp_diagnostic(
332374
code=diagnostic.rule.name,
333375
code_description=types.CodeDescription(href=rule_uri),
334376
)
377+
378+
def update_external_model_columns(self, ls: LanguageServer, uri: URI, model_name: str) -> bool:
379+
"""
380+
Update the columns for an external model in the YAML file. Returns True if changed, False if didn't because
381+
of the columns already being up to date.
382+
383+
Errors still throw exceptions to be handled by the caller.
384+
"""
385+
models = yaml.load(uri.to_path())
386+
if not isinstance(models, list):
387+
raise ValueError(
388+
f"Expected a list of models in {uri.to_path()}, but got {type(models).__name__}"
389+
)
390+
391+
existing_model = next((model for model in models if model.get("name") == model_name), None)
392+
if existing_model is None:
393+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
394+
395+
existing_model_columns = existing_model.get("columns")
396+
397+
# Get the adapter and fetch columns
398+
adapter = self.context.engine_adapter
399+
# Get columns for the model
400+
new_columns = get_columns(
401+
adapter=adapter,
402+
dialect=self.context.config.model_defaults.dialect,
403+
table=model_name,
404+
strict=True,
405+
)
406+
# Compare existing columns and matching types and if they are the same, do not update
407+
if existing_model_columns is not None:
408+
if existing_model_columns == new_columns:
409+
ls.show_message("Columns already up to date")
410+
return False
411+
412+
# Model index to update
413+
model_index = next(
414+
(i for i, model in enumerate(models) if model.get("name") == model_name), None
415+
)
416+
if model_index is None:
417+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
418+
419+
# Get end of the file to set the edit range
420+
with open(uri.to_path(), "r", encoding="utf-8") as file:
421+
read_file = file.read()
422+
423+
end_line = read_file.count("\n")
424+
end_character = len(read_file.splitlines()[-1]) if end_line > 0 else 0
425+
426+
models[model_index]["columns"] = new_columns
427+
edit = types.TextDocumentEdit(
428+
text_document=types.OptionalVersionedTextDocumentIdentifier(
429+
uri=uri.value,
430+
version=None,
431+
),
432+
edits=[
433+
types.TextEdit(
434+
range=types.Range(
435+
start=types.Position(line=0, character=0),
436+
end=types.Position(
437+
line=end_line,
438+
character=end_character,
439+
),
440+
),
441+
new_text=yaml.dump(models),
442+
)
443+
],
444+
)
445+
ls.apply_edit(types.WorkspaceEdit(document_changes=[edit]))
446+
return True

sqlmesh/lsp/main.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ApiResponseGetLineage,
2424
ApiResponseGetModels,
2525
)
26+
from sqlmesh.lsp.commands import ExternalModelUpdateColumnsRequest, EXTERNAL_MODEL_UPDATE_COLUMNS
2627
from sqlmesh.lsp.completions import get_sql_completions
2728
from sqlmesh.lsp.context import (
2829
LSPContext,
@@ -313,6 +314,47 @@ def function_call(ls: LanguageServer, params: t.Any) -> t.Dict[str, t.Any]:
313314

314315
self.server.feature(name)(create_function_call(method))
315316

317+
@self.server.command(EXTERNAL_MODEL_UPDATE_COLUMNS)
318+
def command_external_models_update_columns(ls: LanguageServer, raw: t.Any) -> None:
319+
try:
320+
if not isinstance(raw, list):
321+
raise ValueError("Invalid command parameters")
322+
if len(raw) != 1:
323+
raise ValueError("Command expects exactly one parameter")
324+
model_name = raw[0]
325+
if not isinstance(model_name, str):
326+
raise ValueError("Command parameter must be a string")
327+
328+
request = ExternalModelUpdateColumnsRequest(model_name=model_name)
329+
context = self._context_get_or_load()
330+
if not isinstance(context, LSPContext):
331+
raise ValueError("Context is not loaded or invalid")
332+
model = context.context.get_model(request.model_name)
333+
if model is None:
334+
raise ValueError(f"External model '{request.model_name}' not found")
335+
if model._path is None:
336+
raise ValueError(
337+
f"External model '{request.model_name}' does not have a file path"
338+
)
339+
uri = URI.from_path(model._path)
340+
updated = context.update_external_model_columns(
341+
ls=ls,
342+
uri=uri,
343+
model_name=request.model_name,
344+
)
345+
if updated:
346+
ls.show_message(
347+
f"Updated columns for '{request.model_name}'",
348+
types.MessageType.Info,
349+
)
350+
else:
351+
ls.show_message(
352+
f"Columns for '{request.model_name}' are already up to date",
353+
)
354+
except Exception as e:
355+
ls.show_message(f"Error executing command: {e}", types.MessageType.Error)
356+
return None
357+
316358
@self.server.feature(types.INITIALIZE)
317359
def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
318360
"""Initialize the server when the client connects."""
@@ -688,6 +730,17 @@ def code_action(
688730
ls.log_trace(f"Error getting code actions: {e}")
689731
return None
690732

733+
@self.server.feature(types.TEXT_DOCUMENT_CODE_LENS)
734+
def code_lens(ls: LanguageServer, params: types.CodeLensParams) -> t.List[types.CodeLens]:
735+
try:
736+
uri = URI(params.text_document.uri)
737+
context = self._context_get_or_load(uri)
738+
code_lenses = context.get_code_lenses(uri)
739+
return code_lenses if code_lenses else []
740+
except Exception as e:
741+
ls.log_trace(f"Error getting code lenses: {e}")
742+
return []
743+
691744
@self.server.feature(
692745
types.TEXT_DOCUMENT_COMPLETION,
693746
types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros

0 commit comments

Comments
 (0)