diff --git a/README.md b/README.md index 9775e8fa7..33497b0f8 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ The service includes comprehensive user data collection capabilities for various * [1. Static Tokens from Files (Recommended for Service Credentials)](#1-static-tokens-from-files-recommended-for-service-credentials) * [2. Kubernetes Service Account Tokens (For K8s Deployments)](#2-kubernetes-service-account-tokens-for-k8s-deployments) * [3. Client-Provided Tokens (For Per-User Authentication)](#3-client-provided-tokens-for-per-user-authentication) + * [4. OAuth (For MCP Servers Requiring OAuth)](#4-oauth-for-mcp-servers-requiring-oauth) * [Client-Authenticated MCP Servers Discovery](#client-authenticated-mcp-servers-discovery) * [Combining Authentication Methods](#combining-authentication-methods) * [Authentication Method Comparison](#authentication-method-comparison) @@ -355,7 +356,7 @@ In addition to the basic configuration above, you can configure authentication h #### Configuring MCP Server Authentication -Lightspeed Core Stack supports three methods for authenticating with MCP servers, each suited for different use cases: +Lightspeed Core Stack supports four methods for authenticating with MCP servers, each suited for different use cases: ##### 1. Static Tokens from Files (Recommended for Service Credentials) @@ -392,7 +393,7 @@ mcp_servers: Authorization: "kubernetes" # Uses user's k8s token from request auth ``` -**Note:** Kubernetes token-based MCP authorization only works when Lightspeed Core Stack is configured with Kubernetes authentication (`authentication.k8s`). For any other authentication types, MCP servers configured with `Authorization: "kubernetes"` are removed from the available MCP servers list. +**Note:** Kubernetes token-based MCP authorization only works when Lightspeed Core Stack is configured with Kubernetes authentication (`authentication.module` is `k8s`) or `noop-with-token`. For any other authentication types, MCP servers configured with `Authorization: "kubernetes"` are removed from the available MCP servers list. ##### 3. Client-Provided Tokens (For Per-User Authentication) @@ -420,6 +421,20 @@ curl -X POST "http://localhost:8080/v1/query" \ **Structure**: `MCP-HEADERS: {"": {"": "", ...}, ...}` +##### 4. OAuth (For MCP Servers Requiring OAuth) + +Use the special `"oauth"` keyword when the MCP server requires OAuth and the client will supply a token (e.g. via `MCP-HEADERS` after obtaining it from an OAuth flow): + +```yaml +mcp_servers: + - name: "oauth-protected-service" + url: "https://mcp.example.com" + authorization_headers: + Authorization: "oauth" # Token provided via MCP-HEADERS (from OAuth flow) +``` + +When no token is provided for an OAuth-configured server, the service may respond with **401 Unauthorized** and a **`WWW-Authenticate`** header (probed from the MCP server). Clients can use this to drive an OAuth flow and then retry with the token in `MCP-HEADERS`. + ##### Client-Authenticated MCP Servers Discovery To help clients determine which MCP servers require client-provided tokens, use the **MCP Client Auth Options** endpoint: @@ -481,6 +496,7 @@ mcp_servers: | **Static File** | Service tokens, API keys | File path in config | Global (all users) | `"/var/secrets/token"` | | **Kubernetes** | K8s service accounts | `"kubernetes"` keyword | Per-user (from auth) | `"kubernetes"` | | **Client** | User-specific tokens | `"client"` keyword + HTTP header | Per-request | `"client"` | +| **OAuth** | OAuth-protected MCP servers | `"oauth"` keyword + HTTP header | Per-request (from OAuth flow) | `"oauth"` | ##### Important: Automatic Server Skipping @@ -489,6 +505,7 @@ mcp_servers: **Examples:** - A server with `Authorization: "kubernetes"` will be skipped if the user's request doesn't include a Kubernetes token - A server with `Authorization: "client"` will be skipped if no `MCP-HEADERS` are provided in the request +- A server with `Authorization: "oauth"` and no token in `MCP-HEADERS` may cause the API to return **401 Unauthorized** with a **`WWW-Authenticate`** header (so the client can perform OAuth and retry) - A server with multiple headers will be skipped if **any** required header cannot be resolved Skipped servers are logged as warnings. Check Lightspeed Core logs to see which servers were skipped and why. diff --git a/docs/config.md b/docs/config.md index 7aaba0e39..569d4387d 100644 --- a/docs/config.md +++ b/docs/config.md @@ -371,7 +371,7 @@ Useful resources: | name | string | MCP server name that must be unique | | provider_id | string | MCP provider identification | | url | string | URL of the MCP server | -| authorization_headers | object | Headers to send to the MCP server. The map contains the header name and the path to a file containing the header value (secret). There are 2 special cases: 1. Usage of the kubernetes token in the header. To specify this use a string 'kubernetes' instead of the file path. 2. Usage of the client provided token in the header. To specify this use a string 'client' instead of the file path. | +| authorization_headers | object | Headers to send to the MCP server. The map contains the header name and the path to a file containing the header value (secret). There are 3 special cases: 1. Usage of the kubernetes token in the header — use the string 'kubernetes' instead of the file path. 2. Usage of the client provided token in the header — use the string 'client' instead of the file path. 3. Usage of OAuth token (resolved at request time or 401 with WWW-Authenticate) — use the string 'oauth' instead of the file path. | | timeout | integer | Timeout in seconds for requests to the MCP server. If not specified, the default timeout from Llama Stack will be used. Note: This field is reserved for future use when Llama Stack adds timeout support. | diff --git a/src/app/endpoints/tools.py b/src/app/endpoints/tools.py index 074a30d85..0c51b2cc3 100644 --- a/src/app/endpoints/tools.py +++ b/src/app/endpoints/tools.py @@ -2,9 +2,10 @@ import logging from typing import Annotated, Any +import requests from fastapi import APIRouter, Depends, HTTPException, Request -from llama_stack_client import APIConnectionError, BadRequestError +from llama_stack_client import APIConnectionError, BadRequestError, AuthenticationError from authentication import get_auth_dependency from authentication.interface import AuthTuple @@ -89,6 +90,18 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals except BadRequestError: logger.error("Toolgroup %s is not found", toolgroup.identifier) continue + except AuthenticationError as e: + logger.error("Authentication error: %s", e) + if toolgroup.mcp_endpoint: + resp = requests.get(toolgroup.mcp_endpoint.uri, timeout=10) + cause = f"MCP server at {toolgroup.mcp_endpoint.uri} requires OAuth" + error_response = UnauthorizedResponse(cause=cause) + raise HTTPException( + **error_response.model_dump(), + headers={"WWW-Authenticate": resp.headers["WWW-Authenticate"]}, + ) from e + error_response = UnauthorizedResponse(cause=str(e)) + raise HTTPException(**error_response.model_dump()) from e except APIConnectionError as e: logger.error("Unable to connect to Llama Stack: %s", e) response = ServiceUnavailableResponse( diff --git a/src/constants.py b/src/constants.py index 6b43a2ec2..a04abf597 100644 --- a/src/constants.py +++ b/src/constants.py @@ -125,6 +125,7 @@ # MCP authorization header special values MCP_AUTH_KUBERNETES = "kubernetes" MCP_AUTH_CLIENT = "client" +MCP_AUTH_OAUTH = "oauth" # default RAG tool value DEFAULT_RAG_TOOL = "file_search" diff --git a/src/models/config.py b/src/models/config.py index 2291bb402..07a570759 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -505,11 +505,13 @@ class ModelContextProtocolServer(ConfigurationBase): "Headers to send to the MCP server. " "The map contains the header name and the path to a file containing " "the header value (secret). " - "There are 2 special cases: " + "There are 3 special cases: " "1. Usage of the kubernetes token in the header. " "To specify this use a string 'kubernetes' instead of the file path. " "2. Usage of the client provided token in the header. " "To specify this use a string 'client' instead of the file path." + "3. Usage of the oauth token in the header. " + "To specify this use a string 'oauth' instead of the file path. " ), ) diff --git a/src/utils/mcp_auth_headers.py b/src/utils/mcp_auth_headers.py index 5236df703..31f10a9bd 100644 --- a/src/utils/mcp_auth_headers.py +++ b/src/utils/mcp_auth_headers.py @@ -16,8 +16,10 @@ def resolve_authorization_headers( Parameters: authorization_headers: Map of header names to secret locations or special keywords. - - If value is "kubernetes": leave is unchanged. We substitute it during request. - - If value is "client": leave it unchanged. . We substitute it during request. + - If value is "kubernetes": leave unchanged. We substitute it during request. + - If value is "client": leave unchanged. We substitute it during request. + - If value is "oauth": leave unchanged; if no token is provided, a 401 with + WWW-Authenticate may be forwarded from the MCP server. - Otherwise: Treat as file path and read the secret from that file Returns: @@ -55,6 +57,14 @@ def resolve_authorization_headers( "Header %s will use client-provided token (resolved at request time)", header_name, ) + elif value == constants.MCP_AUTH_OAUTH: + # Special case: Keep oauth keyword; if no token provided, probe endpoint + # and forward 401 WWW-Authenticate for client-driven OAuth flow + resolved[header_name] = constants.MCP_AUTH_OAUTH + logger.debug( + "Header %s will use OAuth token (resolved at request time or 401)", + header_name, + ) else: # Regular case: Read secret from file path secret_path = Path(value).expanduser() diff --git a/src/utils/responses.py b/src/utils/responses.py index 4055eea6d..5a0faa4c9 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -4,6 +4,7 @@ import logging from typing import Any, Optional, cast +import requests from fastapi import HTTPException from llama_stack_api.openai_responses import ( OpenAIResponseObject, @@ -28,6 +29,7 @@ from models.responses import ( InternalServerErrorResponse, ServiceUnavailableResponse, + UnauthorizedResponse, ) from utils.prompts import get_system_prompt, get_topic_summary_system_prompt from utils.query import ( @@ -313,7 +315,7 @@ def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]] ] -def get_mcp_tools( +def get_mcp_tools( # pylint: disable=too-many-return-statements mcp_servers: list[ModelContextProtocolServer], token: str | None = None, mcp_headers: Optional[McpHeaders] = None, @@ -327,6 +329,10 @@ def get_mcp_tools( Returns: List of MCP tool definitions with server details and optional auth headers + + Raises: + HTTPException: 401 with WWW-Authenticate header when an MCP server uses OAuth, + no headers are passed, and the server responds with 401 and WWW-Authenticate. """ def _get_token_value(original: str, header: str) -> str | None: @@ -345,6 +351,14 @@ def _get_token_value(original: str, header: str) -> str | None: if c_headers is None: return None return c_headers.get(header, None) + case constants.MCP_AUTH_OAUTH: + # use oauth token + if mcp_headers is None: + return None + c_headers = mcp_headers.get(mcp_server.name, None) + if c_headers is None: + return None + return c_headers.get(header, None) case _: # use provided return original @@ -372,6 +386,23 @@ def _get_token_value(original: str, header: str) -> str | None: if mcp_server.authorization_headers and len(headers) != len( mcp_server.authorization_headers ): + # If OAuth was required and no headers passed, probe endpoint and forward + # 401 with WWW-Authenticate so the client can perform OAuth + uses_oauth = ( + constants.MCP_AUTH_OAUTH + in mcp_server.resolved_authorization_headers.values() + ) + if uses_oauth and ( + mcp_headers is None or not mcp_headers.get(mcp_server.name) + ): + resp = requests.get(mcp_server.url, timeout=10) + error_response = UnauthorizedResponse( + cause=f"MCP server at {mcp_server.url} requires OAuth authentication", + ) + raise HTTPException( + **error_response.model_dump(), + headers={"WWW-Authenticate": resp.headers["WWW-Authenticate"]}, + ) logger.warning( "Skipping MCP server %s: required %d auth headers but only resolved %d", mcp_server.name, diff --git a/tests/unit/app/endpoints/test_tools.py b/tests/unit/app/endpoints/test_tools.py index 9ec8db3df..d0dbf3124 100644 --- a/tests/unit/app/endpoints/test_tools.py +++ b/tests/unit/app/endpoints/test_tools.py @@ -4,7 +4,7 @@ import pytest from fastapi import HTTPException -from llama_stack_client import APIConnectionError, BadRequestError +from llama_stack_client import APIConnectionError, AuthenticationError, BadRequestError from pytest_mock import MockerFixture, MockType # Import the function directly to bypass decorators @@ -628,3 +628,87 @@ async def test_tools_endpoint_general_exception( # Call the endpointt and expect the exception to propagate (not caught) with pytest.raises(Exception, match="Unexpected error"): await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) # type: ignore + + +@pytest.mark.asyncio +async def test_tools_endpoint_authentication_error_with_mcp_endpoint( + mocker: MockerFixture, + mock_configuration: Configuration, # pylint: disable=redefined-outer-name +) -> None: + """Test tools endpoint raises 401 with WWW-Authenticate when MCP server requires OAuth.""" + app_config = AppConfig() + app_config._configuration = mock_configuration + mocker.patch("app.endpoints.tools.configuration", app_config) + mocker.patch("app.endpoints.tools.authorize", lambda _: lambda func: func) + + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + mock_client = mocker.AsyncMock() + mock_client_holder.return_value.get_client.return_value = mock_client + + mock_toolgroup = mocker.Mock() + mock_toolgroup.identifier = "mcp-tools" + mock_toolgroup.mcp_endpoint = mocker.Mock() + mock_toolgroup.mcp_endpoint.uri = "http://localhost:3000" + mock_client.toolgroups.list.return_value = [mock_toolgroup] + + auth_error = AuthenticationError( + message="MCP server requires OAuth", + response=mocker.Mock(request=None), + body=None, + ) + mock_client.tools.list.side_effect = auth_error + + mock_resp = mocker.Mock() + mock_resp.headers = {"WWW-Authenticate": 'Bearer error="invalid_token"'} + mocker.patch("app.endpoints.tools.requests.get", return_value=mock_resp) + + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + with pytest.raises(HTTPException) as exc_info: + await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + assert exc_info.value.status_code == 401 + assert exc_info.value.headers is not None + assert ( + exc_info.value.headers.get("WWW-Authenticate") == 'Bearer error="invalid_token"' + ) + + +@pytest.mark.asyncio +async def test_tools_endpoint_authentication_error_without_mcp_endpoint( + mocker: MockerFixture, + mock_configuration: Configuration, # pylint: disable=redefined-outer-name +) -> None: + """Test tools endpoint raises 401 without WWW-Authenticate when no mcp_endpoint.""" + app_config = AppConfig() + app_config._configuration = mock_configuration + mocker.patch("app.endpoints.tools.configuration", app_config) + mocker.patch("app.endpoints.tools.authorize", lambda _: lambda func: func) + + mock_client_holder = mocker.patch("app.endpoints.tools.AsyncLlamaStackClientHolder") + mock_client = mocker.AsyncMock() + mock_client_holder.return_value.get_client.return_value = mock_client + + mock_toolgroup = mocker.Mock() + mock_toolgroup.identifier = "mcp-tools" + mock_toolgroup.mcp_endpoint = None + mock_client.toolgroups.list.return_value = [mock_toolgroup] + + auth_error = AuthenticationError( + message="Authentication failed", + response=mocker.Mock(request=None), + body=None, + ) + mock_client.tools.list.side_effect = auth_error + + mock_request = mocker.Mock() + mock_auth = MOCK_AUTH + + with pytest.raises(HTTPException) as exc_info: + await tools.tools_endpoint_handler.__wrapped__(mock_request, mock_auth) + + assert exc_info.value.status_code == 401 + detail = exc_info.value.detail + assert isinstance(detail, dict) + assert detail.get("cause") == "Authentication failed" diff --git a/tests/unit/models/config/test_model_context_protocol_server.py b/tests/unit/models/config/test_model_context_protocol_server.py index 3c8fb3853..8b0263434 100644 --- a/tests/unit/models/config/test_model_context_protocol_server.py +++ b/tests/unit/models/config/test_model_context_protocol_server.py @@ -196,6 +196,17 @@ def test_model_context_protocol_server_client_special_case() -> None: assert mcp.authorization_headers == {"Authorization": "client"} +def test_model_context_protocol_server_oauth_special_case() -> None: + """Test ModelContextProtocolServer with oauth special case.""" + mcp = ModelContextProtocolServer( + name="oauth-server", + url="http://localhost:8080", + authorization_headers={"Authorization": "oauth"}, + ) + assert mcp is not None + assert mcp.authorization_headers == {"Authorization": "oauth"} + + def test_configuration_mcp_servers_with_mixed_auth_headers(tmp_path: Path) -> None: """ Test Configuration with MCP servers having mixed authorization headers. diff --git a/tests/unit/utils/test_mcp_auth_headers.py b/tests/unit/utils/test_mcp_auth_headers.py index fe70fe19f..307288689 100644 --- a/tests/unit/utils/test_mcp_auth_headers.py +++ b/tests/unit/utils/test_mcp_auth_headers.py @@ -64,6 +64,15 @@ def test_resolve_authorization_headers_kubernetes_token() -> None: assert result == {"Authorization": "kubernetes"} +def test_resolve_authorization_headers_oauth_token() -> None: + """Test that oauth keyword is preserved.""" + headers = {"Authorization": "oauth"} + result = resolve_authorization_headers(headers) + + # Should keep "oauth" as-is for later substitution or 401 WWW-Authenticate flow + assert result == {"Authorization": "oauth"} + + def test_resolve_authorization_headers_multiple_headers(tmp_path: Path) -> None: """Test resolving multiple authorization headers.""" # Create multiple secret files diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index a17dca05c..c91c2c93c 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -510,6 +510,32 @@ def test_get_mcp_tools_includes_server_without_auth(self) -> None: assert tools[0]["server_label"] == "public-server" assert "headers" not in tools[0] + def test_get_mcp_tools_oauth_no_headers_raises_401_with_www_authenticate( + self, mocker: MockerFixture + ) -> None: + """Test get_mcp_tools raises 401 with WWW-Authenticate when OAuth required and no headers.""" + servers = [ + ModelContextProtocolServer( + name="oauth-server", + url="http://localhost:3000", + authorization_headers={"Authorization": "oauth"}, + ), + ] + + mock_resp = mocker.Mock() + mock_resp.headers = {"WWW-Authenticate": 'Bearer error="invalid_token"'} + mocker.patch("utils.responses.requests.get", return_value=mock_resp) + + with pytest.raises(HTTPException) as exc_info: + get_mcp_tools(servers, token=None, mcp_headers=None) + + assert exc_info.value.status_code == 401 + assert exc_info.value.headers is not None + assert ( + exc_info.value.headers.get("WWW-Authenticate") + == 'Bearer error="invalid_token"' + ) + class TestGetTopicSummary: """Tests for get_topic_summary function."""