Skip to content

Commit 217e04c

Browse files
committed
feat(vscode): add test recognition to vscode
1 parent 57d2e9f commit 217e04c

File tree

9 files changed

+451
-3
lines changed

9 files changed

+451
-3
lines changed

sqlmesh/lsp/context.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from sqlmesh.core.model.definition import SqlModel
77
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
8-
from sqlmesh.lsp.custom import ModelForRendering
8+
from sqlmesh.lsp.custom import ModelForRendering, TestEntry, RunTestResponse
99
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
1010
from sqlmesh.lsp.uri import URI
1111
from lsprotocol import types
@@ -63,6 +63,56 @@ def __init__(self, context: Context) -> None:
6363
**audit_map,
6464
}
6565

66+
def list_workspace_tests(self) -> t.List[TestEntry]:
67+
"""List all tests in the workspace."""
68+
tests = self.context.load_model_tests()
69+
# TODO Probably want to get all the positions for the tests
70+
return [
71+
TestEntry(
72+
name=test.test_name,
73+
uri=URI.from_path(test.path).value,
74+
)
75+
for test in tests
76+
]
77+
78+
def get_document_tests(self, uri: URI) -> t.List[TestEntry]:
79+
"""Get tests for a specific document.
80+
81+
Args:
82+
uri: The URI of the file to get tests for.
83+
84+
Returns:
85+
List of TestEntry objects for the specified document.
86+
"""
87+
tests = self.context.load_model_tests(tests=[str(uri.to_path())])
88+
return [
89+
TestEntry(
90+
name=test.test_name,
91+
# TODO NEED TO ADD RANGE
92+
uri=URI.from_path(test.path).value,
93+
)
94+
for test in tests
95+
]
96+
97+
def run_test(self, uri: URI, test_name: str) -> RunTestResponse:
98+
"""Run a specific test for a model.
99+
100+
Args:
101+
uri: The URI of the file containing the test.
102+
test_name: The name of the test to run.
103+
104+
Returns:
105+
List of annotated rule violations from the test run.
106+
"""
107+
path = uri.to_path()
108+
results = self.context.test(
109+
tests=[str(path)],
110+
match_patterns=[test_name],
111+
)
112+
if results.testsRun != 1:
113+
raise ValueError(f"Expected to run 1 test, but ran {results.testsRun} tests.")
114+
return RunTestResponse(success=len(results.successes) == 1)
115+
66116
def render_model(self, uri: URI) -> t.List[RenderModelEntry]:
67117
"""Get rendered models for a file, using cache when available.
68118

sqlmesh/lsp/custom.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,50 @@ class FormatProjectResponse(CustomMethodResponseBaseClass):
143143
"""
144144

145145
pass
146+
147+
148+
LIST_WORKSPACE_TESTS_FEATURE = "sqlmesh/list_workspace_tests"
149+
150+
151+
class ListWorkspaceTestsRequest(CustomMethodRequestBaseClass):
152+
"""
153+
Request to list all tests in the current project.
154+
"""
155+
156+
pass
157+
158+
159+
class TestEntry(PydanticModel):
160+
"""
161+
An entry representing a test in the workspace.
162+
"""
163+
164+
name: str
165+
uri: str
166+
167+
168+
class ListWorkspaceTestsResponse(CustomMethodResponseBaseClass):
169+
tests: t.List[TestEntry]
170+
171+
172+
LIST_DOCUMENT_TESTS_FEATURE = "sqlmesh/list_document_tests"
173+
174+
175+
class ListDocumentTestsRequest(CustomMethodRequestBaseClass):
176+
textDocument: types.TextDocumentIdentifier
177+
178+
179+
class ListDocumentTestsResponse(CustomMethodResponseBaseClass):
180+
tests: t.List[TestEntry]
181+
182+
183+
RUN_TEST_FEATURE = "sqlmesh/run_test"
184+
185+
186+
class RunTestRequest(CustomMethodRequestBaseClass):
187+
textDocument: types.TextDocumentIdentifier
188+
testName: str
189+
190+
191+
class RunTestResponse(CustomMethodResponseBaseClass):
192+
success: bool

