Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion sqlmesh/lsp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from sqlmesh.core.context import Context
import typing as t

from sqlmesh.core.linter.rule import Range
from sqlmesh.core.model.definition import SqlModel
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
from sqlmesh.lsp.custom import ModelForRendering
from sqlmesh.lsp.custom import ModelForRendering, TestEntry, RunTestResponse
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
from sqlmesh.lsp.tests_ranges import get_test_ranges
from sqlmesh.lsp.uri import URI
from lsprotocol import types

Expand Down Expand Up @@ -63,6 +65,71 @@ def __init__(self, context: Context) -> None:
**audit_map,
}

def list_workspace_tests(self) -> t.List[TestEntry]:
"""List all tests in the workspace."""
tests = self.context.load_model_tests()

# Use a set to ensure unique URIs
unique_test_uris = {URI.from_path(test.path).value for test in tests}
test_uris: t.Dict[str, t.Dict[str, Range]] = {}
for uri in unique_test_uris:
test_ranges = get_test_ranges(URI(uri).to_path())
if uri not in test_uris:
test_uris[uri] = {}
test_uris[uri].update(test_ranges)
return [
TestEntry(
name=test.test_name,
uri=URI.from_path(test.path).value,
range=test_uris.get(URI.from_path(test.path).value, {}).get(test.test_name),
)
for test in tests
]

def get_document_tests(self, uri: URI) -> t.List[TestEntry]:
"""Get tests for a specific document.

Args:
uri: The URI of the file to get tests for.

Returns:
List of TestEntry objects for the specified document.
"""
tests = self.context.load_model_tests(tests=[str(uri.to_path())])
test_ranges = get_test_ranges(uri.to_path())
return [
TestEntry(
name=test.test_name,
uri=URI.from_path(test.path).value,
range=test_ranges.get(test.test_name),
)
for test in tests
]

def run_test(self, uri: URI, test_name: str) -> RunTestResponse:
"""Run a specific test for a model.

Args:
uri: The URI of the file containing the test.
test_name: The name of the test to run.

Returns:
List of annotated rule violations from the test run.
"""
path = uri.to_path()
results = self.context.test(
tests=[str(path)],
match_patterns=[test_name],
)
if results.testsRun != 1:
raise ValueError(f"Expected to run 1 test, but ran {results.testsRun} tests.")
if len(results.successes) == 1:
return RunTestResponse(success=True)
return RunTestResponse(
success=False,
error_message=str(results.failures[0][1]),
)

def render_model(self, uri: URI) -> t.List[RenderModelEntry]:
"""Get rendered models for a file, using cache when available.

Expand Down
51 changes: 51 additions & 0 deletions sqlmesh/lsp/custom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from lsprotocol import types
import typing as t

from sqlmesh.core.linter.rule import Range
from sqlmesh.utils.pydantic import PydanticModel


Expand Down Expand Up @@ -143,3 +145,52 @@ class FormatProjectResponse(CustomMethodResponseBaseClass):
"""

pass


LIST_WORKSPACE_TESTS_FEATURE = "sqlmesh/list_workspace_tests"


class ListWorkspaceTestsRequest(CustomMethodRequestBaseClass):
"""
Request to list all tests in the current project.
"""

pass


class TestEntry(PydanticModel):
"""
An entry representing a test in the workspace.
"""

name: str
uri: str
range: Range


class ListWorkspaceTestsResponse(CustomMethodResponseBaseClass):
tests: t.List[TestEntry]


LIST_DOCUMENT_TESTS_FEATURE = "sqlmesh/list_document_tests"


class ListDocumentTestsRequest(CustomMethodRequestBaseClass):
textDocument: types.TextDocumentIdentifier


class ListDocumentTestsResponse(CustomMethodResponseBaseClass):
tests: t.List[TestEntry]


RUN_TEST_FEATURE = "sqlmesh/run_test"


class RunTestRequest(CustomMethodRequestBaseClass):
textDocument: types.TextDocumentIdentifier
testName: str


class RunTestResponse(CustomMethodResponseBaseClass):
success: bool
error_message: t.Optional[str] = None
56 changes: 56 additions & 0 deletions sqlmesh/lsp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@
FormatProjectRequest,
FormatProjectResponse,
CustomMethod,
LIST_WORKSPACE_TESTS_FEATURE,
ListWorkspaceTestsRequest,
ListWorkspaceTestsResponse,
LIST_DOCUMENT_TESTS_FEATURE,
ListDocumentTestsRequest,
ListDocumentTestsResponse,
RUN_TEST_FEATURE,
RunTestRequest,
RunTestResponse,
)
from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic
from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position
Expand Down Expand Up @@ -127,11 +136,58 @@ def __init__(
API_FEATURE: self._custom_api,
SUPPORTED_METHODS_FEATURE: self._custom_supported_methods,
FORMAT_PROJECT_FEATURE: self._custom_format_project,
LIST_WORKSPACE_TESTS_FEATURE: self._list_workspace_tests,
LIST_DOCUMENT_TESTS_FEATURE: self._list_document_tests,
RUN_TEST_FEATURE: self._run_test,
}

# Register LSP features (e.g., formatting, hover, etc.)
self._register_features()

def _list_workspace_tests(
self,
ls: LanguageServer,
params: ListWorkspaceTestsRequest,
) -> ListWorkspaceTestsResponse:
"""List all tests in the current workspace."""
try:
context = self._context_get_or_load()
tests = context.list_workspace_tests()
return ListWorkspaceTestsResponse(tests=tests)
except Exception as e:
ls.log_trace(f"Error listing workspace tests: {e}")
return ListWorkspaceTestsResponse(tests=[])

