Skip to content

Commit 606dcb4

Browse files
committed
refactor(lsp): move references to internal types
[ci skip]
1 parent fbd2632 commit 606dcb4

File tree

8 files changed

+74
-52
lines changed

8 files changed

+74
-52
lines changed

sqlmesh/core/linter/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pathlib import Path
22

3-
from sqlmesh.core.linter.rule import Position, Range
3+
from sqlmesh.core.linter.rule import Range, Position
44
from sqlmesh.utils.pydantic import PydanticModel
55
from sqlglot import tokenize, TokenType
66
import typing as t

sqlmesh/core/linter/rule.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@ class Position:
3030
line: int
3131
character: int
3232

33+
def __eq__(self, o: object) -> bool:
34+
if not isinstance(o, Position):
35+
return NotImplemented
36+
return (self.line, self.character) == (o.line, o.character)
37+
38+
def __gt__(self, o: "Position") -> bool:
39+
if not isinstance(o, Position):
40+
return NotImplemented
41+
return (self.line, self.character) > (o.line, o.character)
42+
43+
def __repr__(self) -> str:
44+
return f"{self.line}:{self.character}"
45+
46+
3347

3448
@dataclass(frozen=True)
3549
class Range:

sqlmesh/lsp/main.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
CustomMethod,
4949
)
5050
from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic
51+
from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position
5152
from sqlmesh.lsp.hints import get_hints
5253
from sqlmesh.lsp.reference import (
5354
LSPCteReference,
@@ -418,7 +419,7 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
418419
context = self._context_get_or_load(uri)
419420
document = ls.workspace.get_text_document(params.text_document.uri)
420421

421-
references = get_references(context, uri, params.position)
422+
references = get_references(context, uri, to_sqlmesh_position(params.position))
422423
if not references:
423424
return None
424425
reference = references[0]
@@ -429,7 +430,7 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
429430
kind=types.MarkupKind.Markdown,
430431
value=reference.markdown_description,
431432
),
432-
range=reference.range,
433+
range=to_lsp_range(reference.range),
433434
)
434435

