diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py index 59fd478f78..707b0a9159 100644 --- a/sqlmesh/core/linter/helpers.py +++ b/sqlmesh/core/linter/helpers.py @@ -1,6 +1,6 @@ from pathlib import Path -from sqlmesh.core.linter.rule import Position, Range +from sqlmesh.core.linter.rule import Range, Position from sqlmesh.utils.pydantic import PydanticModel from sqlglot import tokenize, TokenType import typing as t diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index 24b287d74c..26100c1092 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -48,6 +48,7 @@ CustomMethod, ) from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic +from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position from sqlmesh.lsp.hints import get_hints from sqlmesh.lsp.reference import ( LSPCteReference, @@ -418,7 +419,7 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov context = self._context_get_or_load(uri) document = ls.workspace.get_text_document(params.text_document.uri) - references = get_references(context, uri, params.position) + references = get_references(context, uri, to_sqlmesh_position(params.position)) if not references: return None reference = references[0] @@ -429,7 +430,7 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov kind=types.MarkupKind.Markdown, value=reference.markdown_description, ), - range=reference.range, + range=to_lsp_range(reference.range), ) except Exception as e: @@ -464,7 +465,7 @@ def goto_definition( uri = URI(params.text_document.uri) context = self._context_get_or_load(uri) - references = get_references(context, uri, params.position) + references = get_references(context, uri, to_sqlmesh_position(params.position)) location_links = [] for reference in references: # Use target_range if available (CTEs, Macros, and external models in YAML) @@ -489,12 +490,12 @@ def goto_definition( end=types.Position(line=0, character=0), ) if reference.target_range is not None: - target_range = reference.target_range - target_selection_range = reference.target_range + target_range = to_lsp_range(reference.target_range) + target_selection_range = to_lsp_range(reference.target_range) else: # CTEs and Macros always have target_range - target_range = reference.target_range - target_selection_range = reference.target_range + target_range = to_lsp_range(reference.target_range) + target_selection_range = to_lsp_range(reference.target_range) if reference.path is not None: location_links.append( @@ -502,7 +503,7 @@ def goto_definition( target_uri=URI.from_path(reference.path).value, target_selection_range=target_selection_range, target_range=target_range, - origin_selection_range=reference.range, + origin_selection_range=to_lsp_range(reference.range), ) ) return location_links @@ -519,11 +520,13 @@ def find_references( uri = URI(params.text_document.uri) context = self._context_get_or_load(uri) - all_references = get_all_references(context, uri, params.position) + all_references = get_all_references( + context, uri, to_sqlmesh_position(params.position) + ) # Convert references to Location objects locations = [ - types.Location(uri=URI.from_path(ref.path).value, range=ref.range) + types.Location(uri=URI.from_path(ref.path).value, range=to_lsp_range(ref.range)) for ref in all_references if ref.path is not None ] diff --git a/sqlmesh/lsp/reference.py b/sqlmesh/lsp/reference.py index 6849129a7e..6aee3e10da 100644 --- a/sqlmesh/lsp/reference.py +++ b/sqlmesh/lsp/reference.py @@ -1,4 +1,3 @@ -from lsprotocol.types import Range, Position import typing as t from pathlib import Path from pydantic import Field @@ -8,13 +7,13 @@ from sqlmesh.core.linter.helpers import ( TokenPositionDetails, ) +from sqlmesh.core.linter.rule import Range, Position from sqlmesh.core.model.definition import SqlModel, ExternalModel from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget from sqlglot import exp from sqlmesh.lsp.description import generate_markdown_description from sqlglot.optimizer.scope import build_scope -from sqlmesh.lsp.helpers import to_lsp_range, to_lsp_position from sqlmesh.lsp.uri import URI from sqlmesh.utils.pydantic import PydanticModel from sqlglot.optimizer.normalize_identifiers import normalize_identifiers @@ -203,15 +202,11 @@ def get_model_definitions_for_a_path( table.this.meta ).to_range(read_file) - # Convert SQLMesh Range to LSP Range - target_range = to_lsp_range(target_range_sqlmesh) - table_range = to_lsp_range(table_range_sqlmesh) - references.append( LSPCteReference( path=document_uri.to_path(), # Same file - range=table_range, - target_range=target_range, + range=table_range_sqlmesh, + target_range=target_range_sqlmesh, ) ) @@ -222,7 +217,7 @@ def get_model_definitions_for_a_path( referenced_model_path=document_uri.to_path(), description="", reference_type="cte", - cte_target_range=target_range, + cte_target_range=target_range_sqlmesh, ) references.extend(column_references) continue @@ -257,8 +252,8 @@ def get_model_definitions_for_a_path( references.append( LSPExternalModelReference( range=Range( - start=to_lsp_position(start_pos_sqlmesh), - end=to_lsp_position(end_pos_sqlmesh), + start=start_pos_sqlmesh, + end=end_pos_sqlmesh, ), markdown_description="Unregistered external model", ) @@ -300,8 +295,8 @@ def get_model_definitions_for_a_path( LSPExternalModelReference( path=referenced_model_path, range=Range( - start=to_lsp_position(start_pos_sqlmesh), - end=to_lsp_position(end_pos_sqlmesh), + start=start_pos_sqlmesh, + end=end_pos_sqlmesh, ), markdown_description=description, target_range=yaml_target_range, @@ -325,8 +320,8 @@ def get_model_definitions_for_a_path( LSPModelReference( path=referenced_model_path, range=Range( - start=to_lsp_position(start_pos_sqlmesh), - end=to_lsp_position(end_pos_sqlmesh), + start=start_pos_sqlmesh, + end=end_pos_sqlmesh, ), markdown_description=description, ) @@ -432,7 +427,7 @@ def get_macro_reference( macro_range = TokenPositionDetails.from_meta(node.meta).to_range(read_file) # Check if it's a built-in method - if builtin := get_built_in_macro_reference(macro_name, to_lsp_range(macro_range)): + if builtin := get_built_in_macro_reference(macro_name, macro_range): return builtin else: # Skip if we can't get the position @@ -483,7 +478,7 @@ def get_macro_reference( return LSPMacroReference( path=path, - range=to_lsp_range(macro_range), + range=macro_range, target_range=Range( start=Position(line=start_line - 1, character=0), end=Position(line=end_line - 1, character=get_length_of_end_line), @@ -811,8 +806,8 @@ def _get_column_table_range(column: exp.Column, read_file: t.List[str]) -> Range end_range = TokenPositionDetails.from_meta(table_parts[-1].meta).to_range(read_file) return Range( - start=to_lsp_position(start_range.start), - end=to_lsp_position(end_range.end), + start=start_range.start, + end=end_range.end, ) diff --git a/sqlmesh/lsp/rename.py b/sqlmesh/lsp/rename.py index c388b7a305..31f7eb3200 100644 --- a/sqlmesh/lsp/rename.py +++ b/sqlmesh/lsp/rename.py @@ -9,6 +9,7 @@ ) from sqlmesh.lsp.context import LSPContext +from sqlmesh.lsp.helpers import to_sqlmesh_position, to_lsp_range from sqlmesh.lsp.reference import ( _position_within_range, get_cte_references, @@ -18,7 +19,7 @@ def prepare_rename( - lsp_context: LSPContext, document_uri: URI, position: Position + lsp_context: LSPContext, document_uri: URI, lsp_position: Position ) -> t.Optional[PrepareRenameResult_Type1]: """ Prepare for rename operation by checking if the symbol at the position can be renamed. @@ -32,6 +33,7 @@ def prepare_rename( PrepareRenameResult if the symbol can be renamed, None otherwise """ # Check if there's a CTE at this position + position = to_sqlmesh_position(lsp_position) cte_references = get_cte_references(lsp_context, document_uri, position) if cte_references: # Find the target CTE definition to get its range @@ -46,14 +48,16 @@ def prepare_rename( target_range = ref.target_range break if target_range: - return PrepareRenameResult_Type1(range=target_range, placeholder="cte_name") + return PrepareRenameResult_Type1( + range=to_lsp_range(target_range), placeholder="cte_name" + ) # For now, only CTEs are supported return None def rename_symbol( - lsp_context: LSPContext, document_uri: URI, position: Position, new_name: str + lsp_context: LSPContext, document_uri: URI, lsp_position: Position, new_name: str ) -> t.Optional[WorkspaceEdit]: """ Perform rename operation on the symbol at the given position. @@ -68,7 +72,9 @@ def rename_symbol( WorkspaceEdit with the changes, or None if no symbol to rename """ # Check if there's a CTE at this position - cte_references = get_cte_references(lsp_context, document_uri, position) + cte_references = get_cte_references( + lsp_context, document_uri, to_sqlmesh_position(lsp_position) + ) if cte_references: return _rename_cte(cte_references, new_name) @@ -95,7 +101,7 @@ def _rename_cte(cte_references: t.List[LSPCteReference], new_name: str) -> Works changes[uri] = [] # Create a text edit for this reference - text_edit = TextEdit(range=ref.range, new_text=new_name) + text_edit = TextEdit(range=to_lsp_range(ref.range), new_text=new_name) changes[uri].append(text_edit) return WorkspaceEdit(changes=changes) @@ -119,7 +125,7 @@ def get_document_highlights( List of DocumentHighlight objects or None if no symbol found """ # Check if there's a CTE at this position - cte_references = get_cte_references(lsp_context, document_uri, position) + cte_references = get_cte_references(lsp_context, document_uri, to_sqlmesh_position(position)) if cte_references: highlights = [] for ref in cte_references: @@ -130,7 +136,7 @@ def get_document_highlights( else DocumentHighlightKind.Read ) - highlights.append(DocumentHighlight(range=ref.range, kind=kind)) + highlights.append(DocumentHighlight(range=to_lsp_range(ref.range), kind=kind)) return highlights # For now, only CTEs are supported diff --git a/tests/lsp/test_reference.py b/tests/lsp/test_reference.py index b78ed8145c..3d1e19f3cc 100644 --- a/tests/lsp/test_reference.py +++ b/tests/lsp/test_reference.py @@ -1,5 +1,5 @@ -from lsprotocol.types import Position from sqlmesh.core.context import Context +from sqlmesh.core.linter.rule import Position from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget from sqlmesh.lsp.reference import LSPModelReference, get_model_definitions_for_a_path, by_position from sqlmesh.lsp.uri import URI diff --git a/tests/lsp/test_reference_external_model.py b/tests/lsp/test_reference_external_model.py index 8f2c2f7c1b..36c64fe277 100644 --- a/tests/lsp/test_reference_external_model.py +++ b/tests/lsp/test_reference_external_model.py @@ -1,12 +1,10 @@ from pathlib import Path -from lsprotocol.types import Position - from sqlmesh import Config from sqlmesh.core.context import Context from sqlmesh.core.linter.helpers import read_range_from_file +from sqlmesh.core.linter.rule import Position from sqlmesh.lsp.context import LSPContext, ModelTarget -from sqlmesh.lsp.helpers import to_sqlmesh_range from sqlmesh.lsp.reference import get_references, LSPExternalModelReference from sqlmesh.lsp.uri import URI from tests.utils.test_filesystem import create_temp_file @@ -34,14 +32,14 @@ def test_reference() -> None: assert path is not None assert str(path).endswith("external_models.yaml") - source_range = read_range_from_file(customers, to_sqlmesh_range(reference.range)) + source_range = read_range_from_file(customers, reference.range) assert source_range == "raw.demographics" if reference.target_range is None: raise AssertionError("Reference target range should not be None") path = reference.path assert path is not None - target_range = read_range_from_file(path, to_sqlmesh_range(reference.target_range)) + target_range = read_range_from_file(path, reference.target_range) assert target_range == "raw.demographics" @@ -61,4 +59,4 @@ def test_unregistered_external_model(tmp_path: Path): assert reference.path is None assert reference.target_range is None assert reference.markdown_description == "Unregistered external model" - assert read_range_from_file(model_path, to_sqlmesh_range(reference.range)) == "external_model" + assert read_range_from_file(model_path, reference.range) == "external_model" diff --git a/tests/lsp/test_reference_model_column_prefix.py b/tests/lsp/test_reference_model_column_prefix.py index 01b91de570..3cd25a080e 100644 --- a/tests/lsp/test_reference_model_column_prefix.py +++ b/tests/lsp/test_reference_model_column_prefix.py @@ -1,8 +1,8 @@ from pathlib import Path -from lsprotocol.types import Position from sqlmesh.cli.project_init import init_example_project from sqlmesh.core.context import Context +from sqlmesh.core.linter.rule import Position from sqlmesh.lsp.context import LSPContext, ModelTarget from sqlmesh.lsp.reference import get_all_references from sqlmesh.lsp.uri import URI