Skip to content

Commit 10921a9

Browse files
committed
feat(vscode): adding ability to update columns
1 parent a94c4f0 commit 10921a9

File tree

8 files changed

+372
-22
lines changed

8 files changed

+372
-22
lines changed

pnpm-lock.yaml

Lines changed: 61 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlmesh/core/schema_loader.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,28 +57,17 @@ def create_external_models_file(
5757
external_model_fqns -= existing_model_fqns
5858

5959
with ThreadPoolExecutor(max_workers=max_workers) as pool:
60-
61-
def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
62-
try:
63-
return adapter.columns(table, include_pseudo_columns=True)
64-
except Exception as e:
65-
msg = f"Unable to get schema for '{table}': '{e}'."
66-
if strict:
67-
raise SQLMeshError(msg) from e
68-
get_console().log_warning(msg)
69-
return None
70-
7160
gateway_part = {"gateway": gateway} if gateway else {}
7261

7362
schemas = [
7463
{
7564
"name": exp.to_table(table).sql(dialect=dialect),
76-
"columns": {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()},
65+
"columns": columns,
7766
**gateway_part,
7867
}
7968
for table, columns in sorted(
8069
pool.map(
81-
lambda table: (table, _get_columns(table)),
70+
lambda table: (table, get_columns(adapter, dialect, table, strict)),
8271
external_model_fqns,
8372
)
8473
)
@@ -94,3 +83,20 @@ def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
9483

9584
with open(path, "w", encoding="utf-8") as file:
9685
yaml.dump(entries_to_keep + schemas, file)
86+
87+
88+
def get_columns(
89+
adapter: EngineAdapter, dialect: DialectType, table: str, strict: bool
90+
) -> t.Optional[t.Dict[str, t.Any]]:
91+
"""
92+
Return the column and their types in a dictionary
93+
"""
94+
try:
95+
columns = adapter.columns(table, include_pseudo_columns=True)
96+
return {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()}
97+
except Exception as e:
98+
msg = f"Unable to get schema for '{table}': '{e}'."
99+
if strict:
100+
raise SQLMeshError(msg) from e
101+
get_console().log_warning(msg)
102+
return None

sqlmesh/lsp/commands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
EXTERNAL_MODEL_UPDATE_COLUMNS = "sqlmesh.external_model_update_columns"

sqlmesh/lsp/context.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
from dataclasses import dataclass
22
from pathlib import Path
3+
from pygls.server import LanguageServer
34
from sqlmesh.core.context import Context
45
import typing as t
5-
66
from sqlmesh.core.linter.rule import Range
7-
from sqlmesh.core.model.definition import SqlModel
7+
from sqlmesh.core.model.definition import SqlModel, ExternalModel
88
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
9+
from sqlmesh.core.schema_loader import get_columns
10+
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
911
from sqlmesh.lsp.custom import ModelForRendering, TestEntry, RunTestResponse
1012
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
1113
from sqlmesh.lsp.tests_ranges import get_test_ranges
14+
from sqlmesh.lsp.helpers import to_lsp_range
1215
from sqlmesh.lsp.uri import URI
1316
from lsprotocol import types
17+
from sqlmesh.utils import yaml
1418

1519

1620
@dataclass
@@ -298,6 +302,42 @@ def get_code_actions(
298302

299303
return code_actions if code_actions else None
300304

305+
def get_code_lenses(self, uri: URI) -> t.Optional[t.List[types.CodeLens]]:
306+
models_in_file = self.map.get(uri.to_path())
307+
if isinstance(models_in_file, ModelTarget):
308+
models = [self.context.get_model(model) for model in models_in_file.names]
309+
if any(isinstance(model, ExternalModel) for model in models):
310+
code_lenses = self._get_external_model_code_lenses(uri)
311+
if code_lenses:
312+
return code_lenses
313+
314+
return None
315+
316+
def _get_external_model_code_lenses(self, uri: URI) -> t.List[types.CodeLens]:
317+
from sqlmesh.lsp.reference import get_yaml_model_name_ranges
318+
319+
"""Get code actions for external models YAML files."""
320+
ranges = get_yaml_model_name_ranges(uri.to_path())
321+
322+
code_lenses: t.List[types.CodeLens] = []
323+
for name, range in ranges.items():
324+
# Create a code action to update columns for external models
325+
command = types.Command(
326+
title="Update Columns",
327+
command=EXTERNAL_MODEL_UPDATE_COLUMNS,
328+
arguments=[
329+
name,
330+
],
331+
)
332+
code_lenses.append(
333+
types.CodeLens(
334+
range=to_lsp_range(range),
335+
command=command,
336+
)
337+
)
338+
339+
return code_lenses if code_lenses else []
340+
301341
def list_of_models_for_rendering(self) -> t.List[ModelForRendering]:
302342
"""Get a list of models for rendering.
303343
@@ -399,3 +439,73 @@ def diagnostic_to_lsp_diagnostic(
399439
code=diagnostic.rule.name,
400440
code_description=types.CodeDescription(href=rule_uri),
401441
)
442+
443+
def update_external_model_columns(self, ls: LanguageServer, uri: URI, model_name: str) -> bool:
444+
"""
445+
Update the columns for an external model in the YAML file. Returns True if changed, False if didn't because
446+
of the columns already being up to date.
447+
448+
Errors still throw exceptions to be handled by the caller.
449+
"""
450+
models = yaml.load(uri.to_path())
451+
if not isinstance(models, list):
452+
raise ValueError(
453+
f"Expected a list of models in {uri.to_path()}, but got {type(models).__name__}"
454+
)
455+
456+
existing_model = next((model for model in models if model.get("name") == model_name), None)
457+
if existing_model is None:
458+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
459+
460+
existing_model_columns = existing_model.get("columns")
461+
462+
# Get the adapter and fetch columns
463+
adapter = self.context.engine_adapter
464+
# Get columns for the model
465+
new_columns = get_columns(
466+
adapter=adapter,
467+
dialect=self.context.config.model_defaults.dialect,
468+
table=model_name,
469+
strict=True,
470+
)
471+
# Compare existing columns and matching types and if they are the same, do not update
472+
if existing_model_columns is not None:
473+
if existing_model_columns == new_columns:
474+
ls.show_message("Columns already up to date")
475+
return False
476+
477+
# Model index to update
478+
model_index = next(
479+
(i for i, model in enumerate(models) if model.get("name") == model_name), None
480+
)
481+
if model_index is None:
482+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
483+
484+
# Get end of the file to set the edit range
485+
with open(uri.to_path(), "r", encoding="utf-8") as file:
486+
read_file = file.read()
487+
488+
end_line = read_file.count("\n")
489+
end_character = len(read_file.splitlines()[-1]) if end_line > 0 else 0
490+
491+
models[model_index]["columns"] = new_columns
492+
edit = types.TextDocumentEdit(
493+
text_document=types.OptionalVersionedTextDocumentIdentifier(
494+
uri=uri.value,
495+
version=None,
496+
),
497+
edits=[
498+
types.TextEdit(
499+
range=types.Range(
500+
start=types.Position(line=0, character=0),
501+
end=types.Position(
502+
line=end_line,
503+
character=end_character,
504+
),
505+
),
506+
new_text=yaml.dump(models),
507+
)
508+
],
509+
)
510+
ls.apply_edit(types.WorkspaceEdit(document_changes=[edit]))
511+
return True

sqlmesh/lsp/main.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ApiResponseGetLineage,
2424
ApiResponseGetModels,
2525
)
26+
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
2627
from sqlmesh.lsp.completions import get_sql_completions
2728
from sqlmesh.lsp.context import (
2829
LSPContext,
@@ -367,6 +368,44 @@ def function_call(ls: LanguageServer, params: t.Any) -> t.Dict[str, t.Any]:
367368

368369
self.server.feature(name)(create_function_call(method))
369370

371+
@self.server.command(EXTERNAL_MODEL_UPDATE_COLUMNS)
372+
def command_external_models_update_columns(ls: LanguageServer, raw: t.Any) -> None:
373+
try:
374+
if not isinstance(raw, list):
375+
raise ValueError("Invalid command parameters")
376+
if len(raw) != 1:
377+
raise ValueError("Command expects exactly one parameter")
378+
model_name = raw[0]
379+
if not isinstance(model_name, str):
380+
raise ValueError("Command parameter must be a string")
381+
382+
context = self._context_get_or_load()
383+
if not isinstance(context, LSPContext):
384+
raise ValueError("Context is not loaded or invalid")
385+
model = context.context.get_model(model_name)
386+
if model is None:
387+
raise ValueError(f"External model '{model_name}' not found")
388+
if model._path is None:
389+
raise ValueError(f"External model '{model_name}' does not have a file path")
390+
uri = URI.from_path(model._path)
391+
updated = context.update_external_model_columns(
392+
ls=ls,
393+
uri=uri,
394+
model_name=model_name,
395+
)
396+
if updated:
397+
ls.show_message(
398+
f"Updated columns for '{model_name}'",
399+
types.MessageType.Info,
400+
)
401+
else:
402+
ls.show_message(
403+
f"Columns for '{model_name}' are already up to date",
404+
)
405+
except Exception as e:
406+
ls.show_message(f"Error executing command: {e}", types.MessageType.Error)
407+
return None
408+
370409
@self.server.feature(types.INITIALIZE)
371410
def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
372411
"""Initialize the server when the client connects."""
@@ -749,6 +788,17 @@ def code_action(
749788
ls.log_trace(f"Error getting code actions: {e}")
750789
return None
751790

791+
@self.server.feature(types.TEXT_DOCUMENT_CODE_LENS)
792+
def code_lens(ls: LanguageServer, params: types.CodeLensParams) -> t.List[types.CodeLens]:
793+
try:
794+
uri = URI(params.text_document.uri)
795+
context = self._context_get_or_load(uri)
796+
code_lenses = context.get_code_lenses(uri)
797+
return code_lenses if code_lenses else []
798+
except Exception as e:
799+
ls.log_trace(f"Error getting code lenses: {e}")
800+
return []
801+
752802
@self.server.feature(
753803
types.TEXT_DOCUMENT_COMPLETION,
754804
types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros

sqlmesh/lsp/reference.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,34 @@ def _process_column_references(
886886
return references
887887

888888

889+
def get_yaml_model_name_ranges(path: Path) -> t.Dict[str, Range]:
890+
"""
891+
Get a mapping of model names to their ranges in an external models YAML file.
892+
893+
Args:
894+
path: The path to the YAML file.
895+
Returns:
896+
A dictionary mapping model names to their ranges in the YAML file.
897+
"""
898+
yaml = YAML()
899+
model_ranges: t.Dict[str, Range] = {}
900+
with path.open("r", encoding="utf-8") as f:
901+
data = yaml.load(f)
902+
903+
if not isinstance(data, list):
904+
return model_ranges
905+
906+
for item in data:
907+
if isinstance(item, dict) and "name" in item:
908+
# Get size of block by taking the earliest line/col in the items block and the last line/col of the block
909+
position_data = item.lc.data["name"] # type: ignore
910+
start = Position(line=position_data[2], character=position_data[3])
911+
end = Position(line=position_data[2], character=position_data[3] + len(item["name"]))
912+
model_ranges[item["name"]] = Range(start=start, end=end)
913+
914+
return model_ranges
915+
916+
889917
def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
890918
"""
891919
Find the range of a specific model block in a YAML file.
@@ -904,11 +932,8 @@ def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
904932
if not isinstance(data, list):
905933
return None
906934

907-
for item in data:
908-
if isinstance(item, dict) and item.get("name") == model_name:
909-
# Get size of block by taking the earliest line/col in the items block and the last line/col of the block
910-
position_data = item.lc.data["name"] # type: ignore
911-
start = Position(line=position_data[2], character=position_data[3])
912-
end = Position(line=position_data[2], character=position_data[3] + len(item["name"]))
913-
return Range(start=start, end=end)
935+
# Get all model ranges in the YAML file
936+
model_ranges = get_yaml_model_name_ranges(path)
937+
if model_name in model_ranges:
938+
return model_ranges[model_name]
914939
return None

0 commit comments

Comments
 (0)