Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import software.amazon.smithy.java.client.http.HttpContext;
import software.amazon.smithy.java.client.http.JavaHttpClientTransport;
import software.amazon.smithy.java.context.Context;
import software.amazon.smithy.java.core.serde.document.Document;
import software.amazon.smithy.java.http.api.HttpRequest;
import software.amazon.smithy.java.http.api.HttpResponse;
import software.amazon.smithy.java.io.ByteBufferUtils;
Expand All @@ -24,6 +25,7 @@
import software.amazon.smithy.java.mcp.model.JsonRpcErrorResponse;
import software.amazon.smithy.java.mcp.model.JsonRpcRequest;
import software.amazon.smithy.java.mcp.model.JsonRpcResponse;
import software.amazon.smithy.model.shapes.ShapeType;
import software.amazon.smithy.utils.SmithyUnstableApi;

@SmithyUnstableApi
Expand All @@ -38,6 +40,7 @@ public final class HttpMcpProxy extends McpServerProxy {
private final String name;
private final Signer<HttpRequest, ?> signer;
private final Duration timeout;
private volatile String sessionId;

private HttpMcpProxy(Builder builder) {
this.transport = builder.transport != null ? builder.transport : new JavaHttpClientTransport();
Expand Down Expand Up @@ -106,12 +109,25 @@ public CompletableFuture<JsonRpcResponse> rpc(JsonRpcRequest request) {

String protocolVersionHeader = getProtocolVersion().identifier();

HttpRequest httpRequest = HttpRequest.builder()
HttpRequest.Builder requestBuilder = HttpRequest.builder()
.uri(endpoint)
.method("POST")
.withAddedHeader("Content-Type", "application/json")
.withAddedHeader("Accept", "application/json, text/event-stream")
.withAddedHeader("MCP-Protocol-Version", protocolVersionHeader)
.withAddedHeader("MCP-Protocol-Version", protocolVersionHeader);

// Include session ID if we have one
String currentSessionId = sessionId;
if (currentSessionId != null) {
requestBuilder.withAddedHeader("Mcp-Session-Id", currentSessionId);
LOG.debug("Including session ID in request: method={}, sessionId={}",
request.getMethod(),
currentSessionId);
} else {
LOG.debug("No session ID available for request: method={}", request.getMethod());
}

HttpRequest httpRequest = requestBuilder
.body(DataStream.ofBytes(body, "application/json"))
.build();

Expand All @@ -125,10 +141,32 @@ public CompletableFuture<JsonRpcResponse> rpc(JsonRpcRequest request) {
HttpResponse response = transport.send(context, httpRequest);
LOG.trace("Received HTTP response with status: {}", response.statusCode());

// Extract and store session ID from response only during initialize
if ("initialize".equals(request.getMethod())) {
String responseSessionId = response.headers().firstValue("Mcp-Session-Id");
if (responseSessionId != null) {
sessionId = responseSessionId;
LOG.debug("Stored session ID from initialize response: {}", responseSessionId);
}
}

// "When a client receives HTTP 404 in response to a request containing an Mcp-Session-Id,
// it MUST start a new session by sending a new InitializeRequest without a session ID attached."
if (response.statusCode() == 404 && sessionId != null) {
LOG.debug("Received 404 with active session ID. Clearing session to force restart.");
sessionId = null;
}

if (response.statusCode() < 200 || response.statusCode() >= 300) {
return CompletableFuture.completedFuture(handleErrorResponse(response));
}

// Check if response is SSE
String contentType = response.body().contentType();
if ("text/event-stream".equals(contentType)) {
return CompletableFuture.completedFuture(parseSseResponse(response, request));
}

return CompletableFuture.completedFuture(JsonRpcResponse.builder()
.deserialize(JSON_CODEC.createDeserializer(response.body().asByteBuffer()))
.build());
Expand All @@ -137,6 +175,134 @@ public CompletableFuture<JsonRpcResponse> rpc(JsonRpcRequest request) {
}
}

private JsonRpcResponse parseSseResponse(HttpResponse response, JsonRpcRequest request) {
try {
byte[] bodyBytes = ByteBufferUtils.getBytes(response.body().asByteBuffer());
String sseContent = new String(bodyBytes, StandardCharsets.UTF_8);

JsonRpcResponse finalResponse = null;
Iterable<String> lines = sseContent.lines()::iterator;
StringBuilder dataBuffer = new StringBuilder();

for (String line : lines) {
if (line.startsWith("data:")) {
var value = line.substring(5);
dataBuffer.append(value.startsWith(" ") ? value.substring(1) : value);
} else if (line.trim().isEmpty() && !dataBuffer.isEmpty()) {
// End of an SSE event
String jsonData = dataBuffer.toString().trim();
dataBuffer.setLength(0);

if (jsonData.isEmpty()) {
continue;
}

try {
// Parse JSON once into Document
Document jsonDocument = JSON_CODEC.createDeserializer(jsonData.getBytes(StandardCharsets.UTF_8))
.readDocument();

// Check if it's a notification by checking for top-level "id" field
// Notifications have "method" but no "id", responses have "id"
if (isNotification(jsonDocument)) {
// This is a notification - convert Document to JsonRpcRequest and forward
JsonRpcRequest notification = jsonDocument.asShape(JsonRpcRequest.builder());
LOG.debug("Received notification from SSE stream: method={}", notification.getMethod());
notifyRequest(notification);
} else {
// This is a response - convert Document to JsonRpcResponse
finalResponse = jsonDocument.asShape(JsonRpcResponse.builder());
}
} catch (Exception e) {
LOG.warn("Failed to parse SSE message: {}", jsonData, e);
}
}
}

// Process any remaining data in buffer (in case stream doesn't end with empty line)
if (!dataBuffer.isEmpty()) {
String jsonData = dataBuffer.toString().trim();
if (!jsonData.isEmpty()) {
try {
// Parse JSON once into Document
Document jsonDocument = JSON_CODEC.createDeserializer(jsonData.getBytes(StandardCharsets.UTF_8))
.readDocument();

// Check if it's a notification by checking for top-level "id" field
// Notifications have "method" but no "id", responses have "id"
if (isNotification(jsonDocument)) {
JsonRpcRequest notification = JsonRpcRequest.builder()
.deserialize(jsonDocument.createDeserializer())
.build();
LOG.debug("Received notification from remaining SSE buffer: method={}",
notification.getMethod());
notifyRequest(notification);
} else {
JsonRpcResponse message = JsonRpcResponse.builder()
.deserialize(jsonDocument.createDeserializer())
.build();

if (message.getId() == null) {
notifyRequest(JsonRpcRequest.builder()
.jsonrpc("2.0")
.method("notifications/unknown")
.build());
} else {
finalResponse = message;
}
}
} catch (Exception e) {
LOG.warn("Failed to parse remaining SSE message: {}", jsonData, e);
}
}
}

if (finalResponse == null) {
return JsonRpcResponse.builder()
.jsonrpc("2.0")
.id(request.getId())
.error(JsonRpcErrorResponse.builder()
.code(-32001)
.message("SSE parsing error: No final response found in stream")
.build())
.build();
}

return finalResponse;
} catch (Exception e) {
LOG.error("Error parsing SSE response", e);
return JsonRpcResponse.builder()
.jsonrpc("2.0")
.id(request.getId())
.error(JsonRpcErrorResponse.builder()
.code(-32001)
.message("SSE parsing error: " + e.getMessage())
.build())
.build();
}
}

/**
* Determines if a Document represents a notification (has "method" but no "id")
* rather than a response (has "id").
*
* - Responses have an "id" field at the top level
* - Notifications have a "method" field but no "id" field at the top level
*/
private boolean isNotification(Document doc) {
try {
if (!doc.isType(ShapeType.STRUCTURE) && !doc.isType(ShapeType.MAP)) {
return false;
}

// If it has a "method" field but no "id", it's a notification
return doc.getMember("id") == null && doc.getMember("method") != null;
} catch (Exception e) {
LOG.warn("Failed to determine if notification from Document", e);
return false;
}
}

private JsonRpcResponse handleErrorResponse(HttpResponse response) {
long contentLength = response.body().contentLength();
String errorMessage = "HTTP " + response.statusCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,24 @@ private void writeResponse(JsonRpcResponse response) {
}
}

private void writeNotification(JsonRpcRequest notification) {
synchronized (os) {
try {
LOG.debug("Writing notification to stdout: method={}", notification.getMethod());
os.write(CODEC.serializeToString(notification).getBytes(StandardCharsets.UTF_8));
os.write('\n');
os.flush();
} catch (Exception e) {
LOG.error("Error encoding notification", e);
}
}
}

@Override
public void start() {
// Set up notification writer for proxies
mcpService.setNotificationWriter(this::writeNotification);

// Initialize proxies
mcpService.startProxies();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import software.amazon.smithy.java.core.schema.SerializableStruct;
import software.amazon.smithy.java.core.schema.ShapeBuilder;
import software.amazon.smithy.java.core.serde.document.Document;
import software.amazon.smithy.java.logging.InternalLogger;
import software.amazon.smithy.java.mcp.model.JsonRpcRequest;
import software.amazon.smithy.java.mcp.model.JsonRpcResponse;
import software.amazon.smithy.java.mcp.model.ListToolsResult;
Expand All @@ -24,9 +25,11 @@

public abstract class McpServerProxy {

private static final InternalLogger LOG = InternalLogger.getLogger(McpServerProxy.class);
private static final AtomicInteger ID_GENERATOR = new AtomicInteger(0);

private final AtomicReference<Consumer<JsonRpcResponse>> notificationConsumer = new AtomicReference<>();
private final AtomicReference<Consumer<JsonRpcRequest>> requestNotificationConsumer = new AtomicReference<>();
private final AtomicReference<ProtocolVersion> protocolVersion =
new AtomicReference<>(ProtocolVersion.defaultVersion());

Expand Down Expand Up @@ -69,6 +72,7 @@ public List<PromptInfo> listPrompts() {

public void initialize(
Consumer<JsonRpcResponse> notificationConsumer,
Consumer<JsonRpcRequest> requestNotificationConsumer,
JsonRpcRequest initializeRequest,
ProtocolVersion protocolVersion
) {
Expand All @@ -78,6 +82,7 @@ public void initialize(
throw new RuntimeException("Error during initialization: " + result.getError().getMessage());
}
this.notificationConsumer.set(notificationConsumer);
this.requestNotificationConsumer.set(requestNotificationConsumer);
this.protocolVersion.set(protocolVersion);
}

Expand Down Expand Up @@ -118,5 +123,20 @@ protected void notify(JsonRpcResponse response) {
}
}

/**
* Forwards a notification request by converting it to a response format.
* Notifications have a method field but no id.
*/
protected void notifyRequest(JsonRpcRequest notification) {
var rnc = requestNotificationConsumer.get();
if (rnc != null) {
LOG.debug("Forwarding notification to consumer: method={}", notification.getMethod());
rnc.accept(notification);
} else {
LOG.warn("No request notification consumer set, dropping notification: method={}",
notification.getMethod());
}
}

public abstract String name();
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ public final class McpService {
private final AtomicReference<Boolean> proxiesInitialized = new AtomicReference<>(false);
private final McpMetricsObserver metricsObserver;
private final SchemaIndex schemaIndex;
private Consumer<JsonRpcRequest> notificationWriter;

McpService(
Map<String, Service> services,
Expand Down Expand Up @@ -311,6 +312,40 @@ private JsonRpcResponse handleToolsCall(
}
}

/**
* Sets the notification writer for forwarding notifications from proxies.
*/
public void setNotificationWriter(Consumer<JsonRpcRequest> notificationWriter) {
this.notificationWriter = notificationWriter;
}

/**
* Creates a notification writer for a specific proxy that handles cache invalidation
* for only that proxy's tools.
*/
private Consumer<JsonRpcRequest> createProxyNotificationWriter(
McpServerProxy proxy,
Consumer<JsonRpcRequest> baseNotificationWriter
) {
return notification -> {
// Check if this is a tools/list_changed notification
if ("notifications/tools/list_changed".equals(notification.getMethod())) {
LOG.debug("Received tools/list_changed notification from proxy: {}", proxy.name());
// Remove only this proxy's tools
tools.entrySet().removeIf(entry -> entry.getValue().proxy() == proxy);
// Re-fetch tools from only this proxy
List<ToolInfo> proxyTools = proxy.listTools();
for (var toolInfo : proxyTools) {
tools.put(toolInfo.getName(), new Tool(toolInfo, proxy.name(), proxy));
}
}
// Forward the notification
if (baseNotificationWriter != null) {
baseNotificationWriter.accept(notification);
}
};
}

/**
* Starts proxies without initializing them.
*/
Expand Down Expand Up @@ -343,7 +378,8 @@ public void initializeProxies(Consumer<JsonRpcResponse> responseWriter) {

for (McpServerProxy proxy : proxies.values()) {
if (initRequest != null) {
proxy.initialize(responseWriter, initRequest, protocolVersion);
var proxyNotificationWriter = createProxyNotificationWriter(proxy, notificationWriter);
proxy.initialize(responseWriter, proxyNotificationWriter, initRequest, protocolVersion);
}

List<ToolInfo> proxyTools = proxy.listTools();
Expand Down Expand Up @@ -382,11 +418,12 @@ public void addNewService(String id, Service service) {
tools.putAll(createTools(Map.of(id, service)));
}

/**
* Adds a new proxy and initializes it.
*/
public void addNewProxy(McpServerProxy mcpServerProxy, Consumer<JsonRpcResponse> responseWriter) {
public void addNewProxy(
McpServerProxy mcpServerProxy,
Consumer<JsonRpcResponse> responseWriter
) {
proxies.put(mcpServerProxy.name(), mcpServerProxy);

mcpServerProxy.start();

try {
Expand Down
Loading