|
1 | 1 | from dataclasses import dataclass |
2 | 2 | from pathlib import Path |
| 3 | +from pygls.server import LanguageServer |
3 | 4 | from sqlmesh.core.context import Context |
4 | 5 | import typing as t |
5 | | - |
6 | 6 | from sqlmesh.core.linter.rule import Range |
7 | | -from sqlmesh.core.model.definition import SqlModel |
| 7 | +from sqlmesh.core.model.definition import SqlModel, ExternalModel |
8 | 8 | from sqlmesh.core.linter.definition import AnnotatedRuleViolation |
| 9 | +from sqlmesh.core.schema_loader import get_columns |
| 10 | +from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS |
9 | 11 | from sqlmesh.lsp.custom import ModelForRendering, TestEntry, RunTestResponse |
10 | 12 | from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry |
11 | 13 | from sqlmesh.lsp.tests_ranges import get_test_ranges |
| 14 | +from sqlmesh.lsp.helpers import to_lsp_range |
12 | 15 | from sqlmesh.lsp.uri import URI |
13 | 16 | from lsprotocol import types |
| 17 | +from sqlmesh.utils import yaml |
| 18 | +from sqlmesh.utils.lineage import get_yaml_model_name_ranges |
14 | 19 |
|
15 | 20 |
|
16 | 21 | @dataclass |
@@ -298,6 +303,36 @@ def get_code_actions( |
298 | 303 |
|
299 | 304 | return code_actions if code_actions else None |
300 | 305 |
|
| 306 | + def get_code_lenses(self, uri: URI) -> t.Optional[t.List[types.CodeLens]]: |
| 307 | + models_in_file = self.map.get(uri.to_path()) |
| 308 | + if isinstance(models_in_file, ModelTarget): |
| 309 | + models = [self.context.get_model(model) for model in models_in_file.names] |
| 310 | + if any(isinstance(model, ExternalModel) for model in models): |
| 311 | + code_lenses = self._get_external_model_code_lenses(uri) |
| 312 | + if code_lenses: |
| 313 | + return code_lenses |
| 314 | + |
| 315 | + return None |
| 316 | + |
| 317 | + def _get_external_model_code_lenses(self, uri: URI) -> t.List[types.CodeLens]: |
| 318 | + """Get code lenses for external models YAML files.""" |
| 319 | + ranges = get_yaml_model_name_ranges(uri.to_path()) |
| 320 | + if ranges is None: |
| 321 | + return [] |
| 322 | + return [ |
| 323 | + types.CodeLens( |
| 324 | + range=to_lsp_range(range), |
| 325 | + command=types.Command( |
| 326 | + title="Update Columns", |
| 327 | + command=EXTERNAL_MODEL_UPDATE_COLUMNS, |
| 328 | + arguments=[ |
| 329 | + name, |
| 330 | + ], |
| 331 | + ), |
| 332 | + ) |
| 333 | + for name, range in ranges.items() |
| 334 | + ] |
| 335 | + |
301 | 336 | def list_of_models_for_rendering(self) -> t.List[ModelForRendering]: |
302 | 337 | """Get a list of models for rendering. |
303 | 338 |
|
@@ -399,3 +434,72 @@ def diagnostic_to_lsp_diagnostic( |
399 | 434 | code=diagnostic.rule.name, |
400 | 435 | code_description=types.CodeDescription(href=rule_uri), |
401 | 436 | ) |
| 437 | + |
| 438 | + def update_external_model_columns(self, ls: LanguageServer, uri: URI, model_name: str) -> bool: |
| 439 | + """ |
| 440 | + Update the columns for an external model in the YAML file. Returns True if changed, False if didn't because |
| 441 | + of the columns already being up to date. |
| 442 | +
|
| 443 | + Errors still throw exceptions to be handled by the caller. |
| 444 | + """ |
| 445 | + models = yaml.load(uri.to_path()) |
| 446 | + if not isinstance(models, list): |
| 447 | + raise ValueError( |
| 448 | + f"Expected a list of models in {uri.to_path()}, but got {type(models).__name__}" |
| 449 | + ) |
| 450 | + |
| 451 | + existing_model = next((model for model in models if model.get("name") == model_name), None) |
| 452 | + if existing_model is None: |
| 453 | + raise ValueError(f"Could not find model {model_name} in {uri.to_path()}") |
| 454 | + |
| 455 | + existing_model_columns = existing_model.get("columns") |
| 456 | + |
| 457 | + # Get the adapter and fetch columns |
| 458 | + adapter = self.context.engine_adapter |
| 459 | + # Get columns for the model |
| 460 | + new_columns = get_columns( |
| 461 | + adapter=adapter, |
| 462 | + dialect=self.context.config.model_defaults.dialect, |
| 463 | + table=model_name, |
| 464 | + strict=True, |
| 465 | + ) |
| 466 | + # Compare existing columns and matching types and if they are the same, do not update |
| 467 | + if existing_model_columns is not None: |
| 468 | + if existing_model_columns == new_columns: |
| 469 | + return False |
| 470 | + |
| 471 | + # Model index to update |
| 472 | + model_index = next( |
| 473 | + (i for i, model in enumerate(models) if model.get("name") == model_name), None |
| 474 | + ) |
| 475 | + if model_index is None: |
| 476 | + raise ValueError(f"Could not find model {model_name} in {uri.to_path()}") |
| 477 | + |
| 478 | + # Get end of the file to set the edit range |
| 479 | + with open(uri.to_path(), "r", encoding="utf-8") as file: |
| 480 | + read_file = file.read() |
| 481 | + |
| 482 | + end_line = read_file.count("\n") |
| 483 | + end_character = len(read_file.splitlines()[-1]) if end_line > 0 else 0 |
| 484 | + |
| 485 | + models[model_index]["columns"] = new_columns |
| 486 | + edit = types.TextDocumentEdit( |
| 487 | + text_document=types.OptionalVersionedTextDocumentIdentifier( |
| 488 | + uri=uri.value, |
| 489 | + version=None, |
| 490 | + ), |
| 491 | + edits=[ |
| 492 | + types.TextEdit( |
| 493 | + range=types.Range( |
| 494 | + start=types.Position(line=0, character=0), |
| 495 | + end=types.Position( |
| 496 | + line=end_line, |
| 497 | + character=end_character, |
| 498 | + ), |
| 499 | + ), |
| 500 | + new_text=yaml.dump(models), |
| 501 | + ) |
| 502 | + ], |
| 503 | + ) |
| 504 | + ls.apply_edit(types.WorkspaceEdit(document_changes=[edit])) |
| 505 | + return True |
0 commit comments