Skip to content

Commit 0ed67af

Browse files
committed
feat(vscode): adding ability to update columns
1 parent 9238534 commit 0ed67af

File tree

8 files changed

+367
-19
lines changed

8 files changed

+367
-19
lines changed

pnpm-lock.yaml

Lines changed: 61 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlmesh/core/schema_loader.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,28 +57,17 @@ def create_external_models_file(
5757
external_model_fqns -= existing_model_fqns
5858

5959
with ThreadPoolExecutor(max_workers=max_workers) as pool:
60-
61-
def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
62-
try:
63-
return adapter.columns(table, include_pseudo_columns=True)
64-
except Exception as e:
65-
msg = f"Unable to get schema for '{table}': '{e}'."
66-
if strict:
67-
raise SQLMeshError(msg) from e
68-
get_console().log_warning(msg)
69-
return None
70-
7160
gateway_part = {"gateway": gateway} if gateway else {}
7261

7362
schemas = [
7463
{
7564
"name": exp.to_table(table).sql(dialect=dialect),
76-
"columns": {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()},
65+
"columns": columns,
7766
**gateway_part,
7867
}
7968
for table, columns in sorted(
8069
pool.map(
81-
lambda table: (table, _get_columns(table)),
70+
lambda table: (table, get_columns(adapter, dialect, table, strict)),
8271
external_model_fqns,
8372
)
8473
)
@@ -94,3 +83,20 @@ def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
9483

9584
with open(path, "w", encoding="utf-8") as file:
9685
yaml.dump(entries_to_keep + schemas, file)
86+
87+
88+
def get_columns(
89+
adapter: EngineAdapter, dialect: DialectType, table: str, strict: bool
90+
) -> t.Optional[t.Dict[str, t.Any]]:
91+
"""
92+
Return the column and their types in a dictionary
93+
"""
94+
try:
95+
columns = adapter.columns(table, include_pseudo_columns=True)
96+
return {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()}
97+
except Exception as e:
98+
msg = f"Unable to get schema for '{table}': '{e}'."
99+
if strict:
100+
raise SQLMeshError(msg) from e
101+
get_console().log_warning(msg)
102+
return None

sqlmesh/lsp/commands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
EXTERNAL_MODEL_UPDATE_COLUMNS = "sqlmesh.external_model_update_columns"

sqlmesh/lsp/context.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
from dataclasses import dataclass
22
from pathlib import Path
3+
from pygls.server import LanguageServer
34
from sqlmesh.core.context import Context
45
import typing as t
5-
66
from sqlmesh.core.linter.rule import Range
7-
from sqlmesh.core.model.definition import SqlModel
7+
from sqlmesh.core.model.definition import SqlModel, ExternalModel
88
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
9+
from sqlmesh.core.schema_loader import get_columns
10+
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
911
from sqlmesh.lsp.custom import ModelForRendering, TestEntry, RunTestResponse
1012
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
1113
from sqlmesh.lsp.tests_ranges import get_test_ranges
14+
from sqlmesh.lsp.helpers import to_lsp_range
1215
from sqlmesh.lsp.uri import URI
1316
from lsprotocol import types
17+
from sqlmesh.utils import yaml
18+
from sqlmesh.utils.lineage import get_yaml_model_name_ranges
1419

1520

1621
@dataclass
@@ -298,6 +303,42 @@ def get_code_actions(
298303

299304
return code_actions if code_actions else None
300305

306+
def get_code_lenses(self, uri: URI) -> t.Optional[t.List[types.CodeLens]]:
307+
models_in_file = self.map.get(uri.to_path())
308+
if isinstance(models_in_file, ModelTarget):
309+
models = [self.context.get_model(model) for model in models_in_file.names]
310+
if any(isinstance(model, ExternalModel) for model in models):
311+
code_lenses = self._get_external_model_code_lenses(uri)
312+
if code_lenses:
313+
return code_lenses
314+
315+
return None
316+
317+
def _get_external_model_code_lenses(self, uri: URI) -> t.List[types.CodeLens]:
318+
"""Get code lenses for external models YAML files."""
319+
ranges = get_yaml_model_name_ranges(uri.to_path())
320+
if not ranges:
321+
return []
322+
323+
code_lenses: t.List[types.CodeLens] = []
324+
for name, range in ranges.items():
325+
# Create a code action to update columns for external models
326+
command = types.Command(
327+
title="Update Columns",
328+
command=EXTERNAL_MODEL_UPDATE_COLUMNS,
329+
arguments=[
330+
name,
331+
],
332+
)
333+
code_lenses.append(
334+
types.CodeLens(
335+
range=to_lsp_range(range),
336+
command=command,
337+
)
338+
)
339+
340+
return code_lenses if code_lenses else []
341+
301342
def list_of_models_for_rendering(self) -> t.List[ModelForRendering]:
302343
"""Get a list of models for rendering.
303344
@@ -399,3 +440,72 @@ def diagnostic_to_lsp_diagnostic(
399440
code=diagnostic.rule.name,
400441
code_description=types.CodeDescription(href=rule_uri),
401442
)
443+
444+
def update_external_model_columns(self, ls: LanguageServer, uri: URI, model_name: str) -> bool:
445+
"""
446+
Update the columns for an external model in the YAML file. Returns True if changed, False if didn't because
447+
of the columns already being up to date.
448+
449+
Errors still throw exceptions to be handled by the caller.
450+
"""
451+
models = yaml.load(uri.to_path())
452+
if not isinstance(models, list):
453+
raise ValueError(
454+
f"Expected a list of models in {uri.to_path()}, but got {type(models).__name__}"
455+
)
456+
457+
existing_model = next((model for model in models if model.get("name") == model_name), None)
458+
if existing_model is None:
459+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
460+
461+
existing_model_columns = existing_model.get("columns")
462+
463+
# Get the adapter and fetch columns
464+
adapter = self.context.engine_adapter
465+
# Get columns for the model
466+
new_columns = get_columns(
467+
adapter=adapter,
468+
dialect=self.context.config.model_defaults.dialect,
469+
table=model_name,
470+
strict=True,
471+
)
472+
# Compare existing columns and matching types and if they are the same, do not update
473+
if existing_model_columns is not None:
474+
if existing_model_columns == new_columns:
475+
return False
476+
477+
# Model index to update
478+
model_index = next(
479+
(i for i, model in enumerate(models) if model.get("name") == model_name), None
480+
)
481+
if model_index is None:
482+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
483+
484+
# Get end of the file to set the edit range
485+
with open(uri.to_path(), "r", encoding="utf-8") as file:
486+
read_file = file.read()
487+
488+
end_line = read_file.count("\n")
489+
end_character = len(read_file.splitlines()[-1]) if end_line > 0 else 0
490+
491+
models[model_index]["columns"] = new_columns
492+
edit = types.TextDocumentEdit(
493+
text_document=types.OptionalVersionedTextDocumentIdentifier(
494+
uri=uri.value,
495+
version=None,
496+
),
497+
edits=[
498+
types.TextEdit(
499+
range=types.Range(
500+
start=types.Position(line=0, character=0),
501+
end=types.Position(
502+
line=end_line,
503+
character=end_character,
504+
),
505+
),
506+
new_text=yaml.dump(models),
507+
)
508+
],
509+
)
510+
ls.apply_edit(types.WorkspaceEdit(document_changes=[edit]))
511+
return True

sqlmesh/lsp/main.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ApiResponseGetModels,
2525
)
2626

27+
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
2728
from sqlmesh.lsp.completions import get_sql_completions
2829
from sqlmesh.lsp.context import (
2930
LSPContext,
@@ -368,6 +369,47 @@ def function_call(ls: LanguageServer, params: t.Any) -> t.Dict[str, t.Any]:
368369

369370
self.server.feature(name)(create_function_call(method))
370371

372+
# Check if command is already registered to prevent duplicate registration errors
373+
if EXTERNAL_MODEL_UPDATE_COLUMNS not in getattr(self.server, "_commands", {}):
374+
375+
@self.server.command(EXTERNAL_MODEL_UPDATE_COLUMNS)
376+
def command_external_models_update_columns(ls: LanguageServer, raw: t.Any) -> None:
377+
try:
378+
if not isinstance(raw, list):
379+
raise ValueError("Invalid command parameters")
380+
if len(raw) != 1:
381+
raise ValueError("Command expects exactly one parameter")
382+
model_name = raw[0]
383+
if not isinstance(model_name, str):
384+
raise ValueError("Command parameter must be a string")
385+
386+
context = self._context_get_or_load()
387+
if not isinstance(context, LSPContext):
388+
raise ValueError("Context is not loaded or invalid")
389+
model = context.context.get_model(model_name)
390+
if model is None:
391+
raise ValueError(f"External model '{model_name}' not found")
392+
if model._path is None:
393+
raise ValueError(f"External model '{model_name}' does not have a file path")
394+
uri = URI.from_path(model._path)
395+
updated = context.update_external_model_columns(
396+
ls=ls,
397+
uri=uri,
398+
model_name=model_name,
399+
)
400+
if updated:
401+
ls.show_message(
402+
f"Updated columns for '{model_name}'",
403+
types.MessageType.Info,
404+
)
405+
else:
406+
ls.show_message(
407+
f"Columns for '{model_name}' are already up to date",
408+
)
409+
except Exception as e:
410+
ls.show_message(f"Error executing command: {e}", types.MessageType.Error)
411+
return None
412+
371413
@self.server.feature(types.INITIALIZE)
372414
def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
373415
"""Initialize the server when the client connects."""
@@ -750,6 +792,17 @@ def code_action(
750792
ls.log_trace(f"Error getting code actions: {e}")
751793
return None
752794

795+
@self.server.feature(types.TEXT_DOCUMENT_CODE_LENS)
796+
def code_lens(ls: LanguageServer, params: types.CodeLensParams) -> t.List[types.CodeLens]:
797+
try:
798+
uri = URI(params.text_document.uri)
799+
context = self._context_get_or_load(uri)
800+
code_lenses = context.get_code_lenses(uri)
801+
return code_lenses if code_lenses else []
802+
except Exception as e:
803+
ls.log_trace(f"Error getting code lenses: {e}")
804+
return []
805+
753806
@self.server.feature(
754807
types.TEXT_DOCUMENT_COMPLETION,
755808
types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros

sqlmesh/utils/lineage.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,18 +387,38 @@ def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
387387
Returns:
388388
The Range of the model block in the YAML file, or None if not found
389389
"""
390+
model_name_ranges = get_yaml_model_name_ranges(path)
391+
if model_name_ranges is None:
392+
return None
393+
return model_name_ranges.get(model_name, None)
394+
395+
396+
def get_yaml_model_name_ranges(path: Path) -> t.Optional[t.Dict[str, Range]]:
397+
"""
398+
Get the ranges of all model names in a YAML file.
399+
400+
Args:
401+
path: Path to the YAML file
402+
403+
Returns:
404+
A dictionary mapping model names to their ranges in the YAML file.
405+
"""
390406
yaml = YAML()
391407
with path.open("r", encoding="utf-8") as f:
392408
data = yaml.load(f)
393409

394410
if not isinstance(data, list):
395411
return None
396412

413+
model_name_ranges = {}
397414
for item in data:
398-
if isinstance(item, dict) and item.get("name") == model_name:
399-
# Get size of block by taking the earliest line/col in the items block and the last line/col of the block
415+
if isinstance(item, dict):
400416
position_data = item.lc.data["name"] # type: ignore
401417
start = Position(line=position_data[2], character=position_data[3])
402418
end = Position(line=position_data[2], character=position_data[3] + len(item["name"]))
403-
return Range(start=start, end=end)
404-
return None
419+
name = item.get("name")
420+
if not name :
421+
continue
422+
model_name_ranges[name] = Range(start=start, end=end)
423+
424+
return model_name_ranges

0 commit comments

Comments
 (0)