Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions server/internal/attr/conventions.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ const (
HTTPStatusCodePatternKey = attribute.Key("gram.http.status_code_pattern")
IngressNameKey = attribute.Key("gram.ingress.name")
McpMethodKey = attribute.Key("gram.mcp.method")
McpSessionIDKey = attribute.Key("gram.mcp.session_id")
McpURLKey = attribute.Key("gram.mcp.url")
MetricNameKey = attribute.Key("gram.metric.name")
MimeTypeKey = attribute.Key("mime.type")
Expand Down Expand Up @@ -776,6 +777,9 @@ func SlogMcpURL(v string) slog.Attr { return slog.String(string(McpURLKey),
func McpMethod(v string) attribute.KeyValue { return McpMethodKey.String(v) }
func SlogMcpMethod(v string) slog.Attr { return slog.String(string(McpMethodKey), v) }

func McpSessionID(v string) attribute.KeyValue { return McpSessionIDKey.String(v) }
func SlogMcpSessionID(v string) slog.Attr { return slog.String(string(McpSessionIDKey), v) }

func MimeType(v string) attribute.KeyValue { return MimeTypeKey.String(v) }
func SlogMimeType(v string) slog.Attr { return slog.String(string(MimeTypeKey), v) }

Expand Down
84 changes: 84 additions & 0 deletions server/internal/mcp/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,87 @@ func TestMsgID_UnmarshalJSON_Null(t *testing.T) {
require.Equal(t, int64(0), id.Number)
})
}

func TestRpcError_MarshalJSON_ZeroID(t *testing.T) {
t.Parallel()

t.Run("zero_value_id_serializes_as_null", func(t *testing.T) {
t.Parallel()
rpcErr := &rpcError{
ID: msgID{format: 0, String: "", Number: 0},
Code: internalError,
Message: "something went wrong",
Data: nil,
}
data, err := json.Marshal(rpcErr)
require.NoError(t, err)

var parsed map[string]any
err = json.Unmarshal(data, &parsed)
require.NoError(t, err)
require.Nil(t, parsed["id"], "id must be null when request ID is unknown")
require.Equal(t, "2.0", parsed["jsonrpc"])
require.NotNil(t, parsed["error"])
})

t.Run("zero_int_id_serializes_as_null", func(t *testing.T) {
t.Parallel()
rpcErr := &rpcError{
ID: msgID{format: 1, Number: 0},
Code: parseError,
Message: "parse error",
Data: nil,
}
data, err := json.Marshal(rpcErr)
require.NoError(t, err)

var parsed map[string]any
err = json.Unmarshal(data, &parsed)
require.NoError(t, err)
require.Nil(t, parsed["id"], "id must be null when request ID is zero")
})

t.Run("nonzero_int_id_serializes_normally", func(t *testing.T) {
t.Parallel()
rpcErr := &rpcError{
ID: msgID{format: 1, Number: 42},
Code: methodNotFound,
Message: "not found",
Data: nil,
}
data, err := json.Marshal(rpcErr)
require.NoError(t, err)

var parsed map[string]any
err = json.Unmarshal(data, &parsed)
require.NoError(t, err)
require.Equal(t, float64(42), parsed["id"])
})
}

func TestBatchContainsMethod(t *testing.T) {
t.Parallel()

t.Run("returns_true_when_method_present", func(t *testing.T) {
t.Parallel()
batch := batchedRawRequest{
{Method: "initialize"},
{Method: "tools/list"},
}
require.True(t, batchContainsMethod(batch, "initialize"))
})

t.Run("returns_false_when_method_absent", func(t *testing.T) {
t.Parallel()
batch := batchedRawRequest{
{Method: "tools/list"},
{Method: "ping"},
}
require.False(t, batchContainsMethod(batch, "initialize"))
})

t.Run("returns_false_for_empty_batch", func(t *testing.T) {
t.Parallel()
require.False(t, batchContainsMethod(batchedRawRequest{}, "initialize"))
})
}
102 changes: 95 additions & 7 deletions server/internal/mcp/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ type Service struct {
externalmcpRepo *externalmcp_repo.Queries
deploymentsRepo *deployments_repo.Queries
enc *encryption.Client
sessionStore *mcpSessionStore
}

type oauthTokenInputs struct {
Expand Down Expand Up @@ -182,6 +183,7 @@ func NewService(
sessions: sessions,
chatSessionsManager: chatSessionsManager,
enc: enc,
sessionStore: newMCPSessionStore(cacheImpl),
}
}

Expand All @@ -192,6 +194,8 @@ func Attach(mux goahttp.Muxer, service *Service, metadataService *mcpmetadata.Se
}).ServeHTTP)
o11y.AttachHandler(mux, "GET", "/mcp/{mcpSlug}/install", oops.ErrHandle(service.logger, metadataService.ServeInstallPage).ServeHTTP)
o11y.AttachHandler(mux, "POST", "/mcp/{project}/{toolset}/{environment}", oops.ErrHandle(service.logger, service.ServeAuthenticated).ServeHTTP)
o11y.AttachHandler(mux, "DELETE", "/mcp/{mcpSlug}", oops.ErrHandle(service.logger, service.HandleDeleteSession).ServeHTTP)
o11y.AttachHandler(mux, "DELETE", "/mcp/{project}/{toolset}/{environment}", oops.ErrHandle(service.logger, service.HandleDeleteSession).ServeHTTP)

