From d1dce743f25f59c81ba94f65a98388309e0d1244 Mon Sep 17 00:00:00 2001 From: Ben King <9087625+benfdking@users.noreply.github.com> Date: Fri, 25 Jul 2025 10:45:57 +0100 Subject: [PATCH] feat(vscode): add test recognition to vscode --- sqlmesh/lsp/context.py | 69 +++++++++++- sqlmesh/lsp/custom.py | 51 +++++++++ sqlmesh/lsp/main.py | 56 ++++++++++ sqlmesh/lsp/tests_ranges.py | 65 +++++++++++ tests/lsp/test_context.py | 43 ++++++++ vscode/extension/src/extension.ts | 8 ++ vscode/extension/src/lsp/custom.ts | 65 +++++++++++ vscode/extension/src/tests/tests.ts | 155 +++++++++++++++++++++++++++ vscode/extension/tests/tests.spec.ts | 42 ++++++++ 9 files changed, 553 insertions(+), 1 deletion(-) create mode 100644 sqlmesh/lsp/tests_ranges.py create mode 100644 vscode/extension/src/tests/tests.ts create mode 100644 vscode/extension/tests/tests.spec.ts diff --git a/sqlmesh/lsp/context.py b/sqlmesh/lsp/context.py index 30adfce5a2..43eb9c8f16 100644 --- a/sqlmesh/lsp/context.py +++ b/sqlmesh/lsp/context.py @@ -3,10 +3,12 @@ from sqlmesh.core.context import Context import typing as t +from sqlmesh.core.linter.rule import Range from sqlmesh.core.model.definition import SqlModel from sqlmesh.core.linter.definition import AnnotatedRuleViolation -from sqlmesh.lsp.custom import ModelForRendering +from sqlmesh.lsp.custom import ModelForRendering, TestEntry, RunTestResponse from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry +from sqlmesh.lsp.tests_ranges import get_test_ranges from sqlmesh.lsp.uri import URI from lsprotocol import types @@ -63,6 +65,71 @@ def __init__(self, context: Context) -> None: **audit_map, } + def list_workspace_tests(self) -> t.List[TestEntry]: + """List all tests in the workspace.""" + tests = self.context.load_model_tests() + + # Use a set to ensure unique URIs + unique_test_uris = {URI.from_path(test.path).value for test in tests} + test_uris: t.Dict[str, t.Dict[str, Range]] = {} + for uri in unique_test_uris: + test_ranges = get_test_ranges(URI(uri).to_path()) + if uri not in test_uris: + test_uris[uri] = {} + test_uris[uri].update(test_ranges) + return [ + TestEntry( + name=test.test_name, + uri=URI.from_path(test.path).value, + range=test_uris.get(URI.from_path(test.path).value, {}).get(test.test_name), + ) + for test in tests + ] + + def get_document_tests(self, uri: URI) -> t.List[TestEntry]: + """Get tests for a specific document. + + Args: + uri: The URI of the file to get tests for. + + Returns: + List of TestEntry objects for the specified document. + """ + tests = self.context.load_model_tests(tests=[str(uri.to_path())]) + test_ranges = get_test_ranges(uri.to_path()) + return [ + TestEntry( + name=test.test_name, + uri=URI.from_path(test.path).value, + range=test_ranges.get(test.test_name), + ) + for test in tests + ] + + def run_test(self, uri: URI, test_name: str) -> RunTestResponse: + """Run a specific test for a model. + + Args: + uri: The URI of the file containing the test. + test_name: The name of the test to run. + + Returns: + List of annotated rule violations from the test run. + """ + path = uri.to_path() + results = self.context.test( + tests=[str(path)], + match_patterns=[test_name], + ) + if results.testsRun != 1: + raise ValueError(f"Expected to run 1 test, but ran {results.testsRun} tests.") + if len(results.successes) == 1: + return RunTestResponse(success=True) + return RunTestResponse( + success=False, + error_message=str(results.failures[0][1]), + ) + def render_model(self, uri: URI) -> t.List[RenderModelEntry]: """Get rendered models for a file, using cache when available. diff --git a/sqlmesh/lsp/custom.py b/sqlmesh/lsp/custom.py index 618b4a44bc..8ad6418401 100644 --- a/sqlmesh/lsp/custom.py +++ b/sqlmesh/lsp/custom.py @@ -1,5 +1,7 @@ from lsprotocol import types import typing as t + +from sqlmesh.core.linter.rule import Range from sqlmesh.utils.pydantic import PydanticModel @@ -143,3 +145,52 @@ class FormatProjectResponse(CustomMethodResponseBaseClass): """ pass + + +LIST_WORKSPACE_TESTS_FEATURE = "sqlmesh/list_workspace_tests" + + +class ListWorkspaceTestsRequest(CustomMethodRequestBaseClass): + """ + Request to list all tests in the current project. + """ + + pass + + +class TestEntry(PydanticModel): + """ + An entry representing a test in the workspace. + """ + + name: str + uri: str + range: Range + + +class ListWorkspaceTestsResponse(CustomMethodResponseBaseClass): + tests: t.List[TestEntry] + + +LIST_DOCUMENT_TESTS_FEATURE = "sqlmesh/list_document_tests" + + +class ListDocumentTestsRequest(CustomMethodRequestBaseClass): + textDocument: types.TextDocumentIdentifier + + +class ListDocumentTestsResponse(CustomMethodResponseBaseClass): + tests: t.List[TestEntry] + + +RUN_TEST_FEATURE = "sqlmesh/run_test" + + +class RunTestRequest(CustomMethodRequestBaseClass): + textDocument: types.TextDocumentIdentifier + testName: str + + +class RunTestResponse(CustomMethodResponseBaseClass): + success: bool + error_message: t.Optional[str] = None diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index 26100c1092..3839245a08 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -46,6 +46,15 @@ FormatProjectRequest, FormatProjectResponse, CustomMethod, + LIST_WORKSPACE_TESTS_FEATURE, + ListWorkspaceTestsRequest, + ListWorkspaceTestsResponse, + LIST_DOCUMENT_TESTS_FEATURE, + ListDocumentTestsRequest, + ListDocumentTestsResponse, + RUN_TEST_FEATURE, + RunTestRequest, + RunTestResponse, ) from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position @@ -127,11 +136,58 @@ def __init__( API_FEATURE: self._custom_api, SUPPORTED_METHODS_FEATURE: self._custom_supported_methods, FORMAT_PROJECT_FEATURE: self._custom_format_project, + LIST_WORKSPACE_TESTS_FEATURE: self._list_workspace_tests, + LIST_DOCUMENT_TESTS_FEATURE: self._list_document_tests, + RUN_TEST_FEATURE: self._run_test, } # Register LSP features (e.g., formatting, hover, etc.) self._register_features() + def _list_workspace_tests( + self, + ls: LanguageServer, + params: ListWorkspaceTestsRequest, + ) -> ListWorkspaceTestsResponse: + """List all tests in the current workspace.""" + try: + context = self._context_get_or_load() + tests = context.list_workspace_tests() + return ListWorkspaceTestsResponse(tests=tests) + except Exception as e: + ls.log_trace(f"Error listing workspace tests: {e}") + return ListWorkspaceTestsResponse(tests=[]) + + def _list_document_tests( + self, + ls: LanguageServer, + params: ListDocumentTestsRequest, + ) -> ListDocumentTestsResponse: + """List tests for a specific document.""" + try: + uri = URI(params.textDocument.uri) + context = self._context_get_or_load(uri) + tests = context.get_document_tests(uri) + return ListDocumentTestsResponse(tests=tests) + except Exception as e: + ls.log_trace(f"Error listing document tests: {e}") + return ListDocumentTestsResponse(tests=[]) + + def _run_test( + self, + ls: LanguageServer, + params: RunTestRequest, + ) -> RunTestResponse: + """Run a specific test.""" + try: + uri = URI(params.textDocument.uri) + context = self._context_get_or_load(uri) + result = context.run_test(uri, params.testName) + return result + except Exception as e: + ls.log_trace(f"Error running test: {e}") + return RunTestResponse(success=False, response_error=str(e)) + # All the custom LSP methods are registered here and prefixed with _custom def _custom_all_models(self, ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse: uri = URI(params.textDocument.uri) diff --git a/sqlmesh/lsp/tests_ranges.py b/sqlmesh/lsp/tests_ranges.py new file mode 100644 index 0000000000..cbcb33d8b6 --- /dev/null +++ b/sqlmesh/lsp/tests_ranges.py @@ -0,0 +1,65 @@ +""" +Provides helper functions to get ranges of tests in SQLMesh LSP. +""" + +from pathlib import Path + +from sqlmesh.core.linter.rule import Range, Position +from ruamel import yaml +from ruamel.yaml.comments import CommentedMap +import typing as t + + +def get_test_ranges( + path: Path, +) -> t.Dict[str, Range]: + """ + Test files are yaml files with a stucture of dict to test information. This returns a dictionary + with the test name as the key and the range of the test in the file as the value. + """ + test_ranges: t.Dict[str, Range] = {} + + with open(path, "r", encoding="utf-8") as file: + content = file.read() + + # Parse YAML to get line numbers + yaml_obj = yaml.YAML() + yaml_obj.preserve_quotes = True + data = yaml_obj.load(content) + + if not isinstance(data, dict): + raise ValueError("Invalid test file format: expected a dictionary at the top level.") + + # For each top-level key (test name), find its range + for test_name in data: + if isinstance(data, CommentedMap) and test_name in data.lc.data: + # Get line and column info from ruamel yaml + line_info = data.lc.data[test_name] + start_line = line_info[0] # 0-based line number + start_col = line_info[1] # 0-based column number + + # Find the end of this test by looking for the next test or end of file + lines = content.splitlines() + end_line = start_line + + # Find where this test ends by looking for the next top-level key + # or the end of the file + for i in range(start_line + 1, len(lines)): + line = lines[i] + # Check if this line starts a new top-level key (no leading spaces) + if line and not line[0].isspace() and ":" in line: + end_line = i - 1 + break + else: + # This test goes to the end of the file + end_line = len(lines) - 1 + + # Create the range + test_ranges[test_name] = Range( + start=Position(line=start_line, character=start_col), + end=Position( + line=end_line, character=len(lines[end_line]) if end_line < len(lines) else 0 + ), + ) + + return test_ranges diff --git a/tests/lsp/test_context.py b/tests/lsp/test_context.py index c26e8f35d5..b463a17139 100644 --- a/tests/lsp/test_context.py +++ b/tests/lsp/test_context.py @@ -1,5 +1,8 @@ +from pathlib import Path + from sqlmesh.core.context import Context from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.uri import URI def test_lsp_context(): @@ -18,3 +21,43 @@ def test_lsp_context(): # Check that the value is a ModelInfo with the expected model name assert isinstance(lsp_context.map[active_customers_key], ModelTarget) assert "sushi.active_customers" in lsp_context.map[active_customers_key].names + + +def test_lsp_context_list_workspace_tests(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # List workspace tests + tests = lsp_context.list_workspace_tests() + + # Check that the tests are returned correctly + assert len(tests) == 3 + assert any(test.name == "test_order_items" for test in tests) + + +def test_lsp_context_get_document_tests(): + test_path = Path.cwd() / "examples/sushi/tests/test_order_items.yaml" + uri = URI.from_path(test_path) + + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + tests = lsp_context.get_document_tests(uri) + + assert len(tests) == 1 + assert tests[0].uri == uri.value + assert tests[0].name == "test_order_items" + + +def test_lsp_context_run_test(): + test_path = Path.cwd() / "examples/sushi/tests/test_order_items.yaml" + uri = URI.from_path(test_path) + + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Run the test + result = lsp_context.run_test(uri, "test_order_items") + + # Check that the result is not None and has the expected properties + assert result is not None + assert result.success is True diff --git a/vscode/extension/src/extension.ts b/vscode/extension/src/extension.ts index de5d35d706..74454f8fdb 100644 --- a/vscode/extension/src/extension.ts +++ b/vscode/extension/src/extension.ts @@ -21,6 +21,10 @@ import { selector, completionProvider } from './completion/completion' import { LineagePanel } from './webviews/lineagePanel' import { RenderedModelProvider } from './providers/renderedModelProvider' import { sleep } from './utilities/sleep' +import { + controller as testController, + setupTestController, +} from './tests/tests' let lspClient: LSPClient | undefined @@ -128,6 +132,7 @@ export async function activate(context: vscode.ExtensionContext) { ) } context.subscriptions.push(lspClient) + context.subscriptions.push(setupTestController(lspClient)) } else { lspClient = new LSPClient() const result = await lspClient.start(invokedByUser) @@ -140,6 +145,7 @@ export async function activate(context: vscode.ExtensionContext) { ) } else { context.subscriptions.push(lspClient) + context.subscriptions.push(setupTestController(lspClient)) } } } @@ -175,6 +181,8 @@ export async function activate(context: vscode.ExtensionContext) { ) } else { context.subscriptions.push(lspClient) + context.subscriptions.push(setupTestController(lspClient)) + context.subscriptions.push(testController) } if (lspClient && !lspClient.hasCompletionCapability()) { diff --git a/vscode/extension/src/lsp/custom.ts b/vscode/extension/src/lsp/custom.ts index 7a9de4ca6f..8113cd86ae 100644 --- a/vscode/extension/src/lsp/custom.ts +++ b/vscode/extension/src/lsp/custom.ts @@ -32,6 +32,9 @@ export type CustomLSPMethods = | AllModelsForRenderMethod | SupportedMethodsMethod | FormatProjectMethod + | ListWorkspaceTests + | ListDocumentTests + | RunTest interface AllModelsRequest { textDocument: { @@ -111,3 +114,65 @@ interface FormatProjectResponse extends BaseResponse {} interface BaseResponse { response_error?: string } + +export interface ListWorkspaceTests { + method: 'sqlmesh/list_workspace_tests' + request: ListWorkspaceTestsRequest + response: ListWorkspaceTestsResponse +} + +type ListWorkspaceTestsRequest = object + +interface Position { + line: number + character: number +} + +interface Range { + start: Position + end: Position +} + +interface TestEntry { + name: string + uri: string + range: Range +} + +interface ListWorkspaceTestsResponse extends BaseResponse { + tests: TestEntry[] +} + +export interface ListDocumentTests { + method: 'sqlmesh/list_document_tests' + request: ListDocumentTestsRequest + response: ListDocumentTestsResponse +} + +export interface DocumentIdentifier { + uri: string +} + +export interface ListDocumentTestsRequest { + textDocument: DocumentIdentifier +} + +export interface ListDocumentTestsResponse extends BaseResponse { + tests: TestEntry[] +} + +export interface RunTest { + method: 'sqlmesh/run_test' + request: RunTestRequest + response: RunTestResponse +} + +export interface RunTestRequest { + textDocument: DocumentIdentifier + testName: string +} + +export interface RunTestResponse extends BaseResponse { + success: boolean + error_message?: string +} diff --git a/vscode/extension/src/tests/tests.ts b/vscode/extension/src/tests/tests.ts new file mode 100644 index 0000000000..dd3503165c --- /dev/null +++ b/vscode/extension/src/tests/tests.ts @@ -0,0 +1,155 @@ +import * as vscode from 'vscode' +import path from 'path' +import { LSPClient } from '../lsp/lsp' +import { isErr } from '@bus/result' +import { Disposable } from 'vscode' + +export const controller = vscode.tests.createTestController( + 'sqlmeshTests', + 'SQLMesh Tests', +) + +export const setupTestController = (lsp: LSPClient): Disposable => { + controller.resolveHandler = async test => { + console.log('Resolving test:', test?.id) + const uri = test?.uri + if (uri) { + await discoverDocumentTests(uri.toString()) + } else { + await discoverWorkspaceTests() + } + } + + // Discover tests immediately when the controller is set up + // This is useful for the initial load of tests in the workspace + // eslint-disable-next-line @typescript-eslint/no-floating-promises + discoverWorkspaceTests() + + controller.createRunProfile( + 'Run', + vscode.TestRunProfileKind.Run, + request => runTests(request), + true, + ) + + async function discoverDocumentTests(uri: string) { + const result = await lsp.call_custom_method('sqlmesh/list_document_tests', { + textDocument: { uri }, + }) + if (isErr(result)) { + vscode.window.showErrorMessage( + `Failed to list SQLMesh tests: ${result.error.message}`, + ) + return + } + const fileItem = controller.items.get(uri) + if (!fileItem) { + vscode.window.showErrorMessage(`No test item found for document: ${uri}`) + return + } + fileItem.children.replace([]) + for (const test of result.value.tests) { + const testItem = controller.createTestItem( + test.name, + test.name, + vscode.Uri.parse(test.uri), + ) + const range = test.range + testItem.range = new vscode.Range( + new vscode.Position(range.start.line, range.start.character), + new vscode.Position(range.end.line, range.end.character), + ) + fileItem.children.add(testItem) + } + } + + async function discoverWorkspaceTests() { + const result = await lsp.call_custom_method( + 'sqlmesh/list_workspace_tests', + {}, + ) + if (isErr(result)) { + vscode.window.showErrorMessage( + `Failed to list SQLMesh tests: ${result.error.message}`, + ) + return + } + controller.items.replace([]) + const files = new Map() + for (const entry of result.value.tests) { + const uri = vscode.Uri.parse(entry.uri) + let fileItem = files.get(uri.toString()) + if (!fileItem) { + fileItem = controller.createTestItem( + uri.toString(), + path.basename(uri.fsPath), + uri, + ) + // THIS IS WHERE YOU RESOLVE THE RANGE + fileItem.canResolveChildren = true + files.set(uri.toString(), fileItem) + controller.items.add(fileItem) + } + const testId = `${uri.toString()}::${entry.name}` + const testItem = controller.createTestItem(testId, entry.name, uri) + fileItem.children.add(testItem) + } + } + + async function runTests(request: vscode.TestRunRequest) { + const run = controller.createTestRun(request) + + const tests: vscode.TestItem[] = [] + const collect = (item: vscode.TestItem) => { + if (item.children.size === 0) tests.push(item) + item.children.forEach(collect) + } + + if (request.include) request.include.forEach(collect) + else controller.items.forEach(collect) + + for (const test of tests) { + run.started(test) + const startTime = Date.now() + const uri = test.uri + if (uri === undefined) { + run.failed(test, new vscode.TestMessage('Test item has no URI')) + continue + } + const response = await lsp.call_custom_method('sqlmesh/run_test', { + textDocument: { uri: uri.toString() }, + testName: test.id, + }) + if (isErr(response)) { + run.failed(test, new vscode.TestMessage(response.error.message)) + continue + } else { + const result = response.value + const duration = Date.now() - startTime + if (result.success) { + run.passed(test, duration) + } else { + run.failed( + test, + new vscode.TestMessage(result.error_message ?? 'Test failed'), + duration, + ) + } + } + } + run.end() + } + + // onChangeFile of yaml file reload the tests + return vscode.workspace.onDidChangeTextDocument(async event => { + if (event.document.languageId === 'yaml') { + const uri = event.document.uri.toString() + const testItem = controller.items.get(uri) + if (testItem) { + await discoverDocumentTests(uri) + } else { + await discoverWorkspaceTests() + } + } + }) +} diff --git a/vscode/extension/tests/tests.spec.ts b/vscode/extension/tests/tests.spec.ts new file mode 100644 index 0000000000..415eddb543 --- /dev/null +++ b/vscode/extension/tests/tests.spec.ts @@ -0,0 +1,42 @@ +import { test } from './fixtures' +import path from 'path' +import fs from 'fs-extra' +import os from 'os' +import { + openServerPage, + runCommand, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +test('Format project works correctly', async ({ page, sharedCodeServer }) => { + const tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'vscode-test-sushi-')) + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customer_revenue_lifetime model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + // Format the project + await runCommand(page, 'Test: Run All Tests') + + await page.waitForSelector('text=test_order_items') +})