Skip to content

Commit 006fdc1

Browse files
committed
feat(vscode): adding ability to update columns
1 parent 95b1f6e commit 006fdc1

File tree

11 files changed

+423
-28
lines changed

11 files changed

+423
-28
lines changed

examples/sushi/data/duckdb.db

5.26 MB
Binary file not shown.

examples/sushi/external_models.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
dialect: duckdb
44
start: 1 week ago
55
audits:
6-
- name: not_null
7-
columns: "[customer_id]"
8-
- name: accepted_range
9-
column: zip
10-
min_v: "'00000'"
11-
max_v: "'99999'"
12-
- name: assert_raw_demographics
6+
- name: not_null
7+
columns: '[customer_id]'
8+
- name: accepted_range
9+
column: zip
10+
min_v: "'00000'"
11+
max_v: "'99999'"
12+
- name: assert_raw_demographics
1313
columns:
1414
customer_id: int
1515
zip: text

pnpm-lock.yaml

Lines changed: 61 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

requirements.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
Ability to click update columns on a model in external models
2+
3+
# Requirements
4+
5+
* [ ] For each model in external models file
6+
* [ ] Below each model in the yaml file, provide an LSP code action to update columns
7+
* [ ] The code action should be able to update the columns in the model
8+
9+
# Useful files
10+
11+
sqlmesh/lsp/main.py is the main file for the LSP server.
12+
sqlmesh/cli/main.py and create_external_model provides similar functionality for creating external models but for the CLI and only does the whole shebang.
13+
14+
# Implementation Plan
15+
16+
1. **Create a new code action handler for external models YAML files**
17+
- Add support for YAML files in the LSP context's `get_code_actions` method
18+
- Detect when cursor is on or near a model definition in external_models.yaml
19+
20+
2. **Add YAML file parsing to identify model positions**
21+
- Parse the YAML file to identify model definitions and their line positions
22+
- Track the name field and columns section for each model
23+
24+
3. **Implement the "Update Columns" code action**
25+
- When triggered, the action will:
26+
- Extract the model name from the YAML
27+
- Use the engine adapter to fetch current columns from the database
28+
- Generate the updated columns section in YAML format
29+
- Replace the existing columns section with the new one
30+
31+
4. **Integration points**
32+
- Extend `LSPContext.get_code_actions()` to handle YAML files
33+
- Add YAML parsing logic to identify model boundaries
34+
- Reuse existing column fetching logic from `schema_loader.py`
35+
- Use the adapter's `columns()` method to get schema information
36+
37+
5. **Code action workflow**
38+
- User opens external_models.yaml in editor
39+
- LSP server identifies model definitions in the file
40+
- When cursor is on a model, show "Update columns" code action
41+
- On activation, fetch latest schema and update the YAML
42+

sqlmesh/core/schema_loader.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,6 @@ def create_external_models_file(
5757
external_model_fqns -= existing_model_fqns
5858

5959
with ThreadPoolExecutor(max_workers=max_workers) as pool:
60-
61-
def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
62-
try:
63-
return adapter.columns(table, include_pseudo_columns=True)
64-
except Exception as e:
65-
msg = f"Unable to get schema for '{table}': '{e}'."
66-
if strict:
67-
raise SQLMeshError(msg) from e
68-
get_console().log_warning(msg)
69-
return None
70-
7160
gateway_part = {"gateway": gateway} if gateway else {}
7261

7362
schemas = [
@@ -78,7 +67,7 @@ def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
7867
}
7968
for table, columns in sorted(
8069
pool.map(
81-
lambda table: (table, _get_columns(table)),
70+
lambda table: (table, get_columns(adapter, dialect, table, strict)),
8271
external_model_fqns,
8372
)
8473
)
@@ -94,3 +83,20 @@ def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]:
9483

9584
with open(path, "w", encoding="utf-8") as file:
9685
yaml.dump(entries_to_keep + schemas, file)
86+
87+
88+
def get_columns(
89+
adapter: EngineAdapter, dialect: DialectType, table: str, strict: bool
90+
) -> t.Optional[t.Dict[str, t.Any]]:
91+
"""
92+
Return the column and their types in a dictionary
93+
"""
94+
try:
95+
columns = adapter.columns(table, include_pseudo_columns=True)
96+
return {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()}
97+
except Exception as e:
98+
msg = f"Unable to get schema for '{table}': '{e}'."
99+
if strict:
100+
raise SQLMeshError(msg) from e
101+
get_console().log_warning(msg)
102+
return None

sqlmesh/lsp/commands.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from sqlmesh.utils.pydantic import PydanticModel
2+
3+
EXTERNAL_MODEL_UPDATE_COLUMNS = "sqlmesh.external_model_update_columns"
4+
5+
6+
class ExternalModelUpdateColumnsRequest(PydanticModel):
7+
model_name: str

sqlmesh/lsp/context.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
from dataclasses import dataclass
22
from pathlib import Path
3+
4+
from pygls.server import LanguageServer
5+
36
from sqlmesh.core.context import Context
47
import typing as t
58

