Skip to content

Commit 8f2947b

Browse files
authored
feat(vscode): add test recognition to vscode (#5019)
1 parent 683a373 commit 8f2947b

File tree

9 files changed

+553
-1
lines changed

9 files changed

+553
-1
lines changed

sqlmesh/lsp/context.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from sqlmesh.core.context import Context
44
import typing as t
55

6+
from sqlmesh.core.linter.rule import Range
67
from sqlmesh.core.model.definition import SqlModel
78
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
8-
from sqlmesh.lsp.custom import ModelForRendering
9+
from sqlmesh.lsp.custom import ModelForRendering, TestEntry, RunTestResponse
910
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
11+
from sqlmesh.lsp.tests_ranges import get_test_ranges
1012
from sqlmesh.lsp.uri import URI
1113
from lsprotocol import types
1214

@@ -63,6 +65,71 @@ def __init__(self, context: Context) -> None:
6365
**audit_map,
6466
}
6567

68+
def list_workspace_tests(self) -> t.List[TestEntry]:
69+
"""List all tests in the workspace."""
70+
tests = self.context.load_model_tests()
71+
72+
# Use a set to ensure unique URIs
73+
unique_test_uris = {URI.from_path(test.path).value for test in tests}
74+
test_uris: t.Dict[str, t.Dict[str, Range]] = {}
75+
for uri in unique_test_uris:
76+
test_ranges = get_test_ranges(URI(uri).to_path())
77+
if uri not in test_uris:
78+
test_uris[uri] = {}
79+
test_uris[uri].update(test_ranges)
80+
return [
81+
TestEntry(
82+
name=test.test_name,
83+
uri=URI.from_path(test.path).value,
84+
range=test_uris.get(URI.from_path(test.path).value, {}).get(test.test_name),
85+
)
86+
for test in tests
87+
]
88+
89+
def get_document_tests(self, uri: URI) -> t.List[TestEntry]:
90+
"""Get tests for a specific document.
91+
92+
Args:
93+
uri: The URI of the file to get tests for.
94+
95+
Returns:
96+
List of TestEntry objects for the specified document.
97+
"""
98+
tests = self.context.load_model_tests(tests=[str(uri.to_path())])
99+
test_ranges = get_test_ranges(uri.to_path())
100+
return [
101+
TestEntry(
102+
name=test.test_name,
103+
uri=URI.from_path(test.path).value,
104+
range=test_ranges.get(test.test_name),
105+
)
106+
for test in tests
107+
]
108+
109+
def run_test(self, uri: URI, test_name: str) -> RunTestResponse:
110+
"""Run a specific test for a model.
111+
112+
Args:
113+
uri: The URI of the file containing the test.
114+
test_name: The name of the test to run.
115+
116+
Returns:
117+
List of annotated rule violations from the test run.
118+
"""
119+
path = uri.to_path()
120+
results = self.context.test(
121+
tests=[str(path)],
122+
match_patterns=[test_name],
123+
)
124+
if results.testsRun != 1:
125+
raise ValueError(f"Expected to run 1 test, but ran {results.testsRun} tests.")
126+
if len(results.successes) == 1:
127+
return RunTestResponse(success=True)
128+
return RunTestResponse(
129+
success=False,
130+
error_message=str(results.failures[0][1]),
131+
)
132+
66133
def render_model(self, uri: URI) -> t.List[RenderModelEntry]:
67134
"""Get rendered models for a file, using cache when available.
68135

sqlmesh/lsp/custom.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from lsprotocol import types
22
import typing as t
3+
4+
from sqlmesh.core.linter.rule import Range
35
from sqlmesh.utils.pydantic import PydanticModel
46

57

@@ -143,3 +145,52 @@ class FormatProjectResponse(CustomMethodResponseBaseClass):
143145
"""
144146

145147
pass
148+
149+
150+
LIST_WORKSPACE_TESTS_FEATURE = "sqlmesh/list_workspace_tests"
151+
152+
153+
class ListWorkspaceTestsRequest(CustomMethodRequestBaseClass):
154+
"""
155+
Request to list all tests in the current project.
156+
"""
157+
158+
pass
159+
160+
161+
class TestEntry(PydanticModel):
162+
"""
163+
An entry representing a test in the workspace.
164+
"""
165+
166+
name: str
167+
uri: str
168+
range: Range
169+
170+
171+
class ListWorkspaceTestsResponse(CustomMethodResponseBaseClass):
172+
tests: t.List[TestEntry]
173+
174+
175+
LIST_DOCUMENT_TESTS_FEATURE = "sqlmesh/list_document_tests"
176+
177+
178+
class ListDocumentTestsRequest(CustomMethodRequestBaseClass):
179+
textDocument: types.TextDocumentIdentifier
180+
181+
182+
class ListDocumentTestsResponse(CustomMethodResponseBaseClass):
183+
tests: t.List[TestEntry]
184+
185+
186+
RUN_TEST_FEATURE = "sqlmesh/run_test"
187+
188+
189+
class RunTestRequest(CustomMethodRequestBaseClass):
190+
textDocument: types.TextDocumentIdentifier
191+
testName: str
192+
193+
194+
class RunTestResponse(CustomMethodResponseBaseClass):
195+
success: bool
196+
error_message: t.Optional[str] = None

