Skip to content

Commit 6e47ba0

Browse files
committed
feat(vscode): adding ability to update columns
1 parent 9238534 commit 6e47ba0

File tree

8 files changed

+358
-19
lines changed

8 files changed

+358
-19
lines changed

pnpm-lock.yaml

Lines changed: 61 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlmesh/core/schema_loader.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,28 +57,17 @@ def create_external_models_file(
5757
external_model_fqns -= existing_model_fqns
5858

5959
with ThreadPoolExecutor(max_workers=max_workers) as pool:
60-
61-
def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
62-
try:
63-
return adapter.columns(table, include_pseudo_columns=True)
64-
except Exception as e:
65-
msg = f"Unable to get schema for '{table}': '{e}'."
66-
if strict:
67-
raise SQLMeshError(msg) from e
68-
get_console().log_warning(msg)
69-
return None
70-
7160
gateway_part = {"gateway": gateway} if gateway else {}
7261

7362
schemas = [
7463
{
7564
"name": exp.to_table(table).sql(dialect=dialect),
76-
"columns": {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()},
65+
"columns": columns,
7766
**gateway_part,
7867
}
7968
for table, columns in sorted(
8069
pool.map(
81-
lambda table: (table, _get_columns(table)),
70+
lambda table: (table, get_columns(adapter, dialect, table, strict)),
8271
external_model_fqns,
8372
)
8473
)
@@ -94,3 +83,20 @@ def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
9483

9584
with open(path, "w", encoding="utf-8") as file:
9685
yaml.dump(entries_to_keep + schemas, file)
86+
87+
88+
def get_columns(
89+
adapter: EngineAdapter, dialect: DialectType, table: str, strict: bool
90+
) -> t.Optional[t.Dict[str, t.Any]]:
91+
"""
92+
Return the column and their types in a dictionary
93+
"""
94+
try:
95+
columns = adapter.columns(table, include_pseudo_columns=True)
96+
return {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()}
97+
except Exception as e:
98+
msg = f"Unable to get schema for '{table}': '{e}'."
99+
if strict:
100+
raise SQLMeshError(msg) from e
101+
get_console().log_warning(msg)
102+
return None

sqlmesh/lsp/commands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
EXTERNAL_MODEL_UPDATE_COLUMNS = "sqlmesh.external_model_update_columns"

sqlmesh/lsp/context.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
from dataclasses import dataclass
22
from pathlib import Path
3+
from pygls.server import LanguageServer
34
from sqlmesh.core.context import Context
45
import typing as t
5-
66
from sqlmesh.core.linter.rule import Range
7-
from sqlmesh.core.model.definition import SqlModel
7+
from sqlmesh.core.model.definition import SqlModel, ExternalModel
88
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
9+
from sqlmesh.core.schema_loader import get_columns
10+
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
911
from sqlmesh.lsp.custom import ModelForRendering, TestEntry, RunTestResponse
1012
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
1113
from sqlmesh.lsp.tests_ranges import get_test_ranges
14+
from sqlmesh.lsp.helpers import to_lsp_range
1215
from sqlmesh.lsp.uri import URI
1316
from lsprotocol import types
17+
from sqlmesh.utils import yaml
18+
from sqlmesh.utils.lineage import get_yaml_model_name_ranges
1419

1520

1621
@dataclass
@@ -298,6 +303,36 @@ def get_code_actions(
298303

299304
return code_actions if code_actions else None
300305

306+
def get_code_lenses(self, uri: URI) -> t.Optional[t.List[types.CodeLens]]:
307+
models_in_file = self.map.get(uri.to_path())
308+
if isinstance(models_in_file, ModelTarget):
309+
models = [self.context.get_model(model) for model in models_in_file.names]
310+
if any(isinstance(model, ExternalModel) for model in models):
311+
code_lenses = self._get_external_model_code_lenses(uri)
312+
if code_lenses:
313+
return code_lenses
314+
315+
return None
316+
317+
def _get_external_model_code_lenses(self, uri: URI) -> t.List[types.CodeLens]:
318+
"""Get code lenses for external models YAML files."""
319+
ranges = get_yaml_model_name_ranges(uri.to_path())
320+
if ranges is None:
321+
return []
322+
return [
323+
types.CodeLens(
324+
range=to_lsp_range(range),
325+
command=types.Command(
326+
title="Update Columns",
327+
command=EXTERNAL_MODEL_UPDATE_COLUMNS,
328+
arguments=[
329+
name,
330+
],
331+
),
332+
)
333+
for name, range in ranges.items()
334+
]
335+
301336
def list_of_models_for_rendering(self) -> t.List[ModelForRendering]:
302337
"""Get a list of models for rendering.
303338
@@ -399,3 +434,72 @@ def diagnostic_to_lsp_diagnostic(
399434
code=diagnostic.rule.name,
400435
code_description=types.CodeDescription(href=rule_uri),
401436
)
437+
438+
def update_external_model_columns(self, ls: LanguageServer, uri: URI, model_name: str) -> bool:
439+
"""
440+
Update the columns for an external model in the YAML file. Returns True if changed, False if didn't because
441+
of the columns already being up to date.
442+
443+
Errors still throw exceptions to be handled by the caller.
444+
"""
445+
models = yaml.load(uri.to_path())
446+
if not isinstance(models, list):
447+
raise ValueError(
448+
f"Expected a list of models in {uri.to_path()}, but got {type(models).__name__}"
449+
)
450+
451+
existing_model = next((model for model in models if model.get("name") == model_name), None)
452+
if existing_model is None:
453+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
454+
455+
existing_model_columns = existing_model.get("columns")
456+
457+
# Get the adapter and fetch columns
458+
adapter = self.context.engine_adapter
459+
# Get columns for the model
460+
new_columns = get_columns(
461+
adapter=adapter,
462+
dialect=self.context.config.model_defaults.dialect,
463+
table=model_name,
464+
strict=True,
465+
)
466+
# Compare existing columns and matching types and if they are the same, do not update
467+
if existing_model_columns is not None:
468+
if existing_model_columns == new_columns:
469+
return False
470+
471+
# Model index to update
472+
model_index = next(
473+
(i for i, model in enumerate(models) if model.get("name") == model_name), None
474+
)
475+
if model_index is None:
476+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
477+
478+
# Get end of the file to set the edit range
479+
with open(uri.to_path(), "r", encoding="utf-8") as file:
480+
read_file = file.read()
481+
482+
end_line = read_file.count("\n")
483+
end_character = len(read_file.splitlines()[-1]) if end_line > 0 else 0
484+
485+
models[model_index]["columns"] = new_columns
486+
edit = types.TextDocumentEdit(
487+
text_document=types.OptionalVersionedTextDocumentIdentifier(
488+
uri=uri.value,
489+
version=None,
490+
),
491+
edits=[
492+
types.TextEdit(
493+
range=types.Range(
494+
start=types.Position(line=0, character=0),
495+
end=types.Position(
496+
line=end_line,
497+
character=end_character,
498+
),
499+
),
500+
new_text=yaml.dump(models),
501+
)
502+
],
503+
)
504+
ls.apply_edit(types.WorkspaceEdit(document_changes=[edit]))
505+
return True

sqlmesh/lsp/main.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ApiResponseGetModels,
2525
)
2626

27+
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
2728
from sqlmesh.lsp.completions import get_sql_completions
2829
from sqlmesh.lsp.context import (
2930
LSPContext,
@@ -368,6 +369,44 @@ def function_call(ls: LanguageServer, params: t.Any) -> t.Dict[str, t.Any]:
368369

369370
self.server.feature(name)(create_function_call(method))
370371

372+
@self.server.command(EXTERNAL_MODEL_UPDATE_COLUMNS)
373+
def command_external_models_update_columns(ls: LanguageServer, raw: t.Any) -> None:
374+
try:
375+
if not isinstance(raw, list):
376+
raise ValueError("Invalid command parameters")
377+
if len(raw) != 1:
378+
raise ValueError("Command expects exactly one parameter")
379+
model_name = raw[0]
380+
if not isinstance(model_name, str):
381+
raise ValueError("Command parameter must be a string")
382+
383+
context = self._context_get_or_load()
384+
if not isinstance(context, LSPContext):
385+
raise ValueError("Context is not loaded or invalid")
386+
model = context.context.get_model(model_name)
387+
if model is None:
388+
raise ValueError(f"External model '{model_name}' not found")
389+
if model._path is None:
390+
raise ValueError(f"External model '{model_name}' does not have a file path")
391+
uri = URI.from_path(model._path)
392+
updated = context.update_external_model_columns(
393+
ls=ls,
394+
uri=uri,
395+
model_name=model_name,
396+
)
397+
if updated:
398+
ls.show_message(
399+
f"Updated columns for '{model_name}'",
400+
types.MessageType.Info,
401+
)
402+
else:
403+
ls.show_message(
404+
f"Columns for '{model_name}' are already up to date",
405+
)
406+
except Exception as e:
407+
ls.show_message(f"Error executing command: {e}", types.MessageType.Error)
408+
return None
409+
371410
@self.server.feature(types.INITIALIZE)
372411
def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
373412
"""Initialize the server when the client connects."""
@@ -750,6 +789,17 @@ def code_action(
750789
ls.log_trace(f"Error getting code actions: {e}")
751790
return None
752791

792+
@self.server.feature(types.TEXT_DOCUMENT_CODE_LENS)
793+
def code_lens(ls: LanguageServer, params: types.CodeLensParams) -> t.List[types.CodeLens]:
794+
try:
795+
uri = URI(params.text_document.uri)
796+
context = self._context_get_or_load(uri)
797+
code_lenses = context.get_code_lenses(uri)
798+
return code_lenses if code_lenses else []
799+
except Exception as e:
800+
ls.log_trace(f"Error getting code lenses: {e}")
801+
return []
802+
753803
@self.server.feature(
754804
types.TEXT_DOCUMENT_COMPLETION,
755805
types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros

sqlmesh/utils/lineage.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,18 +387,38 @@ def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
387387
Returns:
388388
The Range of the model block in the YAML file, or None if not found
389389
"""
390+
model_name_ranges = get_yaml_model_name_ranges(path)
391+
if model_name_ranges is None:
392+
return None
393+
return model_name_ranges.get(model_name, None)
394+
395+
396+
def get_yaml_model_name_ranges(path: Path) -> t.Optional[t.Dict[str, Range]]:
397+
"""
398+
Get the ranges of all model names in a YAML file.
399+
400+
Args:
401+
path: Path to the YAML file
402+
403+
Returns:
404+
A dictionary mapping model names to their ranges in the YAML file.
405+
"""
390406
yaml = YAML()
391407
with path.open("r", encoding="utf-8") as f:
392408
data = yaml.load(f)
393409

394410
if not isinstance(data, list):
395411
return None
396412

413+
model_name_ranges = {}
397414
for item in data:
398-
if isinstance(item, dict) and item.get("name") == model_name:
399-
# Get size of block by taking the earliest line/col in the items block and the last line/col of the block
415+
if isinstance(item, dict):
400416
position_data = item.lc.data["name"] # type: ignore
401417
start = Position(line=position_data[2], character=position_data[3])
402418
end = Position(line=position_data[2], character=position_data[3] + len(item["name"]))
403-
return Range(start=start, end=end)
404-
return None
419+
name = item.get("name")
420+
if not name:
421+
continue
422+
model_name_ranges[name] = Range(start=start, end=end)
423+
424+
return model_name_ranges

vscode/extension/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
"package": "rm -rf ./src_react && mkdir -p ./src_react && cd ../react && pnpm run build && cd ../extension && cp -r ../react/dist/* ./src_react && pnpm run check-types && node esbuild.js --production"
135135
},
136136
"dependencies": {
137+
"@duckdb/node-api": "1.3.2-alpha.25",
137138
"@types/fs-extra": "^11.0.4",
138139
"@vscode/python-extension": "^1.0.5",
139140
"fs-extra": "^11.3.0",

0 commit comments

Comments
 (0)