6-
from sqlmesh.core.model.definition import SqlModel
9+
from sqlmesh.core.model.definition import SqlModel, ExternalModel
710
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
11+
from sqlmesh.core.schema_loader import get_columns
12+
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
813
from sqlmesh.lsp.custom import ModelForRendering
914
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
1015
from sqlmesh.lsp.uri import URI
1116
from lsprotocol import types
17+
from pathlib import Path
18+
19+
from sqlmesh.utils import yaml
1220

1321

1422
@dataclass
@@ -228,8 +236,42 @@ def get_code_actions(
228236
edit=types.WorkspaceEdit(changes={params.text_document.uri: text_edits}),
229237
)
230238
code_actions.append(code_action)
239+
return code_actions
240+
241+
def get_code_lenses(self, uri: URI) -> t.Optional[t.List[types.CodeLens]]:
242+
models_in_file = self.map.get(uri.to_path())
243+
if isinstance(models_in_file, ModelTarget):
244+
models = [self.context.get_model(model) for model in models_in_file.names]
245+
if any(isinstance(model, ExternalModel) for model in models):
246+
code_lenses = self._get_external_model_code_lenses(uri)
247+
if code_lenses:
248+
return code_lenses
249+
250+
return None
251+
252+
def _get_external_model_code_lenses(self, uri: URI) -> t.List[types.CodeLens]:
253+
from sqlmesh.lsp.reference import get_yaml_model_name_ranges
254+
255+
"""Get code actions for external models YAML files."""
256+
ranges = get_yaml_model_name_ranges(uri.to_path())
257+
258+
code_actions: t.List[types.CodeLens] = []
259+
for name, range in ranges.items():
260+
# Create a code action to update columns for external models
261+
command = types.Command(
262+
title="Update Columns",
263+
command=EXTERNAL_MODEL_UPDATE_COLUMNS,
264+
arguments=[
265+
name,
266+
],
267+
)
268+
code_action = types.CodeLens(
269+
range=range,
270+
command=command,
271+
)
272+
code_actions.append(code_action)
231273

232-
return code_actions if code_actions else None
274+
return code_actions if code_actions else []
233275

234276
def list_of_models_for_rendering(self) -> t.List[ModelForRendering]:
235277
"""Get a list of models for rendering.
@@ -332,3 +374,73 @@ def diagnostic_to_lsp_diagnostic(
332374
code=diagnostic.rule.name,
333375
code_description=types.CodeDescription(href=rule_uri),
334376
)
377+
378+
def update_external_model_columns(self, ls: LanguageServer, uri: URI, model_name: str) -> bool:
379+
"""
380+
Update the columns for an external model in the YAML file. Returns True if changed, False if didn't because
381+
of the columns already being up to date.
382+
383+
Errors still throw exceptions to be handled by the caller.
384+
"""
385+
models = yaml.load(uri.to_path())
386+
if not isinstance(models, list):
387+
raise ValueError(
388+
f"Expected a list of models in {uri.to_path()}, but got {type(models).__name__}"
389+
)
390+
391+
existing_model = next((model for model in models if model.get("name") == model_name), None)
392+
if existing_model is None:
393+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
394+
395+
existing_model_columns = existing_model.get("columns")
396+
397+
# Get the adapter and fetch columns
398+
adapter = self.context.engine_adapter
399+
# Get columns for the model
400+
new_columns = get_columns(
401+
adapter=adapter,
402+
dialect=self.context.config.model_defaults.dialect,
403+
table=model_name,
404+
strict=True,
405+
)
406+
# Compare existing columns and matching types and if they are the same, do not update
407+
if existing_model_columns is not None:
408+
if existing_model_columns == new_columns:
409+
ls.show_message("Columns already up to date")
410+
return False
411+
412+
# Model index to update
413+
model_index = next(
414+
(i for i, model in enumerate(models) if model.get("name") == model_name), None
415+
)
416+
if model_index is None:
417+
raise ValueError(f"Could not find model {model_name} in {uri.to_path()}")
418+
419+
# Get end of the file to set the edit range
420+
with open(uri.to_path(), "r", encoding="utf-8") as file:
421+
read_file = file.read()
422+
423+
end_line = read_file.count("\n")
424+
end_character = len(read_file.splitlines()[-1]) if end_line > 0 else 0
425+
426+
models[model_index]["columns"] = new_columns
427+
edit = types.TextDocumentEdit(
428+
text_document=types.OptionalVersionedTextDocumentIdentifier(
429+
uri=uri.value,
430+
version=None,
431+
),
432+
edits=[
433+
types.TextEdit(
434+
range=types.Range(
435+
start=types.Position(line=0, character=0),
436+
end=types.Position(
437+
line=end_line,
438+
character=end_character,
439+
),
440+
),
441+
new_text=yaml.dump(models),
442+
)
443+
],
444+
)
445+
ls.apply_edit(types.WorkspaceEdit(document_changes=[edit]))
446+
return True

0 commit comments

Comments
 (0)