sqlmesh/lsp/main.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@
4646
FormatProjectRequest,
4747
FormatProjectResponse,
4848
CustomMethod,
49+
LIST_WORKSPACE_TESTS_FEATURE,
50+
ListWorkspaceTestsRequest,
51+
ListWorkspaceTestsResponse,
52+
LIST_DOCUMENT_TESTS_FEATURE,
53+
ListDocumentTestsRequest,
54+
ListDocumentTestsResponse,
55+
RUN_TEST_FEATURE,
56+
RunTestRequest,
57+
RunTestResponse,
4958
)
5059
from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic
5160
from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position
@@ -127,11 +136,58 @@ def __init__(
127136
API_FEATURE: self._custom_api,
128137
SUPPORTED_METHODS_FEATURE: self._custom_supported_methods,
129138
FORMAT_PROJECT_FEATURE: self._custom_format_project,
139+
LIST_WORKSPACE_TESTS_FEATURE: self._list_workspace_tests,
140+
LIST_DOCUMENT_TESTS_FEATURE: self._list_document_tests,
141+
RUN_TEST_FEATURE: self._run_test,
130142
}
131143

132144
# Register LSP features (e.g., formatting, hover, etc.)
133145
self._register_features()
134146

147+
def _list_workspace_tests(
148+
self,
149+
ls: LanguageServer,
150+
params: ListWorkspaceTestsRequest,
151+
) -> ListWorkspaceTestsResponse:
152+
"""List all tests in the current workspace."""
153+
try:
154+
context = self._context_get_or_load()
155+
tests = context.list_workspace_tests()
156+
return ListWorkspaceTestsResponse(tests=tests)
157+
except Exception as e:
158+
ls.log_trace(f"Error listing workspace tests: {e}")
159+
return ListWorkspaceTestsResponse(tests=[])
160+
161+
def _list_document_tests(
162+
self,
163+
ls: LanguageServer,
164+
params: ListDocumentTestsRequest,
165+
) -> ListDocumentTestsResponse:
166+
"""List tests for a specific document."""
167+
try:
168+
uri = URI(params.textDocument.uri)
169+
context = self._context_get_or_load(uri)
170+
tests = context.get_document_tests(uri)
171+
return ListDocumentTestsResponse(tests=tests)
172+
except Exception as e:
173+
ls.log_trace(f"Error listing document tests: {e}")
174+
return ListDocumentTestsResponse(tests=[])
175+
176+
def _run_test(
177+
self,
178+
ls: LanguageServer,
179+
params: RunTestRequest,
180+
) -> RunTestResponse:
181+
"""Run a specific test."""
182+
try:
183+
uri = URI(params.textDocument.uri)
184+
context = self._context_get_or_load(uri)
185+
result = context.run_test(uri, params.testName)
186+
return result
187+
except Exception as e:
188+
ls.log_trace(f"Error running test: {e}")
189+
return RunTestResponse(success=False, response_error=str(e))
190+
135191
# All the custom LSP methods are registered here and prefixed with _custom
136192
def _custom_all_models(self, ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse:
137193
uri = URI(params.textDocument.uri)

tests/lsp/test_context.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from pathlib import Path
2+
13
from sqlmesh.core.context import Context
24
from sqlmesh.lsp.context import LSPContext, ModelTarget
5+
from sqlmesh.lsp.uri import URI
36

47

58
def test_lsp_context():
@@ -18,3 +21,43 @@ def test_lsp_context():
1821
# Check that the value is a ModelInfo with the expected model name
1922
assert isinstance(lsp_context.map[active_customers_key], ModelTarget)
2023
assert "sushi.active_customers" in lsp_context.map[active_customers_key].names
24+
25+
26+
def test_lsp_context_list_workspace_tests():
27+
context = Context(paths=["examples/sushi"])
28+
lsp_context = LSPContext(context)
29+
30+
# List workspace tests
31+
tests = lsp_context.list_workspace_tests()
32+
33+
# Check that the tests are returned correctly
34+
assert len(tests) == 0
35+
assert any(test.name == "test_order_items" for test in tests)
36+
37+
38+
def test_lsp_context_get_deocument_tests():
39+
test_path = Path.cwd() / "examples/sushi/tests/test_order_items.yaml"
40+
uri = URI.from_path(test_path)
41+
42+
context = Context(paths=["examples/sushi"])
43+
lsp_context = LSPContext(context)
44+
tests = lsp_context.get_document_tests(uri)
45+
46+
assert len(tests) == 1
47+
assert tests[0].uri == uri.value
48+
assert tests[0].name == "test_order_items"
49+
50+
51+
def test_lsp_context_run_test():
52+
test_path = Path.cwd() / "examples/sushi/tests/test_order_items.yaml"
53+
uri = URI.from_path(test_path)
54+
55+
context = Context(paths=["examples/sushi"])
56+
lsp_context = LSPContext(context)
57+
58+
# Run the test
59+
result = lsp_context.run_test(uri, "test_order_items")
60+
61+
# Check that the result is not None and has the expected properties
62+
assert result is not None
63+
assert result.success is True

vscode/extension/src/extension.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ import { selector, completionProvider } from './completion/completion'
2121
import { LineagePanel } from './webviews/lineagePanel'
2222
import { RenderedModelProvider } from './providers/renderedModelProvider'
2323
import { sleep } from './utilities/sleep'
24+
import {
25+
controller as testController,
26+
setupTestController,
27+
} from './tests/tests'
2428

2529
let lspClient: LSPClient | undefined
2630

@@ -128,6 +132,7 @@ export async function activate(context: vscode.ExtensionContext) {
128132
)
129133
}
130134
context.subscriptions.push(lspClient)
135+
context.subscriptions.push(setupTestController(lspClient))
131136
} else {
132137
lspClient = new LSPClient()
133138
const result = await lspClient.start()
@@ -140,6 +145,7 @@ export async function activate(context: vscode.ExtensionContext) {
140145
)
141146
} else {
142147
context.subscriptions.push(lspClient)
148+
context.subscriptions.push(setupTestController(lspClient))
143149
}
144150
}
145151
}
@@ -175,6 +181,8 @@ export async function activate(context: vscode.ExtensionContext) {
175181
)
176182
} else {
177183
context.subscriptions.push(lspClient)
184+
setupTestController(lspClient)
185+
context.subscriptions.push(testController)
178186
}
179187

