diff --git a/server/internal/attr/conventions.go b/server/internal/attr/conventions.go index 0d2ae347e..749f2209f 100644 --- a/server/internal/attr/conventions.go +++ b/server/internal/attr/conventions.go @@ -127,6 +127,7 @@ const ( HTTPStatusCodePatternKey = attribute.Key("gram.http.status_code_pattern") IngressNameKey = attribute.Key("gram.ingress.name") McpMethodKey = attribute.Key("gram.mcp.method") + McpSessionIDKey = attribute.Key("gram.mcp.session_id") McpURLKey = attribute.Key("gram.mcp.url") MetricNameKey = attribute.Key("gram.metric.name") MimeTypeKey = attribute.Key("mime.type") @@ -776,6 +777,9 @@ func SlogMcpURL(v string) slog.Attr { return slog.String(string(McpURLKey), func McpMethod(v string) attribute.KeyValue { return McpMethodKey.String(v) } func SlogMcpMethod(v string) slog.Attr { return slog.String(string(McpMethodKey), v) } +func McpSessionID(v string) attribute.KeyValue { return McpSessionIDKey.String(v) } +func SlogMcpSessionID(v string) slog.Attr { return slog.String(string(McpSessionIDKey), v) } + func MimeType(v string) attribute.KeyValue { return MimeTypeKey.String(v) } func SlogMimeType(v string) slog.Attr { return slog.String(string(MimeTypeKey), v) } diff --git a/server/internal/mcp/helpers_test.go b/server/internal/mcp/helpers_test.go index af6833486..798dbb037 100644 --- a/server/internal/mcp/helpers_test.go +++ b/server/internal/mcp/helpers_test.go @@ -434,3 +434,87 @@ func TestMsgID_UnmarshalJSON_Null(t *testing.T) { require.Equal(t, int64(0), id.Number) }) } + +func TestRpcError_MarshalJSON_ZeroID(t *testing.T) { + t.Parallel() + + t.Run("zero_value_id_serializes_as_null", func(t *testing.T) { + t.Parallel() + rpcErr := &rpcError{ + ID: msgID{format: 0, String: "", Number: 0}, + Code: internalError, + Message: "something went wrong", + Data: nil, + } + data, err := json.Marshal(rpcErr) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + require.Nil(t, parsed["id"], "id must be null when request ID is unknown") + require.Equal(t, "2.0", parsed["jsonrpc"]) + require.NotNil(t, parsed["error"]) + }) + + t.Run("zero_int_id_serializes_as_null", func(t *testing.T) { + t.Parallel() + rpcErr := &rpcError{ + ID: msgID{format: 1, Number: 0}, + Code: parseError, + Message: "parse error", + Data: nil, + } + data, err := json.Marshal(rpcErr) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + require.Nil(t, parsed["id"], "id must be null when request ID is zero") + }) + + t.Run("nonzero_int_id_serializes_normally", func(t *testing.T) { + t.Parallel() + rpcErr := &rpcError{ + ID: msgID{format: 1, Number: 42}, + Code: methodNotFound, + Message: "not found", + Data: nil, + } + data, err := json.Marshal(rpcErr) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + require.Equal(t, float64(42), parsed["id"]) + }) +} + +func TestBatchContainsMethod(t *testing.T) { + t.Parallel() + + t.Run("returns_true_when_method_present", func(t *testing.T) { + t.Parallel() + batch := batchedRawRequest{ + {Method: "initialize"}, + {Method: "tools/list"}, + } + require.True(t, batchContainsMethod(batch, "initialize")) + }) + + t.Run("returns_false_when_method_absent", func(t *testing.T) { + t.Parallel() + batch := batchedRawRequest{ + {Method: "tools/list"}, + {Method: "ping"}, + } + require.False(t, batchContainsMethod(batch, "initialize")) + }) + + t.Run("returns_false_for_empty_batch", func(t *testing.T) { + t.Parallel() + require.False(t, batchContainsMethod(batchedRawRequest{}, "initialize")) + }) +} diff --git a/server/internal/mcp/impl.go b/server/internal/mcp/impl.go index 07345fa56..e1ebe0dd4 100644 --- a/server/internal/mcp/impl.go +++ b/server/internal/mcp/impl.go @@ -90,6 +90,7 @@ type Service struct { externalmcpRepo *externalmcp_repo.Queries deploymentsRepo *deployments_repo.Queries enc *encryption.Client + sessionStore *mcpSessionStore } type oauthTokenInputs struct { @@ -182,6 +183,7 @@ func NewService( sessions: sessions, chatSessionsManager: chatSessionsManager, enc: enc, + sessionStore: newMCPSessionStore(cacheImpl), } } @@ -192,6 +194,8 @@ func Attach(mux goahttp.Muxer, service *Service, metadataService *mcpmetadata.Se }).ServeHTTP) o11y.AttachHandler(mux, "GET", "/mcp/{mcpSlug}/install", oops.ErrHandle(service.logger, metadataService.ServeInstallPage).ServeHTTP) o11y.AttachHandler(mux, "POST", "/mcp/{project}/{toolset}/{environment}", oops.ErrHandle(service.logger, service.ServeAuthenticated).ServeHTTP) + o11y.AttachHandler(mux, "DELETE", "/mcp/{mcpSlug}", oops.ErrHandle(service.logger, service.HandleDeleteSession).ServeHTTP) + o11y.AttachHandler(mux, "DELETE", "/mcp/{project}/{toolset}/{environment}", oops.ErrHandle(service.logger, service.HandleDeleteSession).ServeHTTP) // OAuth 2.1 Authorization Server Metadata o11y.AttachHandler(mux, "GET", "/.well-known/oauth-authorization-server/mcp/{mcpSlug}", oops.ErrHandle(service.logger, service.HandleWellKnownOAuthServerMetadata).ServeHTTP) @@ -234,6 +238,24 @@ func (s *Service) HandleGetServer(w http.ResponseWriter, r *http.Request, metada return nil } +// HandleDeleteSession handles DELETE requests to terminate an MCP session. +// Per the MCP spec, clients send DELETE with Mcp-Session-Id to end a session. +func (s *Service) HandleDeleteSession(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + sessionID := r.Header.Get("Mcp-Session-Id") + if sessionID == "" { + return oops.E(oops.CodeBadRequest, nil, "Mcp-Session-Id header is required") + } + + if err := s.sessionStore.Delete(ctx, sessionID); err != nil { + s.logger.WarnContext(ctx, "failed to delete MCP session", attr.SlogError(err)) + } + + w.WriteHeader(http.StatusOK) + return nil +} + // handleWellKnownMetadata handles OAuth 2.1 authorization server metadata discovery func (s *Service) HandleWellKnownOAuthServerMetadata(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() @@ -592,6 +614,9 @@ func (s *Service) ServePublic(w http.ResponseWriter, r *http.Request) error { sessionID := parseMcpSessionID(r.Header) w.Header().Set("Mcp-Session-Id", sessionID) + hasInitialize := batchContainsMethod(batch, "initialize") + s.validateMCPSession(ctx, r.Header, hasInitialize) + // Load header display names for remapping headerDisplayNames := s.loadHeaderDisplayNames(ctx, toolset.ID) @@ -626,6 +651,16 @@ func (s *Service) ServePublic(w http.ResponseWriter, r *http.Request) error { return NewErrorFromCause(batch[0].ID, err) } + if hasInitialize { + if createErr := s.sessionStore.Create(ctx, sessionID); createErr != nil { + s.logger.WarnContext(ctx, "failed to create MCP session", attr.SlogError(createErr)) + } + } else { + if touchErr := s.sessionStore.Touch(ctx, sessionID); touchErr != nil { + s.logger.WarnContext(ctx, "failed to touch MCP session", attr.SlogError(touchErr)) + } + } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, writeErr := w.Write(body) @@ -746,6 +781,9 @@ func (s *Service) ServeAuthenticated(w http.ResponseWriter, r *http.Request) err sessionID := parseMcpSessionID(r.Header) w.Header().Set("Mcp-Session-Id", sessionID) + hasInitialize := batchContainsMethod(batch, "initialize") + s.validateMCPSession(ctx, r.Header, hasInitialize) + toolset, err := s.toolsetsRepo.GetToolset(ctx, toolsets_repo.GetToolsetParams{ Slug: toolsetSlug, ProjectID: *authCtx.ProjectID, @@ -780,6 +818,16 @@ func (s *Service) ServeAuthenticated(w http.ResponseWriter, r *http.Request) err return NewErrorFromCause(batch[0].ID, err) } + if hasInitialize { + if createErr := s.sessionStore.Create(ctx, sessionID); createErr != nil { + s.logger.WarnContext(ctx, "failed to create MCP session", attr.SlogError(createErr)) + } + } else { + if touchErr := s.sessionStore.Touch(ctx, sessionID); touchErr != nil { + s.logger.WarnContext(ctx, "failed to touch MCP session", attr.SlogError(touchErr)) + } + } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, writeErr := w.Write(body) @@ -811,7 +859,8 @@ func (s *Service) handleBatch(ctx context.Context, payload *mcpInputs, batch bat result, err := s.handleRequest(ctx, payload, req) switch { case result == nil && err == nil: - return nil, nil + // Notifications return nil, nil — skip them in the response per JSON-RPC 2.0. + continue case err != nil: bs, merr := json.Marshal(NewErrorFromCause(req.ID, err)) if merr != nil { @@ -824,16 +873,21 @@ func (s *Service) handleBatch(ctx context.Context, payload *mcpInputs, batch bat results = append(results, result) } + // If no results (notification-only batch), return nil to signal 202 No Content. + if len(results) == 0 { + return nil, nil + } + if len(results) == 1 { return results[0], nil - } else { - m, err := json.Marshal(results) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "failed to serialize results").Log(ctx, s.logger) - } + } - return m, nil + m, err := json.Marshal(results) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "failed to serialize results").Log(ctx, s.logger) } + + return m, nil } // parseMcpEnvVariables: Map potential user provided mcp variables into inputs @@ -905,6 +959,12 @@ func (s *Service) handleRequest(ctx context.Context, payload *mcpInputs, req *ra return handleResourcesList(ctx, s.logger, s.db, payload, req, &s.toolsetCache) case "resources/read": return handleResourcesRead(ctx, s.logger, s.db, payload, req, s.toolProxy, s.env, s.billingTracker, s.billingRepository, s.telemetryService) + case "completion/complete": + return handleCompletionComplete(ctx, s.logger, req.ID) + case "logging/setLevel": + return handleLoggingSetLevel(ctx, s.logger, req.ID) + case "resources/subscribe", "resources/unsubscribe": + return handlePing(ctx, s.logger, req.ID) // no-op, return empty result default: return nil, &rpcError{ ID: req.ID, @@ -923,6 +983,34 @@ func parseMcpSessionID(headers http.Header) string { return session } +// batchContainsMethod checks whether any request in a batch uses the given method. +func batchContainsMethod(batch batchedRawRequest, method string) bool { + for _, req := range batch { + if req.Method == method { + return true + } + } + return false +} + +// validateMCPSession logs warnings for missing or unknown session IDs on non-initialize requests. +// Per backward-compat policy, we warn but still allow the request to proceed. +func (s *Service) validateMCPSession(ctx context.Context, headers http.Header, hasInitialize bool) { + if hasInitialize { + return + } + + headerVal := headers.Get("Mcp-Session-Id") + if headerVal == "" { + s.logger.WarnContext(ctx, "MCP request missing Mcp-Session-Id header") + return + } + + if !s.sessionStore.Validate(ctx, headerVal) { + s.logger.WarnContext(ctx, "MCP request with unknown session ID", attr.SlogMcpSessionID(headerVal)) + } +} + func (s *Service) authenticateToken(ctx context.Context, token string, toolsetID uuid.UUID, isOAuthCapable bool) (context.Context, error) { if token == "" { return ctx, oops.C(oops.CodeUnauthorized) diff --git a/server/internal/mcp/rpc.go b/server/internal/mcp/rpc.go index 3a3725266..b4764b47e 100644 --- a/server/internal/mcp/rpc.go +++ b/server/internal/mcp/rpc.go @@ -237,8 +237,12 @@ func (e *rpcError) MarshalJSON() ([]byte, error) { }, } + // JSON-RPC 2.0 requires "id" to always be present in error responses. + // When the ID is unknown (zero-value), it must be set to null. if (e.ID.format == 1 && e.ID.Number != 0) || (e.ID.format != 1 && e.ID.String != "") { payload["id"] = e.ID + } else { + payload["id"] = nil } bs, err := json.Marshal(payload) diff --git a/server/internal/mcp/rpc_completion.go b/server/internal/mcp/rpc_completion.go new file mode 100644 index 000000000..13fe05add --- /dev/null +++ b/server/internal/mcp/rpc_completion.go @@ -0,0 +1,37 @@ +package mcp + +import ( + "context" + "encoding/json" + "log/slog" + + "github.com/speakeasy-api/gram/server/internal/oops" +) + +type completionResult struct { + Completion completionValues `json:"completion"` +} + +type completionValues struct { + Values []string `json:"values"` + HasMore bool `json:"hasMore"` + Total int `json:"total"` +} + +func handleCompletionComplete(ctx context.Context, logger *slog.Logger, id msgID) (json.RawMessage, error) { + bs, err := json.Marshal(&result[completionResult]{ + ID: id, + Result: completionResult{ + Completion: completionValues{ + Values: []string{}, + HasMore: false, + Total: 0, + }, + }, + }) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "failed to serialize completion/complete response").Log(ctx, logger) + } + + return bs, nil +} diff --git a/server/internal/mcp/rpc_logging.go b/server/internal/mcp/rpc_logging.go new file mode 100644 index 000000000..31a3d11b8 --- /dev/null +++ b/server/internal/mcp/rpc_logging.go @@ -0,0 +1,21 @@ +package mcp + +import ( + "context" + "encoding/json" + "log/slog" + + "github.com/speakeasy-api/gram/server/internal/oops" +) + +func handleLoggingSetLevel(ctx context.Context, logger *slog.Logger, id msgID) (json.RawMessage, error) { + bs, err := json.Marshal(&result[struct{}]{ + ID: id, + Result: struct{}{}, + }) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "failed to serialize logging/setLevel response").Log(ctx, logger) + } + + return bs, nil +} diff --git a/server/internal/mcp/serveauthenticated_test.go b/server/internal/mcp/serveauthenticated_test.go index c62e99ff7..c9f650d19 100644 --- a/server/internal/mcp/serveauthenticated_test.go +++ b/server/internal/mcp/serveauthenticated_test.go @@ -1086,6 +1086,287 @@ func TestService_ServeAuthenticated(t *testing.T) { require.NotNil(t, response["error"]) }) + t.Run("handles completion/complete request", func(t *testing.T) { + t.Parallel() + + ctx, ti := newTestMCPService(t) + toolsetsRepo := toolsets_repo.New(ti.conn) + + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + require.NotNil(t, authCtx.ProjectID) + require.NotNil(t, authCtx.ProjectSlug) + + toolset, err := toolsetsRepo.CreateToolset(ctx, toolsets_repo.CreateToolsetParams{ + OrganizationID: authCtx.ActiveOrganizationID, + ProjectID: *authCtx.ProjectID, + Name: "Completion Test MCP", + Slug: "completion-test-mcp", + Description: conv.ToPGText("A test MCP for completion/complete"), + DefaultEnvironmentSlug: pgtype.Text{String: "production", Valid: true}, + McpSlug: conv.ToPGText("completion-test-mcp"), + McpEnabled: true, + }) + require.NoError(t, err) + + apiKey := ti.createTestAPIKey(ctx, t) + + reqBody := []map[string]any{ + { + "jsonrpc": "2.0", + "id": 1, + "method": "completion/complete", + "params": map[string]any{ + "ref": map[string]any{"type": "ref/prompt", "name": "test"}, + "argument": map[string]any{"name": "arg", "value": "val"}, + }, + }, + } + bodyBytes, err := json.Marshal(reqBody) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/mcp/"+*authCtx.ProjectSlug+"/"+toolset.Slug+"/production", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("project", *authCtx.ProjectSlug) + rctx.URLParams.Add("toolset", toolset.Slug) + rctx.URLParams.Add("environment", "production") + req = req.WithContext(context.WithValue(t.Context(), chi.RouteCtxKey, rctx)) + + w := httptest.NewRecorder() + err = ti.service.ServeAuthenticated(w, req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, w.Code) + + var response map[string]any + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + require.Equal(t, "2.0", response["jsonrpc"]) + + result, ok := response["result"].(map[string]any) + require.True(t, ok) + completion, ok := result["completion"].(map[string]any) + require.True(t, ok) + require.NotNil(t, completion["values"]) + require.Equal(t, false, completion["hasMore"]) + }) + + t.Run("handles logging/setLevel request", func(t *testing.T) { + t.Parallel() + + ctx, ti := newTestMCPService(t) + toolsetsRepo := toolsets_repo.New(ti.conn) + + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + require.NotNil(t, authCtx.ProjectID) + require.NotNil(t, authCtx.ProjectSlug) + + toolset, err := toolsetsRepo.CreateToolset(ctx, toolsets_repo.CreateToolsetParams{ + OrganizationID: authCtx.ActiveOrganizationID, + ProjectID: *authCtx.ProjectID, + Name: "Logging Test MCP", + Slug: "logging-test-mcp", + Description: conv.ToPGText("A test MCP for logging/setLevel"), + DefaultEnvironmentSlug: pgtype.Text{String: "production", Valid: true}, + McpSlug: conv.ToPGText("logging-test-mcp"), + McpEnabled: true, + }) + require.NoError(t, err) + + apiKey := ti.createTestAPIKey(ctx, t) + + reqBody := []map[string]any{ + { + "jsonrpc": "2.0", + "id": 1, + "method": "logging/setLevel", + "params": map[string]any{"level": "debug"}, + }, + } + bodyBytes, err := json.Marshal(reqBody) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/mcp/"+*authCtx.ProjectSlug+"/"+toolset.Slug+"/production", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("project", *authCtx.ProjectSlug) + rctx.URLParams.Add("toolset", toolset.Slug) + rctx.URLParams.Add("environment", "production") + req = req.WithContext(context.WithValue(t.Context(), chi.RouteCtxKey, rctx)) + + w := httptest.NewRecorder() + err = ti.service.ServeAuthenticated(w, req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, w.Code) + + var response map[string]any + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + require.Equal(t, "2.0", response["jsonrpc"]) + require.NotNil(t, response["id"]) + }) + + t.Run("handles resources/subscribe request", func(t *testing.T) { + t.Parallel() + + ctx, ti := newTestMCPService(t) + toolsetsRepo := toolsets_repo.New(ti.conn) + + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + require.NotNil(t, authCtx.ProjectID) + require.NotNil(t, authCtx.ProjectSlug) + + toolset, err := toolsetsRepo.CreateToolset(ctx, toolsets_repo.CreateToolsetParams{ + OrganizationID: authCtx.ActiveOrganizationID, + ProjectID: *authCtx.ProjectID, + Name: "Subscribe Test MCP", + Slug: "subscribe-test-mcp", + Description: conv.ToPGText("A test MCP for resources/subscribe"), + DefaultEnvironmentSlug: pgtype.Text{String: "production", Valid: true}, + McpSlug: conv.ToPGText("subscribe-test-mcp"), + McpEnabled: true, + }) + require.NoError(t, err) + + apiKey := ti.createTestAPIKey(ctx, t) + + reqBody := []map[string]any{ + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/subscribe", + "params": map[string]any{"uri": "file:///test.txt"}, + }, + } + bodyBytes, err := json.Marshal(reqBody) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/mcp/"+*authCtx.ProjectSlug+"/"+toolset.Slug+"/production", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("project", *authCtx.ProjectSlug) + rctx.URLParams.Add("toolset", toolset.Slug) + rctx.URLParams.Add("environment", "production") + req = req.WithContext(context.WithValue(t.Context(), chi.RouteCtxKey, rctx)) + + w := httptest.NewRecorder() + err = ti.service.ServeAuthenticated(w, req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, w.Code) + + var response map[string]any + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + require.Equal(t, "2.0", response["jsonrpc"]) + }) + + t.Run("mixed batch with notification and call", func(t *testing.T) { + t.Parallel() + + ctx, ti := newTestMCPService(t) + toolsetsRepo := toolsets_repo.New(ti.conn) + + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + require.NotNil(t, authCtx.ProjectID) + require.NotNil(t, authCtx.ProjectSlug) + + toolset, err := toolsetsRepo.CreateToolset(ctx, toolsets_repo.CreateToolsetParams{ + OrganizationID: authCtx.ActiveOrganizationID, + ProjectID: *authCtx.ProjectID, + Name: "Mixed Batch Test MCP", + Slug: "mixed-batch-test-mcp", + Description: conv.ToPGText("A test MCP for mixed batches"), + DefaultEnvironmentSlug: pgtype.Text{String: "production", Valid: true}, + McpSlug: conv.ToPGText("mixed-batch-test-mcp"), + McpEnabled: true, + }) + require.NoError(t, err) + + apiKey := ti.createTestAPIKey(ctx, t) + + // Notification followed by a regular call + reqBody := []map[string]any{ + { + "jsonrpc": "2.0", + "method": "notifications/initialized", + }, + { + "jsonrpc": "2.0", + "id": 1, + "method": "ping", + }, + } + bodyBytes, err := json.Marshal(reqBody) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/mcp/"+*authCtx.ProjectSlug+"/"+toolset.Slug+"/production", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("project", *authCtx.ProjectSlug) + rctx.URLParams.Add("toolset", toolset.Slug) + rctx.URLParams.Add("environment", "production") + req = req.WithContext(context.WithValue(t.Context(), chi.RouteCtxKey, rctx)) + + w := httptest.NewRecorder() + err = ti.service.ServeAuthenticated(w, req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, w.Code) + + // Only the ping response should be returned (notification is skipped) + var response map[string]any + err = json.Unmarshal(w.Body.Bytes(), &response) + require.NoError(t, err) + require.Equal(t, "2.0", response["jsonrpc"]) + require.NotNil(t, response["id"]) + }) + + t.Run("handles DELETE session with valid session ID", func(t *testing.T) { + t.Parallel() + + _, ti := newTestMCPService(t) + + req := httptest.NewRequest(http.MethodDelete, "/mcp/test-slug", nil) + req.Header.Set("Mcp-Session-Id", "test-session-id") + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("mcpSlug", "test-slug") + req = req.WithContext(context.WithValue(t.Context(), chi.RouteCtxKey, rctx)) + + w := httptest.NewRecorder() + err := ti.service.HandleDeleteSession(w, req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("DELETE session returns error without session ID header", func(t *testing.T) { + t.Parallel() + + _, ti := newTestMCPService(t) + + req := httptest.NewRequest(http.MethodDelete, "/mcp/test-slug", nil) + // No Mcp-Session-Id header + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("mcpSlug", "test-slug") + req = req.WithContext(context.WithValue(t.Context(), chi.RouteCtxKey, rctx)) + + w := httptest.NewRecorder() + err := ti.service.HandleDeleteSession(w, req) + require.Error(t, err) + require.Contains(t, err.Error(), "Mcp-Session-Id header is required") + }) + t.Run("handles multiple requests in batch", func(t *testing.T) { t.Parallel() diff --git a/server/internal/mcp/session_store.go b/server/internal/mcp/session_store.go new file mode 100644 index 000000000..35a6c3d66 --- /dev/null +++ b/server/internal/mcp/session_store.go @@ -0,0 +1,62 @@ +package mcp + +import ( + "context" + "fmt" + "time" + + "github.com/speakeasy-api/gram/server/internal/cache" +) + +const ( + mcpSessionPrefix = "mcp-session:" + mcpSessionTTL = 30 * time.Minute +) + +// mcpSessionStore tracks initialized MCP sessions using the cache layer. +// Sessions are created on successful initialize requests and validated on +// subsequent requests to ensure the client is using a known session. +type mcpSessionStore struct { + cache cache.Cache +} + +func newMCPSessionStore(c cache.Cache) *mcpSessionStore { + return &mcpSessionStore{cache: c} +} + +func (s *mcpSessionStore) key(sessionID string) string { + return mcpSessionPrefix + sessionID +} + +// Create stores a new session ID with a 30-minute TTL. +func (s *mcpSessionStore) Create(ctx context.Context, sessionID string) error { + if err := s.cache.Set(ctx, s.key(sessionID), true, mcpSessionTTL); err != nil { + return fmt.Errorf("create mcp session: %w", err) + } + return nil +} + +// Validate checks whether the session ID exists in the store. +func (s *mcpSessionStore) Validate(ctx context.Context, sessionID string) bool { + var exists bool + if err := s.cache.Get(ctx, s.key(sessionID), &exists); err != nil { + return false + } + return exists +} + +// Delete removes a session ID from the store. +func (s *mcpSessionStore) Delete(ctx context.Context, sessionID string) error { + if err := s.cache.Delete(ctx, s.key(sessionID)); err != nil { + return fmt.Errorf("delete mcp session: %w", err) + } + return nil +} + +// Touch refreshes the TTL on an existing session. +func (s *mcpSessionStore) Touch(ctx context.Context, sessionID string) error { + if err := s.cache.Set(ctx, s.key(sessionID), true, mcpSessionTTL); err != nil { + return fmt.Errorf("touch mcp session: %w", err) + } + return nil +}