From 0cac06f9ba6ad16be683da7578d5666054c72682 Mon Sep 17 00:00:00 2001 From: Patrick Vice Date: Sun, 23 Nov 2025 17:00:03 -0500 Subject: [PATCH 1/2] add OAuth challenge handling with automatic retry - Add handle_authentication_challenge hooks to OAuth providers - Implement automatic retry logic in StreamableHTTP and SSE transports - Parse WWW-Authenticate headers for scope and resource metadata - Add infinite loop prevention guards - Add comprehensive test coverage Brings Ruby OAuth handling closer to parity with TypeScript client. --- examples/oauth/browser_oauth.rb | 3 + examples/oauth/custom_storage.rb | 3 + examples/oauth/standard_oauth.rb | 3 + .../mcp/auth/browser_oauth_provider.rb | 32 +++ lib/ruby_llm/mcp/auth/memory_storage.rb | 18 ++ lib/ruby_llm/mcp/auth/oauth_provider.rb | 81 ++++++++ lib/ruby_llm/mcp/native/transports/sse.rb | 138 ++++++++++++- .../mcp/native/transports/streamable_http.rb | 106 ++++++++-- .../mcp/auth/browser_oauth_provider_spec.rb | 95 +++++++++ spec/ruby_llm/mcp/auth/oauth_provider_spec.rb | 191 ++++++++++++++++++ .../mcp/native/transports/sse_spec.rb | 119 +++++++++++ .../native/transports/streamable_http_spec.rb | 140 +++++++++++++ 12 files changed, 907 insertions(+), 22 deletions(-) diff --git a/examples/oauth/browser_oauth.rb b/examples/oauth/browser_oauth.rb index 37fc7af..8cc5848 100755 --- a/examples/oauth/browser_oauth.rb +++ b/examples/oauth/browser_oauth.rb @@ -82,3 +82,6 @@ puts "4. Uncomment the authentication code above" puts "5. Run: ruby #{__FILE__}" puts "=" * 60 + + + diff --git a/examples/oauth/custom_storage.rb b/examples/oauth/custom_storage.rb index 81ed0ff..c286f11 100755 --- a/examples/oauth/custom_storage.rb +++ b/examples/oauth/custom_storage.rb @@ -275,3 +275,6 @@ def get_token(server_url) EXAMPLES puts "=" * 60 + + + diff --git a/examples/oauth/standard_oauth.rb b/examples/oauth/standard_oauth.rb index 9fbeb1d..07e9264 100755 --- a/examples/oauth/standard_oauth.rb +++ b/examples/oauth/standard_oauth.rb @@ -126,3 +126,6 @@ puts "4. Uncomment the interactive code above" puts "5. Run: ruby #{__FILE__}" puts "=" * 60 + + + diff --git a/lib/ruby_llm/mcp/auth/browser_oauth_provider.rb b/lib/ruby_llm/mcp/auth/browser_oauth_provider.rb index 669e598..c228b07 100644 --- a/lib/ruby_llm/mcp/auth/browser_oauth_provider.rb +++ b/lib/ruby_llm/mcp/auth/browser_oauth_provider.rb @@ -155,6 +155,38 @@ def complete_authorization_flow(code, state) @oauth_provider.complete_authorization_flow(code, state) end + # Handle authentication challenge with browser-based auth + # @param www_authenticate [String, nil] WWW-Authenticate header value + # @param resource_metadata_url [String, nil] Resource metadata URL from response + # @param requested_scope [String, nil] Scope from WWW-Authenticate challenge + # @return [Boolean] true if authentication was completed successfully + def handle_authentication_challenge(www_authenticate: nil, resource_metadata_url: nil, requested_scope: nil) + @logger.debug("BrowserOAuthProvider handling authentication challenge") + + # Try standard provider's automatic handling first (token refresh, client credentials) + begin + return @oauth_provider.handle_authentication_challenge( + www_authenticate: www_authenticate, + resource_metadata_url: resource_metadata_url, + requested_scope: requested_scope + ) + rescue Errors::AuthenticationRequiredError + # Standard provider couldn't handle it - need interactive auth + @logger.info("Automatic authentication failed, starting browser-based OAuth flow") + end + + # Perform full browser-based authentication + authenticate(auto_open_browser: true) + true + end + + # Parse WWW-Authenticate header (delegate to oauth_provider) + # @param header [String] WWW-Authenticate header value + # @return [Hash] parsed challenge information + def parse_www_authenticate(header) + @oauth_provider.parse_www_authenticate(header) + end + private # Validate and synchronize redirect_uri between this provider and oauth_provider diff --git a/lib/ruby_llm/mcp/auth/memory_storage.rb b/lib/ruby_llm/mcp/auth/memory_storage.rb index e439627..29d4694 100644 --- a/lib/ruby_llm/mcp/auth/memory_storage.rb +++ b/lib/ruby_llm/mcp/auth/memory_storage.rb @@ -12,6 +12,7 @@ def initialize @server_metadata = {} @pkce_data = {} @state_data = {} + @resource_metadata = {} end # Token storage @@ -23,6 +24,10 @@ def set_token(server_url, token) @tokens[server_url] = token end + def delete_token(server_url) + @tokens.delete(server_url) + end + # Client registration storage def get_client_info(server_url) @client_infos[server_url] @@ -66,6 +71,19 @@ def set_state(server_url, state) def delete_state(server_url) @state_data.delete(server_url) end + + # Resource metadata management + def get_resource_metadata(server_url) + @resource_metadata[server_url] + end + + def set_resource_metadata(server_url, metadata) + @resource_metadata[server_url] = metadata + end + + def delete_resource_metadata(server_url) + @resource_metadata.delete(server_url) + end end end end diff --git a/lib/ruby_llm/mcp/auth/oauth_provider.rb b/lib/ruby_llm/mcp/auth/oauth_provider.rb index 6536ec8..5b11ac2 100644 --- a/lib/ruby_llm/mcp/auth/oauth_provider.rb +++ b/lib/ruby_llm/mcp/auth/oauth_provider.rb @@ -150,6 +150,87 @@ def apply_authorization(request) request.headers["Authorization"] = token.to_header end + # Handle authentication challenge from server (401 response) + # Attempts to refresh token or raises error if interactive auth required + # @param www_authenticate [String, nil] WWW-Authenticate header value + # @param resource_metadata_url [String, nil] Resource metadata URL from response + # @param requested_scope [String, nil] Scope from WWW-Authenticate challenge + # @return [Boolean] true if authentication was refreshed successfully + # @raise [Errors::AuthenticationRequiredError] if interactive auth is required + def handle_authentication_challenge(www_authenticate: nil, resource_metadata_url: nil, requested_scope: nil) + logger.debug("Handling authentication challenge") + logger.debug(" WWW-Authenticate: #{www_authenticate}") if www_authenticate + logger.debug(" Resource metadata URL: #{resource_metadata_url}") if resource_metadata_url + logger.debug(" Requested scope: #{requested_scope}") if requested_scope + + # Parse WWW-Authenticate header if provided + if www_authenticate + challenge_info = parse_www_authenticate(www_authenticate) + resource_metadata_url ||= challenge_info[:resource_metadata_url] + requested_scope ||= challenge_info[:scope] + end + + # Update scope if server requested different scope + if requested_scope && requested_scope != scope + logger.debug("Updating scope from '#{scope}' to '#{requested_scope}'") + self.scope = requested_scope + end + + # Try to refresh existing token + token = storage.get_token(server_url) + if token&.refresh_token + logger.debug("Attempting token refresh with existing refresh token") + refreshed_token = refresh_token(token) + return true if refreshed_token + end + + # If we have client credentials, try that flow + if grant_type == :client_credentials + logger.debug("Attempting client credentials flow") + begin + new_token = client_credentials_flow(scope: requested_scope) + return true if new_token + rescue StandardError => e + logger.warn("Client credentials flow failed: #{e.message}") + end + end + + # Cannot automatically authenticate - interactive auth required + logger.warn("Cannot automatically authenticate - interactive authorization required") + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required. Token refresh failed and interactive authorization is needed." + ) + end + + # Parse WWW-Authenticate header to extract challenge parameters + # @param header [String] WWW-Authenticate header value + # @return [Hash] parsed challenge information + def parse_www_authenticate(header) + result = {} + + # Example: Bearer realm="example", scope="mcp:read mcp:write", resource_metadata_url="https://..." + if header =~ /Bearer\s+(.+)/i + params = ::Regexp.last_match(1) + + # Extract scope + if params =~ /scope="([^"]+)"/ + result[:scope] = ::Regexp.last_match(1) + end + + # Extract resource metadata URL + if params =~ /resource_metadata_url="([^"]+)"/ + result[:resource_metadata_url] = ::Regexp.last_match(1) + end + + # Extract realm + if params =~ /realm="([^"]+)"/ + result[:realm] = ::Regexp.last_match(1) + end + end + + result + end + private # Create HTTP client for OAuth requests diff --git a/lib/ruby_llm/mcp/native/transports/sse.rb b/lib/ruby_llm/mcp/native/transports/sse.rb index e6d2ea2..d2fb5b1 100644 --- a/lib/ruby_llm/mcp/native/transports/sse.rb +++ b/lib/ruby_llm/mcp/native/transports/sse.rb @@ -9,13 +9,17 @@ class SSE attr_reader :headers, :id, :coordinator - def initialize(url:, coordinator:, request_timeout:, version: :http2, headers: {}) + def initialize(url:, coordinator:, request_timeout:, version: :http2, headers: {}, oauth_provider: nil, options: {}) @event_url = url @messages_url = nil @coordinator = coordinator @request_timeout = request_timeout @version = version + # Extract oauth_provider from options if present + extracted_options = options.dup + oauth_provider = extracted_options.delete(:oauth_provider) || oauth_provider + uri = URI.parse(url) @root_url = "#{uri.scheme}://#{uri.host}" @root_url += ":#{uri.port}" if uri.port != uri.default_port @@ -28,6 +32,10 @@ def initialize(url:, coordinator:, request_timeout:, version: :http2, headers: { "X-CLIENT-ID" => @client_id }) + @oauth_provider = oauth_provider + @resource_metadata_url = nil + @auth_retry_attempted = false + @id_counter = 0 @id_mutex = Mutex.new @pending_requests = {} @@ -100,13 +108,20 @@ def set_protocol_version(version) private def send_request(body, request_id) + headers = build_request_headers http_client = Support::HTTPClient.connection.with(timeout: { request_timeout: @request_timeout / 1000 }, - headers: @headers) + headers: headers) response = http_client.post(@messages_url, body: JSON.generate(body)) handle_httpx_error_response!(response, context: { location: "message endpoint request", request_id: request_id }) - unless [200, 202].include?(response.status) + case response.status + when 200, 202 + # Success + return + when 401 + handle_authentication_challenge(response, body, request_id) + else message = "Failed to have a successful request to #{@messages_url}: #{response.status} - #{response.body}" RubyLLM::MCP.logger.error(message) raise Errors::TransportError.new( @@ -116,6 +131,82 @@ def send_request(body, request_id) end end + def build_request_headers + headers = @headers.dup + + # Apply OAuth authorization if available + if @oauth_provider + RubyLLM::MCP.logger.debug "OAuth provider present, attempting to get token..." + token = @oauth_provider.access_token + if token + headers["Authorization"] = token.to_header + RubyLLM::MCP.logger.debug "Applied OAuth authorization header" + else + RubyLLM::MCP.logger.warn "OAuth provider present but no valid token available!" + end + end + + headers + end + + def handle_authentication_challenge(response, original_body, request_id) + # If we've already attempted auth retry, don't try again + if @auth_retry_attempted + RubyLLM::MCP.logger.warn("Authentication retry already attempted, raising error") + @auth_retry_attempted = false + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required (401 Unauthorized) - retry failed" + ) + end + + unless @oauth_provider + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required (401 Unauthorized) but no OAuth provider configured" + ) + end + + RubyLLM::MCP.logger.info("Received 401 Unauthorized, attempting automatic authentication") + + www_authenticate = response.headers["www-authenticate"] + resource_metadata_url = response.headers["mcp-resource-metadata-url"] + @resource_metadata_url = resource_metadata_url if resource_metadata_url + + begin + @auth_retry_attempted = true + + success = @oauth_provider.handle_authentication_challenge( + www_authenticate: www_authenticate, + resource_metadata_url: resource_metadata_url, + requested_scope: nil + ) + + if success + RubyLLM::MCP.logger.info("Authentication challenge handled successfully, retrying request") + + # Retry the original request (flag stays true to prevent loop) + send_request(original_body, request_id) + + # Only reset flag after successful retry + @auth_retry_attempted = false + return + end + rescue Errors::AuthenticationRequiredError => e + @auth_retry_attempted = false + raise e + rescue StandardError => e + @auth_retry_attempted = false + RubyLLM::MCP.logger.error("Authentication challenge handling failed: #{e.message}") + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication failed: #{e.message}" + ) + end + + @auth_retry_attempted = false + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required (401 Unauthorized)" + ) + end + def start_sse_listener @connection_mutex.synchronize do return if sse_thread_running? @@ -178,6 +269,12 @@ def create_sse_client def validate_sse_response!(response) return unless response.status >= 400 + # Handle 401 specially for OAuth + if response.status == 401 + handle_sse_authentication_challenge(response) + return + end + error_body = read_error_body(response) error_message = "HTTP #{response.status} error from SSE endpoint: #{error_body}" RubyLLM::MCP.logger.error error_message @@ -187,6 +284,41 @@ def validate_sse_response!(response) raise StandardError, error_message end + def handle_sse_authentication_challenge(response) + unless @oauth_provider + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required for SSE stream but no OAuth provider configured" + ) + end + + RubyLLM::MCP.logger.info("SSE stream received 401, attempting authentication") + + www_authenticate = response.headers["www-authenticate"] + resource_metadata_url = response.headers["mcp-resource-metadata-url"] + + begin + success = @oauth_provider.handle_authentication_challenge( + www_authenticate: www_authenticate, + resource_metadata_url: resource_metadata_url, + requested_scope: nil + ) + + if success + RubyLLM::MCP.logger.info("Authentication successful, SSE stream will reconnect") + # The caller will retry the SSE connection + return + end + rescue Errors::AuthenticationRequiredError + raise + rescue StandardError => e + RubyLLM::MCP.logger.error("SSE authentication failed: #{e.message}") + end + + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required for SSE stream" + ) + end + def handle_client_error!(error_message, status_code) @running = false raise Errors::TransportError.new( diff --git a/lib/ruby_llm/mcp/native/transports/streamable_http.rb b/lib/ruby_llm/mcp/native/transports/streamable_http.rb index b938c0d..03f38b3 100644 --- a/lib/ruby_llm/mcp/native/transports/streamable_http.rb +++ b/lib/ruby_llm/mcp/native/transports/streamable_http.rb @@ -73,21 +73,24 @@ def initialize( # rubocop:disable Metrics/ParameterLists @protocol_version = nil @session_id = session_id - @resource_metadata_url = nil - @client_id = SecureRandom.uuid - - @reconnection_options = ReconnectionOptions.new(**reconnection) - @oauth_provider = oauth_provider - @rate_limiter = Support::RateLimiter.new(**rate_limit) if rate_limit - - @id_counter = 0 - @id_mutex = Mutex.new - @pending_requests = {} - @pending_mutex = Mutex.new - @running = true - @abort_controller = nil - @sse_thread = nil - @sse_mutex = Mutex.new + @resource_metadata_url = nil + @client_id = SecureRandom.uuid + + @reconnection_options = ReconnectionOptions.new(**reconnection) + @oauth_provider = oauth_provider + @rate_limiter = Support::RateLimiter.new(**rate_limit) if rate_limit + + @id_counter = 0 + @id_mutex = Mutex.new + @pending_requests = {} + @pending_mutex = Mutex.new + @running = true + @abort_controller = nil + @sse_thread = nil + @sse_mutex = Mutex.new + + # Track if we've attempted auth flow to prevent infinite loops + @auth_retry_attempted = false # Thread-safe collection of all HTTPX clients @clients = [] @@ -342,10 +345,7 @@ def handle_response(response, request_id, original_message) when 404 handle_session_expired when 401 - # OAuth authentication required - raise Errors::AuthenticationRequiredError.new( - message: "OAuth authentication required (401 Unauthorized)" - ) + handle_authentication_challenge(response, request_id, original_message) when 405 # Method not allowed - acceptable for some endpoints nil @@ -481,9 +481,77 @@ def extract_resource_metadata_url(response) return nil unless response.respond_to?(:headers) metadata_url = response.headers["mcp-resource-metadata-url"] + if metadata_url + @resource_metadata_url = metadata_url + RubyLLM::MCP.logger.debug("Extracted resource metadata URL: #{metadata_url}") + end metadata_url ? URI(metadata_url) : nil end + def handle_authentication_challenge(response, request_id, original_message) + # If we've already attempted auth retry, don't try again (prevent infinite loop) + if @auth_retry_attempted + RubyLLM::MCP.logger.warn("Authentication retry already attempted, raising error") + @auth_retry_attempted = false # Reset for next request + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required (401 Unauthorized) - retry failed" + ) + end + + # No OAuth provider configured - can't handle challenge + unless @oauth_provider + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required (401 Unauthorized) but no OAuth provider configured" + ) + end + + RubyLLM::MCP.logger.info("Received 401 Unauthorized, attempting automatic authentication") + + # Extract challenge information from response + www_authenticate = response.headers["www-authenticate"] + resource_metadata_url = extract_resource_metadata_url(response) + + begin + # Set flag to prevent infinite retry loop + @auth_retry_attempted = true + + # Ask OAuth provider to handle the challenge + success = @oauth_provider.handle_authentication_challenge( + www_authenticate: www_authenticate, + resource_metadata_url: resource_metadata_url&.to_s, + requested_scope: nil + ) + + if success + RubyLLM::MCP.logger.info("Authentication challenge handled successfully, retrying request") + + # Retry the original request with new auth (flag stays true to prevent loop) + result = send_http_request(original_message, request_id, is_initialization: false) + + # Only reset flag after successful retry + @auth_retry_attempted = false + return result + end + rescue Errors::AuthenticationRequiredError => e + # Reset flag and re-raise + @auth_retry_attempted = false + raise e + rescue StandardError => e + # Reset flag and wrap error + @auth_retry_attempted = false + RubyLLM::MCP.logger.error("Authentication challenge handling failed: #{e.message}") + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication failed: #{e.message}" + ) + end + + # If we get here, authentication didn't succeed + @auth_retry_attempted = false + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required (401 Unauthorized)" + ) + end + def start_sse_stream(options = StartSSEOptions.new) return unless @running && !@abort_controller diff --git a/spec/ruby_llm/mcp/auth/browser_oauth_provider_spec.rb b/spec/ruby_llm/mcp/auth/browser_oauth_provider_spec.rb index 1d71b21..51c970a 100644 --- a/spec/ruby_llm/mcp/auth/browser_oauth_provider_spec.rb +++ b/spec/ruby_llm/mcp/auth/browser_oauth_provider_spec.rb @@ -1190,4 +1190,99 @@ end end end + + describe "#handle_authentication_challenge" do + let(:browser_oauth) do + described_class.new(oauth_provider: oauth_provider, callback_port: callback_port, callback_path: callback_path) + end + + context "when standard provider can handle challenge" do + before do + allow(oauth_provider).to receive(:handle_authentication_challenge).and_return(true) + end + + it "delegates to oauth_provider" do + result = browser_oauth.handle_authentication_challenge( + www_authenticate: 'Bearer scope="test"' + ) + + expect(result).to be true + expect(oauth_provider).to have_received(:handle_authentication_challenge) + end + + it "passes all parameters to oauth_provider" do + browser_oauth.handle_authentication_challenge( + www_authenticate: 'Bearer scope="test"', + resource_metadata_url: "https://example.com/meta", + requested_scope: "custom:scope" + ) + + expect(oauth_provider).to have_received(:handle_authentication_challenge).with( + www_authenticate: 'Bearer scope="test"', + resource_metadata_url: "https://example.com/meta", + requested_scope: "custom:scope" + ) + end + end + + context "when interactive auth is required" do + let(:tcp_server) { instance_double(TCPServer) } + let(:client_socket) { instance_double(TCPSocket) } + let(:token) do + RubyLLM::MCP::Auth::Token.new( + access_token: "new_token", + expires_in: 3600 + ) + end + + before do + allow(oauth_provider).to receive(:handle_authentication_challenge) + .and_raise(RubyLLM::MCP::Errors::AuthenticationRequiredError.new(message: "Interactive auth required")) + allow(oauth_provider).to receive(:start_authorization_flow).and_return(auth_url) + allow(TCPServer).to receive(:new).and_return(tcp_server) + allow(tcp_server).to receive(:close) + allow(tcp_server).to receive(:closed?).and_return(false) + allow(tcp_server).to receive(:wait_readable).and_return(true, false) + allow(tcp_server).to receive(:accept).and_return(client_socket) + allow(client_socket).to receive(:setsockopt) + allow(client_socket).to receive(:gets).and_return( + "GET /callback?code=test&state=test HTTP/1.1\r\n", + "\r\n" + ) + allow(client_socket).to receive(:write) + allow(client_socket).to receive(:close) + allow(oauth_provider).to receive(:complete_authorization_flow).and_return(token) + end + + it "falls back to browser-based authentication" do + result = browser_oauth.handle_authentication_challenge + + expect(result).to be true + expect(oauth_provider).to have_received(:start_authorization_flow) + expect(oauth_provider).to have_received(:complete_authorization_flow) + end + + it "logs info about falling back to browser auth" do + browser_oauth.handle_authentication_challenge + + expect(logger).to have_received(:info).with(/starting browser-based OAuth flow/) + end + end + end + + describe "#parse_www_authenticate" do + let(:browser_oauth) do + described_class.new(oauth_provider: oauth_provider) + end + + it "delegates to oauth_provider" do + header = 'Bearer scope="test"' + allow(oauth_provider).to receive(:parse_www_authenticate).and_return({ scope: "test" }) + + result = browser_oauth.parse_www_authenticate(header) + + expect(result).to eq({ scope: "test" }) + expect(oauth_provider).to have_received(:parse_www_authenticate).with(header) + end + end end diff --git a/spec/ruby_llm/mcp/auth/oauth_provider_spec.rb b/spec/ruby_llm/mcp/auth/oauth_provider_spec.rb index c2ace09..eeea736 100644 --- a/spec/ruby_llm/mcp/auth/oauth_provider_spec.rb +++ b/spec/ruby_llm/mcp/auth/oauth_provider_spec.rb @@ -397,4 +397,195 @@ # NOTE: State parameter validation is now tested in SessionManager specs # These tests were testing internal implementation details end + + describe "#handle_authentication_challenge" do + let(:provider) do + described_class.new( + server_url: server_url, + storage: storage, + logger: logger, + grant_type: :authorization_code + ) + end + + context "when token can be refreshed" do + let(:expired_token) do + token = RubyLLM::MCP::Auth::Token.new( + access_token: "expired_token", + refresh_token: "refresh_token_123", + expires_in: 1 + ) + token.instance_variable_set(:@expires_at, Time.now - 3600) + token + end + + let(:new_token) do + RubyLLM::MCP::Auth::Token.new( + access_token: "new_token", + expires_in: 3600 + ) + end + + before do + storage.set_token(server_url, expired_token) + allow(provider).to receive(:refresh_token).with(expired_token).and_return(new_token) + end + + it "refreshes token and returns true" do + result = provider.handle_authentication_challenge + + expect(result).to be true + expect(provider).to have_received(:refresh_token) + end + + it "logs debug information" do + provider.handle_authentication_challenge + + expect(logger).to have_received(:debug).with(/Handling authentication challenge/) + expect(logger).to have_received(:debug).with(/Attempting token refresh/) + end + end + + context "when using client credentials grant" do + let(:provider) do + described_class.new( + server_url: server_url, + storage: storage, + logger: logger, + grant_type: :client_credentials + ) + end + + let(:new_token) do + RubyLLM::MCP::Auth::Token.new( + access_token: "client_creds_token", + expires_in: 3600 + ) + end + + before do + allow(provider).to receive(:client_credentials_flow).and_return(new_token) + end + + it "attempts client credentials flow" do + result = provider.handle_authentication_challenge + + expect(result).to be true + expect(provider).to have_received(:client_credentials_flow) + end + + it "passes requested scope to client credentials flow" do + provider.handle_authentication_challenge(requested_scope: "custom:scope") + + expect(provider).to have_received(:client_credentials_flow).with(scope: "custom:scope") + end + end + + context "when interactive auth is required" do + it "raises AuthenticationRequiredError" do + expect do + provider.handle_authentication_challenge + end.to raise_error(RubyLLM::MCP::Errors::AuthenticationRequiredError, /interactive authorization is needed/) + end + + it "logs warning about interactive auth requirement" do + begin + provider.handle_authentication_challenge + rescue RubyLLM::MCP::Errors::AuthenticationRequiredError + # Expected + end + + expect(logger).to have_received(:warn).with(/Cannot automatically authenticate/) + end + end + + context "with WWW-Authenticate header" do + let(:www_authenticate) { 'Bearer realm="example", scope="mcp:read mcp:write"' } + + it "parses and updates scope" do + expect do + provider.handle_authentication_challenge(www_authenticate: www_authenticate) + end.to raise_error(RubyLLM::MCP::Errors::AuthenticationRequiredError) + + expect(provider.scope).to eq("mcp:read mcp:write") + end + + it "logs WWW-Authenticate header" do + begin + provider.handle_authentication_challenge(www_authenticate: www_authenticate) + rescue RubyLLM::MCP::Errors::AuthenticationRequiredError + # Expected + end + + expect(logger).to have_received(:debug).with(/WWW-Authenticate:/) + end + end + + context "with resource metadata URL" do + let(:metadata_url) { "https://example.com/.well-known/oauth-protected-resource" } + + it "logs resource metadata URL" do + begin + provider.handle_authentication_challenge(resource_metadata_url: metadata_url) + rescue RubyLLM::MCP::Errors::AuthenticationRequiredError + # Expected + end + + expect(logger).to have_received(:debug).with(/Resource metadata URL:/) + end + end + end + + describe "#parse_www_authenticate" do + let(:provider) do + described_class.new( + server_url: server_url, + storage: storage + ) + end + + it "parses scope from header" do + header = 'Bearer realm="example", scope="mcp:read mcp:write"' + result = provider.parse_www_authenticate(header) + + expect(result[:scope]).to eq("mcp:read mcp:write") + end + + it "parses resource_metadata_url from header" do + header = 'Bearer resource_metadata_url="https://example.com/.well-known/oauth"' + result = provider.parse_www_authenticate(header) + + expect(result[:resource_metadata_url]).to eq("https://example.com/.well-known/oauth") + end + + it "parses realm from header" do + header = 'Bearer realm="example.com"' + result = provider.parse_www_authenticate(header) + + expect(result[:realm]).to eq("example.com") + end + + it "parses all parameters together" do + header = 'Bearer realm="example", scope="mcp:read", resource_metadata_url="https://example.com/meta"' + result = provider.parse_www_authenticate(header) + + expect(result[:realm]).to eq("example") + expect(result[:scope]).to eq("mcp:read") + expect(result[:resource_metadata_url]).to eq("https://example.com/meta") + end + + it "returns empty hash for non-Bearer header" do + header = 'Basic realm="example"' + result = provider.parse_www_authenticate(header) + + expect(result).to eq({}) + end + + it "handles case-insensitive Bearer" do + header = 'bearer scope="test"' + result = provider.parse_www_authenticate(header) + + expect(result[:scope]).to eq("test") + end + end end diff --git a/spec/ruby_llm/mcp/native/transports/sse_spec.rb b/spec/ruby_llm/mcp/native/transports/sse_spec.rb index e61c9b0..3f2be12 100644 --- a/spec/ruby_llm/mcp/native/transports/sse_spec.rb +++ b/spec/ruby_llm/mcp/native/transports/sse_spec.rb @@ -219,4 +219,123 @@ def client expect(events).to eq([]) end end + + describe "OAuth integration" do + let(:server_url) { "http://localhost:3000/sse" } + let(:storage) { RubyLLM::MCP::Auth::MemoryStorage.new } + let(:oauth_provider) do + RubyLLM::MCP::Auth::OAuthProvider.new( + server_url: server_url, + storage: storage + ) + end + let(:coordinator) { instance_double(RubyLLM::MCP::Adapters::MCPTransports::CoordinatorStub) } + let(:transport_with_oauth) do + RubyLLM::MCP::Native::Transports::SSE.new( + url: server_url, + coordinator: coordinator, + request_timeout: 5000, + oauth_provider: oauth_provider + ) + end + + it "accepts OAuth provider in initialization" do + expect(transport_with_oauth.instance_variable_get(:@oauth_provider)).to eq(oauth_provider) + end + + it "applies OAuth authorization header to requests" do + token = RubyLLM::MCP::Auth::Token.new( + access_token: "test_token", + expires_in: 3600 + ) + storage.set_token(server_url, token) + + headers = transport_with_oauth.send(:build_request_headers) + + expect(headers["Authorization"]).to eq("Bearer test_token") + end + + it "does not apply OAuth header when no token available" do + headers = transport_with_oauth.send(:build_request_headers) + + expect(headers["Authorization"]).to be_nil + end + + context "with authentication challenges" do + let(:mock_response) { instance_double(HTTPX::Response) } + + before do + allow(mock_response).to receive(:headers).and_return({ + "www-authenticate" => 'Bearer scope="mcp:read"', + "mcp-resource-metadata-url" => "https://example.com/meta" + }) + allow(mock_response).to receive(:status).and_return(401) + end + + it "handles 401 during message POST" do + transport_with_oauth.instance_variable_set(:@messages_url, "http://localhost:3000/messages") + + allow(oauth_provider).to receive(:handle_authentication_challenge).and_raise( + RubyLLM::MCP::Errors::AuthenticationRequiredError.new(message: "Auth required") + ) + + expect do + transport_with_oauth.send(:handle_authentication_challenge, mock_response, {}, 1) + end.to raise_error(RubyLLM::MCP::Errors::AuthenticationRequiredError) + + expect(oauth_provider).to have_received(:handle_authentication_challenge) + end + + it "retries request after successful authentication" do + transport_with_oauth.instance_variable_set(:@messages_url, "http://localhost:3000/messages") + + new_token = RubyLLM::MCP::Auth::Token.new( + access_token: "new_token", + expires_in: 3600 + ) + storage.set_token(server_url, new_token) + + allow(oauth_provider).to receive(:handle_authentication_challenge).and_return(true) + allow(transport_with_oauth).to receive(:send_request).and_call_original + + # Mock the retry to succeed + allow(RubyLLM::MCP::Native::Transports::Support::HTTPClient).to receive(:connection).and_return( + double(with: double(post: double(status: 200, headers: {}))) + ) + + expect do + transport_with_oauth.send(:handle_authentication_challenge, mock_response, { "method" => "test" }, 1) + end.not_to raise_error + end + + it "prevents infinite retry loop" do + transport_with_oauth.instance_variable_set(:@messages_url, "http://localhost:3000/messages") + transport_with_oauth.instance_variable_set(:@auth_retry_attempted, true) + + expect do + transport_with_oauth.send(:handle_authentication_challenge, mock_response, {}, 1) + end.to raise_error(RubyLLM::MCP::Errors::AuthenticationRequiredError, /retry failed/) + end + + it "handles SSE stream 401 authentication" do + allow(oauth_provider).to receive(:handle_authentication_challenge).and_return(true) + + expect do + transport_with_oauth.send(:handle_sse_authentication_challenge, mock_response) + end.not_to raise_error + + expect(oauth_provider).to have_received(:handle_authentication_challenge) + end + + it "raises error when SSE auth fails" do + allow(oauth_provider).to receive(:handle_authentication_challenge).and_raise( + RubyLLM::MCP::Errors::AuthenticationRequiredError.new(message: "Auth failed") + ) + + expect do + transport_with_oauth.send(:handle_sse_authentication_challenge, mock_response) + end.to raise_error(RubyLLM::MCP::Errors::AuthenticationRequiredError) + end + end + end end diff --git a/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb b/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb index 54d80ad..1bedbbf 100644 --- a/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb +++ b/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb @@ -1194,5 +1194,145 @@ end.to raise_error(RubyLLM::MCP::Errors::TransportError, /Plain text error message/) end end + + context "with OAuth challenge and retry" do + let(:oauth_provider) do + RubyLLM::MCP::Auth::OAuthProvider.new( + server_url: server_url, + storage: storage + ) + end + let(:transport_with_oauth) do + described_class.new( + url: server_url, + coordinator: mock_coordinator, + request_timeout: 5000, + options: { oauth_provider: oauth_provider } + ) + end + + before do + # Set up initial token + token = RubyLLM::MCP::Auth::Token.new( + access_token: "initial_token", + refresh_token: "refresh_token_123", + expires_in: 3600 + ) + storage.set_token(server_url, token) + end + + it "handles 401 with WWW-Authenticate and retries request" do + # First request returns 401 + stub_request(:post, server_url) + .with(headers: { "Authorization" => "Bearer initial_token" }) + .to_return( + status: 401, + headers: { + "WWW-Authenticate" => 'Bearer scope="mcp:read mcp:write"', + "mcp-resource-metadata-url" => "https://example.com/.well-known/oauth" + } + ) + + # After refresh, second request succeeds + new_token = RubyLLM::MCP::Auth::Token.new( + access_token: "refreshed_token", + expires_in: 3600 + ) + + # Mock the OAuth provider to update the token and return success + allow(oauth_provider).to receive(:handle_authentication_challenge) do + storage.set_token(server_url, new_token) + true + end + + stub_request(:post, server_url) + .with(headers: { "Authorization" => "Bearer refreshed_token" }) + .to_return( + status: 200, + headers: { "Content-Type" => "application/json" }, + body: '{"result": {"content": [{"type": "text", "text": "success"}]}}' + ) + + result = transport_with_oauth.request({ "method" => "test", "id" => 1 }, wait_for_response: false) + + expect(result).to be_a(RubyLLM::MCP::Result) + expect(oauth_provider).to have_received(:handle_authentication_challenge) + end + + it "prevents infinite retry loop on repeated 401" do + # Both requests return 401 + stub_request(:post, server_url) + .to_return(status: 401) + + allow(oauth_provider).to receive(:handle_authentication_challenge).and_return(true) + + expect do + transport_with_oauth.request({ "method" => "test", "id" => 1 }, wait_for_response: false) + end.to raise_error(RubyLLM::MCP::Errors::AuthenticationRequiredError, /retry failed/) + end + + it "raises error when no OAuth provider configured" do + transport_without_oauth = described_class.new( + url: server_url, + coordinator: mock_coordinator, + request_timeout: 5000, + options: {} + ) + + stub_request(:post, server_url).to_return(status: 401) + + expect do + transport_without_oauth.request({ "method" => "test", "id" => 1 }, wait_for_response: false) + end.to raise_error(RubyLLM::MCP::Errors::AuthenticationRequiredError, /no OAuth provider configured/) + end + + it "extracts and caches resource metadata URL" do + stub_request(:post, server_url) + .to_return( + status: 401, + headers: { "mcp-resource-metadata-url" => "https://example.com/.well-known/oauth" } + ) + + allow(oauth_provider).to receive(:handle_authentication_challenge).and_raise( + RubyLLM::MCP::Errors::AuthenticationRequiredError.new(message: "Auth required") + ) + + begin + transport_with_oauth.request({ "method" => "test", "id" => 1 }, wait_for_response: false) + rescue RubyLLM::MCP::Errors::AuthenticationRequiredError + # Expected + end + + expect(logger).to have_received(:debug).with(/Extracted resource metadata URL/) + end + + it "logs authentication challenge handling" do + stub_request(:post, server_url).to_return(status: 401) + + allow(oauth_provider).to receive(:handle_authentication_challenge).and_raise( + RubyLLM::MCP::Errors::AuthenticationRequiredError.new(message: "Auth required") + ) + + begin + transport_with_oauth.request({ "method" => "test", "id" => 1 }, wait_for_response: false) + rescue RubyLLM::MCP::Errors::AuthenticationRequiredError + # Expected + end + + expect(logger).to have_received(:info).with(/Received 401 Unauthorized, attempting automatic authentication/) + end + + it "handles authentication challenge failure gracefully" do + stub_request(:post, server_url).to_return(status: 401) + + allow(oauth_provider).to receive(:handle_authentication_challenge).and_raise( + StandardError.new("Network error") + ) + + expect do + transport_with_oauth.request({ "method" => "test", "id" => 1 }, wait_for_response: false) + end.to raise_error(RubyLLM::MCP::Errors::AuthenticationRequiredError, /Network error/) + end + end end end From a3738477b0925cf752781937def2dcc6580ba797 Mon Sep 17 00:00:00 2001 From: Patrick Vice Date: Sun, 23 Nov 2025 19:29:31 -0500 Subject: [PATCH 2/2] fixed lint issues --- examples/oauth/browser_oauth.rb | 3 - examples/oauth/custom_storage.rb | 3 - examples/oauth/standard_oauth.rb | 3 - .../mcp/adapters/mcp_transports/sse.rb | 8 +- lib/ruby_llm/mcp/auth/oauth_provider.rb | 11 +- lib/ruby_llm/mcp/native/transport.rb | 16 ++- lib/ruby_llm/mcp/native/transports/sse.rb | 89 +++++++------ .../mcp/native/transports/streamable_http.rb | 123 +++++++++--------- .../mcp/auth/browser_oauth_provider_spec.rb | 7 +- spec/ruby_llm/mcp/native/transport_spec.rb | 7 +- .../mcp/native/transports/sse_spec.rb | 18 +-- 11 files changed, 148 insertions(+), 140 deletions(-) diff --git a/examples/oauth/browser_oauth.rb b/examples/oauth/browser_oauth.rb index 8cc5848..37fc7af 100755 --- a/examples/oauth/browser_oauth.rb +++ b/examples/oauth/browser_oauth.rb @@ -82,6 +82,3 @@ puts "4. Uncomment the authentication code above" puts "5. Run: ruby #{__FILE__}" puts "=" * 60 - - - diff --git a/examples/oauth/custom_storage.rb b/examples/oauth/custom_storage.rb index c286f11..81ed0ff 100755 --- a/examples/oauth/custom_storage.rb +++ b/examples/oauth/custom_storage.rb @@ -275,6 +275,3 @@ def get_token(server_url) EXAMPLES puts "=" * 60 - - - diff --git a/examples/oauth/standard_oauth.rb b/examples/oauth/standard_oauth.rb index 07e9264..9fbeb1d 100755 --- a/examples/oauth/standard_oauth.rb +++ b/examples/oauth/standard_oauth.rb @@ -126,6 +126,3 @@ puts "4. Uncomment the interactive code above" puts "5. Run: ruby #{__FILE__}" puts "=" * 60 - - - diff --git a/lib/ruby_llm/mcp/adapters/mcp_transports/sse.rb b/lib/ruby_llm/mcp/adapters/mcp_transports/sse.rb index e55525b..fa8b4a7 100644 --- a/lib/ruby_llm/mcp/adapters/mcp_transports/sse.rb +++ b/lib/ruby_llm/mcp/adapters/mcp_transports/sse.rb @@ -12,10 +12,12 @@ def initialize(url:, headers: {}, version: :http2, request_timeout: 10_000) @native_transport = RubyLLM::MCP::Native::Transports::SSE.new( url: url, - headers: headers, - version: version, coordinator: @coordinator, - request_timeout: request_timeout + request_timeout: request_timeout, + options: { + headers: headers, + version: version + } ) end diff --git a/lib/ruby_llm/mcp/auth/oauth_provider.rb b/lib/ruby_llm/mcp/auth/oauth_provider.rb index 5b11ac2..35de86d 100644 --- a/lib/ruby_llm/mcp/auth/oauth_provider.rb +++ b/lib/ruby_llm/mcp/auth/oauth_provider.rb @@ -164,16 +164,17 @@ def handle_authentication_challenge(www_authenticate: nil, resource_metadata_url logger.debug(" Requested scope: #{requested_scope}") if requested_scope # Parse WWW-Authenticate header if provided + final_requested_scope = requested_scope if www_authenticate challenge_info = parse_www_authenticate(www_authenticate) - resource_metadata_url ||= challenge_info[:resource_metadata_url] - requested_scope ||= challenge_info[:scope] + final_requested_scope ||= challenge_info[:scope] + # NOTE: resource_metadata_url from challenge_info could be used for future discovery end # Update scope if server requested different scope - if requested_scope && requested_scope != scope - logger.debug("Updating scope from '#{scope}' to '#{requested_scope}'") - self.scope = requested_scope + if final_requested_scope && final_requested_scope != scope + logger.debug("Updating scope from '#{scope}' to '#{final_requested_scope}'") + self.scope = final_requested_scope end # Try to refresh existing token diff --git a/lib/ruby_llm/mcp/native/transport.rb b/lib/ruby_llm/mcp/native/transport.rb index ce0eb05..4c8d09e 100644 --- a/lib/ruby_llm/mcp/native/transport.rb +++ b/lib/ruby_llm/mcp/native/transport.rb @@ -58,9 +58,23 @@ def build_transport transport_config.merge!(options) end + # Handle SSE transport specially - it uses options hash pattern + if transport_type == :sse + url = transport_config.delete(:url) || transport_config.delete("url") + request_timeout = transport_config.delete(:request_timeout) || + transport_config.delete("request_timeout") || + MCP.config.request_timeout + # Everything else goes into options + options_hash = transport_config.dup + transport_config.clear + transport_config[:url] = url + transport_config[:request_timeout] = request_timeout + transport_config[:options] = options_hash + end + # Remove OAuth-specific params from transports that don't support them # This allows other arbitrary params (like timeout) to pass through for testing - unless %i[streamable streamable_http].include?(transport_type) + unless %i[streamable streamable_http sse].include?(transport_type) transport_config.delete(:oauth_provider) transport_config.delete(:oauth) end diff --git a/lib/ruby_llm/mcp/native/transports/sse.rb b/lib/ruby_llm/mcp/native/transports/sse.rb index d2fb5b1..69edd21 100644 --- a/lib/ruby_llm/mcp/native/transports/sse.rb +++ b/lib/ruby_llm/mcp/native/transports/sse.rb @@ -9,16 +9,17 @@ class SSE attr_reader :headers, :id, :coordinator - def initialize(url:, coordinator:, request_timeout:, version: :http2, headers: {}, oauth_provider: nil, options: {}) + def initialize(url:, coordinator:, request_timeout:, options: {}) @event_url = url @messages_url = nil @coordinator = coordinator @request_timeout = request_timeout - @version = version - # Extract oauth_provider from options if present + # Extract options extracted_options = options.dup - oauth_provider = extracted_options.delete(:oauth_provider) || oauth_provider + @version = extracted_options.delete(:version) || :http2 + headers = extracted_options.delete(:headers) || {} + oauth_provider = extracted_options.delete(:oauth_provider) uri = URI.parse(url) @root_url = "#{uri.scheme}://#{uri.host}" @@ -118,7 +119,7 @@ def send_request(body, request_id) case response.status when 200, 202 # Success - return + nil when 401 handle_authentication_challenge(response, body, request_id) else @@ -150,20 +151,8 @@ def build_request_headers end def handle_authentication_challenge(response, original_body, request_id) - # If we've already attempted auth retry, don't try again - if @auth_retry_attempted - RubyLLM::MCP.logger.warn("Authentication retry already attempted, raising error") - @auth_retry_attempted = false - raise Errors::AuthenticationRequiredError.new( - message: "OAuth authentication required (401 Unauthorized) - retry failed" - ) - end - - unless @oauth_provider - raise Errors::AuthenticationRequiredError.new( - message: "OAuth authentication required (401 Unauthorized) but no OAuth provider configured" - ) - end + check_retry_guard! + check_oauth_provider_configured! RubyLLM::MCP.logger.info("Received 401 Unauthorized, attempting automatic authentication") @@ -171,40 +160,56 @@ def handle_authentication_challenge(response, original_body, request_id) resource_metadata_url = response.headers["mcp-resource-metadata-url"] @resource_metadata_url = resource_metadata_url if resource_metadata_url - begin - @auth_retry_attempted = true + attempt_authentication_retry(www_authenticate, resource_metadata_url, original_body, request_id) + end - success = @oauth_provider.handle_authentication_challenge( - www_authenticate: www_authenticate, - resource_metadata_url: resource_metadata_url, - requested_scope: nil - ) + def check_retry_guard! + return unless @auth_retry_attempted - if success - RubyLLM::MCP.logger.info("Authentication challenge handled successfully, retrying request") + RubyLLM::MCP.logger.warn("Authentication retry already attempted, raising error") + @auth_retry_attempted = false + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required (401 Unauthorized) - retry failed" + ) + end - # Retry the original request (flag stays true to prevent loop) - send_request(original_body, request_id) + def check_oauth_provider_configured! + return if @oauth_provider - # Only reset flag after successful retry - @auth_retry_attempted = false - return - end - rescue Errors::AuthenticationRequiredError => e - @auth_retry_attempted = false - raise e - rescue StandardError => e + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required (401 Unauthorized) but no OAuth provider configured" + ) + end + + def attempt_authentication_retry(www_authenticate, resource_metadata_url, original_body, request_id) + @auth_retry_attempted = true + + success = @oauth_provider.handle_authentication_challenge( + www_authenticate: www_authenticate, + resource_metadata_url: resource_metadata_url, + requested_scope: nil + ) + + if success + RubyLLM::MCP.logger.info("Authentication challenge handled successfully, retrying request") + send_request(original_body, request_id) @auth_retry_attempted = false - RubyLLM::MCP.logger.error("Authentication challenge handling failed: #{e.message}") - raise Errors::AuthenticationRequiredError.new( - message: "OAuth authentication failed: #{e.message}" - ) + return end @auth_retry_attempted = false raise Errors::AuthenticationRequiredError.new( message: "OAuth authentication required (401 Unauthorized)" ) + rescue Errors::AuthenticationRequiredError => e + @auth_retry_attempted = false + raise e + rescue StandardError => e + @auth_retry_attempted = false + RubyLLM::MCP.logger.error("Authentication challenge handling failed: #{e.message}") + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication failed: #{e.message}" + ) end def start_sse_listener diff --git a/lib/ruby_llm/mcp/native/transports/streamable_http.rb b/lib/ruby_llm/mcp/native/transports/streamable_http.rb index 03f38b3..f4f0b66 100644 --- a/lib/ruby_llm/mcp/native/transports/streamable_http.rb +++ b/lib/ruby_llm/mcp/native/transports/streamable_http.rb @@ -73,24 +73,24 @@ def initialize( # rubocop:disable Metrics/ParameterLists @protocol_version = nil @session_id = session_id - @resource_metadata_url = nil - @client_id = SecureRandom.uuid - - @reconnection_options = ReconnectionOptions.new(**reconnection) - @oauth_provider = oauth_provider - @rate_limiter = Support::RateLimiter.new(**rate_limit) if rate_limit - - @id_counter = 0 - @id_mutex = Mutex.new - @pending_requests = {} - @pending_mutex = Mutex.new - @running = true - @abort_controller = nil - @sse_thread = nil - @sse_mutex = Mutex.new - - # Track if we've attempted auth flow to prevent infinite loops - @auth_retry_attempted = false + @resource_metadata_url = nil + @client_id = SecureRandom.uuid + + @reconnection_options = ReconnectionOptions.new(**reconnection) + @oauth_provider = oauth_provider + @rate_limiter = Support::RateLimiter.new(**rate_limit) if rate_limit + + @id_counter = 0 + @id_mutex = Mutex.new + @pending_requests = {} + @pending_mutex = Mutex.new + @running = true + @abort_controller = nil + @sse_thread = nil + @sse_mutex = Mutex.new + + # Track if we've attempted auth flow to prevent infinite loops + @auth_retry_attempted = false # Thread-safe collection of all HTTPX clients @clients = [] @@ -489,67 +489,64 @@ def extract_resource_metadata_url(response) end def handle_authentication_challenge(response, request_id, original_message) - # If we've already attempted auth retry, don't try again (prevent infinite loop) - if @auth_retry_attempted - RubyLLM::MCP.logger.warn("Authentication retry already attempted, raising error") - @auth_retry_attempted = false # Reset for next request - raise Errors::AuthenticationRequiredError.new( - message: "OAuth authentication required (401 Unauthorized) - retry failed" - ) - end - - # No OAuth provider configured - can't handle challenge - unless @oauth_provider - raise Errors::AuthenticationRequiredError.new( - message: "OAuth authentication required (401 Unauthorized) but no OAuth provider configured" - ) - end + check_retry_guard! + check_oauth_provider_configured! RubyLLM::MCP.logger.info("Received 401 Unauthorized, attempting automatic authentication") - # Extract challenge information from response www_authenticate = response.headers["www-authenticate"] resource_metadata_url = extract_resource_metadata_url(response) - begin - # Set flag to prevent infinite retry loop - @auth_retry_attempted = true - - # Ask OAuth provider to handle the challenge - success = @oauth_provider.handle_authentication_challenge( - www_authenticate: www_authenticate, - resource_metadata_url: resource_metadata_url&.to_s, - requested_scope: nil - ) + attempt_authentication_retry(www_authenticate, resource_metadata_url, request_id, original_message) + end - if success - RubyLLM::MCP.logger.info("Authentication challenge handled successfully, retrying request") + def check_retry_guard! + return unless @auth_retry_attempted - # Retry the original request with new auth (flag stays true to prevent loop) - result = send_http_request(original_message, request_id, is_initialization: false) + RubyLLM::MCP.logger.warn("Authentication retry already attempted, raising error") + @auth_retry_attempted = false + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required (401 Unauthorized) - retry failed" + ) + end - # Only reset flag after successful retry - @auth_retry_attempted = false - return result - end - rescue Errors::AuthenticationRequiredError => e - # Reset flag and re-raise - @auth_retry_attempted = false - raise e - rescue StandardError => e - # Reset flag and wrap error + def check_oauth_provider_configured! + return if @oauth_provider + + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication required (401 Unauthorized) but no OAuth provider configured" + ) + end + + def attempt_authentication_retry(www_authenticate, resource_metadata_url, request_id, original_message) + @auth_retry_attempted = true + + success = @oauth_provider.handle_authentication_challenge( + www_authenticate: www_authenticate, + resource_metadata_url: resource_metadata_url&.to_s, + requested_scope: nil + ) + + if success + RubyLLM::MCP.logger.info("Authentication challenge handled successfully, retrying request") + result = send_http_request(original_message, request_id, is_initialization: false) @auth_retry_attempted = false - RubyLLM::MCP.logger.error("Authentication challenge handling failed: #{e.message}") - raise Errors::AuthenticationRequiredError.new( - message: "OAuth authentication failed: #{e.message}" - ) + return result end - # If we get here, authentication didn't succeed @auth_retry_attempted = false raise Errors::AuthenticationRequiredError.new( message: "OAuth authentication required (401 Unauthorized)" ) + rescue Errors::AuthenticationRequiredError => e + @auth_retry_attempted = false + raise e + rescue StandardError => e + @auth_retry_attempted = false + RubyLLM::MCP.logger.error("Authentication challenge handling failed: #{e.message}") + raise Errors::AuthenticationRequiredError.new( + message: "OAuth authentication failed: #{e.message}" + ) end def start_sse_stream(options = StartSSEOptions.new) diff --git a/spec/ruby_llm/mcp/auth/browser_oauth_provider_spec.rb b/spec/ruby_llm/mcp/auth/browser_oauth_provider_spec.rb index 51c970a..090fd7d 100644 --- a/spec/ruby_llm/mcp/auth/browser_oauth_provider_spec.rb +++ b/spec/ruby_llm/mcp/auth/browser_oauth_provider_spec.rb @@ -1238,12 +1238,10 @@ before do allow(oauth_provider).to receive(:handle_authentication_challenge) .and_raise(RubyLLM::MCP::Errors::AuthenticationRequiredError.new(message: "Interactive auth required")) - allow(oauth_provider).to receive(:start_authorization_flow).and_return(auth_url) allow(TCPServer).to receive(:new).and_return(tcp_server) allow(tcp_server).to receive(:close) - allow(tcp_server).to receive(:closed?).and_return(false) allow(tcp_server).to receive(:wait_readable).and_return(true, false) - allow(tcp_server).to receive(:accept).and_return(client_socket) + allow(tcp_server).to receive_messages(closed?: false, accept: client_socket) allow(client_socket).to receive(:setsockopt) allow(client_socket).to receive(:gets).and_return( "GET /callback?code=test&state=test HTTP/1.1\r\n", @@ -1251,7 +1249,8 @@ ) allow(client_socket).to receive(:write) allow(client_socket).to receive(:close) - allow(oauth_provider).to receive(:complete_authorization_flow).and_return(token) + allow(oauth_provider).to receive_messages(start_authorization_flow: auth_url, + complete_authorization_flow: token) end it "falls back to browser-based authentication" do diff --git a/spec/ruby_llm/mcp/native/transport_spec.rb b/spec/ruby_llm/mcp/native/transport_spec.rb index 8d622ed..784d291 100644 --- a/spec/ruby_llm/mcp/native/transport_spec.rb +++ b/spec/ruby_llm/mcp/native/transport_spec.rb @@ -146,7 +146,8 @@ describe "#build_transport" do context "with valid transport type" do it "builds SSE transport" do - transport = described_class.new(:sse, coordinator, config: config) + sse_config = { url: "http://localhost:3000/sse", timeout: 30 } + transport = described_class.new(:sse, coordinator, config: sse_config) mock_sse = instance_double(RubyLLM::MCP::Native::Transports::SSE) allow(RubyLLM::MCP::Native::Transports::SSE).to receive(:new).and_return(mock_sse) @@ -154,8 +155,10 @@ expect(protocol).to eq(mock_sse) expect(RubyLLM::MCP::Native::Transports::SSE).to have_received(:new).with( + url: "http://localhost:3000/sse", coordinator: coordinator, - **config + request_timeout: RubyLLM::MCP.config.request_timeout, + options: { timeout: 30 } ) end diff --git a/spec/ruby_llm/mcp/native/transports/sse_spec.rb b/spec/ruby_llm/mcp/native/transports/sse_spec.rb index 3f2be12..02e5e36 100644 --- a/spec/ruby_llm/mcp/native/transports/sse_spec.rb +++ b/spec/ruby_llm/mcp/native/transports/sse_spec.rb @@ -235,7 +235,7 @@ def client url: server_url, coordinator: coordinator, request_timeout: 5000, - oauth_provider: oauth_provider + options: { oauth_provider: oauth_provider } ) end @@ -265,11 +265,10 @@ def client let(:mock_response) { instance_double(HTTPX::Response) } before do - allow(mock_response).to receive(:headers).and_return({ - "www-authenticate" => 'Bearer scope="mcp:read"', - "mcp-resource-metadata-url" => "https://example.com/meta" - }) - allow(mock_response).to receive(:status).and_return(401) + allow(mock_response).to receive_messages(headers: { + "www-authenticate" => 'Bearer scope="mcp:read"', + "mcp-resource-metadata-url" => "https://example.com/meta" + }, status: 401) end it "handles 401 during message POST" do @@ -296,12 +295,9 @@ def client storage.set_token(server_url, new_token) allow(oauth_provider).to receive(:handle_authentication_challenge).and_return(true) - allow(transport_with_oauth).to receive(:send_request).and_call_original - # Mock the retry to succeed - allow(RubyLLM::MCP::Native::Transports::Support::HTTPClient).to receive(:connection).and_return( - double(with: double(post: double(status: 200, headers: {}))) - ) + # Mock the retry to succeed by stubbing send_request + allow(transport_with_oauth).to receive(:send_request).and_return(nil) expect do transport_with_oauth.send(:handle_authentication_challenge, mock_response, { "method" => "test" }, 1)