sqlmesh/lsp/main.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@
4646
FormatProjectRequest,
4747
FormatProjectResponse,
4848
CustomMethod,
49+
LIST_WORKSPACE_TESTS_FEATURE,
50+
ListWorkspaceTestsRequest,
51+
ListWorkspaceTestsResponse,
52+
LIST_DOCUMENT_TESTS_FEATURE,
53+
ListDocumentTestsRequest,
54+
ListDocumentTestsResponse,
55+
RUN_TEST_FEATURE,
56+
RunTestRequest,
57+
RunTestResponse,
4958
)
5059
from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic
5160
from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position
@@ -127,11 +136,58 @@ def __init__(
127136
API_FEATURE: self._custom_api,
128137
SUPPORTED_METHODS_FEATURE: self._custom_supported_methods,
129138
FORMAT_PROJECT_FEATURE: self._custom_format_project,
139+
LIST_WORKSPACE_TESTS_FEATURE: self._list_workspace_tests,
140+
LIST_DOCUMENT_TESTS_FEATURE: self._list_document_tests,
141+
RUN_TEST_FEATURE: self._run_test,
130142
}
131143

132144
# Register LSP features (e.g., formatting, hover, etc.)
133145
self._register_features()
134146

147+
def _list_workspace_tests(
148+
self,
149+
ls: LanguageServer,
150+
params: ListWorkspaceTestsRequest,
151+
) -> ListWorkspaceTestsResponse:
152+
"""List all tests in the current workspace."""
153+
try:
154+
context = self._context_get_or_load()
155+
tests = context.list_workspace_tests()
156+
return ListWorkspaceTestsResponse(tests=tests)
157+
except Exception as e:
158+
ls.log_trace(f"Error listing workspace tests: {e}")
159+
return ListWorkspaceTestsResponse(tests=[])
160+
161+
def _list_document_tests(
162+
self,
163+
ls: LanguageServer,
164+
params: ListDocumentTestsRequest,
165+
) -> ListDocumentTestsResponse:
166+
"""List tests for a specific document."""
167+
try:
168+
uri = URI(params.textDocument.uri)
169+
context = self._context_get_or_load(uri)
170+
tests = context.get_document_tests(uri)
171+
return ListDocumentTestsResponse(tests=tests)
172+
except Exception as e:
173+
ls.log_trace(f"Error listing document tests: {e}")
174+
return ListDocumentTestsResponse(tests=[])
175+
176+
def _run_test(
177+
self,
178+
ls: LanguageServer,
179+
params: RunTestRequest,
180+
) -> RunTestResponse:
181+
"""Run a specific test."""
182+
try:
183+
uri = URI(params.textDocument.uri)
184+
context = self._context_get_or_load(uri)
185+
result = context.run_test(uri, params.testName)
186+
return result
187+
except Exception as e:
188+
ls.log_trace(f"Error running test: {e}")
189+
return RunTestResponse(success=False, response_error=str(e))
190+
135191
# All the custom LSP methods are registered here and prefixed with _custom
136192
def _custom_all_models(self, ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse:
137193
uri = URI(params.textDocument.uri)

sqlmesh/lsp/tests_ranges.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
Provides helper functions to get ranges of tests in SQLMesh LSP.
3+
"""
4+
5+
from pathlib import Path
6+
7+
from sqlmesh.core.linter.rule import Range, Position
8+
from ruamel import yaml
9+
from ruamel.yaml.comments import CommentedMap
10+
import typing as t
11+
12+
13+
def get_test_ranges(
14+
path: Path,
15+
) -> t.Dict[str, Range]:
16+
"""
17+
Test files are yaml files with a stucture of dict to test information. This returns a dictionary
18+
with the test name as the key and the range of the test in the file as the value.
19+
"""
20+
test_ranges: t.Dict[str, Range] = {}
21+
22+
with open(path, "r", encoding="utf-8") as file:
23+
content = file.read()
24+
25+
# Parse YAML to get line numbers
26+
yaml_obj = yaml.YAML()
27+
yaml_obj.preserve_quotes = True
28+
data = yaml_obj.load(content)
29+
30+
if not isinstance(data, dict):
31+
raise ValueError("Invalid test file format: expected a dictionary at the top level.")
32+
33+
# For each top-level key (test name), find its range
34+
for test_name in data:
35+
if isinstance(data, CommentedMap) and test_name in data.lc.data:
36+
# Get line and column info from ruamel yaml
37+
line_info = data.lc.data[test_name]
38+
start_line = line_info[0] # 0-based line number
39+
start_col = line_info[1] # 0-based column number
40+
41+
# Find the end of this test by looking for the next test or end of file
42+
lines = content.splitlines()
43+
end_line = start_line
44+
45+
# Find where this test ends by looking for the next top-level key
46+
# or the end of the file
47+
for i in range(start_line + 1, len(lines)):
48+
line = lines[i]
49+
# Check if this line starts a new top-level key (no leading spaces)
50+
if line and not line[0].isspace() and ":" in line:
51+
end_line = i - 1
52+
break
53+
else:
54+
# This test goes to the end of the file
55+
end_line = len(lines) - 1
56+
57+
# Create the range
58+
test_ranges[test_name] = Range(
59+
start=Position(line=start_line, character=start_col),
60+
end=Position(
61+
line=end_line, character=len(lines[end_line]) if end_line < len(lines) else 0
62+
),
63+
)
64+
65+
return test_ranges

tests/lsp/test_context.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from pathlib import Path
2+
13
from sqlmesh.core.context import Context
24
from sqlmesh.lsp.context import LSPContext, ModelTarget
5+
from sqlmesh.lsp.uri import URI
36

47

58
def test_lsp_context():
@@ -18,3 +21,43 @@ def test_lsp_context():
1821
# Check that the value is a ModelInfo with the expected model name
1922
assert isinstance(lsp_context.map[active_customers_key], ModelTarget)
2023
assert "sushi.active_customers" in lsp_context.map[active_customers_key].names
24+
25+
26+
def test_lsp_context_list_workspace_tests():
27+
context = Context(paths=["examples/sushi"])
28+
lsp_context = LSPContext(context)
29+
30+
# List workspace tests
31+
tests = lsp_context.list_workspace_tests()
32+
33+
# Check that the tests are returned correctly
34+
assert len(tests) == 3
35+
assert any(test.name == "test_order_items" for test in tests)
36+
37+
38+
def test_lsp_context_get_document_tests():
39+
test_path = Path.cwd() / "examples/sushi/tests/test_order_items.yaml"
40+
uri = URI.from_path(test_path)
41+
42+
context = Context(paths=["examples/sushi"])
43+
lsp_context = LSPContext(context)
44+
tests = lsp_context.get_document_tests(uri)
45+
46+
assert len(tests) == 1
47+
assert tests[0].uri == uri.value
48+
assert tests[0].name == "test_order_items"
49+
50+
51+
def test_lsp_context_run_test():
52+
test_path = Path.cwd() / "examples/sushi/tests/test_order_items.yaml"
53+
uri = URI.from_path(test_path)
54+
55+
context = Context(paths=["examples/sushi"])
56+
lsp_context = LSPContext(context)
57+
58+
# Run the test
59+
result = lsp_context.run_test(uri, "test_order_items")
60+
61+
# Check that the result is not None and has the expected properties
62+
assert result is not None
63+
assert result.success is True

vscode/extension/src/extension.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ import { selector, completionProvider } from './completion/completion'
2121
import { LineagePanel } from './webviews/lineagePanel'
2222
import { RenderedModelProvider } from './providers/renderedModelProvider'
2323
import { sleep } from './utilities/sleep'
24+
import {
25+
controller as testController,
26+
setupTestController,
27+
} from './tests/tests'
2428

2529
let lspClient: LSPClient | undefined
2630

@@ -128,6 +132,7 @@ export async function activate(context: vscode.ExtensionContext) {
128132
)
129133
}
130134
context.subscriptions.push(lspClient)
135+
context.subscriptions.push(setupTestController(lspClient))
131136
} else {
132137
lspClient = new LSPClient()
133138
const result = await lspClient.start(invokedByUser)
@@ -140,6 +145,7 @@ export async function activate(context: vscode.ExtensionContext) {
140145
)
141146
} else {
142147
context.subscriptions.push(lspClient)
148+
context.subscriptions.push(setupTestController(lspClient))
143149
}
144150
}
145151
}
@@ -175,6 +181,8 @@ export async function activate(context: vscode.ExtensionContext) {
175181
)
176182
} else {
177183
context.subscriptions.push(lspClient)
184+
context.subscriptions.push(setupTestController(lspClient))
185+
context.subscriptions.push(testController)
178186
}
179187

180188
if (lspClient && !lspClient.hasCompletionCapability()) {

0 commit comments

Comments
 (0)