|
1 | 1 | from dataclasses import dataclass |
2 | 2 | from pathlib import Path |
| 3 | + |
| 4 | +from pygls.server import LanguageServer |
| 5 | + |
3 | 6 | from sqlmesh.core.context import Context |
4 | 7 | import typing as t |
5 | 8 |
|
6 | | -from sqlmesh.core.model.definition import SqlModel |
| 9 | +from sqlmesh.core.model.definition import SqlModel, ExternalModel |
7 | 10 | from sqlmesh.core.linter.definition import AnnotatedRuleViolation |
| 11 | +from sqlmesh.core.schema_loader import get_columns |
| 12 | +from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS |
8 | 13 | from sqlmesh.lsp.custom import ModelForRendering |
9 | 14 | from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry |
| 15 | +from sqlmesh.lsp.helpers import to_lsp_range |
10 | 16 | from sqlmesh.lsp.uri import URI |
11 | 17 | from lsprotocol import types |
| 18 | +from pathlib import Path |
| 19 | + |
| 20 | +from sqlmesh.utils import yaml |
12 | 21 |
|
13 | 22 |
|
14 | 23 | @dataclass |
@@ -228,9 +237,44 @@ def get_code_actions( |
228 | 237 | edit=types.WorkspaceEdit(changes={params.text_document.uri: text_edits}), |
229 | 238 | ) |
230 | 239 | code_actions.append(code_action) |
231 | | - |
232 | 240 | return code_actions if code_actions else None |
233 | 241 |
|
| 242 | + def get_code_lenses(self, uri: URI) -> t.Optional[t.List[types.CodeLens]]: |
| 243 | + models_in_file = self.map.get(uri.to_path()) |
| 244 | + if isinstance(models_in_file, ModelTarget): |
| 245 | + models = [self.context.get_model(model) for model in models_in_file.names] |
| 246 | + if any(isinstance(model, ExternalModel) for model in models): |
| 247 | + code_lenses = self._get_external_model_code_lenses(uri) |
| 248 | + if code_lenses: |
| 249 | + return code_lenses |
| 250 | + |
| 251 | + return None |
| 252 | + |
| 253 | + def _get_external_model_code_lenses(self, uri: URI) -> t.List[types.CodeLens]: |
| 254 | + from sqlmesh.lsp.reference import get_yaml_model_name_ranges |
| 255 | + |
| 256 | + """Get code actions for external models YAML files.""" |
| 257 | + ranges = get_yaml_model_name_ranges(uri.to_path()) |
| 258 | + |
| 259 | + code_lenses: t.List[types.CodeLens] = [] |
| 260 | + for name, range in ranges.items(): |
| 261 | + # Create a code action to update columns for external models |
| 262 | + command = types.Command( |
| 263 | + title="Update Columns", |
| 264 | + command=EXTERNAL_MODEL_UPDATE_COLUMNS, |
| 265 | + arguments=[ |
| 266 | + name, |
| 267 | + ], |
| 268 | + ) |
| 269 | + code_lenses.append( |
| 270 | + types.CodeLens( |
| 271 | + range=to_lsp_range(range), |
| 272 | + command=command, |
| 273 | + ) |
| 274 | + ) |
| 275 | + |
| 276 | + return code_lenses if code_lenses else [] |
| 277 | + |
234 | 278 | def list_of_models_for_rendering(self) -> t.List[ModelForRendering]: |
235 | 279 | """Get a list of models for rendering. |
236 | 280 |
|
@@ -332,3 +376,73 @@ def diagnostic_to_lsp_diagnostic( |
332 | 376 | code=diagnostic.rule.name, |
333 | 377 | code_description=types.CodeDescription(href=rule_uri), |
334 | 378 | ) |
| 379 | + |
| 380 | + def update_external_model_columns(self, ls: LanguageServer, uri: URI, model_name: str) -> bool: |
| 381 | + """ |
| 382 | + Update the columns for an external model in the YAML file. Returns True if changed, False if didn't because |
| 383 | + of the columns already being up to date. |
| 384 | +
|
| 385 | + Errors still throw exceptions to be handled by the caller. |
| 386 | + """ |
| 387 | + models = yaml.load(uri.to_path()) |
| 388 | + if not isinstance(models, list): |
| 389 | + raise ValueError( |
| 390 | + f"Expected a list of models in {uri.to_path()}, but got {type(models).__name__}" |
| 391 | + ) |
| 392 | + |
| 393 | + existing_model = next((model for model in models if model.get("name") == model_name), None) |
| 394 | + if existing_model is None: |
| 395 | + raise ValueError(f"Could not find model {model_name} in {uri.to_path()}") |
| 396 | + |
| 397 | + existing_model_columns = existing_model.get("columns") |
| 398 | + |
| 399 | + # Get the adapter and fetch columns |
| 400 | + adapter = self.context.engine_adapter |
| 401 | + # Get columns for the model |
| 402 | + new_columns = get_columns( |
| 403 | + adapter=adapter, |
| 404 | + dialect=self.context.config.model_defaults.dialect, |
| 405 | + table=model_name, |
| 406 | + strict=True, |
| 407 | + ) |
| 408 | + # Compare existing columns and matching types and if they are the same, do not update |
| 409 | + if existing_model_columns is not None: |
| 410 | + if existing_model_columns == new_columns: |
| 411 | + ls.show_message("Columns already up to date") |
| 412 | + return False |
| 413 | + |
| 414 | + # Model index to update |
| 415 | + model_index = next( |
| 416 | + (i for i, model in enumerate(models) if model.get("name") == model_name), None |
| 417 | + ) |
| 418 | + if model_index is None: |
| 419 | + raise ValueError(f"Could not find model {model_name} in {uri.to_path()}") |
| 420 | + |
| 421 | + # Get end of the file to set the edit range |
| 422 | + with open(uri.to_path(), "r", encoding="utf-8") as file: |
| 423 | + read_file = file.read() |
| 424 | + |
| 425 | + end_line = read_file.count("\n") |
| 426 | + end_character = len(read_file.splitlines()[-1]) if end_line > 0 else 0 |
| 427 | + |
| 428 | + models[model_index]["columns"] = new_columns |
| 429 | + edit = types.TextDocumentEdit( |
| 430 | + text_document=types.OptionalVersionedTextDocumentIdentifier( |
| 431 | + uri=uri.value, |
| 432 | + version=None, |
| 433 | + ), |
| 434 | + edits=[ |
| 435 | + types.TextEdit( |
| 436 | + range=types.Range( |
| 437 | + start=types.Position(line=0, character=0), |
| 438 | + end=types.Position( |
| 439 | + line=end_line, |
| 440 | + character=end_character, |
| 441 | + ), |
| 442 | + ), |
| 443 | + new_text=yaml.dump(models), |
| 444 | + ) |
| 445 | + ], |
| 446 | + ) |
| 447 | + ls.apply_edit(types.WorkspaceEdit(document_changes=[edit])) |
| 448 | + return True |
0 commit comments