diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/HttpMcpProxy.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/HttpMcpProxy.java index 5f065ab8c..5c19500b8 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/HttpMcpProxy.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/HttpMcpProxy.java @@ -14,6 +14,7 @@ import software.amazon.smithy.java.client.http.HttpContext; import software.amazon.smithy.java.client.http.JavaHttpClientTransport; import software.amazon.smithy.java.context.Context; +import software.amazon.smithy.java.core.serde.document.Document; import software.amazon.smithy.java.http.api.HttpRequest; import software.amazon.smithy.java.http.api.HttpResponse; import software.amazon.smithy.java.io.ByteBufferUtils; @@ -24,6 +25,7 @@ import software.amazon.smithy.java.mcp.model.JsonRpcErrorResponse; import software.amazon.smithy.java.mcp.model.JsonRpcRequest; import software.amazon.smithy.java.mcp.model.JsonRpcResponse; +import software.amazon.smithy.model.shapes.ShapeType; import software.amazon.smithy.utils.SmithyUnstableApi; @SmithyUnstableApi @@ -38,6 +40,7 @@ public final class HttpMcpProxy extends McpServerProxy { private final String name; private final Signer signer; private final Duration timeout; + private volatile String sessionId; private HttpMcpProxy(Builder builder) { this.transport = builder.transport != null ? builder.transport : new JavaHttpClientTransport(); @@ -106,12 +109,25 @@ public CompletableFuture rpc(JsonRpcRequest request) { String protocolVersionHeader = getProtocolVersion().identifier(); - HttpRequest httpRequest = HttpRequest.builder() + HttpRequest.Builder requestBuilder = HttpRequest.builder() .uri(endpoint) .method("POST") .withAddedHeader("Content-Type", "application/json") .withAddedHeader("Accept", "application/json, text/event-stream") - .withAddedHeader("MCP-Protocol-Version", protocolVersionHeader) + .withAddedHeader("MCP-Protocol-Version", protocolVersionHeader); + + // Include session ID if we have one + String currentSessionId = sessionId; + if (currentSessionId != null) { + requestBuilder.withAddedHeader("Mcp-Session-Id", currentSessionId); + LOG.debug("Including session ID in request: method={}, sessionId={}", + request.getMethod(), + currentSessionId); + } else { + LOG.debug("No session ID available for request: method={}", request.getMethod()); + } + + HttpRequest httpRequest = requestBuilder .body(DataStream.ofBytes(body, "application/json")) .build(); @@ -125,10 +141,32 @@ public CompletableFuture rpc(JsonRpcRequest request) { HttpResponse response = transport.send(context, httpRequest); LOG.trace("Received HTTP response with status: {}", response.statusCode()); + // Extract and store session ID from response only during initialize + if ("initialize".equals(request.getMethod())) { + String responseSessionId = response.headers().firstValue("Mcp-Session-Id"); + if (responseSessionId != null) { + sessionId = responseSessionId; + LOG.debug("Stored session ID from initialize response: {}", responseSessionId); + } + } + + // "When a client receives HTTP 404 in response to a request containing an Mcp-Session-Id, + // it MUST start a new session by sending a new InitializeRequest without a session ID attached." + if (response.statusCode() == 404 && sessionId != null) { + LOG.debug("Received 404 with active session ID. Clearing session to force restart."); + sessionId = null; + } + if (response.statusCode() < 200 || response.statusCode() >= 300) { return CompletableFuture.completedFuture(handleErrorResponse(response)); } + // Check if response is SSE + String contentType = response.body().contentType(); + if ("text/event-stream".equals(contentType)) { + return CompletableFuture.completedFuture(parseSseResponse(response, request)); + } + return CompletableFuture.completedFuture(JsonRpcResponse.builder() .deserialize(JSON_CODEC.createDeserializer(response.body().asByteBuffer())) .build()); @@ -137,6 +175,134 @@ public CompletableFuture rpc(JsonRpcRequest request) { } } + private JsonRpcResponse parseSseResponse(HttpResponse response, JsonRpcRequest request) { + try { + byte[] bodyBytes = ByteBufferUtils.getBytes(response.body().asByteBuffer()); + String sseContent = new String(bodyBytes, StandardCharsets.UTF_8); + + JsonRpcResponse finalResponse = null; + Iterable lines = sseContent.lines()::iterator; + StringBuilder dataBuffer = new StringBuilder(); + + for (String line : lines) { + if (line.startsWith("data:")) { + var value = line.substring(5); + dataBuffer.append(value.startsWith(" ") ? value.substring(1) : value); + } else if (line.trim().isEmpty() && !dataBuffer.isEmpty()) { + // End of an SSE event + String jsonData = dataBuffer.toString().trim(); + dataBuffer.setLength(0); + + if (jsonData.isEmpty()) { + continue; + } + + try { + // Parse JSON once into Document + Document jsonDocument = JSON_CODEC.createDeserializer(jsonData.getBytes(StandardCharsets.UTF_8)) + .readDocument(); + + // Check if it's a notification by checking for top-level "id" field + // Notifications have "method" but no "id", responses have "id" + if (isNotification(jsonDocument)) { + // This is a notification - convert Document to JsonRpcRequest and forward + JsonRpcRequest notification = jsonDocument.asShape(JsonRpcRequest.builder()); + LOG.debug("Received notification from SSE stream: method={}", notification.getMethod()); + notifyRequest(notification); + } else { + // This is a response - convert Document to JsonRpcResponse + finalResponse = jsonDocument.asShape(JsonRpcResponse.builder()); + } + } catch (Exception e) { + LOG.warn("Failed to parse SSE message: {}", jsonData, e); + } + } + } + + // Process any remaining data in buffer (in case stream doesn't end with empty line) + if (!dataBuffer.isEmpty()) { + String jsonData = dataBuffer.toString().trim(); + if (!jsonData.isEmpty()) { + try { + // Parse JSON once into Document + Document jsonDocument = JSON_CODEC.createDeserializer(jsonData.getBytes(StandardCharsets.UTF_8)) + .readDocument(); + + // Check if it's a notification by checking for top-level "id" field + // Notifications have "method" but no "id", responses have "id" + if (isNotification(jsonDocument)) { + JsonRpcRequest notification = JsonRpcRequest.builder() + .deserialize(jsonDocument.createDeserializer()) + .build(); + LOG.debug("Received notification from remaining SSE buffer: method={}", + notification.getMethod()); + notifyRequest(notification); + } else { + JsonRpcResponse message = JsonRpcResponse.builder() + .deserialize(jsonDocument.createDeserializer()) + .build(); + + if (message.getId() == null) { + notifyRequest(JsonRpcRequest.builder() + .jsonrpc("2.0") + .method("notifications/unknown") + .build()); + } else { + finalResponse = message; + } + } + } catch (Exception e) { + LOG.warn("Failed to parse remaining SSE message: {}", jsonData, e); + } + } + } + + if (finalResponse == null) { + return JsonRpcResponse.builder() + .jsonrpc("2.0") + .id(request.getId()) + .error(JsonRpcErrorResponse.builder() + .code(-32001) + .message("SSE parsing error: No final response found in stream") + .build()) + .build(); + } + + return finalResponse; + } catch (Exception e) { + LOG.error("Error parsing SSE response", e); + return JsonRpcResponse.builder() + .jsonrpc("2.0") + .id(request.getId()) + .error(JsonRpcErrorResponse.builder() + .code(-32001) + .message("SSE parsing error: " + e.getMessage()) + .build()) + .build(); + } + } + + /** + * Determines if a Document represents a notification (has "method" but no "id") + * rather than a response (has "id"). + * + * - Responses have an "id" field at the top level + * - Notifications have a "method" field but no "id" field at the top level + */ + private boolean isNotification(Document doc) { + try { + if (!doc.isType(ShapeType.STRUCTURE) && !doc.isType(ShapeType.MAP)) { + return false; + } + + // If it has a "method" field but no "id", it's a notification + return doc.getMember("id") == null && doc.getMember("method") != null; + } catch (Exception e) { + LOG.warn("Failed to determine if notification from Document", e); + return false; + } + } + private JsonRpcResponse handleErrorResponse(HttpResponse response) { long contentLength = response.body().contentLength(); String errorMessage = "HTTP " + response.statusCode(); diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServer.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServer.java index dfb4e4b5f..1393f27f5 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServer.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServer.java @@ -128,8 +128,24 @@ private void writeResponse(JsonRpcResponse response) { } } + private void writeNotification(JsonRpcRequest notification) { + synchronized (os) { + try { + LOG.debug("Writing notification to stdout: method={}", notification.getMethod()); + os.write(CODEC.serializeToString(notification).getBytes(StandardCharsets.UTF_8)); + os.write('\n'); + os.flush(); + } catch (Exception e) { + LOG.error("Error encoding notification", e); + } + } + } + @Override public void start() { + // Set up notification writer for proxies + mcpService.setNotificationWriter(this::writeNotification); + // Initialize proxies mcpService.startProxies(); diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java index 5997c6d93..1a9c10d90 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpServerProxy.java @@ -16,6 +16,7 @@ import software.amazon.smithy.java.core.schema.SerializableStruct; import software.amazon.smithy.java.core.schema.ShapeBuilder; import software.amazon.smithy.java.core.serde.document.Document; +import software.amazon.smithy.java.logging.InternalLogger; import software.amazon.smithy.java.mcp.model.JsonRpcRequest; import software.amazon.smithy.java.mcp.model.JsonRpcResponse; import software.amazon.smithy.java.mcp.model.ListToolsResult; @@ -24,9 +25,11 @@ public abstract class McpServerProxy { + private static final InternalLogger LOG = InternalLogger.getLogger(McpServerProxy.class); private static final AtomicInteger ID_GENERATOR = new AtomicInteger(0); private final AtomicReference> notificationConsumer = new AtomicReference<>(); + private final AtomicReference> requestNotificationConsumer = new AtomicReference<>(); private final AtomicReference protocolVersion = new AtomicReference<>(ProtocolVersion.defaultVersion()); @@ -69,6 +72,7 @@ public List listPrompts() { public void initialize( Consumer notificationConsumer, + Consumer requestNotificationConsumer, JsonRpcRequest initializeRequest, ProtocolVersion protocolVersion ) { @@ -78,6 +82,7 @@ public void initialize( throw new RuntimeException("Error during initialization: " + result.getError().getMessage()); } this.notificationConsumer.set(notificationConsumer); + this.requestNotificationConsumer.set(requestNotificationConsumer); this.protocolVersion.set(protocolVersion); } @@ -118,5 +123,20 @@ protected void notify(JsonRpcResponse response) { } } + /** + * Forwards a notification request by converting it to a response format. + * Notifications have a method field but no id. + */ + protected void notifyRequest(JsonRpcRequest notification) { + var rnc = requestNotificationConsumer.get(); + if (rnc != null) { + LOG.debug("Forwarding notification to consumer: method={}", notification.getMethod()); + rnc.accept(notification); + } else { + LOG.warn("No request notification consumer set, dropping notification: method={}", + notification.getMethod()); + } + } + public abstract String name(); } diff --git a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java index 176d9d2a1..a184ed0fc 100644 --- a/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java +++ b/mcp/mcp-server/src/main/java/software/amazon/smithy/java/mcp/server/McpService.java @@ -94,6 +94,7 @@ public final class McpService { private final AtomicReference proxiesInitialized = new AtomicReference<>(false); private final McpMetricsObserver metricsObserver; private final SchemaIndex schemaIndex; + private Consumer notificationWriter; McpService( Map services, @@ -311,6 +312,40 @@ private JsonRpcResponse handleToolsCall( } } + /** + * Sets the notification writer for forwarding notifications from proxies. + */ + public void setNotificationWriter(Consumer notificationWriter) { + this.notificationWriter = notificationWriter; + } + + /** + * Creates a notification writer for a specific proxy that handles cache invalidation + * for only that proxy's tools. + */ + private Consumer createProxyNotificationWriter( + McpServerProxy proxy, + Consumer baseNotificationWriter + ) { + return notification -> { + // Check if this is a tools/list_changed notification + if ("notifications/tools/list_changed".equals(notification.getMethod())) { + LOG.debug("Received tools/list_changed notification from proxy: {}", proxy.name()); + // Remove only this proxy's tools + tools.entrySet().removeIf(entry -> entry.getValue().proxy() == proxy); + // Re-fetch tools from only this proxy + List proxyTools = proxy.listTools(); + for (var toolInfo : proxyTools) { + tools.put(toolInfo.getName(), new Tool(toolInfo, proxy.name(), proxy)); + } + } + // Forward the notification + if (baseNotificationWriter != null) { + baseNotificationWriter.accept(notification); + } + }; + } + /** * Starts proxies without initializing them. */ @@ -343,7 +378,8 @@ public void initializeProxies(Consumer responseWriter) { for (McpServerProxy proxy : proxies.values()) { if (initRequest != null) { - proxy.initialize(responseWriter, initRequest, protocolVersion); + var proxyNotificationWriter = createProxyNotificationWriter(proxy, notificationWriter); + proxy.initialize(responseWriter, proxyNotificationWriter, initRequest, protocolVersion); } List proxyTools = proxy.listTools(); @@ -382,11 +418,12 @@ public void addNewService(String id, Service service) { tools.putAll(createTools(Map.of(id, service))); } - /** - * Adds a new proxy and initializes it. - */ - public void addNewProxy(McpServerProxy mcpServerProxy, Consumer responseWriter) { + public void addNewProxy( + McpServerProxy mcpServerProxy, + Consumer responseWriter + ) { proxies.put(mcpServerProxy.name(), mcpServerProxy); + mcpServerProxy.start(); try { diff --git a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/HttpMcpProxyTest.java b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/HttpMcpProxyTest.java index 5f6892142..b242d553f 100644 --- a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/HttpMcpProxyTest.java +++ b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/HttpMcpProxyTest.java @@ -24,6 +24,7 @@ import software.amazon.smithy.java.json.JsonCodec; import software.amazon.smithy.java.mcp.model.JsonRpcRequest; import software.amazon.smithy.java.mcp.model.JsonRpcResponse; +import software.amazon.smithy.model.shapes.ShapeType; class HttpMcpProxyTest { private static final JsonCodec JSON_CODEC = JsonCodec.builder().build(); @@ -187,6 +188,279 @@ void testStartAndShutdown() { }); } + @Test + void testSseStreamingResponse() throws IOException { + // Set up SSE handler + mockServer.removeContext("/mcp"); + mockServer.createContext("/mcp", new SseStreamingHandler()); + + JsonRpcRequest request = JsonRpcRequest.builder() + .method("test/streaming") + .id(Document.of(1)) + .jsonrpc("2.0") + .build(); + + CompletableFuture future = proxy.rpc(request); + JsonRpcResponse response = future.join(); + + assertNotNull(response); + assertEquals("2.0", response.getJsonrpc()); + assertEquals(1, response.getId().asInteger()); + assertEquals("final result", response.getResult().asString()); + } + + @Test + void testSseStreamingWithNotifications() throws IOException { + // Track notifications + final JsonRpcRequest[] capturedNotification = {null}; + + // Set up SSE handler + mockServer.removeContext("/mcp"); + mockServer.createContext("/mcp", new SseStreamingWithNotificationsHandler()); + + // Initialize proxy with notification consumer + JsonRpcRequest initRequest = JsonRpcRequest.builder() + .method("initialize") + .id(Document.of(0)) + .jsonrpc("2.0") + .build(); + + proxy.initialize( + notification -> {}, // Old-style consumer (not used) + notification -> capturedNotification[0] = notification, // Request notification consumer + initRequest, + ProtocolVersion.defaultVersion()); + + JsonRpcRequest request = JsonRpcRequest.builder() + .method("test/streaming") + .id(Document.of(1)) + .jsonrpc("2.0") + .build(); + + CompletableFuture future = proxy.rpc(request); + JsonRpcResponse response = future.join(); + + // Verify final response + assertNotNull(response); + assertEquals("2.0", response.getJsonrpc()); + assertEquals(1, response.getId().asInteger()); + assertEquals("final result", response.getResult().asString()); + + // Verify notification was captured (notifications don't have an id field) + assertNotNull(capturedNotification[0]); + assertNull(capturedNotification[0].getId()); + } + + @Test + void testSseStreamingWithoutFinalResponse() throws IOException { + // Set up SSE handler that doesn't send a final response + mockServer.removeContext("/mcp"); + mockServer.createContext("/mcp", new SseStreamingNoFinalResponseHandler()); + + JsonRpcRequest request = JsonRpcRequest.builder() + .method("test/streaming") + .id(Document.of(1)) + .jsonrpc("2.0") + .build(); + + CompletableFuture future = proxy.rpc(request); + JsonRpcResponse response = future.join(); + + // Should return an error response + assertNotNull(response); + assertNotNull(response.getError()); + assertEquals(-32001, response.getError().getCode()); + assertTrue(response.getError().getMessage().contains("SSE parsing error")); + } + + @Test + void testSseStreamingMalformedJson() throws IOException { + // Set up SSE handler with malformed JSON + mockServer.removeContext("/mcp"); + mockServer.createContext("/mcp", new SseMalformedJsonHandler()); + + JsonRpcRequest request = JsonRpcRequest.builder() + .method("test/streaming") + .id(Document.of(1)) + .jsonrpc("2.0") + .build(); + + CompletableFuture future = proxy.rpc(request); + JsonRpcResponse response = future.join(); + + // Should return an error response + assertNotNull(response); + assertNotNull(response.getError()); + assertEquals(-32001, response.getError().getCode()); + } + + @Test + void testSseStreamingWithMethodInToolResponse() throws IOException { + // This tests the bug fix where tool responses containing "method" in their data + // were incorrectly classified as notifications + mockServer.removeContext("/mcp"); + mockServer.createContext("/mcp", new SseToolResponseWithMethodHandler()); + + JsonRpcRequest request = JsonRpcRequest.builder() + .method("tools/call") + .id(Document.of(1)) + .jsonrpc("2.0") + .build(); + + CompletableFuture future = proxy.rpc(request); + JsonRpcResponse response = future.join(); + + // Should correctly parse as a response, not a notification + assertNotNull(response); + assertEquals("2.0", response.getJsonrpc()); + assertEquals(1, response.getId().asInteger()); + assertNotNull(response.getResult()); + + // The result should contain the tool response with "method" in it + Document result = response.getResult(); + assertTrue(result.isType(ShapeType.STRUCTURE) || result.isType(ShapeType.MAP)); + Document content = result.asStringMap().get("content"); + assertNotNull(content); + assertTrue(content.asString().contains("method")); + } + + @Test + void testSessionIdHandling() throws IOException { + // Set up handler that returns and expects session ID + mockServer.removeContext("/mcp"); + mockServer.createContext("/mcp", new SessionIdHandler()); + + // First request - should be initialize to receive session ID + JsonRpcRequest request1 = JsonRpcRequest.builder() + .method("initialize") + .id(Document.of(1)) + .jsonrpc("2.0") + .params(Document.of(Map.of( + "protocolVersion", + Document.of("2024-11-05"), + "capabilities", + Document.of(Map.of()), + "clientInfo", + Document.of(Map.of( + "name", + Document.of("test-client"), + "version", + Document.of("1.0.0")))))) + .build(); + + CompletableFuture future1 = proxy.rpc(request1); + JsonRpcResponse response1 = future1.join(); + + assertNotNull(response1); + assertEquals("session-created", response1.getResult().asString()); + + // Second request - should include session ID + JsonRpcRequest request2 = JsonRpcRequest.builder() + .method("test/method") + .id(Document.of(2)) + .jsonrpc("2.0") + .build(); + + CompletableFuture future2 = proxy.rpc(request2); + JsonRpcResponse response2 = future2.join(); + + assertNotNull(response2); + assertEquals("session-valid", response2.getResult().asString()); + } + + private static class SseStreamingHandler implements HttpHandler { + @Override + public void handle(HttpExchange exchange) throws IOException { + String sseResponse = "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"final result\"}\n\n"; + + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, sseResponse.getBytes(StandardCharsets.UTF_8).length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(sseResponse.getBytes(StandardCharsets.UTF_8)); + } finally { + exchange.close(); + } + } + } + + private static class SseStreamingWithNotificationsHandler implements HttpHandler { + @Override + public void handle(HttpExchange exchange) throws IOException { + StringBuilder sseResponse = new StringBuilder(); + + // Send a notification first + sseResponse.append( + "data: {\"jsonrpc\":\"2.0\",\"method\":\"notifications/progress\",\"params\":{\"progress\":50}}\n\n"); + + // Then send the final response + sseResponse.append("data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":\"final result\"}\n\n"); + + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + byte[] responseBytes = sseResponse.toString().getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, responseBytes.length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(responseBytes); + } finally { + exchange.close(); + } + } + } + + private static class SseStreamingNoFinalResponseHandler implements HttpHandler { + @Override + public void handle(HttpExchange exchange) throws IOException { + // Only send notifications, no final response + String sseResponse = + "data: {\"jsonrpc\":\"2.0\",\"method\":\"notifications/progress\",\"params\":{\"progress\":100}}\n\n"; + + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, sseResponse.getBytes(StandardCharsets.UTF_8).length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(sseResponse.getBytes(StandardCharsets.UTF_8)); + } finally { + exchange.close(); + } + } + } + + private static class SseMalformedJsonHandler implements HttpHandler { + @Override + public void handle(HttpExchange exchange) throws IOException { + String sseResponse = "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":malformed\n\n"; + + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, sseResponse.getBytes(StandardCharsets.UTF_8).length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(sseResponse.getBytes(StandardCharsets.UTF_8)); + } finally { + exchange.close(); + } + } + } + + private static class SseToolResponseWithMethodHandler implements HttpHandler { + @Override + public void handle(HttpExchange exchange) throws IOException { + // Simulate a tool response that contains "method" in its content + // This should NOT be classified as a notification because it has an "id" field + String sseResponse = + "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"content\":\"The HTTP method used was POST\",\"isError\":false}}\n\n"; + + exchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + exchange.sendResponseHeaders(200, sseResponse.getBytes(StandardCharsets.UTF_8).length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(sseResponse.getBytes(StandardCharsets.UTF_8)); + } finally { + exchange.close(); + } + } + } + private static class MockMcpHandler implements HttpHandler { @Override public void handle(HttpExchange exchange) throws IOException { @@ -224,4 +498,60 @@ public void handle(HttpExchange exchange) throws IOException { } } } + + private static class SessionIdHandler implements HttpHandler { + private static final String SESSION_ID = "test-session-123"; + + @Override + public void handle(HttpExchange exchange) throws IOException { + if (!"POST".equals(exchange.getRequestMethod())) { + exchange.sendResponseHeaders(405, 0); + exchange.close(); + return; + } + + String requestBody = new String(exchange.getRequestBody().readAllBytes(), StandardCharsets.UTF_8); + + try { + JsonRpcRequest request = JsonRpcRequest.builder() + .deserialize(JSON_CODEC.createDeserializer(requestBody.getBytes(StandardCharsets.UTF_8))) + .build(); + + // Check if session ID is present in request + String sessionIdHeader = exchange.getRequestHeaders().getFirst("Mcp-Session-Id"); + String resultMessage; + + if (sessionIdHeader == null) { + // First request - return session ID + resultMessage = "session-created"; + exchange.getResponseHeaders().set("Mcp-Session-Id", SESSION_ID); + } else if (SESSION_ID.equals(sessionIdHeader)) { + // Subsequent request with valid session ID + resultMessage = "session-valid"; + } else { + // Invalid session ID + resultMessage = "session-invalid"; + } + + JsonRpcResponse response = JsonRpcResponse.builder() + .jsonrpc("2.0") + .id(request.getId()) + .result(Document.of(resultMessage)) + .build(); + + String responseBody = JSON_CODEC.serializeToString(response); + + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders(200, responseBody.getBytes(StandardCharsets.UTF_8).length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(responseBody.getBytes(StandardCharsets.UTF_8)); + } + } catch (Exception e) { + exchange.sendResponseHeaders(500, 0); + } finally { + exchange.close(); + } + } + } } diff --git a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java index aab97592a..9f55ad9f3 100644 --- a/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java +++ b/mcp/mcp-server/src/test/java/software/amazon/smithy/java/mcp/server/McpServerTest.java @@ -16,9 +16,12 @@ import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.ArrayList; import java.util.Base64; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -1405,4 +1408,160 @@ void testUnionWithOneOfTraitSchemaAlsoGeneratesOneOf() { var oneOf = shapeProperty.get("oneOf").asList(); assertEquals(2, oneOf.size(), "Document with @oneOf should have 2 oneOf variants"); } + + @Test + void testToolsListChangedNotificationInvalidatesCache() { + var callCounter = new AtomicInteger(0); + var mockProxy = new CacheTestProxy(callCounter); + + var service = McpService.builder() + .name("test") + .proxyList(List.of(mockProxy)) + .build(); + + var notifications = new ArrayList(); + service.setNotificationWriter(notifications::add); + + // Initialize to set up proxies + var initRequest = JsonRpcRequest.builder() + .method("initialize") + .id(Document.of(1)) + .params(Document.of(Map.of("protocolVersion", Document.of("2024-11-05")))) + .jsonrpc("2.0") + .build(); + service.handleRequest(initRequest, r -> {}, ProtocolVersion.defaultVersion()); + + // First tools/list - fetches from proxy + var toolsRequest = JsonRpcRequest.builder() + .method("tools/list") + .id(Document.of(2)) + .params(Document.of(Map.of())) + .jsonrpc("2.0") + .build(); + service.handleRequest(toolsRequest, r -> {}, ProtocolVersion.defaultVersion()); + assertEquals(1, callCounter.get(), "First call should fetch from proxy"); + + // Second tools/list - uses cache + service.handleRequest(toolsRequest, r -> {}, ProtocolVersion.defaultVersion()); + assertEquals(1, callCounter.get(), "Second call should use cache"); + + // Send tools/list_changed notification + var notification = JsonRpcRequest.builder() + .method("notifications/tools/list_changed") + .params(Document.of(Map.of())) + .jsonrpc("2.0") + .build(); + mockProxy.sendNotification(notification); + + // Verify notification was forwarded + assertEquals(1, notifications.size()); + assertEquals("notifications/tools/list_changed", notifications.get(0).getMethod()); + + // Third tools/list - should refresh from proxy + service.handleRequest(toolsRequest, r -> {}, ProtocolVersion.defaultVersion()); + assertEquals(2, callCounter.get(), "Third call should refresh after notification"); + + // Fourth tools/list - uses cache again (counter should NOT increment) + service.handleRequest(toolsRequest, r -> {}, ProtocolVersion.defaultVersion()); + assertEquals(2, callCounter.get(), "Fourth call should use cache (not increment to 3)"); + } + + @Test + void testOtherNotificationsDoNotInvalidateCache() { + var callCounter = new AtomicInteger(0); + var mockProxy = new CacheTestProxy(callCounter); + + var service = McpService.builder() + .name("test") + .proxyList(List.of(mockProxy)) + .build(); + + var notifications = new ArrayList(); + service.setNotificationWriter(notifications::add); + + // Initialize + var initRequest = JsonRpcRequest.builder() + .method("initialize") + .id(Document.of(1)) + .params(Document.of(Map.of("protocolVersion", Document.of("2024-11-05")))) + .jsonrpc("2.0") + .build(); + service.handleRequest(initRequest, r -> {}, ProtocolVersion.defaultVersion()); + + // First tools/list + var toolsRequest = JsonRpcRequest.builder() + .method("tools/list") + .id(Document.of(2)) + .params(Document.of(Map.of())) + .jsonrpc("2.0") + .build(); + service.handleRequest(toolsRequest, r -> {}, ProtocolVersion.defaultVersion()); + assertEquals(1, callCounter.get()); + + // Send different notification + var notification = JsonRpcRequest.builder() + .method("notifications/prompts/list_changed") + .params(Document.of(Map.of())) + .jsonrpc("2.0") + .build(); + mockProxy.sendNotification(notification); + + assertEquals(1, notifications.size()); + assertEquals("notifications/prompts/list_changed", notifications.get(0).getMethod()); + + // Second tools/list - should still use cache + service.handleRequest(toolsRequest, r -> {}, ProtocolVersion.defaultVersion()); + assertEquals(1, callCounter.get(), "Cache should not be invalidated by other notifications"); + } + + private static class CacheTestProxy extends McpServerProxy { + private final AtomicInteger callCounter; + + CacheTestProxy(AtomicInteger callCounter) { + this.callCounter = callCounter; + } + + @Override + public List listTools() { + callCounter.incrementAndGet(); + return List.of( + software.amazon.smithy.java.mcp.model.ToolInfo.builder() + .name("test-tool") + .description("Test") + .inputSchema(software.amazon.smithy.java.mcp.model.JsonObjectSchema.builder().build()) + .build()); + } + + @Override + public List listPrompts() { + return List.of(); + } + + @Override + CompletableFuture rpc(JsonRpcRequest request) { + return CompletableFuture.completedFuture( + JsonRpcResponse.builder() + .id(request.getId()) + .result(Document.of(Map.of())) + .jsonrpc("2.0") + .build()); + } + + @Override + void start() {} + + @Override + CompletableFuture shutdown() { + return CompletableFuture.completedFuture(null); + } + + @Override + public String name() { + return "cache-test"; + } + + void sendNotification(JsonRpcRequest notification) { + notifyRequest(notification); + } + } }