435436
except Exception as e:
@@ -464,7 +465,7 @@ def goto_definition(
464465
uri = URI(params.text_document.uri)
465466
context = self._context_get_or_load(uri)
466467

467-
references = get_references(context, uri, params.position)
468+
references = get_references(context, uri, to_sqlmesh_position(params.position))
468469
location_links = []
469470
for reference in references:
470471
# Use target_range if available (CTEs, Macros, and external models in YAML)
@@ -489,20 +490,20 @@ def goto_definition(
489490
end=types.Position(line=0, character=0),
490491
)
491492
if reference.target_range is not None:
492-
target_range = reference.target_range
493-
target_selection_range = reference.target_range
493+
target_range = to_lsp_range(reference.target_range)
494+
target_selection_range = to_lsp_range(reference.target_range)
494495
else:
495496
# CTEs and Macros always have target_range
496-
target_range = reference.target_range
497-
target_selection_range = reference.target_range
497+
target_range = to_lsp_range(reference.target_range)
498+
target_selection_range = to_lsp_range(reference.target_range)
498499

499500
if reference.path is not None:
500501
location_links.append(
501502
types.LocationLink(
502503
target_uri=URI.from_path(reference.path).value,
503504
target_selection_range=target_selection_range,
504505
target_range=target_range,
505-
origin_selection_range=reference.range,
506+
origin_selection_range=to_lsp_range(reference.range),
506507
)
507508
)
508509
return location_links
@@ -519,11 +520,13 @@ def find_references(
519520
uri = URI(params.text_document.uri)
520521
context = self._context_get_or_load(uri)
521522

522-
all_references = get_all_references(context, uri, params.position)
523+
all_references = get_all_references(
524+
context, uri, to_sqlmesh_position(params.position)
525+
)
523526

524527
# Convert references to Location objects
525528
locations = [
526-
types.Location(uri=URI.from_path(ref.path).value, range=ref.range)
529+
types.Location(uri=URI.from_path(ref.path).value, range=to_lsp_range(ref.range))
527530
for ref in all_references
528531
if ref.path is not None
529532
]

sqlmesh/lsp/reference.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from lsprotocol.types import Range, Position
1+
import lsprotocol.types as types
22
import typing as t
33
from pathlib import Path
44
from pydantic import Field
@@ -8,6 +8,7 @@
88
from sqlmesh.core.linter.helpers import (
99
TokenPositionDetails,
1010
)
11+
from sqlmesh.core.linter.rule import Range, Position
1112
from sqlmesh.core.model.definition import SqlModel, ExternalModel
1213
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
1314
from sqlglot import exp
@@ -203,15 +204,11 @@ def get_model_definitions_for_a_path(
203204
table.this.meta
204205
).to_range(read_file)
205206

206-
# Convert SQLMesh Range to LSP Range
207-
target_range = to_lsp_range(target_range_sqlmesh)
208-
table_range = to_lsp_range(table_range_sqlmesh)
209-
210207
references.append(
211208
LSPCteReference(
212209
path=document_uri.to_path(), # Same file
213-
range=table_range,
214-
target_range=target_range,
210+
range=table_range_sqlmesh,
211+
target_range=target_range_sqlmesh,
215212
)
216213
)
217214

@@ -222,7 +219,7 @@ def get_model_definitions_for_a_path(
222219
referenced_model_path=document_uri.to_path(),
223220
description="",
224221
reference_type="cte",
225-
cte_target_range=target_range,
222+
cte_target_range=target_range_sqlmesh,
226223
)
227224
references.extend(column_references)
228225
continue
@@ -257,8 +254,8 @@ def get_model_definitions_for_a_path(
257254
references.append(
258255
LSPExternalModelReference(
259256
range=Range(
260-
start=to_lsp_position(start_pos_sqlmesh),
261-
end=to_lsp_position(end_pos_sqlmesh),
257+
start=start_pos_sqlmesh,
258+
end=end_pos_sqlmesh,
262259
),
263260
markdown_description="Unregistered external model",
264261
)
@@ -288,7 +285,7 @@ def get_model_definitions_for_a_path(
288285

289286
# For external models in YAML files, find the specific model block
290287
if isinstance(referenced_model, ExternalModel):
291-
yaml_target_range: t.Optional[Range] = None
288+
yaml_target_range: t.Optional[types.Range] = None
292289
if (
293290
referenced_model_path.suffix in (".yaml", ".yml")
294291
and referenced_model_path.is_file()
@@ -300,8 +297,8 @@ def get_model_definitions_for_a_path(
300297
LSPExternalModelReference(
301298
path=referenced_model_path,
302299
range=Range(
303-
start=to_lsp_position(start_pos_sqlmesh),
304-
end=to_lsp_position(end_pos_sqlmesh),
300+
start=start_pos_sqlmesh,
301+
end=end_pos_sqlmesh,
305302
),
306303
markdown_description=description,
307304
target_range=yaml_target_range,
@@ -325,8 +322,8 @@ def get_model_definitions_for_a_path(
325322
LSPModelReference(
326323
path=referenced_model_path,
327324
range=Range(
328-
start=to_lsp_position(start_pos_sqlmesh),
329-
end=to_lsp_position(end_pos_sqlmesh),
325+
start=start_pos_sqlmesh,
326+
end=end_pos_sqlmesh,
330327
),
331328
markdown_description=description,
332329
)
@@ -484,17 +481,19 @@ def get_macro_reference(
484481
return LSPMacroReference(
485482
path=path,
486483
range=to_lsp_range(macro_range),
487-
target_range=Range(
488-
start=Position(line=start_line - 1, character=0),
489-
end=Position(line=end_line - 1, character=get_length_of_end_line),
484+
target_range=types.Range(
485+
start=types.Position(line=start_line - 1, character=0),
486+
end=types.Position(line=end_line - 1, character=get_length_of_end_line),
490487
),
491488
markdown_description=docstring,
492489
)
493490
except Exception:
494491
return None
495492

496493

497-
def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optional[Reference]:
494+
def get_built_in_macro_reference(
495+
macro_name: str, macro_range: types.Range
496+
) -> t.Optional[Reference]:
498497
"""
499498
Get a reference to a built-in macro by its name.
500499
@@ -517,9 +516,9 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio
517516
return LSPMacroReference(
518517
path=Path(filename),
519518
range=macro_range,
520-
target_range=Range(
521-
start=Position(line=line_number - 1, character=0),
522-
end=Position(line=end_line_number - 1, character=0),
519+
target_range=types.Range(
520+
start=types.Position(line=line_number - 1, character=0),
521+
end=types.Position(line=end_line_number - 1, character=0),
523522
),
524523
markdown_description=func.__doc__ if func.__doc__ else None,
525524
)
@@ -811,8 +810,8 @@ def _get_column_table_range(column: exp.Column, read_file: t.List[str]) -> Range
811810
end_range = TokenPositionDetails.from_meta(table_parts[-1].meta).to_range(read_file)
812811

813812
return Range(
814-
start=to_lsp_position(start_range.start),
815-
end=to_lsp_position(end_range.end),
813+
start=start_range.start,
814+
end=end_range.end,
816815
)
817816

818817

@@ -822,7 +821,7 @@ def _process_column_references(
822821
read_file: t.List[str],
823822
referenced_model_path: Path,
824823
description: t.Optional[str] = None,
825-
yaml_target_range: t.Optional[Range] = None,
824+
yaml_target_range: t.Optional[types.Range] = None,
826825
reference_type: t.Literal["model", "external_model", "cte"] = "model",
827826
default_catalog: t.Optional[str] = None,
828827
dialect: t.Optional[str] = None,
@@ -914,6 +913,8 @@ def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
914913
# Get size of block by taking the earliest line/col in the items block and the last line/col of the block
915914
position_data = item.lc.data["name"] # type: ignore
916915
start = Position(line=position_data[2], character=position_data[3])
917-
end = Position(line=position_data[2], character=position_data[3] + len(item["name"]))
916+
end = Position(
917+
line=position_data[2], character=position_data[3] + len(item["name"])
918+
)
918919
return Range(start=start, end=end)
919920
return None

sqlmesh/lsp/rename.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010

1111
from sqlmesh.lsp.context import LSPContext
12+
from sqlmesh.lsp.helpers import to_sqlmesh_position, to_lsp_range
1213
from sqlmesh.lsp.reference import (
1314
_position_within_range,
1415
get_cte_references,
@@ -18,7 +19,7 @@
1819

1920

2021
def prepare_rename(
21-
lsp_context: LSPContext, document_uri: URI, position: Position
22+
lsp_context: LSPContext, document_uri: URI, lsp_position: Position
2223
) -> t.Optional[PrepareRenameResult_Type1]:
2324
"""
2425
Prepare for rename operation by checking if the symbol at the position can be renamed.
@@ -32,6 +33,7 @@ def prepare_rename(
3233
PrepareRenameResult if the symbol can be renamed, None otherwise
3334
"""
3435
# Check if there's a CTE at this position
36+
position = to_sqlmesh_position(lsp_position)
3537
cte_references = get_cte_references(lsp_context, document_uri, position)
3638
if cte_references:
3739
# Find the target CTE definition to get its range
@@ -46,14 +48,16 @@ def prepare_rename(
4648
target_range = ref.target_range
4749
break
4850
if target_range:
49-
return PrepareRenameResult_Type1(range=target_range, placeholder="cte_name")
51+
return PrepareRenameResult_Type1(
52+
range=to_lsp_range(target_range), placeholder="cte_name"
53+
)
5054

5155
# For now, only CTEs are supported
5256
return None
5357

5458

5559
def rename_symbol(
56-
lsp_context: LSPContext, document_uri: URI, position: Position, new_name: str
60+
lsp_context: LSPContext, document_uri: URI, lsp_position: Position, new_name: str
5761
) -> t.Optional[WorkspaceEdit]:
5862
"""
5963
Perform rename operation on the symbol at the given position.
@@ -68,7 +72,9 @@ def rename_symbol(
6872
WorkspaceEdit with the changes, or None if no symbol to rename
6973
"""
7074
# Check if there's a CTE at this position
71-
cte_references = get_cte_references(lsp_context, document_uri, position)
75+
cte_references = get_cte_references(
76+
lsp_context, document_uri, to_sqlmesh_position(lsp_position)
77+
)
7278
if cte_references:
7379
return _rename_cte(cte_references, new_name)
7480

@@ -95,7 +101,7 @@ def _rename_cte(cte_references: t.List[LSPCteReference], new_name: str) -> Works
95101
changes[uri] = []
96102

97103
# Create a text edit for this reference
98-
text_edit = TextEdit(range=ref.range, new_text=new_name)
104+
text_edit = TextEdit(range=to_lsp_range(ref.range), new_text=new_name)
99105
changes[uri].append(text_edit)
100106

101107
return WorkspaceEdit(changes=changes)
@@ -119,7 +125,7 @@ def get_document_highlights(
119125
List of DocumentHighlight objects or None if no symbol found
120126
"""
121127
# Check if there's a CTE at this position
122-
cte_references = get_cte_references(lsp_context, document_uri, position)
128+
cte_references = get_cte_references(lsp_context, document_uri, to_sqlmesh_position(position))
123129
if cte_references:
124130
highlights = []
125131
for ref in cte_references:
@@ -130,7 +136,7 @@ def get_document_highlights(
130136
else DocumentHighlightKind.Read
131137
)
132138

133-
highlights.append(DocumentHighlight(range=ref.range, kind=kind))
139+
highlights.append(DocumentHighlight(range=to_lsp_range(ref.range), kind=kind))
134140
return highlights
135141

136142
# For now, only CTEs are supported

tests/lsp/test_reference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from lsprotocol.types import Position
21
from sqlmesh.core.context import Context
2+
from sqlmesh.core.linter.rule import Position
33
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
44
from sqlmesh.lsp.reference import LSPModelReference, get_model_definitions_for_a_path, by_position
55
from sqlmesh.lsp.uri import URI

tests/lsp/test_reference_external_model.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from pathlib import Path
22

3-
from lsprotocol.types import Position
4-
53
from sqlmesh import Config
64
from sqlmesh.core.context import Context
75
from sqlmesh.core.linter.helpers import read_range_from_file
6+
from sqlmesh.core.linter.rule import Position
87
from sqlmesh.lsp.context import LSPContext, ModelTarget
9-
from sqlmesh.lsp.helpers import to_sqlmesh_range
108
from sqlmesh.lsp.reference import get_references, LSPExternalModelReference
119
from sqlmesh.lsp.uri import URI
1210
from tests.utils.test_filesystem import create_temp_file
@@ -34,14 +32,14 @@ def test_reference() -> None:
3432
assert path is not None
3533
assert str(path).endswith("external_models.yaml")
3634

37-
source_range = read_range_from_file(customers, to_sqlmesh_range(reference.range))
35+
source_range = read_range_from_file(customers, reference.range)
3836
assert source_range == "raw.demographics"
3937

4038
if reference.target_range is None:
4139
raise AssertionError("Reference target range should not be None")
4240
path = reference.path
4341
assert path is not None
44-
target_range = read_range_from_file(path, to_sqlmesh_range(reference.target_range))
42+
target_range = read_range_from_file(path, reference.target_range)
4543
assert target_range == "raw.demographics"
4644

4745

@@ -61,4 +59,4 @@ def test_unregistered_external_model(tmp_path: Path):
6159
assert reference.path is None
6260
assert reference.target_range is None
6361
assert reference.markdown_description == "Unregistered external model"
64-
assert read_range_from_file(model_path, to_sqlmesh_range(reference.range)) == "external_model"
62+
assert read_range_from_file(model_path, reference.range) == "external_model"

tests/lsp/test_reference_model_column_prefix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from pathlib import Path
22

3-
from lsprotocol.types import Position
43
from sqlmesh.cli.project_init import init_example_project
54
from sqlmesh.core.context import Context
5+
from sqlmesh.core.linter.rule import Position
66
from sqlmesh.lsp.context import LSPContext, ModelTarget
77
from sqlmesh.lsp.reference import get_all_references
88
from sqlmesh.lsp.uri import URI

0 commit comments

Comments
 (0)