diff --git a/crates/squawk_server/src/lib.rs b/crates/squawk_server/src/lib.rs index a64ddcd6..d0ff0786 100644 --- a/crates/squawk_server/src/lib.rs +++ b/crates/squawk_server/src/lib.rs @@ -15,6 +15,13 @@ use lsp_types::{ }; use squawk_linter::Linter; use squawk_syntax::{Parse, SourceFile}; +use std::collections::HashMap; +mod lsp_utils; + +struct DocumentState { + content: String, + version: i32, +} pub fn run() -> Result<()> { info!("Starting Squawk LSP server"); @@ -22,7 +29,9 @@ pub fn run() -> Result<()> { let (connection, io_threads) = Connection::stdio(); let server_capabilities = serde_json::to_value(&ServerCapabilities { - text_document_sync: Some(TextDocumentSyncCapability::Kind(TextDocumentSyncKind::FULL)), + text_document_sync: Some(TextDocumentSyncCapability::Kind( + TextDocumentSyncKind::INCREMENTAL, + )), // definition_provider: Some(OneOf::Left(true)), ..Default::default() }) @@ -48,6 +57,8 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { let client_name = init_params.client_info.map(|x| x.name); info!("Client name: {client_name:?}"); + let mut documents: HashMap = HashMap::new(); + for msg in &connection.receiver { match msg { Message::Request(req) => { @@ -63,10 +74,10 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { handle_goto_definition(&connection, req)?; } "squawk/syntaxTree" => { - handle_syntax_tree(&connection, req)?; + handle_syntax_tree(&connection, req, &documents)?; } "squawk/tokens" => { - handle_tokens(&connection, req)?; + handle_tokens(&connection, req, &documents)?; } _ => { info!("Ignoring unhandled request: {}", req.method); @@ -78,13 +89,19 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { } Message::Notification(notif) => { info!("Received notification: method={}", notif.method); - - if notif.method == DidOpenTextDocument::METHOD { - handle_did_open(&connection, notif)?; - } else if notif.method == DidChangeTextDocument::METHOD { - handle_did_change(&connection, notif)?; - } else if notif.method == DidCloseTextDocument::METHOD { - handle_did_close(&connection, notif)?; + match notif.method.as_ref() { + DidOpenTextDocument::METHOD => { + handle_did_open(&connection, notif, &mut documents)?; + } + DidChangeTextDocument::METHOD => { + handle_did_change(&connection, notif, &mut documents)?; + } + DidCloseTextDocument::METHOD => { + handle_did_close(&connection, notif, &mut documents)?; + } + _ => { + info!("Ignoring unhandled notification: {}", notif.method); + } } } } @@ -97,10 +114,7 @@ fn handle_goto_definition(connection: &Connection, req: lsp_server::Request) -> let location = Location { uri: params.text_document_position_params.text_document.uri, - range: Range { - start: Position::new(1, 2), - end: Position::new(1, 3), - }, + range: Range::new(Position::new(1, 2), Position::new(1, 3)), }; let result = GotoDefinitionResponse::Scalar(location); @@ -114,33 +128,83 @@ fn handle_goto_definition(connection: &Connection, req: lsp_server::Request) -> Ok(()) } -fn handle_did_open(connection: &Connection, notif: lsp_server::Notification) -> Result<()> { +fn publish_diagnostics( + connection: &Connection, + uri: Url, + version: i32, + diagnostics: Vec, +) -> Result<()> { + let publish_params = PublishDiagnosticsParams { + uri, + diagnostics, + version: Some(version), + }; + + let notification = Notification { + method: PublishDiagnostics::METHOD.to_owned(), + params: serde_json::to_value(publish_params)?, + }; + + connection + .sender + .send(Message::Notification(notification))?; + Ok(()) +} + +fn handle_did_open( + connection: &Connection, + notif: lsp_server::Notification, + documents: &mut HashMap, +) -> Result<()> { let params: DidOpenTextDocumentParams = serde_json::from_value(notif.params)?; let uri = params.text_document.uri; let content = params.text_document.text; let version = params.text_document.version; - lint(connection, uri, &content, version)?; + documents.insert(uri.clone(), DocumentState { content, version }); + + let content = documents.get(&uri).map_or("", |doc| &doc.content); + + // TODO: we need a better setup for "run func when input changed" + let diagnostics = lint(content); + publish_diagnostics(connection, uri, version, diagnostics)?; Ok(()) } -fn handle_did_change(connection: &Connection, notif: lsp_server::Notification) -> Result<()> { +fn handle_did_change( + connection: &Connection, + notif: lsp_server::Notification, + documents: &mut HashMap, +) -> Result<()> { let params: DidChangeTextDocumentParams = serde_json::from_value(notif.params)?; let uri = params.text_document.uri; let version = params.text_document.version; - if let Some(change) = params.content_changes.last() { - lint(connection, uri, &change.text, version)?; - } + let Some(doc_state) = documents.get_mut(&uri) else { + return Ok(()); + }; + + doc_state.content = + lsp_utils::apply_incremental_changes(&doc_state.content, params.content_changes); + doc_state.version = version; + + let diagnostics = lint(&doc_state.content); + publish_diagnostics(connection, uri, version, diagnostics)?; Ok(()) } -fn handle_did_close(connection: &Connection, notif: lsp_server::Notification) -> Result<()> { +fn handle_did_close( + connection: &Connection, + notif: lsp_server::Notification, + documents: &mut HashMap, +) -> Result<()> { let params: DidCloseTextDocumentParams = serde_json::from_value(notif.params)?; let uri = params.text_document.uri; + documents.remove(&uri); + let publish_params = PublishDiagnosticsParams { uri, diagnostics: vec![], @@ -159,14 +223,14 @@ fn handle_did_close(connection: &Connection, notif: lsp_server::Notification) -> Ok(()) } -fn lint(connection: &Connection, uri: lsp_types::Url, content: &str, version: i32) -> Result<()> { +fn lint(content: &str) -> Vec { let parse: Parse = SourceFile::parse(content); let parse_errors = parse.errors(); let mut linter = Linter::with_all_rules(); let violations = linter.lint(parse, content); let line_index = LineIndex::new(content); - let mut diagnostics = vec![]; + let mut diagnostics = Vec::with_capacity(violations.len() + parse_errors.len()); for error in parse_errors { let range_start = error.range().start(); @@ -179,10 +243,10 @@ fn lint(connection: &Connection, uri: lsp_types::Url, content: &str, version: i3 } let diagnostic = Diagnostic { - range: Range { - start: Position::new(start_line_col.line, start_line_col.col), - end: Position::new(end_line_col.line, end_line_col.col), - }, + range: Range::new( + Position::new(start_line_col.line, start_line_col.col), + Position::new(end_line_col.line, end_line_col.col), + ), severity: Some(DiagnosticSeverity::ERROR), code: Some(lsp_types::NumberOrString::String( "syntax-error".to_string(), @@ -208,10 +272,10 @@ fn lint(connection: &Connection, uri: lsp_types::Url, content: &str, version: i3 } let diagnostic = Diagnostic { - range: Range { - start: Position::new(start_line_col.line, start_line_col.col), - end: Position::new(end_line_col.line, end_line_col.col), - }, + range: Range::new( + Position::new(start_line_col.line, start_line_col.col), + Position::new(end_line_col.line, end_line_col.col), + ), severity: Some(DiagnosticSeverity::WARNING), code: Some(lsp_types::NumberOrString::String( violation.code.to_string(), @@ -225,42 +289,28 @@ fn lint(connection: &Connection, uri: lsp_types::Url, content: &str, version: i3 }; diagnostics.push(diagnostic); } - - let publish_params = PublishDiagnosticsParams { - uri, - diagnostics, - version: Some(version), - }; - - let notification = Notification { - method: PublishDiagnostics::METHOD.to_owned(), - params: serde_json::to_value(publish_params)?, - }; - - connection - .sender - .send(Message::Notification(notification))?; - - Ok(()) + diagnostics } #[derive(serde::Deserialize)] struct SyntaxTreeParams { #[serde(rename = "textDocument")] text_document: lsp_types::TextDocumentIdentifier, - // TODO: once we start storing the text doc on the server we won't need to - // send the content across the wire - text: String, } -fn handle_syntax_tree(connection: &Connection, req: lsp_server::Request) -> Result<()> { +fn handle_syntax_tree( + connection: &Connection, + req: lsp_server::Request, + documents: &HashMap, +) -> Result<()> { let params: SyntaxTreeParams = serde_json::from_value(req.params)?; let uri = params.text_document.uri; - let content = params.text; info!("Generating syntax tree for: {}", uri); - let parse: Parse = SourceFile::parse(&content); + let content = documents.get(&uri).map_or("", |doc| &doc.content); + + let parse: Parse = SourceFile::parse(content); let syntax_tree = format!("{:#?}", parse.syntax_node()); let resp = Response { @@ -277,19 +327,21 @@ fn handle_syntax_tree(connection: &Connection, req: lsp_server::Request) -> Resu struct TokensParams { #[serde(rename = "textDocument")] text_document: lsp_types::TextDocumentIdentifier, - // TODO: once we start storing the text doc on the server we won't need to - // send the content across the wire - text: String, } -fn handle_tokens(connection: &Connection, req: lsp_server::Request) -> Result<()> { +fn handle_tokens( + connection: &Connection, + req: lsp_server::Request, + documents: &HashMap, +) -> Result<()> { let params: TokensParams = serde_json::from_value(req.params)?; let uri = params.text_document.uri; - let content = params.text; info!("Generating tokens for: {}", uri); - let tokens = squawk_lexer::tokenize(&content); + let content = documents.get(&uri).map_or("", |doc| &doc.content); + + let tokens = squawk_lexer::tokenize(content); let mut output = Vec::new(); let mut char_pos = 0; diff --git a/crates/squawk_server/src/lsp_utils.rs b/crates/squawk_server/src/lsp_utils.rs new file mode 100644 index 00000000..fc881a9b --- /dev/null +++ b/crates/squawk_server/src/lsp_utils.rs @@ -0,0 +1,230 @@ +use std::ops::Range; + +use line_index::{LineIndex, TextRange, TextSize}; +use log::warn; + +fn text_range(index: &LineIndex, range: lsp_types::Range) -> Option { + let start = offset(index, range.start)?; + let end = offset(index, range.end)?; + if end >= start { + Some(TextRange::new(start, end)) + } else { + warn!( + "Invalid range: start {} > end {}", + u32::from(start), + u32::from(end) + ); + None + } +} +fn offset(index: &LineIndex, position: lsp_types::Position) -> Option { + let line_range = index.line(position.line)?; + + let col = TextSize::from(position.character); + let clamped_len = col.min(line_range.len()); + + if clamped_len < col { + warn!( + "Position line {}, col {} exceeds line length {}, clamping it", + position.line, + position.character, + u32::from(line_range.len()) + ); + } + + Some(line_range.start() + clamped_len) +} + +// base on rust-analyzer's +// see: https://github.com/rust-lang/rust-analyzer/blob/3816d0ae53c19fe75532a8b41d8c546d94246b53/crates/rust-analyzer/src/lsp/utils.rs#L168C1-L168C1 +pub(crate) fn apply_incremental_changes( + content: &str, + mut content_changes: Vec, +) -> String { + // If at least one of the changes is a full document change, use the last + // of them as the starting point and ignore all previous changes. + let (mut text, content_changes) = match content_changes + .iter() + .rposition(|change| change.range.is_none()) + { + Some(idx) => { + let text = std::mem::take(&mut content_changes[idx].text); + (text, &content_changes[idx + 1..]) + } + None => (content.to_owned(), &content_changes[..]), + }; + + if content_changes.is_empty() { + return text; + } + + let mut line_index = LineIndex::new(&text); + + // The changes we got must be applied sequentially, but can cross lines so we + // have to keep our line index updated. + // Some clients (e.g. Code) sort the ranges in reverse. As an optimization, we + // remember the last valid line in the index and only rebuild it if needed. + let mut index_valid = !0u32; + for change in content_changes { + // The None case can't happen as we have handled it above already + if let Some(range) = change.range { + if index_valid <= range.end.line { + line_index = LineIndex::new(&text); + } + index_valid = range.start.line; + if let Some(range) = text_range(&line_index, range) { + text.replace_range(Range::::from(range), &change.text); + } + } + } + + text +} + +#[cfg(test)] +mod tests { + use super::*; + use lsp_types::{Position, Range, TextDocumentContentChangeEvent}; + + #[test] + fn apply_incremental_changes_no_changes() { + let content = "hello world"; + let changes = vec![]; + let result = apply_incremental_changes(content, changes); + assert_eq!(result, "hello world"); + } + + #[test] + fn apply_incremental_changes_full_document_change() { + let content = "old content"; + let changes = vec![TextDocumentContentChangeEvent { + range: None, + range_length: None, + text: "new content".to_string(), + }]; + let result = apply_incremental_changes(content, changes); + assert_eq!(result, "new content"); + } + + #[test] + fn apply_incremental_changes_single_line_edit() { + let content = "hello world"; + let changes = vec![TextDocumentContentChangeEvent { + range: Some(Range::new(Position::new(0, 6), Position::new(0, 11))), + range_length: None, + text: "rust".to_string(), + }]; + let result = apply_incremental_changes(content, changes); + assert_eq!(result, "hello rust"); + } + + #[test] + fn apply_incremental_changes_multiple_edits() { + let content = "line 1\nline 2\nline 3"; + let changes = vec![ + TextDocumentContentChangeEvent { + range: Some(Range::new(Position::new(0, 4), Position::new(0, 6))), + range_length: None, + text: " updated".to_string(), + }, + TextDocumentContentChangeEvent { + range: Some(Range::new(Position::new(2, 4), Position::new(2, 6))), + range_length: None, + text: " also updated".to_string(), + }, + ]; + let result = apply_incremental_changes(content, changes); + assert_eq!(result, "line updated\nline 2\nline also updated"); + } + + #[test] + fn apply_incremental_changes_insertion() { + let content = "hello world"; + let changes = vec![TextDocumentContentChangeEvent { + range: Some(Range::new(Position::new(0, 5), Position::new(0, 5))), + range_length: None, + text: " foo".to_string(), + }]; + let result = apply_incremental_changes(content, changes); + assert_eq!(result, "hello foo world"); + } + + #[test] + fn apply_incremental_changes_deletion() { + let content = "hello foo world"; + let changes = vec![TextDocumentContentChangeEvent { + range: Some(Range::new(Position::new(0, 5), Position::new(0, 9))), + range_length: None, + text: "".to_string(), + }]; + let result = apply_incremental_changes(content, changes); + assert_eq!(result, "hello world"); + } + + #[test] + fn apply_incremental_changes_multiline_edit() { + let content = "line 1\nline 2\nline 3"; + let changes = vec![TextDocumentContentChangeEvent { + range: Some(Range::new(Position::new(0, 6), Position::new(1, 6))), + range_length: None, + text: " and\nreplaced".to_string(), + }]; + let result = apply_incremental_changes(content, changes); + assert_eq!(result, "line 1 and\nreplaced\nline 3"); + } + + #[test] + fn apply_incremental_changes_full_then_incremental() { + let content = "original"; + let changes = vec![ + TextDocumentContentChangeEvent { + range: None, + range_length: None, + text: "hello world".to_string(), + }, + TextDocumentContentChangeEvent { + range: Some(Range::new(Position::new(0, 6), Position::new(0, 11))), + range_length: None, + text: "rust".to_string(), + }, + ]; + let result = apply_incremental_changes(content, changes); + assert_eq!(result, "hello rust"); + } + + #[test] + fn apply_incremental_changes_invalid_range_ignored() { + let content = "hello"; + let changes = vec![TextDocumentContentChangeEvent { + range: Some(Range::new(Position::new(10, 0), Position::new(10, 5))), + range_length: None, + text: "invalid".to_string(), + }]; + let result = apply_incremental_changes(content, changes); + assert_eq!(result, "hello"); + } + + #[test] + fn apply_incremental_changes_with_invalid_line_no() { + let content = "hello world"; + let changes = vec![TextDocumentContentChangeEvent { + range: Some(Range::new(Position::new(10, 0), Position::new(10, 5))), + range_length: None, + text: "invalid".to_string(), + }]; + let result = apply_incremental_changes(content, changes); + assert_eq!(result, "hello world"); + } + + #[test] + fn apply_incremental_changes_column_clamping() { + let content = "short\nlong line"; + let changes = vec![TextDocumentContentChangeEvent { + range: Some(Range::new(Position::new(0, 3), Position::new(0, 100))), + range_length: None, + text: " extended".to_string(), + }]; + let result = apply_incremental_changes(content, changes); + assert_eq!(result, "sho extendedlong line"); + } +} diff --git a/squawk-vscode/src/extension.ts b/squawk-vscode/src/extension.ts index 848f4c33..d8c688ae 100644 --- a/squawk-vscode/src/extension.ts +++ b/squawk-vscode/src/extension.ts @@ -268,13 +268,16 @@ class SyntaxTreeProvider implements vscode.TextDocumentContentProvider { ) context.subscriptions.push( vscode.workspace.onDidChangeTextDocument((event) => { - this._onDidChangeTextDocument(event.document) + this._onDidChangeTextDocument(event) }), ) context.subscriptions.push( vscode.commands.registerCommand("squawk.showSyntaxTree", async () => { const doc = await vscode.workspace.openTextDocument(this._uri) - await vscode.window.showTextDocument(doc, vscode.ViewColumn.Beside) + await vscode.window.showTextDocument(doc, { + viewColumn: vscode.ViewColumn.Beside, + preserveFocus: true, + }) }), ) @@ -291,13 +294,12 @@ class SyntaxTreeProvider implements vscode.TextDocumentContentProvider { } } - _onDidChangeTextDocument(document: vscode.TextDocument) { - if ( - isSqlDocument(document) && - this._activeEditor && - document === this._activeEditor.document - ) { - this._eventEmitter.fire(this._uri) + _onDidChangeTextDocument(event: vscode.TextDocumentChangeEvent) { + if (isSqlDocument(event.document)) { + // via rust-analzyer: + // We need to order this after language server updates, but there's no API for that. + // Hence, good old sleep(). + void sleep(10).then(() => this._eventEmitter.fire(this._uri)) } } @@ -310,12 +312,10 @@ class SyntaxTreeProvider implements vscode.TextDocumentContentProvider { if (!client) { return "Error: no client found" } - const text = document.getText() const uri = document.uri.toString() log.info(`Requesting syntax tree for: ${uri}`) const response = await client.sendRequest("squawk/syntaxTree", { textDocument: { uri }, - text, }) log.info("Syntax tree received") return response @@ -339,13 +339,16 @@ class TokensProvider implements vscode.TextDocumentContentProvider { ) context.subscriptions.push( vscode.workspace.onDidChangeTextDocument((event) => { - this._onDidChangeTextDocument(event.document) + this._onDidChangeTextDocument(event) }), ) context.subscriptions.push( vscode.commands.registerCommand("squawk.showTokens", async () => { const doc = await vscode.workspace.openTextDocument(this._uri) - await vscode.window.showTextDocument(doc, vscode.ViewColumn.Beside) + await vscode.window.showTextDocument(doc, { + viewColumn: vscode.ViewColumn.Beside, + preserveFocus: true, + }) }), ) @@ -362,13 +365,12 @@ class TokensProvider implements vscode.TextDocumentContentProvider { } } - _onDidChangeTextDocument(document: vscode.TextDocument) { - if ( - isSqlDocument(document) && - this._activeEditor && - document === this._activeEditor.document - ) { - this._eventEmitter.fire(this._uri) + _onDidChangeTextDocument(event: vscode.TextDocumentChangeEvent) { + if (isSqlDocument(event.document)) { + // via rust-analzyer: + // We need to order this after language server updates, but there's no API for that. + // Hence, good old sleep(). + void sleep(10).then(() => this._eventEmitter.fire(this._uri)) } } @@ -381,12 +383,10 @@ class TokensProvider implements vscode.TextDocumentContentProvider { if (!client) { return "Error: no client found" } - const text = document.getText() const uri = document.uri.toString() log.info(`Requesting tokens for: ${uri}`) const response = await client.sendRequest("squawk/tokens", { textDocument: { uri }, - text, }) log.info("Tokens received") return response @@ -397,6 +397,10 @@ class TokensProvider implements vscode.TextDocumentContentProvider { } } +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)) +} + function assertNever(param: never): never { throw new Error(`should never get here, but got ${String(param)}`) }