def _list_document_tests(
self,
ls: LanguageServer,
params: ListDocumentTestsRequest,
) -> ListDocumentTestsResponse:
"""List tests for a specific document."""
try:
uri = URI(params.textDocument.uri)
context = self._context_get_or_load(uri)
tests = context.get_document_tests(uri)
return ListDocumentTestsResponse(tests=tests)
except Exception as e:
ls.log_trace(f"Error listing document tests: {e}")
return ListDocumentTestsResponse(tests=[])

def _run_test(
self,
ls: LanguageServer,
params: RunTestRequest,
) -> RunTestResponse:
"""Run a specific test."""
try:
uri = URI(params.textDocument.uri)
context = self._context_get_or_load(uri)
result = context.run_test(uri, params.testName)
return result
except Exception as e:
ls.log_trace(f"Error running test: {e}")
return RunTestResponse(success=False, response_error=str(e))

# All the custom LSP methods are registered here and prefixed with _custom
def _custom_all_models(self, ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse:
uri = URI(params.textDocument.uri)
Expand Down
65 changes: 65 additions & 0 deletions sqlmesh/lsp/tests_ranges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Provides helper functions to get ranges of tests in SQLMesh LSP.
"""

from pathlib import Path

from sqlmesh.core.linter.rule import Range, Position
from ruamel import yaml
from ruamel.yaml.comments import CommentedMap
import typing as t


def get_test_ranges(
path: Path,
) -> t.Dict[str, Range]:
"""
Test files are yaml files with a stucture of dict to test information. This returns a dictionary
with the test name as the key and the range of the test in the file as the value.
"""
test_ranges: t.Dict[str, Range] = {}

with open(path, "r", encoding="utf-8") as file:
content = file.read()

# Parse YAML to get line numbers
yaml_obj = yaml.YAML()
yaml_obj.preserve_quotes = True
data = yaml_obj.load(content)

if not isinstance(data, dict):
raise ValueError("Invalid test file format: expected a dictionary at the top level.")

# For each top-level key (test name), find its range
for test_name in data:
if isinstance(data, CommentedMap) and test_name in data.lc.data:
# Get line and column info from ruamel yaml
line_info = data.lc.data[test_name]
start_line = line_info[0] # 0-based line number
start_col = line_info[1] # 0-based column number

# Find the end of this test by looking for the next test or end of file
lines = content.splitlines()
end_line = start_line

# Find where this test ends by looking for the next top-level key
# or the end of the file
for i in range(start_line + 1, len(lines)):
line = lines[i]
# Check if this line starts a new top-level key (no leading spaces)
if line and not line[0].isspace() and ":" in line:
end_line = i - 1
break
else:
# This test goes to the end of the file
end_line = len(lines) - 1

# Create the range
test_ranges[test_name] = Range(
start=Position(line=start_line, character=start_col),
end=Position(
line=end_line, character=len(lines[end_line]) if end_line < len(lines) else 0
),
)

return test_ranges
43 changes: 43 additions & 0 deletions tests/lsp/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from pathlib import Path

from sqlmesh.core.context import Context
from sqlmesh.lsp.context import LSPContext, ModelTarget
from sqlmesh.lsp.uri import URI


def test_lsp_context():
Expand All @@ -18,3 +21,43 @@ def test_lsp_context():
# Check that the value is a ModelInfo with the expected model name
assert isinstance(lsp_context.map[active_customers_key], ModelTarget)
assert "sushi.active_customers" in lsp_context.map[active_customers_key].names


def test_lsp_context_list_workspace_tests():
context = Context(paths=["examples/sushi"])
lsp_context = LSPContext(context)

# List workspace tests
tests = lsp_context.list_workspace_tests()

# Check that the tests are returned correctly
assert len(tests) == 3
assert any(test.name == "test_order_items" for test in tests)


def test_lsp_context_get_document_tests():
test_path = Path.cwd() / "examples/sushi/tests/test_order_items.yaml"
uri = URI.from_path(test_path)

context = Context(paths=["examples/sushi"])
lsp_context = LSPContext(context)
tests = lsp_context.get_document_tests(uri)

assert len(tests) == 1
assert tests[0].uri == uri.value
assert tests[0].name == "test_order_items"


def test_lsp_context_run_test():
test_path = Path.cwd() / "examples/sushi/tests/test_order_items.yaml"
uri = URI.from_path(test_path)

context = Context(paths=["examples/sushi"])
lsp_context = LSPContext(context)

# Run the test
result = lsp_context.run_test(uri, "test_order_items")

# Check that the result is not None and has the expected properties
assert result is not None
assert result.success is True
8 changes: 8 additions & 0 deletions vscode/extension/src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import { selector, completionProvider } from './completion/completion'
import { LineagePanel } from './webviews/lineagePanel'
import { RenderedModelProvider } from './providers/renderedModelProvider'
import { sleep } from './utilities/sleep'
import {
controller as testController,
setupTestController,
} from './tests/tests'

let lspClient: LSPClient | undefined

Expand Down Expand Up @@ -128,6 +132,7 @@ export async function activate(context: vscode.ExtensionContext) {
)
}
context.subscriptions.push(lspClient)
context.subscriptions.push(setupTestController(lspClient))
} else {
lspClient = new LSPClient()
const result = await lspClient.start(invokedByUser)
Expand All @@ -140,6 +145,7 @@ export async function activate(context: vscode.ExtensionContext) {
)
} else {
context.subscriptions.push(lspClient)
context.subscriptions.push(setupTestController(lspClient))
}
}
}
Expand Down Expand Up @@ -175,6 +181,8 @@ export async function activate(context: vscode.ExtensionContext) {
)
} else {
context.subscriptions.push(lspClient)
context.subscriptions.push(setupTestController(lspClient))
context.subscriptions.push(testController)
}

if (lspClient && !lspClient.hasCompletionCapability()) {
Expand Down
Loading