180188
if (lspClient && !lspClient.hasCompletionCapability()) {

vscode/extension/src/lsp/custom.ts

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ export type CustomLSPMethods =
3232
| AllModelsForRenderMethod
3333
| SupportedMethodsMethod
3434
| FormatProjectMethod
35+
| ListWorkspaceTests
36+
| ListDocumentTests
37+
| RunTest
3538

3639
interface AllModelsRequest {
3740
textDocument: {
@@ -111,3 +114,55 @@ interface FormatProjectResponse extends BaseResponse {}
111114
interface BaseResponse {
112115
response_error?: string
113116
}
117+
118+
export interface ListWorkspaceTests {
119+
method: 'sqlmesh/list_workspace_tests'
120+
request: ListWorkspaceTestsRequest
121+
response: ListWorkspaceTestsResponse
122+
}
123+
124+
type ListWorkspaceTestsRequest = object
125+
126+
interface TestEntry {
127+
name: string
128+
uri: string
129+
// TODO Probably want to add position at some point
130+
}
131+
132+
interface ListWorkspaceTestsResponse extends BaseResponse {
133+
tests: TestEntry[]
134+
}
135+
136+
export interface ListDocumentTests {
137+
method: 'sqlmesh/list_document_tests'
138+
request: ListDocumentTestsRequest
139+
response: ListDocumentTestsResponse
140+
}
141+
142+
export interface DocumentIdentifier {
143+
uri: string
144+
}
145+
146+
export interface ListDocumentTestsRequest {
147+
textDocument: DocumentIdentifier
148+
}
149+
150+
export interface ListDocumentTestsResponse extends BaseResponse {
151+
tests: TestEntry[]
152+
}
153+
154+
export interface RunTest {
155+
method: 'sqlmesh/run_test'
156+
request: RunTestRequest
157+
response: RunTestResponse
158+
}
159+
160+
export interface RunTestRequest {
161+
textDocument: DocumentIdentifier
162+
testName: string
163+
}
164+
165+
export interface RunTestResponse extends BaseResponse {
166+
success: boolean
167+
// TODO Add message of why not passed
168+
}

vscode/extension/src/lsp/lsp.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ export class LSPClient implements Disposable {
252252

253253
try {
254254
const result = await this.client.sendRequest<Response>(method, request)
255-
if (result.response_error) {
256-
return err(result.response_error)
255+
if ((result as any).response_error) {
256+
return err((result as any).response_error)
257257
}
258258
return ok(result)
259259
} catch (error) {

0 commit comments

Comments
 (0)