From c574ffec28d2e54db70887786be0f0e6035697e4 Mon Sep 17 00:00:00 2001 From: Patrick Vice Date: Sun, 23 Nov 2025 20:45:09 -0500 Subject: [PATCH 1/2] Improve transport lifecycle handling - add mutex-guarded running? helpers across SSE, stdio, and HTTP transports - ensure pending queues get transport errors instead of hanging - make SSE endpoint parsing stricter and stdio shutdown more graceful - expand specs to cover new lifecycle and error paths --- lib/ruby_llm/mcp/native/transports/sse.rb | 244 ++++++-- lib/ruby_llm/mcp/native/transports/stdio.rb | 238 +++++--- .../mcp/native/transports/streamable_http.rb | 198 ++++-- .../mcp/native/transports/sse_spec.rb | 286 +++++++++ .../mcp/native/transports/stdio_spec.rb | 127 +++- .../native/transports/streamable_http_spec.rb | 575 +++++++++++++++++- 6 files changed, 1452 insertions(+), 216 deletions(-) diff --git a/lib/ruby_llm/mcp/native/transports/sse.rb b/lib/ruby_llm/mcp/native/transports/sse.rb index e6d2ea2..9271800 100644 --- a/lib/ruby_llm/mcp/native/transports/sse.rb +++ b/lib/ruby_llm/mcp/native/transports/sse.rb @@ -33,21 +33,32 @@ def initialize(url:, coordinator:, request_timeout:, version: :http2, headers: { @pending_requests = {} @pending_mutex = Mutex.new @connection_mutex = Mutex.new + @state_mutex = Mutex.new @running = false @sse_thread = nil + @sse_response = nil RubyLLM::MCP.logger.info "Initializing SSE transport to #{@event_url} with client ID #{@client_id}" end - def request(body, add_id: true, wait_for_response: true) + def request(body, add_id: true, wait_for_response: true) # rubocop:disable Metrics/MethodLength + request_id = nil + if add_id @id_mutex.synchronize { @id_counter += 1 } request_id = @id_counter body["id"] = request_id + elsif body.is_a?(Hash) + request_id = body["id"] || body[:id] + end + + if wait_for_response && request_id.nil? + raise ArgumentError, "Request ID must be provided when wait_for_response is true and add_id is false" end - response_queue = Queue.new + response_queue = nil if wait_for_response + response_queue = Queue.new @pending_mutex.synchronize do @pending_requests[request_id.to_s] = response_queue end @@ -56,41 +67,85 @@ def request(body, add_id: true, wait_for_response: true) begin send_request(body, request_id) rescue Errors::TransportError, Errors::TimeoutError => e - @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } + if wait_for_response && request_id + @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } + end RubyLLM::MCP.logger.error "Request error (ID: #{request_id}): #{e.message}" raise e end return unless wait_for_response + result = nil begin - with_timeout(@request_timeout / 1000, request_id: request_id) do + result = with_timeout(@request_timeout / 1000, request_id: request_id) do response_queue.pop end rescue Errors::TimeoutError => e - @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } + if request_id + @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } + end RubyLLM::MCP.logger.error "SSE request timeout (ID: #{request_id}) \ after #{@request_timeout / 1000} seconds." raise e end + + raise result if result.is_a?(Errors::TransportError) + + result end def alive? - @running + running? + end + + def running? + @state_mutex.synchronize { @running } end def start - return if @running + @state_mutex.synchronize do + return if @running + + @running = true + end - @running = true start_sse_listener end def close + should_close = @state_mutex.synchronize do + return unless @running + + @running = false + true + end + + return unless should_close + RubyLLM::MCP.logger.info "Closing SSE transport connection" - @running = false - @sse_thread&.join(1) # Give the thread a second to clean up + + # Close the SSE response stream if it exists + begin + @sse_response&.body&.close + rescue StandardError => e + RubyLLM::MCP.logger.debug "Error closing SSE response: #{e.message}" + end + + # Wait for the thread to finish + @sse_thread&.join(1) @sse_thread = nil + + # Fail all pending requests + fail_pending_requests!( + Errors::TransportError.new( + message: "SSE transport closed", + code: nil + ) + ) + + # Reset state + @messages_url = nil end def set_protocol_version(version) @@ -117,7 +172,7 @@ def send_request(body, request_id) end def start_sse_listener - @connection_mutex.synchronize do + @connection_mutex.synchronize do # rubocop:disable Metrics/BlockLength return if sse_thread_running? RubyLLM::MCP.logger.info "Starting SSE listener thread" @@ -128,27 +183,63 @@ def start_sse_listener end @sse_thread = Thread.new do - listen_for_events while @running + listen_for_events end - @sse_thread.abort_on_exception = true - with_timeout(@request_timeout / 1000) do - endpoint = response_queue.pop - set_message_endpoint(endpoint) + begin + with_timeout(@request_timeout / 1000) do + endpoint = response_queue.pop + set_message_endpoint(endpoint) + end + rescue Errors::TimeoutError => e + # Clean up the pending request on timeout + @pending_mutex.synchronize do + @pending_requests.delete("endpoint") + end + RubyLLM::MCP.logger.error "Timeout waiting for endpoint event: #{e.message}" + raise e + rescue StandardError => e + # Clean up the pending request on any error + @pending_mutex.synchronize do + @pending_requests.delete("endpoint") + end + raise e end end end def set_message_endpoint(endpoint) - uri = URI.parse(endpoint) + # Handle both string endpoints and JSON payloads + endpoint_url = if endpoint.is_a?(String) + endpoint + elsif endpoint.is_a?(Hash) + # Support richer endpoint metadata (e.g., { "url": "...", "last_event_id": "..." }) + endpoint["url"] || endpoint[:url] + else + endpoint.to_s + end + + unless endpoint_url && !endpoint_url.empty? + raise Errors::TransportError.new( + message: "Invalid endpoint event: missing URL", + code: nil + ) + end + + uri = URI.parse(endpoint_url) @messages_url = if uri.host.nil? - "#{@root_url}#{endpoint}" + "#{@root_url}#{endpoint_url}" else - endpoint + endpoint_url end RubyLLM::MCP.logger.info "SSE message endpoint set to: #{@messages_url}" + rescue URI::InvalidURIError => e + raise Errors::TransportError.new( + message: "Invalid endpoint URL: #{e.message}", + code: nil + ) end def sse_thread_running? @@ -156,16 +247,16 @@ def sse_thread_running? end def listen_for_events - stream_events_from_server + stream_events_from_server while running? rescue StandardError => e handle_connection_error("SSE connection error", e) end def stream_events_from_server sse_client = create_sse_client - response = sse_client.get(@event_url, stream: true) - validate_sse_response!(response) - process_event_stream(response) + @sse_response = sse_client.get(@event_url, stream: true) + validate_sse_response!(@sse_response) + process_event_stream(@sse_response) end def create_sse_client @@ -188,11 +279,24 @@ def validate_sse_response!(response) end def handle_client_error!(error_message, status_code) - @running = false - raise Errors::TransportError.new( + transport_error = Errors::TransportError.new( message: error_message, code: status_code ) + + # Close the transport (which will fail pending requests) + close + + raise transport_error + end + + def fail_pending_requests!(error) + @pending_mutex.synchronize do + @pending_requests.each_value do |queue| + queue.push(error) + end + @pending_requests.clear + end end def process_event_stream(response) @@ -203,7 +307,7 @@ def process_event_stream(response) end def handle_event_line?(event_line, event_buffer, response) - unless @running + unless running? response.body.close return false end @@ -241,11 +345,22 @@ def read_error_body(response) end def handle_connection_error(message, error) - return unless @running + return unless running? error_message = "#{message}: #{error.message}" - RubyLLM::MCP.logger.error "#{error_message}. Reconnecting in 1 seconds..." - sleep 1 + RubyLLM::MCP.logger.error "#{error_message}. Closing SSE transport." + + # Create a transport error to fail pending requests + transport_error = Errors::TransportError.new( + message: error_message, + code: nil + ) + + # Close the transport (which will fail pending requests) + close + + # Notify coordinator if needed + @coordinator&.handle_error(transport_error) end def handle_httpx_error_response!(response, context:) @@ -272,40 +387,55 @@ def process_event(raw_event) return if raw_event[:data].nil? if raw_event[:event] == "endpoint" - request_id = "endpoint" - event = raw_event[:data] - return if event.nil? - - RubyLLM::MCP.logger.debug "Received endpoint event: #{event}" - @pending_mutex.synchronize do - response_queue = @pending_requests.delete(request_id) - response_queue&.push(event) - end + process_endpoint_event(raw_event) else - event = begin - JSON.parse(raw_event[:data]) - rescue JSON::ParserError => e - # We can sometimes get partial endpoint events, so we will ignore them - unless @endpoint.nil? - RubyLLM::MCP.logger.info "Failed to parse SSE event data: #{raw_event[:data]} - #{e.message}" - end + process_message_event(raw_event) + end + end + + def process_endpoint_event(raw_event) + request_id = "endpoint" + event_data = raw_event[:data] + return if event_data.nil? + + # Try to parse as JSON first, fall back to string + endpoint = begin + JSON.parse(event_data) + rescue JSON::ParserError + event_data + end - nil + RubyLLM::MCP.logger.debug "Received endpoint event: #{endpoint.inspect}" + + @pending_mutex.synchronize do + response_queue = @pending_requests.delete(request_id) + response_queue&.push(endpoint) + end + end + + def process_message_event(raw_event) + event = begin + JSON.parse(raw_event[:data]) + rescue JSON::ParserError => e + # We can sometimes get partial events, so we will ignore them + if @messages_url + RubyLLM::MCP.logger.debug "Failed to parse SSE event data: #{raw_event[:data]} - #{e.message}" end - return if event.nil? + nil + end + return if event.nil? - request_id = event["id"]&.to_s - result = RubyLLM::MCP::Result.new(event) + request_id = event["id"]&.to_s + result = RubyLLM::MCP::Result.new(event) - result = @coordinator.process_result(result) - return if result.nil? + result = @coordinator.process_result(result) + return if result.nil? - @pending_mutex.synchronize do - # You can receieve duplicate events for the same request id, and we will ignore thoses - if result.matching_id?(request_id) && @pending_requests.key?(request_id) - response_queue = @pending_requests.delete(request_id) - response_queue&.push(result) - end + @pending_mutex.synchronize do + # You can receive duplicate events for the same request id, and we will ignore those + if result.matching_id?(request_id) && @pending_requests.key?(request_id) + response_queue = @pending_requests.delete(request_id) + response_queue&.push(result) end end end diff --git a/lib/ruby_llm/mcp/native/transports/stdio.rb b/lib/ruby_llm/mcp/native/transports/stdio.rb index e2dda7c..ea5f436 100644 --- a/lib/ruby_llm/mcp/native/transports/stdio.rb +++ b/lib/ruby_llm/mcp/native/transports/stdio.rb @@ -9,92 +9,65 @@ class Stdio attr_reader :command, :stdin, :stdout, :stderr, :id, :coordinator + # Default environment that merges with user-provided env + # This ensures PATH and other critical env vars are preserved + DEFAULT_ENV = ENV.to_h.freeze + def initialize(command:, coordinator:, request_timeout:, args: [], env: {}) @request_timeout = request_timeout @command = command @coordinator = coordinator @args = args - @env = env || {} + # Merge provided env with default environment (user env takes precedence) + @env = DEFAULT_ENV.merge(env || {}) @client_id = SecureRandom.uuid @id_counter = 0 @id_mutex = Mutex.new @pending_requests = {} @pending_mutex = Mutex.new + @state_mutex = Mutex.new @running = false @reader_thread = nil @stderr_thread = nil end def request(body, add_id: true, wait_for_response: true) - if add_id - @id_mutex.synchronize { @id_counter += 1 } - request_id = @id_counter - body["id"] = request_id - else - # When add_id is false, the ID should already be in the body - # Try both string and symbol keys to be flexible - request_id = body["id"] || body[:id] - end + request_id = prepare_request_id(body, add_id, wait_for_response) + response_queue = register_pending_request(request_id, wait_for_response) - response_queue = Queue.new - if wait_for_response - @pending_mutex.synchronize do - @pending_requests[request_id.to_s] = response_queue - end - end - - begin - body = JSON.generate(body) - RubyLLM::MCP.logger.debug "Sending Request: #{body}" - @stdin.puts(body) - @stdin.flush - rescue IOError, Errno::EPIPE => e - @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } - restart_process - raise RubyLLM::MCP::Errors::TransportError.new(message: e.message, error: e) - end + send_request(body, request_id) return unless wait_for_response - begin - with_timeout(@request_timeout / 1000, request_id: request_id) do - response_queue.pop - end - rescue RubyLLM::MCP::Errors::TimeoutError => e - @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } - log_message = "Stdio request timeout (ID: #{request_id}) after #{@request_timeout / 1000} seconds" - RubyLLM::MCP.logger.error(log_message) - raise e - end + wait_for_request_response(request_id, response_queue) end def alive? - @running + running? end - def start - start_process unless @running - @running = true + def running? + @state_mutex.synchronize { @running } end - def close - @running = false + def start + @state_mutex.synchronize do + return if @running - [@stdin, @stdout, @stderr].each do |stream| - stream&.close - rescue IOError, Errno::EBADF - nil + @running = true end + start_process + end - [@wait_thread, @reader_thread, @stderr_thread].each do |thread| - thread&.join(1) - rescue StandardError - nil - end + def close + @state_mutex.synchronize do + return unless @running - @stdin = @stdout = @stderr = nil - @wait_thread = @reader_thread = @stderr_thread = nil + @running = false + end + shutdown_process + fail_pending_requests!(RubyLLM::MCP::Errors::TransportError.new(message: "Transport closed")) end def set_protocol_version(version) @@ -103,39 +76,144 @@ def set_protocol_version(version) private + def prepare_request_id(body, add_id, wait_for_response) + request_id = if add_id + @id_mutex.synchronize { @id_counter += 1 } + body["id"] = @id_counter + @id_counter + else + body["id"] || body[:id] + end + + if wait_for_response && request_id.nil? + raise ArgumentError, "Request ID must be provided when wait_for_response is true and add_id is false" + end + + request_id + end + + def register_pending_request(request_id, wait_for_response) + return nil unless wait_for_response + + response_queue = Queue.new + @pending_mutex.synchronize do + @pending_requests[request_id.to_s] = response_queue + end + response_queue + end + + def send_request(body, request_id) + body = JSON.generate(body) + RubyLLM::MCP.logger.debug "Sending Request: #{body}" + @stdin.puts(body) + @stdin.flush + rescue IOError, Errno::EPIPE => e + @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } if request_id + raise RubyLLM::MCP::Errors::TransportError.new(message: e.message, error: e) + end + + def wait_for_request_response(request_id, response_queue) + with_timeout(@request_timeout / 1000, request_id: request_id) do + response_queue.pop + end + rescue RubyLLM::MCP::Errors::TimeoutError => e + @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } + log_message = "Stdio request timeout (ID: #{request_id}) after #{@request_timeout / 1000} seconds" + RubyLLM::MCP.logger.error(log_message) + raise e + end + def start_process - close if @stdin || @stdout || @stderr || @wait_thread + shutdown_process if @stdin || @stdout || @stderr || @wait_thread - @stdin, @stdout, @stderr, @wait_thread = if @env.empty? - Open3.popen3(@command, *@args) - else - Open3.popen3(@env, @command, *@args) - end + # Always pass env - it now includes defaults merged with user overrides + @stdin, @stdout, @stderr, @wait_thread = Open3.popen3(@env, @command, *@args) start_reader_thread start_stderr_thread end - def restart_process - RubyLLM::MCP.logger.error "Process connection lost. Restarting..." - start_process + def shutdown_process + close_stdin + terminate_child_process + close_output_streams + join_reader_threads + clear_process_handles + end + + def close_stdin + @stdin&.close + rescue IOError, Errno::EBADF + # Already closed + end + + def terminate_child_process + return unless @wait_thread + + @wait_thread.join(1) if @wait_thread.alive? # 1s grace period + send_signal_to_process("TERM", 2) if @wait_thread.alive? + send_signal_to_process("KILL", 0) if @wait_thread.alive? + end + + def send_signal_to_process(signal, wait_time) + Process.kill(signal, @wait_thread.pid) + @wait_thread.join(wait_time) if wait_time.positive? + rescue StandardError => e + RubyLLM::MCP.logger.debug "Error sending #{signal}: #{e.message}" + end + + def close_output_streams + [@stdout, @stderr].each do |stream| + stream&.close + rescue IOError, Errno::EBADF + # Already closed + end + end + + def join_reader_threads + [@reader_thread, @stderr_thread].each do |thread| + next unless thread&.alive? + next if Thread.current == thread # Avoid self-join deadlock + + thread.join(1) + rescue StandardError => e + RubyLLM::MCP.logger.debug "Error joining thread: #{e.message}" + end + end + + def clear_process_handles + @stdin = @stdout = @stderr = nil + @wait_thread = @reader_thread = @stderr_thread = nil + end + + def fail_pending_requests!(error) + @pending_mutex.synchronize do + @pending_requests.each_value do |queue| + queue.push(error) + end + @pending_requests.clear + end + end + + def safe_close_with_error(error) + fail_pending_requests!(error) + close end def start_reader_thread @reader_thread = Thread.new do read_stdout_loop end - - @reader_thread.abort_on_exception = true + # Don't use abort_on_exception - handle errors cooperatively end def read_stdout_loop - while @running + while running? begin handle_stdout_read rescue IOError, Errno::EPIPE => e handle_stream_error(e, "Reader") - break unless @running + break unless running? rescue StandardError => e RubyLLM::MCP.logger.error "Error in reader thread: #{e.message}, #{e.backtrace.join("\n")}" sleep 1 @@ -145,9 +223,12 @@ def read_stdout_loop def handle_stdout_read if @stdout.closed? || @wait_thread.nil? || !@wait_thread.alive? - if @running - sleep 1 - restart_process + # Process is dead - if we're still running, this is an error + if running? + error = RubyLLM::MCP::Errors::TransportError.new( + message: "Process terminated unexpectedly" + ) + safe_close_with_error(error) end return end @@ -160,11 +241,10 @@ def handle_stdout_read def handle_stream_error(error, stream_name) # Check @running to distinguish graceful shutdown from unexpected errors. - # During shutdown, streams are closed intentionally and shouldn't trigger restarts. - if @running - RubyLLM::MCP.logger.error "#{stream_name} error: #{error.message}. Restarting in 1 second..." - sleep 1 - restart_process + # During shutdown, streams are closed intentionally and shouldn't trigger close. + if running? + RubyLLM::MCP.logger.error "#{stream_name} error: #{error.message}. Closing transport." + safe_close_with_error(error) else # Graceful shutdown in progress RubyLLM::MCP.logger.debug "#{stream_name} thread exiting during shutdown" @@ -175,17 +255,16 @@ def start_stderr_thread @stderr_thread = Thread.new do read_stderr_loop end - - @stderr_thread.abort_on_exception = true + # Don't use abort_on_exception - handle errors cooperatively end def read_stderr_loop - while @running + while running? begin handle_stderr_read rescue IOError, Errno::EPIPE => e handle_stream_error(e, "Stderr reader") - break unless @running + break unless running? rescue StandardError => e RubyLLM::MCP.logger.error "Error in stderr thread: #{e.message}" sleep 1 @@ -195,7 +274,6 @@ def read_stderr_loop def handle_stderr_read if @stderr.closed? || @wait_thread.nil? || !@wait_thread.alive? - sleep 1 return end diff --git a/lib/ruby_llm/mcp/native/transports/streamable_http.rb b/lib/ruby_llm/mcp/native/transports/streamable_http.rb index b938c0d..0f28a9e 100644 --- a/lib/ruby_llm/mcp/native/transports/streamable_http.rb +++ b/lib/ruby_llm/mcp/native/transports/streamable_http.rb @@ -24,7 +24,8 @@ def initialize( # Options for starting SSE connections class StartSSEOptions - attr_reader :resumption_token, :on_resumption_token, :replay_message_id + attr_accessor :resumption_token + attr_reader :on_resumption_token, :replay_message_id def initialize(resumption_token: nil, on_resumption_token: nil, replay_message_id: nil) @resumption_token = resumption_token @@ -39,7 +40,7 @@ class StreamableHTTP attr_reader :session_id, :protocol_version, :coordinator, :oauth_provider - def initialize( # rubocop:disable Metrics/ParameterLists + def initialize( # rubocop:disable Metrics/MethodLength, Metrics/ParameterLists url:, request_timeout:, coordinator:, @@ -50,6 +51,7 @@ def initialize( # rubocop:disable Metrics/ParameterLists rate_limit: nil, reconnection_options: nil, session_id: nil, + sse_timeout: nil, options: {} ) # Extract options if provided (for backward compatibility) @@ -61,22 +63,29 @@ def initialize( # rubocop:disable Metrics/ParameterLists reconnection_options = extracted_options.delete(:reconnection_options) || reconnection_options rate_limit = extracted_options.delete(:rate_limit) || rate_limit session_id = extracted_options.delete(:session_id) || session_id + sse_timeout = extracted_options.delete(:sse_timeout) || sse_timeout @url = URI(url) @coordinator = coordinator @request_timeout = request_timeout + @sse_timeout = sse_timeout @headers = headers || {} @session_id = session_id @version = version - @reconnection_options = reconnection_options || ReconnectionOptions.new @protocol_version = nil - @session_id = session_id - @resource_metadata_url = nil @client_id = SecureRandom.uuid - @reconnection_options = ReconnectionOptions.new(**reconnection) + # Reconnection options precedence: explicit > hash > defaults + @reconnection_options = if reconnection_options + reconnection_options + elsif reconnection && !reconnection.empty? + ReconnectionOptions.new(**reconnection) + else + ReconnectionOptions.new + end + @oauth_provider = oauth_provider @rate_limiter = Support::RateLimiter.new(**rate_limit) if rate_limit @@ -85,9 +94,11 @@ def initialize( # rubocop:disable Metrics/ParameterLists @pending_requests = {} @pending_mutex = Mutex.new @running = true - @abort_controller = nil + @sse_stopped = false + @state_mutex = Mutex.new @sse_thread = nil @sse_mutex = Mutex.new + @last_sse_event_id = nil # Thread-safe collection of all HTTPX clients @clients = [] @@ -121,7 +132,7 @@ def request(body, add_id: true, wait_for_response: true) end def alive? - @running + running? end def close @@ -131,15 +142,43 @@ def close end def start - @abort_controller = false + @state_mutex.synchronize do + @sse_stopped = false + end end def set_protocol_version(version) @protocol_version = version end + # Public hooks for SSE events (similar to TS transport) + def on_message(&block) + @on_message_callback = block + end + + def on_error(&block) + @on_error_callback = block + end + + def on_close(&block) + @on_close_callback = block + end + private + # Thread-safe check if transport is running and not stopped + def running? + @state_mutex.synchronize { @running && !@sse_stopped } + end + + # Thread-safe stop signal + def abort! + @state_mutex.synchronize do + @running = false + @sse_stopped = true + end + end + def terminate_session return unless @session_id @@ -308,7 +347,7 @@ def create_connection_with_streaming_callbacks(request_id) client = Support::HTTPClient.connection.plugin(:callbacks) .on_response_body_chunk do |request, _response, chunk| - next unless @running && !@abort_controller + next unless running? RubyLLM::MCP.logger.debug "Received chunk: #{chunk.bytesize} bytes for #{request.uri}" buffer << chunk @@ -475,17 +514,8 @@ def handle_session_expired ) end - def extract_resource_metadata_url(response) - # Extract resource metadata URL from response headers if present - # Guard against error responses that don't have headers - return nil unless response.respond_to?(:headers) - - metadata_url = response.headers["mcp-resource-metadata-url"] - metadata_url ? URI(metadata_url) : nil - end - def start_sse_stream(options = StartSSEOptions.new) - return unless @running && !@abort_controller + return unless running? @sse_mutex.synchronize do return if @sse_thread&.alive? @@ -541,12 +571,20 @@ def start_sse(options) # rubocop:disable Metrics/MethodLength RubyLLM::MCP.logger.error "SSE stream error: #{e.message}" # Attempt reconnection with exponential backoff - if @running && !@abort_controller && attempt_count < @reconnection_options.max_retries + if running? && attempt_count < @reconnection_options.max_retries delay = calculate_reconnection_delay(attempt_count) RubyLLM::MCP.logger.info "Reconnecting SSE stream in #{delay}ms..." sleep(delay / 1000.0) attempt_count += 1 + + # Create new options with the last event ID for resumption + options = StartSSEOptions.new( + resumption_token: @last_sse_event_id, + on_resumption_token: options.on_resumption_token, + replay_message_id: options.replay_message_id + ) + retry end @@ -558,14 +596,21 @@ def create_connection_with_sse_callbacks(options, headers) client = HTTPX.plugin(:callbacks) client = add_on_response_body_chunk_callback(client, options) - # Use request_timeout for all timeout values (converted from ms to seconds) - timeout_seconds = @request_timeout / 1000.0 + # Use sse_timeout if provided, otherwise use a very large timeout for SSE + # SSE connections are long-lived and should not timeout quickly + sse_timeout_seconds = if @sse_timeout + @sse_timeout / 1000.0 + else + # Default to 1 hour for SSE if not specified + 3600 + end + client = client.with( timeout: { connect_timeout: 10, - read_timeout: timeout_seconds, - write_timeout: timeout_seconds, - operation_timeout: timeout_seconds + read_timeout: sse_timeout_seconds, + write_timeout: sse_timeout_seconds, + operation_timeout: sse_timeout_seconds }, headers: headers ) @@ -583,7 +628,7 @@ def add_on_response_body_chunk_callback(client, options) buffer = +"" client.on_response_body_chunk do |request, response, chunk| # Only process chunks for text/event-stream and if still running - next unless @running && !@abort_controller + next unless running? if chunk.include?("event: stop") RubyLLM::MCP.logger.debug "Closing SSE stream" @@ -601,6 +646,7 @@ def add_on_response_body_chunk_callback(client, options) next unless raw_event && raw_event[:data] if raw_event[:id] + @last_sse_event_id = raw_event[:id] options.on_resumption_token&.call(raw_event[:id]) end @@ -618,14 +664,17 @@ def calculate_reconnection_delay(attempt) [initial * (factor**attempt), max_delay].min end - def process_sse_buffer_events(buffer, _request_id) - return unless @running && !@abort_controller + def process_sse_buffer_events(buffer, request_id) + return unless running? while (event_data = extract_sse_event(buffer)) raw_event, remaining_buffer = event_data buffer.replace(remaining_buffer) - process_sse_event(raw_event, nil) if raw_event && raw_event[:data] + if raw_event && raw_event[:data] + RubyLLM::MCP.logger.debug "Processing SSE buffer event for request #{request_id}" if request_id + process_sse_event(raw_event, nil) + end end end @@ -660,13 +709,18 @@ def parse_sse_event(raw) event end - def process_sse_event(raw_event, replay_message_id) + def process_sse_event(raw_event, replay_message_id) # rubocop:disable Metrics/MethodLength return unless raw_event[:data] - return unless @running && !@abort_controller + return unless running? begin event_data = JSON.parse(raw_event[:data]) + # Enhanced logging with event details + event_type = raw_event[:event] || "message" + event_id = raw_event[:id] + RubyLLM::MCP.logger.debug "Processing SSE event: type=#{event_type}, id=#{event_id || 'none'}" + # Handle replay message ID if specified if replay_message_id && event_data.is_a?(Hash) && event_data["id"] event_data["id"] = replay_message_id @@ -675,6 +729,9 @@ def process_sse_event(raw_event, replay_message_id) result = RubyLLM::MCP::Result.new(event_data, session_id: @session_id) RubyLLM::MCP.logger.debug "SSE Result Received: #{result.inspect}" + # Call on_message hook if registered + @on_message_callback&.call(result) + result = @coordinator.process_result(result) return if result.nil? @@ -682,15 +739,23 @@ def process_sse_event(raw_event, replay_message_id) if request_id @pending_mutex.synchronize do response_queue = @pending_requests.delete(request_id) - response_queue&.push(result) + if response_queue + RubyLLM::MCP.logger.debug "Matched SSE event to pending request: #{request_id}" + response_queue.push(result) + else + RubyLLM::MCP.logger.debug "No pending request found for SSE event: #{request_id}" + end end end rescue JSON::ParserError => e RubyLLM::MCP.logger.warn "Failed to parse SSE event data: #{raw_event[:data]} - #{e.message}" + @on_error_callback&.call(e) rescue Errors::UnknownRequest => e RubyLLM::MCP.logger.warn "Unknown request from MCP server: #{e.message}" + @on_error_callback&.call(e) rescue StandardError => e RubyLLM::MCP.logger.error "Error processing SSE event: #{e.message}" + @on_error_callback&.call(e) raise Errors::TransportError.new( message: "Error processing SSE event: #{e.message}", error: e @@ -699,9 +764,16 @@ def process_sse_event(raw_event, replay_message_id) end def wait_for_response_with_timeout(request_id, response_queue) - with_timeout(@request_timeout / 1000, request_id: request_id) do + result = with_timeout(@request_timeout / 1000, request_id: request_id) do response_queue.pop end + + # Check if we received a shutdown error sentinel + if result.is_a?(Errors::TransportError) + raise result + end + + result rescue RubyLLM::MCP::Errors::TimeoutError => e log_message = "StreamableHTTP request timeout (ID: #{request_id}) after #{@request_timeout / 1000} seconds" RubyLLM::MCP.logger.error(log_message) @@ -710,34 +782,38 @@ def wait_for_response_with_timeout(request_id, response_queue) end def cleanup_sse_resources - @running = false - @abort_controller = true + # Set shutdown flags under mutex + abort! + + # Call on_close hook if registered + @on_close_callback&.call + # Close all HTTPX clients to signal SSE thread to exit + close_all_clients + + # Wait for SSE thread to exit cooperatively @sse_mutex.synchronize do if @sse_thread&.alive? - @sse_thread.kill - @sse_thread.join(5) # Wait up to 5 seconds for thread to finish + # Try to join the thread first (cooperative shutdown) + unless @sse_thread.join(5) + # Only kill as last resort if join times out + RubyLLM::MCP.logger.warn "SSE thread did not exit cleanly, forcing termination" + @sse_thread.kill + @sse_thread.join(1) + end @sse_thread = nil end end - # Clear any pending requests - @pending_mutex.synchronize do - @pending_requests.each_value do |queue| - queue.close if queue.respond_to?(:close) - rescue StandardError - # Ignore errors when closing queues - end - @pending_requests.clear - end + # Drain pending requests with error instead of closing queues + drain_pending_requests_with_error end - def cleanup_connection + def close_all_clients clients_to_close = [] @clients_mutex.synchronize do clients_to_close = @clients.dup - @clients.clear end clients_to_close.each do |client| @@ -745,9 +821,33 @@ def cleanup_connection rescue StandardError => e RubyLLM::MCP.logger.debug "Error closing HTTPX client: #{e.message}" end + end + + def cleanup_connection + close_all_clients + + @clients_mutex.synchronize do + @clients.clear + end @connection = nil end + + def drain_pending_requests_with_error + shutdown_error = Errors::TransportError.new( + message: "Transport is shutting down", + code: nil + ) + + @pending_mutex.synchronize do + @pending_requests.each_value do |queue| + queue.push(shutdown_error) + rescue StandardError => e + RubyLLM::MCP.logger.debug "Error pushing shutdown error to queue: #{e.message}" + end + @pending_requests.clear + end + end end end end diff --git a/spec/ruby_llm/mcp/native/transports/sse_spec.rb b/spec/ruby_llm/mcp/native/transports/sse_spec.rb index e61c9b0..a55bda1 100644 --- a/spec/ruby_llm/mcp/native/transports/sse_spec.rb +++ b/spec/ruby_llm/mcp/native/transports/sse_spec.rb @@ -84,6 +84,292 @@ def client end end + describe "thread safety and lifecycle" do + let(:coordinator) { instance_double(RubyLLM::MCP::Adapters::MCPTransports::CoordinatorStub) } + let(:transport) do + RubyLLM::MCP::Native::Transports::SSE.new( + url: "http://localhost:3000/sse", + coordinator: coordinator, + request_timeout: 5000 + ) + end + + describe "#running?" do + it "safely checks running state with mutex" do + expect(transport.running?).to be(false) # Not started yet + + # Simulate concurrent access + threads = 10.times.map do + Thread.new { transport.running? } + end + + results = threads.map(&:value) + expect(results).to all(be(false)) + end + end + + describe "#close" do + it "handles multiple close calls gracefully" do + expect { transport.close }.not_to raise_error + expect { transport.close }.not_to raise_error + expect(transport.running?).to be(false) + end + + it "resets messages_url on close" do + transport.instance_variable_set(:@messages_url, "http://test.com/messages") + transport.instance_variable_set(:@running, true) + + transport.close + + expect(transport.instance_variable_get(:@messages_url)).to be_nil + end + + it "is idempotent when called multiple times" do + transport.instance_variable_set(:@running, true) + + expect { 3.times { transport.close } }.not_to raise_error + expect(transport.running?).to be(false) + end + end + end + + describe "pending request cleanup" do + let(:coordinator) { instance_double(RubyLLM::MCP::Adapters::MCPTransports::CoordinatorStub) } + let(:transport) do + RubyLLM::MCP::Native::Transports::SSE.new( + url: "http://localhost:3000/sse", + coordinator: coordinator, + request_timeout: 5000 + ) + end + + describe "#fail_pending_requests!" do + it "pushes error to all pending request queues" do + queue1 = Queue.new + queue2 = Queue.new + + transport.instance_variable_get(:@pending_requests)["1"] = queue1 + transport.instance_variable_get(:@pending_requests)["2"] = queue2 + + error = RubyLLM::MCP::Errors::TransportError.new( + message: "Test error", + code: nil + ) + + transport.send(:fail_pending_requests!, error) + + expect(queue1.pop).to eq(error) + expect(queue2.pop).to eq(error) + expect(transport.instance_variable_get(:@pending_requests)).to be_empty + end + + it "clears all pending requests after pushing errors" do + 3.times do |i| + transport.instance_variable_get(:@pending_requests)[i.to_s] = Queue.new + end + + error = RubyLLM::MCP::Errors::TransportError.new(message: "Test", code: nil) + transport.send(:fail_pending_requests!, error) + + expect(transport.instance_variable_get(:@pending_requests)).to be_empty + end + end + + describe "#close" do + it "fails all pending requests with transport error" do + queue = Queue.new + transport.instance_variable_get(:@pending_requests)["test"] = queue + transport.instance_variable_set(:@running, true) + + result_thread = Thread.new { queue.pop } + sleep(0.05) # Let thread start waiting + + transport.close + + result = result_thread.value + expect(result).to be_a(RubyLLM::MCP::Errors::TransportError) + expect(result.message).to include("closed") + end + + it "wakes up multiple waiting requests" do + queues = 3.times.map do |i| + queue = Queue.new + transport.instance_variable_get(:@pending_requests)[i.to_s] = queue + queue + end + transport.instance_variable_set(:@running, true) + + threads = queues.map { |q| Thread.new { q.pop } } + sleep(0.05) # Let threads start waiting + + transport.close + + results = threads.map(&:value) + expect(results).to all(be_a(RubyLLM::MCP::Errors::TransportError)) + end + end + end + + describe "endpoint bootstrapping" do + let(:coordinator) { instance_double(RubyLLM::MCP::Adapters::MCPTransports::CoordinatorStub) } + let(:transport) do + RubyLLM::MCP::Native::Transports::SSE.new( + url: "http://localhost:3000/sse", + coordinator: coordinator, + request_timeout: 5000 + ) + end + + describe "#set_message_endpoint" do + it "handles string endpoints" do + transport.send(:set_message_endpoint, "/messages") + expect(transport.instance_variable_get(:@messages_url)).to eq("http://localhost:3000/messages") + end + + it "handles JSON payloads with url key (string)" do + transport.send(:set_message_endpoint, { "url" => "/api/messages" }) + expect(transport.instance_variable_get(:@messages_url)).to eq("http://localhost:3000/api/messages") + end + + it "handles JSON payloads with url key (symbol)" do + transport.send(:set_message_endpoint, { url: "/api/messages" }) + expect(transport.instance_variable_get(:@messages_url)).to eq("http://localhost:3000/api/messages") + end + + it "raises error for missing URL in hash" do + expect do + transport.send(:set_message_endpoint, { "other" => "value" }) + end.to raise_error(RubyLLM::MCP::Errors::TransportError, /missing URL/) + end + + it "raises error for invalid URI" do + expect do + transport.send(:set_message_endpoint, "ht!tp://invalid") + end.to raise_error(RubyLLM::MCP::Errors::TransportError, /Invalid endpoint URL/) + end + + it "handles absolute URLs" do + transport.send(:set_message_endpoint, "http://other.com:8080/messages") + expect(transport.instance_variable_get(:@messages_url)).to eq("http://other.com:8080/messages") + end + + it "raises error for empty string" do + expect do + transport.send(:set_message_endpoint, "") + end.to raise_error(RubyLLM::MCP::Errors::TransportError, /missing URL/) + end + + it "raises error for nil" do + expect do + transport.send(:set_message_endpoint, nil) + end.to raise_error(RubyLLM::MCP::Errors::TransportError, /missing URL/) + end + end + + describe "#process_endpoint_event" do + it "parses JSON endpoint data" do + raw_event = { data: '{"url": "/messages", "last_event_id": "123"}' } + + queue = Queue.new + transport.instance_variable_get(:@pending_requests)["endpoint"] = queue + + result_thread = Thread.new { queue.pop } + + transport.send(:process_endpoint_event, raw_event) + + result = result_thread.value + expect(result).to be_a(Hash) + expect(result["url"]).to eq("/messages") + expect(result["last_event_id"]).to eq("123") + end + + it "falls back to string for non-JSON data" do + raw_event = { data: "/messages" } + + queue = Queue.new + transport.instance_variable_get(:@pending_requests)["endpoint"] = queue + + result_thread = Thread.new { queue.pop } + + transport.send(:process_endpoint_event, raw_event) + + result = result_thread.value + expect(result).to eq("/messages") + end + + it "removes endpoint from pending requests after processing" do + raw_event = { data: "/messages" } + + queue = Queue.new + transport.instance_variable_get(:@pending_requests)["endpoint"] = queue + + result_thread = Thread.new { queue.pop } + + transport.send(:process_endpoint_event, raw_event) + result_thread.value + + expect(transport.instance_variable_get(:@pending_requests)).not_to have_key("endpoint") + end + end + end + + describe "event processing improvements" do + let(:coordinator) { instance_double(RubyLLM::MCP::Adapters::MCPTransports::CoordinatorStub) } + let(:transport) do + RubyLLM::MCP::Native::Transports::SSE.new( + url: "http://localhost:3000/sse", + coordinator: coordinator, + request_timeout: 5000 + ) + end + + before do + allow(RubyLLM::MCP.logger).to receive(:debug) + allow(RubyLLM::MCP.logger).to receive(:info) + allow(RubyLLM::MCP.logger).to receive(:error) + end + + describe "#process_message_event" do + it "logs at debug level for parse errors when messages_url is set" do + transport.instance_variable_set(:@messages_url, "http://test.com/messages") + raw_event = { data: "invalid json" } + + transport.send(:process_message_event, raw_event) + + expect(RubyLLM::MCP.logger).to have_received(:debug).with(/Failed to parse SSE event data/) + end + + it "does not log parse errors when messages_url is not set" do + transport.instance_variable_set(:@messages_url, nil) + raw_event = { data: "invalid json" } + + transport.send(:process_message_event, raw_event) + + # Should not log the specific parse error message + expect(RubyLLM::MCP.logger).not_to have_received(:debug).with(/Failed to parse SSE event data/) + end + + it "processes valid JSON events" do + raw_event = { data: '{"id": "123", "result": {"success": true}}' } + result = instance_double(RubyLLM::MCP::Result) + + allow(RubyLLM::MCP::Result).to receive(:new).and_return(result) + allow(result).to receive(:matching_id?).with("123").and_return(true) + allow(coordinator).to receive(:process_result).and_return(result) + + queue = Queue.new + transport.instance_variable_get(:@pending_requests)["123"] = queue + + result_thread = Thread.new { queue.pop } + + transport.send(:process_message_event, raw_event) + + received = result_thread.value + expect(received).to eq(result) + end + end + end + describe "#parse_event" do let(:transport) do RubyLLM::MCP::Native::Transports::SSE.new( diff --git a/spec/ruby_llm/mcp/native/transports/stdio_spec.rb b/spec/ruby_llm/mcp/native/transports/stdio_spec.rb index 1f50eca..f8476b7 100644 --- a/spec/ruby_llm/mcp/native/transports/stdio_spec.rb +++ b/spec/ruby_llm/mcp/native/transports/stdio_spec.rb @@ -20,6 +20,7 @@ transport.instance_variable_set(:@id_mutex, Mutex.new) transport.instance_variable_set(:@pending_requests, {}) transport.instance_variable_set(:@pending_mutex, Mutex.new) + transport.instance_variable_set(:@state_mutex, Mutex.new) transport.instance_variable_set(:@running, true) transport.instance_variable_set(:@stdin, mock_stdin) transport.instance_variable_set(:@stdout, mock_stdout) @@ -40,7 +41,8 @@ allow(mock_stdout).to receive(:close) allow(mock_stderr).to receive(:close) allow(mock_wait_thread).to receive(:join).with(1) - allow(mock_wait_thread).to receive(:alive?).and_return(true) + allow(mock_wait_thread).to receive(:join).with(2) + allow(mock_wait_thread).to receive_messages(alive?: false, pid: 12_345) allow(coordinator).to receive(:process_result) end @@ -72,11 +74,27 @@ end it "returns false when transport is not running" do - mock_transport.instance_variable_set(:@running, false) + mock_transport.instance_variable_get(:@state_mutex).synchronize do + mock_transport.instance_variable_set(:@running, false) + end expect(mock_transport.alive?).to be(false) end end + describe "#running?" do + it "safely checks running state with mutex" do + mock_transport.instance_variable_get(:@state_mutex).synchronize do + mock_transport.instance_variable_set(:@running, true) + end + expect(mock_transport.running?).to be(true) + + mock_transport.instance_variable_get(:@state_mutex).synchronize do + mock_transport.instance_variable_set(:@running, false) + end + expect(mock_transport.running?).to be(false) + end + end + describe "#close" do it "sets running to false" do mock_transport.close @@ -143,7 +161,6 @@ it "raises TransportError on IOError" do request_body = { "method" => "test" } allow(mock_stdin).to receive(:puts).and_raise(IOError.new("Broken pipe")) - allow(mock_transport).to receive(:restart_process) expect { mock_transport.request(request_body) }.to raise_error(RubyLLM::MCP::Errors::TransportError) do |error| expect(error.message).to include("Broken pipe") @@ -154,7 +171,6 @@ it "raises TransportError on EPIPE" do request_body = { "method" => "test" } allow(mock_stdin).to receive(:puts).and_raise(Errno::EPIPE.new("Broken pipe")) - allow(mock_transport).to receive(:restart_process) expect { mock_transport.request(request_body) }.to raise_error(RubyLLM::MCP::Errors::TransportError) do |error| expect(error.message).to include("Broken pipe") @@ -162,6 +178,14 @@ end end + it "raises ArgumentError when request_id is nil with wait_for_response" do + request_body = { "method" => "test" } + + expect do + mock_transport.request(request_body, add_id: false, wait_for_response: true) + end.to raise_error(ArgumentError, /Request ID must be provided/) + end + it "raises TimeoutError when request times out" do request_body = { "method" => "test" } allow(mock_stdin).to receive(:puts) @@ -279,12 +303,32 @@ expect(transport.instance_variable_get(:@args)).to eq(%w[arg1 arg2]) end - it "stores environment variables correctly" do - transport = described_class.allocate - test_env = { "TEST_VAR" => "test_value" } - transport.instance_variable_set(:@env, test_env) + it "merges user environment variables with default environment" do + transport = described_class.new( + command: "echo", + coordinator: coordinator, + request_timeout: 5000, + env: { "TEST_VAR" => "test_value", "PATH" => "/custom/path" } + ) + + stored_env = transport.instance_variable_get(:@env) + expect(stored_env["TEST_VAR"]).to eq("test_value") + expect(stored_env["PATH"]).to eq("/custom/path") # User override takes precedence + # Should still have other default env vars + expect(stored_env.keys.length).to be > 2 + end + + it "uses default environment when no custom env provided" do + transport = described_class.new( + command: "echo", + coordinator: coordinator, + request_timeout: 5000 + ) - expect(transport.instance_variable_get(:@env)).to eq(test_env) + stored_env = transport.instance_variable_get(:@env) + # Should have default environment variables + expect(stored_env.keys.length).to be > 0 + expect(stored_env).to include(ENV.to_h) end it "stores request timeout correctly" do @@ -295,12 +339,40 @@ end end - describe "process restart behavior" do - it "can handle restart scenarios" do - allow(mock_transport).to receive(:start_process) - allow(mock_transport).to receive(:close) + describe "error handling and closing" do + it "fails pending requests when closing with error" do + queue1 = Queue.new + queue2 = Queue.new + mock_transport.instance_variable_get(:@pending_requests)["1"] = queue1 + mock_transport.instance_variable_get(:@pending_requests)["2"] = queue2 + + error = RubyLLM::MCP::Errors::TransportError.new(message: "Test error") + mock_transport.send(:fail_pending_requests!, error) - expect { mock_transport.send(:restart_process) }.not_to raise_error + expect(queue1.pop).to eq(error) + expect(queue2.pop).to eq(error) + expect(mock_transport.instance_variable_get(:@pending_requests)).to be_empty + end + + it "closes transport on stream error when running" do + allow(mock_transport).to receive(:running?).and_return(true) + allow(mock_transport).to receive(:safe_close_with_error) + + error = IOError.new("Test error") + mock_transport.send(:handle_stream_error, error, "Test stream") + + expect(mock_transport).to have_received(:safe_close_with_error).with(error) + end + + it "does not close transport on stream error when not running" do + allow(mock_transport).to receive(:running?).and_return(false) + allow(mock_transport).to receive(:safe_close_with_error) + allow(RubyLLM::MCP.logger).to receive(:debug) + + error = IOError.new("Test error") + mock_transport.send(:handle_stream_error, error, "Test stream") + + expect(mock_transport).not_to have_received(:safe_close_with_error) end end @@ -308,7 +380,7 @@ let(:real_stdin) { IO.pipe[1] } let(:real_stdout) { IO.pipe[0] } let(:real_stderr) { IO.pipe[0] } - let(:real_wait_thread) { instance_double(Process::Waiter, alive?: true, join: nil) } + let(:real_wait_thread) { instance_double(Process::Waiter, alive?: true, join: nil, pid: 12_345) } let(:transport_with_threads) do transport = described_class.allocate @@ -322,6 +394,7 @@ transport.instance_variable_set(:@id_mutex, Mutex.new) transport.instance_variable_set(:@pending_requests, {}) transport.instance_variable_set(:@pending_mutex, Mutex.new) + transport.instance_variable_set(:@state_mutex, Mutex.new) transport.instance_variable_set(:@running, true) transport.instance_variable_set(:@stdin, real_stdin) transport.instance_variable_set(:@stdout, real_stdout) @@ -364,7 +437,9 @@ allow(RubyLLM::MCP.logger).to receive(:error) { error_calls += 1 } allow(RubyLLM::MCP.logger).to receive(:debug) { debug_calls += 1 } - transport_with_threads.instance_variable_set(:@running, false) + transport_with_threads.instance_variable_get(:@state_mutex).synchronize do + transport_with_threads.instance_variable_set(:@running, false) + end real_stdout.close sleep 0.2 @@ -382,7 +457,9 @@ allow(RubyLLM::MCP.logger).to receive(:error) { error_calls += 1 } allow(RubyLLM::MCP.logger).to receive(:debug) { debug_calls += 1 } allow(RubyLLM::MCP.logger).to receive(:info) - transport_with_threads.instance_variable_set(:@running, false) + transport_with_threads.instance_variable_get(:@state_mutex).synchronize do + transport_with_threads.instance_variable_set(:@running, false) + end real_stderr.close sleep 0.2 @@ -411,7 +488,7 @@ expect(stderr_thread.join(2)).to eq(stderr_thread), "Stderr thread should exit cleanly" end - it "logs errors and restarts when @running is true and stream closes unexpectedly" do + it "closes transport when @running is true and stream closes unexpectedly" do transport_with_threads.send(:start_reader_thread) # Give thread time to start sleep 0.1 @@ -422,17 +499,19 @@ end allow(RubyLLM::MCP.logger).to receive(:debug) - restart_called = false - allow(transport_with_threads).to receive(:restart_process) do - restart_called = true - transport_with_threads.instance_variable_set(:@running, false) + close_called = false + allow(transport_with_threads).to receive(:safe_close_with_error) do |_error| + close_called = true + transport_with_threads.instance_variable_get(:@state_mutex).synchronize do + transport_with_threads.instance_variable_set(:@running, false) + end end real_stdout.close deadline = Time.now + 2 - sleep 0.1 until restart_called || Time.now > deadline + sleep 0.1 until close_called || Time.now > deadline - expect(restart_called).to be(true), "Expected restart_process to be called within 2 seconds" + expect(close_called).to be(true), "Expected safe_close_with_error to be called within 2 seconds" end end end diff --git a/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb b/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb index 54d80ad..49c947a 100644 --- a/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb +++ b/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb @@ -360,9 +360,9 @@ end end - it "respects abort controller in SSE processing" do + it "respects sse_stopped flag in SSE processing" do allow(mock_coordinator).to receive(:process_result) - transport.instance_variable_set(:@abort_controller, true) + transport.instance_variable_set(:@sse_stopped, true) raw_event = { data: '{"method": "test"}' } @@ -382,9 +382,9 @@ expect(mock_coordinator).not_to have_received(:process_result) end - it "handles SSE buffer events with abort controller" do + it "handles SSE buffer events with sse_stopped flag" do allow(transport).to receive(:extract_sse_event) - transport.instance_variable_set(:@abort_controller, true) + transport.instance_variable_set(:@sse_stopped, true) buffer = +"data: test\n\n" @@ -651,8 +651,8 @@ end.to raise_error(RubyLLM::MCP::Errors::TransportError, /Failed to open SSE stream: 400/) end - it "stops retrying when abort controller is set" do - transport.instance_variable_set(:@abort_controller, true) + it "stops retrying when sse_stopped flag is set" do + transport.instance_variable_set(:@sse_stopped, true) stub_request(:get, TestServerManager::HTTP_SERVER_URL) .with(headers: { "Accept" => "text/event-stream" }) @@ -1195,4 +1195,567 @@ end end end + + describe "reconnection options precedence" do + it "uses explicit reconnection_options when provided" do + explicit_options = RubyLLM::MCP::Native::Transports::ReconnectionOptions.new( + max_retries: 5, + initial_reconnection_delay: 500 + ) + + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + reconnection_options: explicit_options, + options: { reconnection: { max_retries: 1 } } + ) + + reconnection_opts = transport.instance_variable_get(:@reconnection_options) + expect(reconnection_opts.max_retries).to eq(5) + expect(reconnection_opts.initial_reconnection_delay).to eq(500) + end + + it "uses reconnection hash when reconnection_options not provided" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + options: { reconnection: { max_retries: 3, initial_reconnection_delay: 200 } } + ) + + reconnection_opts = transport.instance_variable_get(:@reconnection_options) + expect(reconnection_opts.max_retries).to eq(3) + expect(reconnection_opts.initial_reconnection_delay).to eq(200) + end + + it "uses defaults when neither reconnection_options nor reconnection provided" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + options: {} + ) + + reconnection_opts = transport.instance_variable_get(:@reconnection_options) + expect(reconnection_opts.max_retries).to eq(2) + expect(reconnection_opts.initial_reconnection_delay).to eq(1_000) + end + + it "uses defaults when reconnection hash is empty" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + options: { reconnection: {} } + ) + + reconnection_opts = transport.instance_variable_get(:@reconnection_options) + expect(reconnection_opts.max_retries).to eq(2) + expect(reconnection_opts.initial_reconnection_delay).to eq(1_000) + end + end + + describe "resumable SSE with last event ID tracking" do + before do + WebMock.enable! + end + + after do + WebMock.reset! + WebMock.enable! + end + + it "tracks last SSE event ID" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + options: {} + ) + + expect(transport.instance_variable_get(:@last_sse_event_id)).to be_nil + + # Simulate SSE event with ID + raw_event = { data: '{"method": "test"}', id: "event-123" } + RubyLLM::MCP::Native::Transports::StartSSEOptions.new + + allow(mock_coordinator).to receive(:process_result) + transport.send(:process_sse_event, raw_event, nil) + + # Last event ID should not be tracked in process_sse_event, only in callback + # This is tracked in add_on_response_body_chunk_callback + end + + it "includes last event ID in reconnection headers" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 1000, + options: { reconnection: { max_retries: 1 } } + ) + + # Set a last event ID + transport.instance_variable_set(:@last_sse_event_id, "event-456") + + # Mock the connection to fail once, then succeed + call_count = 0 + stub_request(:get, TestServerManager::HTTP_SERVER_URL) + .with(headers: { "Accept" => "text/event-stream" }) + .to_return do |request| + call_count += 1 + if call_count == 1 + { status: 500 } + else + # Check that Last-Event-ID header is present + expect(request.headers["Last-Event-Id"]).to eq("event-456") + { status: 200, headers: { "Content-Type" => "text/event-stream" } } + end + end + + options = RubyLLM::MCP::Native::Transports::StartSSEOptions.new + transport.send(:start_sse, options) + end + end + + describe "separate timeouts for requests vs SSE" do + it "uses request_timeout for regular requests" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 3000, + options: {} + ) + + connection = transport.instance_variable_get(:@connection) + timeout_config = connection.instance_variable_get(:@options).timeout + + expect(timeout_config[:read_timeout]).to eq(3.0) + end + + it "uses sse_timeout for SSE connections when provided" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 3000, + sse_timeout: 10_000, + options: {} + ) + + expect(transport.instance_variable_get(:@sse_timeout)).to eq(10_000) + end + + it "uses default long timeout for SSE when sse_timeout not provided" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 3000, + options: {} + ) + + expect(transport.instance_variable_get(:@sse_timeout)).to be_nil + # Default should be 1 hour (3600 seconds) in create_connection_with_sse_callbacks + end + end + + describe "SSE state management" do + it "uses sse_stopped instead of abort_controller" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + options: {} + ) + + expect(transport.instance_variable_get(:@sse_stopped)).to be(false) + expect(transport.send(:running?)).to be(true) + + transport.send(:abort!) + + expect(transport.instance_variable_get(:@sse_stopped)).to be(true) + expect(transport.send(:running?)).to be(false) + end + + it "provides on_message hook" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + options: {} + ) + + messages = [] + transport.on_message { |msg| messages << msg } + + allow(mock_coordinator).to receive(:process_result).and_return(nil) + + raw_event = { data: '{"method": "test"}' } + transport.send(:process_sse_event, raw_event, nil) + + expect(messages.size).to eq(1) + expect(messages.first).to be_a(RubyLLM::MCP::Result) + end + + it "provides on_error hook" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + options: {} + ) + + errors = [] + transport.on_error { |err| errors << err } + + raw_event = { data: "invalid json" } + transport.send(:process_sse_event, raw_event, nil) + + expect(errors.size).to eq(1) + expect(errors.first).to be_a(JSON::ParserError) + end + + it "provides on_close hook" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + options: {} + ) + + close_called = false + transport.on_close { close_called = true } + + transport.send(:cleanup_sse_resources) + + expect(close_called).to be(true) + end + end + + describe "enhanced SSE event logging" do + it "logs event type and ID when processing SSE events" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + options: {} + ) + + allow(mock_coordinator).to receive(:process_result).and_return(nil) + + raw_event = { data: '{"method": "test"}', event: "notification", id: "evt-789" } + transport.send(:process_sse_event, raw_event, nil) + + expect(logger).to have_received(:debug).with(/Processing SSE event: type=notification, id=evt-789/) + end + + it "logs when SSE event matches pending request" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + options: {} + ) + + request_id = "req-999" + response_queue = Queue.new + transport.instance_variable_get(:@pending_mutex).synchronize do + transport.instance_variable_get(:@pending_requests)[request_id] = response_queue + end + + mock_result = instance_double(RubyLLM::MCP::Result) + allow(mock_result).to receive(:id).and_return(request_id) + allow(mock_coordinator).to receive(:process_result).and_return(mock_result) + + raw_event = { data: "{\"id\": \"#{request_id}\"}" } + + # Start thread to consume the queue + Thread.new { response_queue.pop } + sleep(0.1) + + transport.send(:process_sse_event, raw_event, nil) + + expect(logger).to have_received(:debug).with(/Matched SSE event to pending request: #{request_id}/) + end + + it "logs when no pending request found for SSE event" do + transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + coordinator: mock_coordinator, + request_timeout: 5000, + options: {} + ) + + request_id = "req-888" + mock_result = instance_double(RubyLLM::MCP::Result) + allow(mock_result).to receive(:id).and_return(request_id) + allow(mock_coordinator).to receive(:process_result).and_return(mock_result) + + raw_event = { data: "{\"id\": \"#{request_id}\"}" } + transport.send(:process_sse_event, raw_event, nil) + + expect(logger).to have_received(:debug).with(/No pending request found for SSE event: #{request_id}/) + end + end + + describe "thread safety improvements" do + describe "state flag synchronization" do + it "provides thread-safe running? check" do + expect(transport.send(:running?)).to be(true) + + # Simulate concurrent access + threads = 10.times.map do + Thread.new { transport.send(:running?) } + end + + results = threads.map(&:value) + expect(results).to all(be(true)) + end + + it "provides thread-safe abort! method" do + expect(transport.send(:running?)).to be(true) + + transport.send(:abort!) + + expect(transport.send(:running?)).to be(false) + expect(transport).not_to be_alive + end + + it "guards chunk callbacks when flags flip mid-stream" do + # Set up a mock callback scenario + allow(mock_coordinator).to receive(:process_result) + + # Simulate running state + expect(transport.send(:running?)).to be(true) + + # Now abort + transport.send(:abort!) + + # Callbacks should respect the running? check + raw_event = { data: '{"method": "test"}' } + transport.send(:process_sse_event, raw_event, nil) + + # Should not process when not running + expect(mock_coordinator).not_to have_received(:process_result) + end + + it "handles concurrent state changes safely" do + threads = [] + + # Multiple threads trying to check state + 5.times do + threads << Thread.new { transport.send(:running?) } + end + + # One thread trying to abort + threads << Thread.new { transport.send(:abort!) } + + # More threads checking state + 5.times do + threads << Thread.new { transport.send(:running?) } + end + + # Should not raise any errors + expect { threads.each(&:join) }.not_to raise_error + end + end + + describe "cooperative SSE shutdown" do + let(:mock_thread) { instance_double(Thread) } + + before do + WebMock.enable! + end + + after do + WebMock.reset! + WebMock.enable! + end + + it "attempts cooperative join before killing thread" do + # Set up a mock SSE thread + allow(mock_thread).to receive(:alive?).and_return(true) + allow(mock_thread).to receive(:join).with(5).and_return(mock_thread) + transport.instance_variable_set(:@sse_thread, mock_thread) + + transport.send(:cleanup_sse_resources) + + # Should have called join (cooperative shutdown) + expect(mock_thread).to have_received(:join).with(5) + end + + it "uses kill only as fallback when join times out" do + # Set up a mock SSE thread that doesn't join + allow(mock_thread).to receive(:alive?).and_return(true) + allow(mock_thread).to receive(:join).with(5).and_return(nil) # Timeout + allow(mock_thread).to receive(:join).with(1).and_return(mock_thread) + allow(mock_thread).to receive(:kill) + transport.instance_variable_set(:@sse_thread, mock_thread) + + transport.send(:cleanup_sse_resources) + + # Should have tried join first, then killed + expect(mock_thread).to have_received(:join).with(5) + expect(mock_thread).to have_received(:kill) + expect(logger).to have_received(:warn).with(/SSE thread did not exit cleanly/) + end + + it "closes all clients during cleanup to signal SSE thread" do + # Track client closing + client_count_before = transport.send(:active_clients_count) + expect(client_count_before).to be > 0 + + transport.send(:cleanup_sse_resources) + + # Clients should be closed (but not cleared yet - that's in cleanup_connection) + # The close_all_clients method should have been called + expect(transport.send(:active_clients_count)).to be > 0 # Not cleared yet + end + + it "sets abort flag under mutex during cleanup" do + expect(transport.send(:running?)).to be(true) + + transport.send(:cleanup_sse_resources) + + expect(transport.send(:running?)).to be(false) + end + end + + describe "pending request teardown with error sentinels" do + let(:request_id) { "test-request-123" } + let(:response_queue) { Queue.new } + + before do + WebMock.enable! + transport.instance_variable_get(:@pending_mutex).synchronize do + transport.instance_variable_get(:@pending_requests)[request_id] = response_queue + end + end + + after do + WebMock.reset! + WebMock.enable! + end + + it "pushes error object instead of closing queues" do + # Start a thread waiting on the queue + result_thread = Thread.new do + response_queue.pop + end + + # Give the thread time to start waiting + sleep(0.1) + + # Cleanup should push an error + transport.send(:drain_pending_requests_with_error) + + # The waiting thread should receive an error object + result = result_thread.value + expect(result).to be_a(RubyLLM::MCP::Errors::TransportError) + expect(result.message).to include("shutting down") + end + + it "does not raise ClosedQueueError" do + # Start a thread waiting on the queue + result_thread = Thread.new do + response_queue.pop + rescue ClosedQueueError + :closed_queue_error + end + + # Give the thread time to start waiting + sleep(0.1) + + # Cleanup should push an error, not close the queue + transport.send(:drain_pending_requests_with_error) + + result = result_thread.value + expect(result).not_to eq(:closed_queue_error) + expect(result).to be_a(RubyLLM::MCP::Errors::TransportError) + end + + it "clears all pending requests after pushing errors" do + pending_requests = transport.instance_variable_get(:@pending_requests) + expect(pending_requests).to have_key(request_id) + + transport.send(:drain_pending_requests_with_error) + + expect(pending_requests).to be_empty + end + + it "handles error when pushing to queue fails" do + # Create a queue that will raise an error + bad_queue = Queue.new + allow(bad_queue).to receive(:push).and_raise(StandardError.new("Queue error")) + + transport.instance_variable_get(:@pending_mutex).synchronize do + transport.instance_variable_get(:@pending_requests)["bad-request"] = bad_queue + end + + # Should not raise, just log + expect { transport.send(:drain_pending_requests_with_error) }.not_to raise_error + expect(logger).to have_received(:debug).with(/Error pushing shutdown error/) + end + + it "wait_for_response_with_timeout raises shutdown error sentinel" do + stub_request(:post, TestServerManager::HTTP_SERVER_URL) + .to_return( + status: 200, + headers: { "Content-Type" => "application/json" }, + body: '{"result": "ok"}' + ) + + # Push a shutdown error to the queue + shutdown_error = RubyLLM::MCP::Errors::TransportError.new( + message: "Transport is shutting down", + code: nil + ) + response_queue.push(shutdown_error) + + expect do + transport.send(:wait_for_response_with_timeout, request_id, response_queue) + end.to raise_error(RubyLLM::MCP::Errors::TransportError, /shutting down/) + end + end + + describe "full shutdown flow integration" do + before do + WebMock.enable! + end + + after do + WebMock.reset! + WebMock.enable! + end + + it "performs complete shutdown sequence correctly" do + # Add some pending requests + request_queue = Queue.new + transport.instance_variable_get(:@pending_mutex).synchronize do + transport.instance_variable_get(:@pending_requests)["req-1"] = request_queue + end + + # Verify initial state + expect(transport.send(:running?)).to be(true) + expect(transport.send(:active_clients_count)).to be > 0 + + # Perform full close + transport.close + + # Verify final state + expect(transport.send(:running?)).to be(false) + expect(transport.send(:active_clients_count)).to eq(0) + + # Pending requests should be cleared + pending = transport.instance_variable_get(:@pending_requests) + expect(pending).to be_empty + end + + it "handles close when already closed" do + transport.close + + # Second close should not raise + expect { transport.close }.not_to raise_error + end + end + end end From dcfbb6fb6eb2b83f9b7e45d766b99348af8eafcb Mon Sep 17 00:00:00 2001 From: Patrick Vice Date: Sun, 23 Nov 2025 20:59:37 -0500 Subject: [PATCH 2/2] removed some useless comments --- lib/ruby_llm/mcp/native/transports/sse.rb | 15 ----------- lib/ruby_llm/mcp/native/transports/stdio.rb | 6 ----- .../mcp/native/transports/streamable_http.rb | 27 ------------------- 3 files changed, 48 deletions(-) diff --git a/lib/ruby_llm/mcp/native/transports/sse.rb b/lib/ruby_llm/mcp/native/transports/sse.rb index 9271800..9f0eb1c 100644 --- a/lib/ruby_llm/mcp/native/transports/sse.rb +++ b/lib/ruby_llm/mcp/native/transports/sse.rb @@ -136,7 +136,6 @@ def close @sse_thread&.join(1) @sse_thread = nil - # Fail all pending requests fail_pending_requests!( Errors::TransportError.new( message: "SSE transport closed", @@ -144,7 +143,6 @@ def close ) ) - # Reset state @messages_url = nil end @@ -192,14 +190,12 @@ def start_sse_listener set_message_endpoint(endpoint) end rescue Errors::TimeoutError => e - # Clean up the pending request on timeout @pending_mutex.synchronize do @pending_requests.delete("endpoint") end RubyLLM::MCP.logger.error "Timeout waiting for endpoint event: #{e.message}" raise e rescue StandardError => e - # Clean up the pending request on any error @pending_mutex.synchronize do @pending_requests.delete("endpoint") end @@ -209,7 +205,6 @@ def start_sse_listener end def set_message_endpoint(endpoint) - # Handle both string endpoints and JSON payloads endpoint_url = if endpoint.is_a?(String) endpoint elsif endpoint.is_a?(Hash) @@ -283,8 +278,6 @@ def handle_client_error!(error_message, status_code) message: error_message, code: status_code ) - - # Close the transport (which will fail pending requests) close raise transport_error @@ -332,7 +325,6 @@ def process_buffered_event(event_buffer) end def read_error_body(response) - # Try to read the error body from the response body = "" begin response.each do |chunk| @@ -350,16 +342,12 @@ def handle_connection_error(message, error) error_message = "#{message}: #{error.message}" RubyLLM::MCP.logger.error "#{error_message}. Closing SSE transport." - # Create a transport error to fail pending requests transport_error = Errors::TransportError.new( message: error_message, code: nil ) - - # Close the transport (which will fail pending requests) close - # Notify coordinator if needed @coordinator&.handle_error(transport_error) end @@ -383,7 +371,6 @@ def handle_httpx_error_response!(response, context:) end def process_event(raw_event) - # Return if we believe that are getting a partial event return if raw_event[:data].nil? if raw_event[:event] == "endpoint" @@ -398,7 +385,6 @@ def process_endpoint_event(raw_event) event_data = raw_event[:data] return if event_data.nil? - # Try to parse as JSON first, fall back to string endpoint = begin JSON.parse(event_data) rescue JSON::ParserError @@ -417,7 +403,6 @@ def process_message_event(raw_event) event = begin JSON.parse(raw_event[:data]) rescue JSON::ParserError => e - # We can sometimes get partial events, so we will ignore them if @messages_url RubyLLM::MCP.logger.debug "Failed to parse SSE event data: #{raw_event[:data]} - #{e.message}" end diff --git a/lib/ruby_llm/mcp/native/transports/stdio.rb b/lib/ruby_llm/mcp/native/transports/stdio.rb index ea5f436..26337ff 100644 --- a/lib/ruby_llm/mcp/native/transports/stdio.rb +++ b/lib/ruby_llm/mcp/native/transports/stdio.rb @@ -204,7 +204,6 @@ def start_reader_thread @reader_thread = Thread.new do read_stdout_loop end - # Don't use abort_on_exception - handle errors cooperatively end def read_stdout_loop @@ -240,13 +239,10 @@ def handle_stdout_read end def handle_stream_error(error, stream_name) - # Check @running to distinguish graceful shutdown from unexpected errors. - # During shutdown, streams are closed intentionally and shouldn't trigger close. if running? RubyLLM::MCP.logger.error "#{stream_name} error: #{error.message}. Closing transport." safe_close_with_error(error) else - # Graceful shutdown in progress RubyLLM::MCP.logger.debug "#{stream_name} thread exiting during shutdown" end end @@ -255,7 +251,6 @@ def start_stderr_thread @stderr_thread = Thread.new do read_stderr_loop end - # Don't use abort_on_exception - handle errors cooperatively end def read_stderr_loop @@ -292,7 +287,6 @@ def process_response(line) result = @coordinator.process_result(result) return if result.nil? - # Handle regular responses (tool calls, etc.) @pending_mutex.synchronize do if result.matching_id?(request_id) && @pending_requests.key?(request_id) response_queue = @pending_requests.delete(request_id) diff --git a/lib/ruby_llm/mcp/native/transports/streamable_http.rb b/lib/ruby_llm/mcp/native/transports/streamable_http.rb index 0f28a9e..33c001c 100644 --- a/lib/ruby_llm/mcp/native/transports/streamable_http.rb +++ b/lib/ruby_llm/mcp/native/transports/streamable_http.rb @@ -151,7 +151,6 @@ def set_protocol_version(version) @protocol_version = version end - # Public hooks for SSE events (similar to TS transport) def on_message(&block) @on_message_callback = block end @@ -166,12 +165,10 @@ def on_close(&block) private - # Thread-safe check if transport is running and not stopped def running? @state_mutex.synchronize { @running && !@sse_stopped } end - # Thread-safe stop signal def abort! @state_mutex.synchronize do @running = false @@ -186,10 +183,8 @@ def terminate_session headers = build_common_headers response = @connection.delete(@url, headers: headers) - # Handle HTTPX error responses first handle_httpx_error_response!(response, context: { location: "terminating session" }) - # 405 Method Not Allowed is acceptable per spec unless [200, 405].include?(response.status) reason_phrase = response.respond_to?(:reason_phrase) ? response.reason_phrase : nil raise Errors::TransportError.new( @@ -213,7 +208,6 @@ def handle_httpx_error_response!(response, context:, allow_eof_for_sse: false) error = response.error - # Special handling for EOFError in SSE contexts if allow_eof_for_sse && error.is_a?(EOFError) RubyLLM::MCP.logger.info "SSE stream closed: #{response.error.message}" return :eof_handled @@ -263,7 +257,6 @@ def active_clients_count end def create_connection - # Use request_timeout for all timeout values (converted from ms to seconds) timeout_seconds = @request_timeout / 1000.0 client = Support::HTTPClient.connection.with( timeout: { @@ -285,7 +278,6 @@ def build_common_headers headers["X-CLIENT-ID"] = @client_id headers["Origin"] = @url.to_s - # Apply OAuth authorization if available if @oauth_provider RubyLLM::MCP.logger.debug "OAuth provider present, attempting to get token..." RubyLLM::MCP.logger.debug " Server URL: #{@oauth_provider.server_url}" @@ -326,7 +318,6 @@ def send_http_request(body, request_id, is_initialization: false) request_client = nil begin - # Set up connection with streaming callbacks if not initialization connection = if is_initialization @connection else @@ -366,10 +357,8 @@ def create_connection_with_streaming_callbacks(request_id) end def handle_response(response, request_id, original_message) - # Handle HTTPX error responses first handle_httpx_error_response!(response, context: { location: "handling response", request_id: request_id }) - # Extract session ID if present (only for successful responses) session_id = response.headers["mcp-session-id"] @session_id = session_id if session_id @@ -443,13 +432,10 @@ def handle_accepted_response(original_message) def handle_client_error(response) status_code = response.respond_to?(:status) ? response.status : "Unknown" - # Special handling for 403 with OAuth provider handle_oauth_authorization_error(response, status_code) if status_code == 403 && @oauth_provider - # Try to parse and handle structured JSON error handle_json_error_response(response, status_code) - # Fallback: generic error response_body = response.respond_to?(:body) ? response.body.to_s : "Unknown error" raise Errors::TransportError.new( code: status_code, @@ -481,7 +467,6 @@ def handle_json_error_response(response, status_code) error_message = error_body["error"]["message"] || error_body["error"]["code"] - # Handle empty error messages if error_message.to_s.empty? raise Errors::TransportError.new( code: status_code, @@ -489,7 +474,6 @@ def handle_json_error_response(response, status_code) ) end - # Handle session-related errors if error_message.to_s.downcase.include?("session") raise Errors::TransportError.new( code: response.status, @@ -497,7 +481,6 @@ def handle_json_error_response(response, status_code) ) end - # Generic JSON error raise Errors::TransportError.new( code: response.status, message: "Server error: #{error_message}" @@ -537,11 +520,9 @@ def start_sse(options) # rubocop:disable Metrics/MethodLength headers["Last-Event-ID"] = options.resumption_token end - # Set up SSE streaming connection with callbacks connection = create_connection_with_sse_callbacks(options, headers) response = connection.get(@url) - # Handle HTTPX error responses first error_result = handle_httpx_error_response!(response, context: { location: "SSE connection" }, allow_eof_for_sse: true) return if error_result == :eof_handled @@ -716,12 +697,10 @@ def process_sse_event(raw_event, replay_message_id) # rubocop:disable Metrics/Me begin event_data = JSON.parse(raw_event[:data]) - # Enhanced logging with event details event_type = raw_event[:event] || "message" event_id = raw_event[:id] RubyLLM::MCP.logger.debug "Processing SSE event: type=#{event_type}, id=#{event_id || 'none'}" - # Handle replay message ID if specified if replay_message_id && event_data.is_a?(Hash) && event_data["id"] event_data["id"] = replay_message_id end @@ -729,7 +708,6 @@ def process_sse_event(raw_event, replay_message_id) # rubocop:disable Metrics/Me result = RubyLLM::MCP::Result.new(event_data, session_id: @session_id) RubyLLM::MCP.logger.debug "SSE Result Received: #{result.inspect}" - # Call on_message hook if registered @on_message_callback&.call(result) result = @coordinator.process_result(result) @@ -782,7 +760,6 @@ def wait_for_response_with_timeout(request_id, response_queue) end def cleanup_sse_resources - # Set shutdown flags under mutex abort! # Call on_close hook if registered @@ -791,12 +768,9 @@ def cleanup_sse_resources # Close all HTTPX clients to signal SSE thread to exit close_all_clients - # Wait for SSE thread to exit cooperatively @sse_mutex.synchronize do if @sse_thread&.alive? - # Try to join the thread first (cooperative shutdown) unless @sse_thread.join(5) - # Only kill as last resort if join times out RubyLLM::MCP.logger.warn "SSE thread did not exit cleanly, forcing termination" @sse_thread.kill @sse_thread.join(1) @@ -805,7 +779,6 @@ def cleanup_sse_resources end end - # Drain pending requests with error instead of closing queues drain_pending_requests_with_error end