// OAuth 2.1 Authorization Server Metadata
o11y.AttachHandler(mux, "GET", "/.well-known/oauth-authorization-server/mcp/{mcpSlug}", oops.ErrHandle(service.logger, service.HandleWellKnownOAuthServerMetadata).ServeHTTP)
Expand Down Expand Up @@ -234,6 +238,24 @@ func (s *Service) HandleGetServer(w http.ResponseWriter, r *http.Request, metada
return nil
}

// HandleDeleteSession handles DELETE requests to terminate an MCP session.
// Per the MCP spec, clients send DELETE with Mcp-Session-Id to end a session.
func (s *Service) HandleDeleteSession(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()

sessionID := r.Header.Get("Mcp-Session-Id")
if sessionID == "" {
return oops.E(oops.CodeBadRequest, nil, "Mcp-Session-Id header is required")
}

if err := s.sessionStore.Delete(ctx, sessionID); err != nil {
s.logger.WarnContext(ctx, "failed to delete MCP session", attr.SlogError(err))
}

w.WriteHeader(http.StatusOK)
return nil
}

// handleWellKnownMetadata handles OAuth 2.1 authorization server metadata discovery
func (s *Service) HandleWellKnownOAuthServerMetadata(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
Expand Down Expand Up @@ -592,6 +614,9 @@ func (s *Service) ServePublic(w http.ResponseWriter, r *http.Request) error {
sessionID := parseMcpSessionID(r.Header)
w.Header().Set("Mcp-Session-Id", sessionID)

hasInitialize := batchContainsMethod(batch, "initialize")
s.validateMCPSession(ctx, r.Header, hasInitialize)

// Load header display names for remapping
headerDisplayNames := s.loadHeaderDisplayNames(ctx, toolset.ID)

Expand Down Expand Up @@ -626,6 +651,16 @@ func (s *Service) ServePublic(w http.ResponseWriter, r *http.Request) error {
return NewErrorFromCause(batch[0].ID, err)
}

if hasInitialize {
if createErr := s.sessionStore.Create(ctx, sessionID); createErr != nil {
s.logger.WarnContext(ctx, "failed to create MCP session", attr.SlogError(createErr))
}
} else {
if touchErr := s.sessionStore.Touch(ctx, sessionID); touchErr != nil {
s.logger.WarnContext(ctx, "failed to touch MCP session", attr.SlogError(touchErr))
}
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, writeErr := w.Write(body)
Expand Down Expand Up @@ -746,6 +781,9 @@ func (s *Service) ServeAuthenticated(w http.ResponseWriter, r *http.Request) err
sessionID := parseMcpSessionID(r.Header)
w.Header().Set("Mcp-Session-Id", sessionID)

hasInitialize := batchContainsMethod(batch, "initialize")
s.validateMCPSession(ctx, r.Header, hasInitialize)

toolset, err := s.toolsetsRepo.GetToolset(ctx, toolsets_repo.GetToolsetParams{
Slug: toolsetSlug,
ProjectID: *authCtx.ProjectID,
Expand Down Expand Up @@ -780,6 +818,16 @@ func (s *Service) ServeAuthenticated(w http.ResponseWriter, r *http.Request) err
return NewErrorFromCause(batch[0].ID, err)
}

if hasInitialize {
if createErr := s.sessionStore.Create(ctx, sessionID); createErr != nil {
s.logger.WarnContext(ctx, "failed to create MCP session", attr.SlogError(createErr))
}
} else {
if touchErr := s.sessionStore.Touch(ctx, sessionID); touchErr != nil {
s.logger.WarnContext(ctx, "failed to touch MCP session", attr.SlogError(touchErr))
}
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, writeErr := w.Write(body)
Expand Down Expand Up @@ -811,7 +859,8 @@ func (s *Service) handleBatch(ctx context.Context, payload *mcpInputs, batch bat
result, err := s.handleRequest(ctx, payload, req)
switch {
case result == nil && err == nil:
return nil, nil
// Notifications return nil, nil — skip them in the response per JSON-RPC 2.0.
continue
case err != nil:
bs, merr := json.Marshal(NewErrorFromCause(req.ID, err))
if merr != nil {
Expand All @@ -824,16 +873,21 @@ func (s *Service) handleBatch(ctx context.Context, payload *mcpInputs, batch bat
results = append(results, result)
}

// If no results (notification-only batch), return nil to signal 202 No Content.
if len(results) == 0 {
return nil, nil
}

if len(results) == 1 {
return results[0], nil
} else {
m, err := json.Marshal(results)
if err != nil {
return nil, oops.E(oops.CodeUnexpected, err, "failed to serialize results").Log(ctx, s.logger)
}
}

return m, nil
m, err := json.Marshal(results)
if err != nil {
return nil, oops.E(oops.CodeUnexpected, err, "failed to serialize results").Log(ctx, s.logger)
}

return m, nil
}

// parseMcpEnvVariables: Map potential user provided mcp variables into inputs
Expand Down Expand Up @@ -905,6 +959,12 @@ func (s *Service) handleRequest(ctx context.Context, payload *mcpInputs, req *ra
return handleResourcesList(ctx, s.logger, s.db, payload, req, &s.toolsetCache)
case "resources/read":
return handleResourcesRead(ctx, s.logger, s.db, payload, req, s.toolProxy, s.env, s.billingTracker, s.billingRepository, s.telemetryService)
case "completion/complete":
return handleCompletionComplete(ctx, s.logger, req.ID)
case "logging/setLevel":
return handleLoggingSetLevel(ctx, s.logger, req.ID)
case "resources/subscribe", "resources/unsubscribe":
return handlePing(ctx, s.logger, req.ID) // no-op, return empty result
default:
return nil, &rpcError{
ID: req.ID,
Expand All @@ -923,6 +983,34 @@ func parseMcpSessionID(headers http.Header) string {
return session
}

// batchContainsMethod checks whether any request in a batch uses the given method.
func batchContainsMethod(batch batchedRawRequest, method string) bool {
for _, req := range batch {
if req.Method == method {
return true
}
}
return false
}

// validateMCPSession logs warnings for missing or unknown session IDs on non-initialize requests.
// Per backward-compat policy, we warn but still allow the request to proceed.
func (s *Service) validateMCPSession(ctx context.Context, headers http.Header, hasInitialize bool) {
if hasInitialize {
return
}

headerVal := headers.Get("Mcp-Session-Id")
if headerVal == "" {
s.logger.WarnContext(ctx, "MCP request missing Mcp-Session-Id header")
return
}

if !s.sessionStore.Validate(ctx, headerVal) {
s.logger.WarnContext(ctx, "MCP request with unknown session ID", attr.SlogMcpSessionID(headerVal))
}
}

func (s *Service) authenticateToken(ctx context.Context, token string, toolsetID uuid.UUID, isOAuthCapable bool) (context.Context, error) {
if token == "" {
return ctx, oops.C(oops.CodeUnauthorized)
Expand Down
4 changes: 4 additions & 0 deletions server/internal/mcp/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,12 @@ func (e *rpcError) MarshalJSON() ([]byte, error) {
},
}

// JSON-RPC 2.0 requires "id" to always be present in error responses.
// When the ID is unknown (zero-value), it must be set to null.
if (e.ID.format == 1 && e.ID.Number != 0) || (e.ID.format != 1 && e.ID.String != "") {
payload["id"] = e.ID
} else {
payload["id"] = nil
}

bs, err := json.Marshal(payload)
Expand Down
37 changes: 37 additions & 0 deletions server/internal/mcp/rpc_completion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package mcp

import (
"context"
"encoding/json"
"log/slog"

"github.com/speakeasy-api/gram/server/internal/oops"
)

type completionResult struct {
Completion completionValues `json:"completion"`
}

type completionValues struct {
Values []string `json:"values"`
HasMore bool `json:"hasMore"`
Total int `json:"total"`
}

func handleCompletionComplete(ctx context.Context, logger *slog.Logger, id msgID) (json.RawMessage, error) {
bs, err := json.Marshal(&result[completionResult]{
ID: id,
Result: completionResult{
Completion: completionValues{
Values: []string{},
HasMore: false,
Total: 0,
},
},
})
if err != nil {
return nil, oops.E(oops.CodeUnexpected, err, "failed to serialize completion/complete response").Log(ctx, logger)
}

return bs, nil
}
21 changes: 21 additions & 0 deletions server/internal/mcp/rpc_logging.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package mcp

import (
"context"
"encoding/json"
"log/slog"

"github.com/speakeasy-api/gram/server/internal/oops"
)

func handleLoggingSetLevel(ctx context.Context, logger *slog.Logger, id msgID) (json.RawMessage, error) {
bs, err := json.Marshal(&result[struct{}]{
ID: id,
Result: struct{}{},
})
if err != nil {
return nil, oops.E(oops.CodeUnexpected, err, "failed to serialize logging/setLevel response").Log(ctx, logger)
}

return bs, nil
}
Loading
Loading