Skip to content

Commit 1a88f27

Browse files
committed
feat(vscode): adding ability to update columns
1 parent c02022e commit 1a88f27

File tree

8 files changed

+371
-22
lines changed

8 files changed

+371
-22
lines changed

pnpm-lock.yaml

Lines changed: 61 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlmesh/core/schema_loader.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,28 +57,17 @@ def create_external_models_file(
5757
external_model_fqns -= existing_model_fqns
5858

5959
with ThreadPoolExecutor(max_workers=max_workers) as pool:
60-
61-
def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
62-
try:
63-
return adapter.columns(table, include_pseudo_columns=True)
64-
except Exception as e:
65-
msg = f"Unable to get schema for '{table}': '{e}'."
66-
if strict:
67-
raise SQLMeshError(msg) from e
68-
get_console().log_warning(msg)
69-
return None
70-
7160
gateway_part = {"gateway": gateway} if gateway else {}
7261

7362
schemas = [
7463
{
7564
"name": exp.to_table(table).sql(dialect=dialect),
76-
"columns": {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()},
65+
"columns": columns,
7766
**gateway_part,
7867
}
7968
for table, columns in sorted(
8069
pool.map(
81-
lambda table: (table, _get_columns(table)),
70+
lambda table: (table, get_columns(adapter, dialect, table, strict)),
8271
external_model_fqns,
8372
)
8473
)
@@ -94,3 +83,20 @@ def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
9483

9584
with open(path, "w", encoding="utf-8") as file:
9685
yaml.dump(entries_to_keep + schemas, file)
86+
87+
88+
def get_columns(
89+
adapter: EngineAdapter, dialect: DialectType, table: str, strict: bool
90+
) -> t.Optional[t.Dict[str, t.Any]]:
91+
"""
92+
Return the column and their types in a dictionary
93+
"""
94+
try:
95+
columns = adapter.columns(table, include_pseudo_columns=True)
96+
return {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()}
97+
except Exception as e:
98+
msg = f"Unable to get schema for '{table}': '{e}'."
99+
if strict:
100+
raise SQLMeshError(msg) from e
101+
get_console().log_warning(msg)
102+
return None

sqlmesh/lsp/commands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
EXTERNAL_MODEL_UPDATE_COLUMNS = "sqlmesh.external_model_update_columns"

sqlmesh/lsp/context.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
from dataclasses import dataclass
22
from pathlib import Path
3+
4+
from pygls.server import LanguageServer
5+
36
from sqlmesh.core.context import Context
47
import typing as t
58

6-
from sqlmesh.core.model.definition import SqlModel
9+
from sqlmesh.core.model.definition import SqlModel, ExternalModel
710
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
11+
from sqlmesh.core.schema_loader import get_columns
12+
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
813
from sqlmesh.lsp.custom import ModelForRendering
914
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
1015
from sqlmesh.lsp.uri import URI
1116
from lsprotocol import types
17+
from pathlib import Path
18+
19+
from sqlmesh.utils import yaml
1220

1321

1422
@dataclass
@@ -228,9 +236,44 @@ def get_code_actions(
228236
edit=types.WorkspaceEdit(changes={params.text_document.uri: text_edits}),
229237
)
230238
code_actions.append(code_action)
231-
232239
return code_actions if code_actions else None
233240

241+
def get_code_lenses(self, uri: URI) -> t.Optional[t.List[types.CodeLens]]:
242+
models_in_file = self.map.get(uri.to_path())
243+
if isinstance(models_in_file, ModelTarget):
244+
models = [self.context.get_model(model) for model in models_in_file.names]
245+
if any(isinstance(model, ExternalModel) for model in models):
246+
code_lenses = self._get_external_model_code_lenses(uri)
247+
if code_lenses:
248+
return code_lenses
249+
250+
return None
251+
252+
def _get_external_model_code_lenses(self, uri: URI) -> t.List[types.CodeLens]:
253+
from sqlmesh.lsp.reference import get_yaml_model_name_ranges
254+
255+
"""Get code actions for external models YAML files."""
256+
ranges = get_yaml_model_name_ranges(uri.to_path())
257+
258+
code_lenses: t.List[types.CodeLens] = []
259+
for name, range in ranges.items():
260+
# Create a code action to update columns for external models
261+
command = types.Command(
262+
title="Update Columns",
263+
command=EXTERNAL_MODEL_UPDATE_COLUMNS,
264+
arguments=[
265+
name,
266+
],
267+
)
268+
code_lenses.append(
269+
types.CodeLens(
270+
range=range,
271+
command=command,
272+
)
273+
)
274+
275+
return code_lenses if code_lenses else []
276+
234277
def list_of_models_for_rendering(self) -> t.List[ModelForRendering]:
235278
"""Get a list of models for rendering.
236279
@@ -332,3 +375,73 @@ def diagnostic_to_lsp_diagnostic(
332375
code=diagnostic.rule.name,
333376
code_description=types.CodeDescription(href=rule_uri),
334377
)
378+
379+
def update_external_model_columns(self, ls: LanguageServer, uri: URI, model_name: str) -> bool:
380+
"""
381+
Update the columns for an external model in the YAML file. Returns True if changed, False if didn't because
382+
of the columns already being up to date.
383+
384+
Errors still throw exceptions to be handled by the caller.
385+
"""
386+
models = yaml.load(uri.to_path())
387+
if not isinstance(models, list):
388+
raise ValueError(
389+
f"Expected a list of models in {uri.to_path()}, but got {type(models).__name__}"
390+
)
391+
392+
existing_model = next((model for model in models if model.get("name") == model_name), None)
393+
if existing_model is None:
394+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
395+
396+
existing_model_columns = existing_model.get("columns")
397+
398+
# Get the adapter and fetch columns
399+
adapter = self.context.engine_adapter
400+
# Get columns for the model
401+
new_columns = get_columns(
402+
adapter=adapter,
403+
dialect=self.context.config.model_defaults.dialect,
404+
table=model_name,
405+
strict=True,
406+
)
407+
# Compare existing columns and matching types and if they are the same, do not update
408+
if existing_model_columns is not None:
409+
if existing_model_columns == new_columns:
410+
ls.show_message("Columns already up to date")
411+
return False
412+
413+
# Model index to update
414+
model_index = next(
415+
(i for i, model in enumerate(models) if model.get("name") == model_name), None
416+
)
417+
if model_index is None:
418+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
419+
420+
# Get end of the file to set the edit range
421+
with open(uri.to_path(), "r", encoding="utf-8") as file:
422+
read_file = file.read()
423+
424+
end_line = read_file.count("\n")
425+
end_character = len(read_file.splitlines()[-1]) if end_line > 0 else 0
426+
427+
models[model_index]["columns"] = new_columns
428+
edit = types.TextDocumentEdit(
429+
text_document=types.OptionalVersionedTextDocumentIdentifier(
430+
uri=uri.value,
431+
version=None,
432+
),
433+
edits=[
434+
types.TextEdit(
435+
range=types.Range(
436+
start=types.Position(line=0, character=0),
437+
end=types.Position(
438+
line=end_line,
439+
character=end_character,
440+
),
441+
),
442+
new_text=yaml.dump(models),
443+
)
444+
],
445+
)
446+
ls.apply_edit(types.WorkspaceEdit(document_changes=[edit]))
447+
return True

sqlmesh/lsp/main.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ApiResponseGetLineage,
2424
ApiResponseGetModels,
2525
)
26+
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
2627
from sqlmesh.lsp.completions import get_sql_completions
2728
from sqlmesh.lsp.context import (
2829
LSPContext,
@@ -310,6 +311,44 @@ def function_call(ls: LanguageServer, params: t.Any) -> t.Dict[str, t.Any]:
310311

311312
self.server.feature(name)(create_function_call(method))
312313

314+
@self.server.command(EXTERNAL_MODEL_UPDATE_COLUMNS)
315+
def command_external_models_update_columns(ls: LanguageServer, raw: t.Any) -> None:
316+
try:
317+
if not isinstance(raw, list):
318+
raise ValueError("Invalid command parameters")
319+
if len(raw) != 1:
320+
raise ValueError("Command expects exactly one parameter")
321+
model_name = raw[0]
322+
if not isinstance(model_name, str):
323+
raise ValueError("Command parameter must be a string")
324+
325+
context = self._context_get_or_load()
326+
if not isinstance(context, LSPContext):
327+
raise ValueError("Context is not loaded or invalid")
328+
model = context.context.get_model(model_name)
329+
if model is None:
330+
raise ValueError(f"External model '{model_name}' not found")
331+
if model._path is None:
332+
raise ValueError(f"External model '{model_name}' does not have a file path")
333+
uri = URI.from_path(model._path)
334+
updated = context.update_external_model_columns(
335+
ls=ls,
336+
uri=uri,
337+
model_name=model_name,
338+
)
339+
if updated:
340+
ls.show_message(
341+
f"Updated columns for '{model_name}'",
342+
types.MessageType.Info,
343+
)
344+
else:
345+
ls.show_message(
346+
f"Columns for '{model_name}' are already up to date",
347+
)
348+
except Exception as e:
349+
ls.show_message(f"Error executing command: {e}", types.MessageType.Error)
350+
return None
351+
313352
@self.server.feature(types.INITIALIZE)
314353
def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
315354
"""Initialize the server when the client connects."""
@@ -685,6 +724,17 @@ def code_action(
685724
ls.log_trace(f"Error getting code actions: {e}")
686725
return None
687726

727+
@self.server.feature(types.TEXT_DOCUMENT_CODE_LENS)
728+
def code_lens(ls: LanguageServer, params: types.CodeLensParams) -> t.List[types.CodeLens]:
729+
try:
730+
uri = URI(params.text_document.uri)
731+
context = self._context_get_or_load(uri)
732+
code_lenses = context.get_code_lenses(uri)
733+
return code_lenses if code_lenses else []
734+
except Exception as e:
735+
ls.log_trace(f"Error getting code lenses: {e}")
736+
return []
737+
688738
@self.server.feature(
689739
types.TEXT_DOCUMENT_COMPLETION,
690740
types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros

0 commit comments

Comments
 (0)