diff --git a/src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap b/src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap index d58f06b0..7de977a1 100644 --- a/src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap +++ b/src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap @@ -1726,6 +1726,23 @@ logger = logging.getLogger(__name__) import httpx from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +from bedrock_agentcore.identity import requires_access_token +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +@requires_access_token( + provider_name="{{credentialProviderName}}", + scopes=[{{#if scopes}}"{{scopes}}"{{/if}}], + auth_flow="M2M", +) +def _get_bearer_token_{{snakeCase name}}(*, access_token: str): + """Obtain OAuth access token via AgentCore Identity for {{name}}.""" + return access_token + +{{/if}} +{{/each}} def get_all_gateway_mcp_toolsets() -> list[MCPToolset]: """Returns MCP Toolsets for all configured gateways.""" @@ -1740,6 +1757,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]: url=url, httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs) ))) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else None + toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers))) {{else}} toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url))) {{/if}} @@ -2012,6 +2033,23 @@ logger = logging.getLogger(__name__) {{#if (includes gatewayAuthTypes "AWS_IAM")}} from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +from bedrock_agentcore.identity import requires_access_token +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +@requires_access_token( + provider_name="{{credentialProviderName}}", + scopes=[{{#if scopes}}"{{scopes}}"{{/if}}], + auth_flow="M2M", +) +def _get_bearer_token_{{snakeCase name}}(*, access_token: str): + """Obtain OAuth access token via AgentCore Identity for {{name}}.""" + return access_token + +{{/if}} +{{/each}} def get_all_gateway_mcp_client() -> MultiServerMCPClient | None: """Returns an MCP Client connected to all configured gateways.""" @@ -2023,6 +2061,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None: session = create_aws_session() auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name) servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth} + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else None + servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers} {{else}} servers["{{name}}"] = {"transport": "streamable_http", "url": url} {{/if}} @@ -2438,6 +2480,23 @@ logger = logging.getLogger(__name__) import httpx from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +from bedrock_agentcore.identity import requires_access_token +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +@requires_access_token( + provider_name="{{credentialProviderName}}", + scopes=[{{#if scopes}}"{{scopes}}"{{/if}}], + auth_flow="M2M", +) +def _get_bearer_token_{{snakeCase name}}(*, access_token: str): + """Obtain OAuth access token via AgentCore Identity for {{name}}.""" + return access_token + +{{/if}} +{{/each}} def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]: """Returns MCP servers for all configured gateways.""" @@ -2452,6 +2511,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]: name="{{name}}", params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)} )) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else {} + servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers})) {{else}} servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url})) {{/if}} @@ -2749,7 +2812,23 @@ logger = logging.getLogger(__name__) {{#if (includes gatewayAuthTypes "AWS_IAM")}} from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +from bedrock_agentcore.identity import requires_access_token +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +@requires_access_token( + provider_name="{{credentialProviderName}}", + scopes=[{{#if scopes}}"{{scopes}}"{{/if}}], + auth_flow="M2M", +) +def _get_bearer_token_{{snakeCase name}}(*, access_token: str): + """Obtain OAuth access token via AgentCore Identity for {{name}}.""" + return access_token +{{/if}} +{{/each}} {{#each gatewayProviders}} def get_{{snakeCase name}}_mcp_client() -> MCPClient | None: """Returns an MCP Client connected to the {{name}} gateway.""" @@ -2759,6 +2838,10 @@ def get_{{snakeCase name}}_mcp_client() -> MCPClient | None: return None {{#if (eq authType "AWS_IAM")}} return MCPClient(lambda: aws_iam_streamablehttp_client(url, aws_service="bedrock-agentcore", aws_region=os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION")))) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else {} + return MCPClient(lambda: streamablehttp_client(url, headers=headers)) {{else}} return MCPClient(lambda: streamablehttp_client(url)) {{/if}} diff --git a/src/assets/python/googleadk/base/mcp_client/client.py b/src/assets/python/googleadk/base/mcp_client/client.py index f2c1a39c..e6dddd62 100644 --- a/src/assets/python/googleadk/base/mcp_client/client.py +++ b/src/assets/python/googleadk/base/mcp_client/client.py @@ -10,6 +10,23 @@ import httpx from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +from bedrock_agentcore.identity import requires_access_token +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +@requires_access_token( + provider_name="{{credentialProviderName}}", + scopes=[{{#if scopes}}"{{scopes}}"{{/if}}], + auth_flow="M2M", +) +def _get_bearer_token_{{snakeCase name}}(*, access_token: str): + """Obtain OAuth access token via AgentCore Identity for {{name}}.""" + return access_token + +{{/if}} +{{/each}} def get_all_gateway_mcp_toolsets() -> list[MCPToolset]: """Returns MCP Toolsets for all configured gateways.""" @@ -24,6 +41,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]: url=url, httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs) ))) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else None + toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers))) {{else}} toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url))) {{/if}} diff --git a/src/assets/python/langchain_langgraph/base/mcp_client/client.py b/src/assets/python/langchain_langgraph/base/mcp_client/client.py index adcb478a..71b336d2 100644 --- a/src/assets/python/langchain_langgraph/base/mcp_client/client.py +++ b/src/assets/python/langchain_langgraph/base/mcp_client/client.py @@ -8,6 +8,23 @@ {{#if (includes gatewayAuthTypes "AWS_IAM")}} from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +from bedrock_agentcore.identity import requires_access_token +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +@requires_access_token( + provider_name="{{credentialProviderName}}", + scopes=[{{#if scopes}}"{{scopes}}"{{/if}}], + auth_flow="M2M", +) +def _get_bearer_token_{{snakeCase name}}(*, access_token: str): + """Obtain OAuth access token via AgentCore Identity for {{name}}.""" + return access_token + +{{/if}} +{{/each}} def get_all_gateway_mcp_client() -> MultiServerMCPClient | None: """Returns an MCP Client connected to all configured gateways.""" @@ -19,6 +36,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None: session = create_aws_session() auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name) servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth} + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else None + servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers} {{else}} servers["{{name}}"] = {"transport": "streamable_http", "url": url} {{/if}} diff --git a/src/assets/python/openaiagents/base/mcp_client/client.py b/src/assets/python/openaiagents/base/mcp_client/client.py index 39612c38..2fe91136 100644 --- a/src/assets/python/openaiagents/base/mcp_client/client.py +++ b/src/assets/python/openaiagents/base/mcp_client/client.py @@ -9,6 +9,23 @@ import httpx from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +from bedrock_agentcore.identity import requires_access_token +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +@requires_access_token( + provider_name="{{credentialProviderName}}", + scopes=[{{#if scopes}}"{{scopes}}"{{/if}}], + auth_flow="M2M", +) +def _get_bearer_token_{{snakeCase name}}(*, access_token: str): + """Obtain OAuth access token via AgentCore Identity for {{name}}.""" + return access_token + +{{/if}} +{{/each}} def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]: """Returns MCP servers for all configured gateways.""" @@ -23,6 +40,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]: name="{{name}}", params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)} )) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else {} + servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers})) {{else}} servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url})) {{/if}} diff --git a/src/assets/python/strands/base/mcp_client/client.py b/src/assets/python/strands/base/mcp_client/client.py index 3b77cdac..01457de2 100644 --- a/src/assets/python/strands/base/mcp_client/client.py +++ b/src/assets/python/strands/base/mcp_client/client.py @@ -9,7 +9,23 @@ {{#if (includes gatewayAuthTypes "AWS_IAM")}} from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client {{/if}} +{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}} +from bedrock_agentcore.identity import requires_access_token +{{/if}} + +{{#each gatewayProviders}} +{{#if (eq authType "CUSTOM_JWT")}} +@requires_access_token( + provider_name="{{credentialProviderName}}", + scopes=[{{#if scopes}}"{{scopes}}"{{/if}}], + auth_flow="M2M", +) +def _get_bearer_token_{{snakeCase name}}(*, access_token: str): + """Obtain OAuth access token via AgentCore Identity for {{name}}.""" + return access_token +{{/if}} +{{/each}} {{#each gatewayProviders}} def get_{{snakeCase name}}_mcp_client() -> MCPClient | None: """Returns an MCP Client connected to the {{name}} gateway.""" @@ -19,6 +35,10 @@ def get_{{snakeCase name}}_mcp_client() -> MCPClient | None: return None {{#if (eq authType "AWS_IAM")}} return MCPClient(lambda: aws_iam_streamablehttp_client(url, aws_service="bedrock-agentcore", aws_region=os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION")))) + {{else if (eq authType "CUSTOM_JWT")}} + token = _get_bearer_token_{{snakeCase name}}() + headers = {"Authorization": f"Bearer {token}"} if token else {} + return MCPClient(lambda: streamablehttp_client(url, headers=headers)) {{else}} return MCPClient(lambda: streamablehttp_client(url)) {{/if}} diff --git a/src/cli/commands/add/__tests__/validate.test.ts b/src/cli/commands/add/__tests__/validate.test.ts index 3b319c9e..0d4f7961 100644 --- a/src/cli/commands/add/__tests__/validate.test.ts +++ b/src/cli/commands/add/__tests__/validate.test.ts @@ -240,6 +240,47 @@ describe('validate', () => { expect(validateAddGatewayOptions(validGatewayOptionsNone)).toEqual({ valid: true }); expect(validateAddGatewayOptions(validGatewayOptionsJwt)).toEqual({ valid: true }); }); + + // AC15: agentClientId and agentClientSecret must be provided together + it('returns error when agentClientId provided without agentClientSecret', () => { + const result = validateAddGatewayOptions({ + ...validGatewayOptionsJwt, + agentClientId: 'my-client-id', + }); + expect(result.valid).toBe(false); + expect(result.error).toBe('Both --agent-client-id and --agent-client-secret must be provided together'); + }); + + it('returns error when agentClientSecret provided without agentClientId', () => { + const result = validateAddGatewayOptions({ + ...validGatewayOptionsJwt, + agentClientSecret: 'my-secret', + }); + expect(result.valid).toBe(false); + expect(result.error).toBe('Both --agent-client-id and --agent-client-secret must be provided together'); + }); + + // AC16: agent credentials only valid with CUSTOM_JWT + it('returns error when agent credentials used with non-CUSTOM_JWT authorizer', () => { + const result = validateAddGatewayOptions({ + ...validGatewayOptionsNone, + agentClientId: 'my-client-id', + agentClientSecret: 'my-secret', + }); + expect(result.valid).toBe(false); + expect(result.error).toBe('Agent OAuth credentials are only valid with CUSTOM_JWT authorizer'); + }); + + // AC17: valid CUSTOM_JWT with agent credentials passes + it('passes for CUSTOM_JWT with agent credentials', () => { + const result = validateAddGatewayOptions({ + ...validGatewayOptionsJwt, + agentClientId: 'my-client-id', + agentClientSecret: 'my-secret', + allowedScopes: 'scope1,scope2', + }); + expect(result.valid).toBe(true); + }); }); describe('validateAddGatewayTargetOptions', () => { diff --git a/src/cli/commands/add/actions.ts b/src/cli/commands/add/actions.ts index 675c52e8..7232f7c7 100644 --- a/src/cli/commands/add/actions.ts +++ b/src/cli/commands/add/actions.ts @@ -64,6 +64,9 @@ export interface ValidatedAddGatewayOptions { discoveryUrl?: string; allowedAudience?: string; allowedClients?: string; + allowedScopes?: string; + agentClientId?: string; + agentClientSecret?: string; agents?: string; } @@ -267,6 +270,14 @@ function buildGatewayConfig(options: ValidatedAddGatewayOptions): AddGatewayConf .allowedClients!.split(',') .map(s => s.trim()) .filter(Boolean), + allowedScopes: options.allowedScopes + ? options.allowedScopes + .split(',') + .map(s => s.trim()) + .filter(Boolean) + : undefined, + agentClientId: options.agentClientId, + agentClientSecret: options.agentClientSecret, }; } diff --git a/src/cli/commands/add/command.tsx b/src/cli/commands/add/command.tsx index 6a9370bb..22e89dc5 100644 --- a/src/cli/commands/add/command.tsx +++ b/src/cli/commands/add/command.tsx @@ -82,6 +82,9 @@ async function handleAddGatewayCLI(options: AddGatewayOptions): Promise { discoveryUrl: options.discoveryUrl, allowedAudience: options.allowedAudience, allowedClients: options.allowedClients, + allowedScopes: options.allowedScopes, + agentClientId: options.agentClientId, + agentClientSecret: options.agentClientSecret, agents: options.agents, }); @@ -272,6 +275,9 @@ export function registerAdd(program: Command) { .option('--discovery-url ', 'OIDC discovery URL (required for CUSTOM_JWT)') .option('--allowed-audience ', 'Comma-separated allowed audience values (required for CUSTOM_JWT)') .option('--allowed-clients ', 'Comma-separated allowed client IDs (required for CUSTOM_JWT)') + .option('--allowed-scopes ', 'Comma-separated allowed scopes (optional for CUSTOM_JWT)') + .option('--agent-client-id ', 'Agent OAuth client ID for Bearer token auth (CUSTOM_JWT)') + .option('--agent-client-secret ', 'Agent OAuth client secret (CUSTOM_JWT)') .option('--json', 'Output as JSON') .action(async options => { requireProject(); diff --git a/src/cli/commands/add/types.ts b/src/cli/commands/add/types.ts index c83db76d..46757121 100644 --- a/src/cli/commands/add/types.ts +++ b/src/cli/commands/add/types.ts @@ -31,6 +31,9 @@ export interface AddGatewayOptions { discoveryUrl?: string; allowedAudience?: string; allowedClients?: string; + allowedScopes?: string; + agentClientId?: string; + agentClientSecret?: string; agents?: string; json?: boolean; } diff --git a/src/cli/commands/add/validate.ts b/src/cli/commands/add/validate.ts index 9a4bc4df..0aac0a21 100644 --- a/src/cli/commands/add/validate.ts +++ b/src/cli/commands/add/validate.ts @@ -181,6 +181,17 @@ export function validateAddGatewayOptions(options: AddGatewayOptions): Validatio } } + // Validate agent OAuth credentials + if (options.agentClientId && !options.agentClientSecret) { + return { valid: false, error: 'Both --agent-client-id and --agent-client-secret must be provided together' }; + } + if (options.agentClientSecret && !options.agentClientId) { + return { valid: false, error: 'Both --agent-client-id and --agent-client-secret must be provided together' }; + } + if (options.agentClientId && options.authorizerType !== 'CUSTOM_JWT') { + return { valid: false, error: 'Agent OAuth credentials are only valid with CUSTOM_JWT authorizer' }; + } + return { valid: true }; } diff --git a/src/cli/commands/remove/actions.ts b/src/cli/commands/remove/actions.ts index 35681c69..74604ea2 100644 --- a/src/cli/commands/remove/actions.ts +++ b/src/cli/commands/remove/actions.ts @@ -72,7 +72,7 @@ export async function handleRemove(options: ValidatedRemoveOptions): Promise ({ - name: gateway.name, - envVarName: computeDefaultGatewayEnvVarName(gateway.name), - authType: gateway.authorizerType, - })); + const project = await configIO.readProjectSpec(); + + return mcpSpec.agentCoreGateways.map(gateway => { + const config: GatewayProviderRenderConfig = { + name: gateway.name, + envVarName: computeDefaultGatewayEnvVarName(gateway.name), + authType: gateway.authorizerType, + }; + + if (gateway.authorizerType === 'CUSTOM_JWT' && gateway.authorizerConfiguration?.customJwtAuthorizer) { + const jwtConfig = gateway.authorizerConfiguration.customJwtAuthorizer; + const credName = `${gateway.name}-agent-oauth`; + const credential = project.credentials.find(c => c.name === credName); + + if (credential) { + config.credentialProviderName = credName; + config.discoveryUrl = jwtConfig.discoveryUrl; + const scopes = 'allowedScopes' in jwtConfig ? (jwtConfig as { allowedScopes?: string[] }).allowedScopes : undefined; + if (scopes?.length) { + config.scopes = scopes.join(' '); + } + } + } + + return config; + }); } catch { return []; } diff --git a/src/cli/operations/identity/create-identity.ts b/src/cli/operations/identity/create-identity.ts index f42bee61..26a0c672 100644 --- a/src/cli/operations/identity/create-identity.ts +++ b/src/cli/operations/identity/create-identity.ts @@ -14,6 +14,7 @@ export type CreateCredentialConfig = clientSecret: string; scopes?: string[]; vendor?: string; + managed?: boolean; }; /** @@ -143,6 +144,7 @@ export async function createCredential(config: CreateCredentialConfig): Promise< discoveryUrl: config.discoveryUrl, vendor: config.vendor ?? 'CustomOauth2', ...(config.scopes && config.scopes.length > 0 ? { scopes: config.scopes } : {}), + ...(config.managed ? { managed: true } : {}), }; project.credentials.push(credential); await configIO.writeProjectSpec(project); diff --git a/src/cli/operations/mcp/create-mcp.ts b/src/cli/operations/mcp/create-mcp.ts index 1f554642..f8bb6e63 100644 --- a/src/cli/operations/mcp/create-mcp.ts +++ b/src/cli/operations/mcp/create-mcp.ts @@ -11,6 +11,7 @@ import { AgentCoreCliMcpDefsSchema, ToolDefinitionSchema } from '../../../schema import { getTemplateToolDefinitions, renderGatewayTargetTemplate } from '../../templates/GatewayTargetRenderer'; import type { AddGatewayConfig, AddGatewayTargetConfig } from '../../tui/screens/mcp/types'; import { DEFAULT_HANDLER, DEFAULT_NODE_VERSION, DEFAULT_PYTHON_VERSION } from '../../tui/screens/mcp/types'; +import { createCredential } from '../identity/create-identity'; import { existsSync } from 'fs'; import { mkdir, readFile, writeFile } from 'fs/promises'; import { dirname, join } from 'path'; @@ -71,6 +72,7 @@ function buildAuthorizerConfiguration(config: AddGatewayConfig): AgentCoreGatewa discoveryUrl: config.jwtConfig.discoveryUrl, allowedAudience: config.jwtConfig.allowedAudience, allowedClients: config.jwtConfig.allowedClients, + ...(config.jwtConfig.allowedScopes?.length && { allowedScopes: config.jwtConfig.allowedScopes }), }, }; } @@ -201,6 +203,20 @@ export async function createGatewayFromWizard(config: AddGatewayConfig): Promise await configIO.writeMcpSpec(mcpSpec); + // Auto-create managed credential if agent OAuth credentials provided + if (config.jwtConfig?.agentClientId && config.jwtConfig?.agentClientSecret) { + const credName = `${config.name}-agent-oauth`; + await createCredential({ + type: 'OAuthCredentialProvider', + name: credName, + discoveryUrl: config.jwtConfig.discoveryUrl, + clientId: config.jwtConfig.agentClientId, + clientSecret: config.jwtConfig.agentClientSecret, + vendor: 'CustomOauth2', + managed: true, + }); + } + return { name: config.name }; } diff --git a/src/cli/operations/remove/__tests__/remove-identity.test.ts b/src/cli/operations/remove/__tests__/remove-identity.test.ts index b6172a33..2426b345 100644 --- a/src/cli/operations/remove/__tests__/remove-identity.test.ts +++ b/src/cli/operations/remove/__tests__/remove-identity.test.ts @@ -1,8 +1,9 @@ -import { previewRemoveCredential } from '../remove-identity.js'; +import { previewRemoveCredential, removeCredential } from '../remove-identity.js'; import { describe, expect, it, vi } from 'vitest'; -const { mockReadProjectSpec, mockConfigExists, mockReadMcpSpec } = vi.hoisted(() => ({ +const { mockReadProjectSpec, mockWriteProjectSpec, mockConfigExists, mockReadMcpSpec } = vi.hoisted(() => ({ mockReadProjectSpec: vi.fn(), + mockWriteProjectSpec: vi.fn(), mockConfigExists: vi.fn(), mockReadMcpSpec: vi.fn(), })); @@ -10,6 +11,7 @@ const { mockReadProjectSpec, mockConfigExists, mockReadMcpSpec } = vi.hoisted(() vi.mock('../../../../lib/index.js', () => ({ ConfigIO: class { readProjectSpec = mockReadProjectSpec; + writeProjectSpec = mockWriteProjectSpec; configExists = mockConfigExists; readMcpSpec = mockReadMcpSpec; }, @@ -118,4 +120,57 @@ describe('previewRemoveCredential', () => { 'Warning: Credential "test-cred" is referenced by gateway targets: gateway2/target2. Removing it may break these targets.' ); }); + + it('shows managed credential warning in preview', async () => { + mockReadProjectSpec.mockResolvedValue({ + credentials: [{ name: 'gw-agent-oauth', type: 'OAuthCredentialProvider', managed: true, usage: 'inbound' }], + }); + mockConfigExists.mockReturnValue(false); + + const result = await previewRemoveCredential('gw-agent-oauth'); + + const warning = result.summary.find(s => s.includes('auto-created')); + expect(warning).toBeTruthy(); + }); +}); + +describe('removeCredential', () => { + it('blocks removal of managed credential without force', async () => { + mockReadProjectSpec.mockResolvedValue({ + credentials: [{ name: 'gw-agent-oauth', type: 'OAuthCredentialProvider', managed: true, usage: 'inbound' }], + }); + mockConfigExists.mockReturnValue(false); + + const result = await removeCredential('gw-agent-oauth'); + + expect(result.ok).toBe(false); + if (!result.ok) { + expect(result.error).toContain('auto-created'); + expect(result.error).toContain('--force'); + } + }); + + it('allows removal of managed credential with force', async () => { + mockReadProjectSpec.mockResolvedValue({ + credentials: [{ name: 'gw-agent-oauth', type: 'OAuthCredentialProvider', managed: true, usage: 'inbound' }], + }); + mockConfigExists.mockReturnValue(false); + mockWriteProjectSpec.mockResolvedValue(undefined); + + const result = await removeCredential('gw-agent-oauth', { force: true }); + + expect(result.ok).toBe(true); + }); + + it('allows removal of non-managed credential without force', async () => { + mockReadProjectSpec.mockResolvedValue({ + credentials: [{ name: 'regular-cred', type: 'OAuthCredentialProvider' }], + }); + mockConfigExists.mockReturnValue(false); + mockWriteProjectSpec.mockResolvedValue(undefined); + + const result = await removeCredential('regular-cred'); + + expect(result.ok).toBe(true); + }); }); diff --git a/src/cli/operations/remove/remove-identity.ts b/src/cli/operations/remove/remove-identity.ts index 68c9e417..6c560c64 100644 --- a/src/cli/operations/remove/remove-identity.ts +++ b/src/cli/operations/remove/remove-identity.ts @@ -43,6 +43,12 @@ export async function previewRemoveCredential(credentialName: string): Promise { +export async function removeCredential(credentialName: string, options?: { force?: boolean }): Promise { try { const configIO = new ConfigIO(); const project = await configIO.readProjectSpec(); @@ -95,6 +101,16 @@ export async function removeCredential(credentialName: string): Promise => { setState({ isLoading: true, result: null }); - const result = await removeIdentity(identityName); + const result = await removeIdentity(identityName, { force: true }); setState({ isLoading: false, result }); let logPath: string | undefined; diff --git a/src/cli/tui/screens/mcp/AddGatewayScreen.tsx b/src/cli/tui/screens/mcp/AddGatewayScreen.tsx index 13269eef..dca25086 100644 --- a/src/cli/tui/screens/mcp/AddGatewayScreen.tsx +++ b/src/cli/tui/screens/mcp/AddGatewayScreen.tsx @@ -4,6 +4,7 @@ import { ConfirmReview, Panel, Screen, + SecretInput, StepIndicator, TextInput, WizardMultiSelect, @@ -29,10 +30,13 @@ interface AddGatewayScreenProps { export function AddGatewayScreen({ onComplete, onExit, existingGateways, unassignedTargets }: AddGatewayScreenProps) { const wizard = useAddGatewayWizard(unassignedTargets.length); - // JWT config sub-step tracking (0 = discoveryUrl, 1 = audience, 2 = clients) + // JWT config sub-step tracking (0=discoveryUrl, 1=audience, 2=clients, 3=scopes, 4=agentClientId, 5=agentClientSecret) const [jwtSubStep, setJwtSubStep] = useState(0); const [jwtDiscoveryUrl, setJwtDiscoveryUrl] = useState(''); const [jwtAudience, setJwtAudience] = useState(''); + const [jwtClients, setJwtClients] = useState(''); + const [jwtScopes, setJwtScopes] = useState(''); + const [jwtAgentClientId, setJwtAgentClientId] = useState(''); const unassignedTargetItems: SelectableItem[] = useMemo( () => unassignedTargets.map(name => ({ id: name, title: name })), @@ -85,12 +89,30 @@ export function AddGatewayScreen({ onComplete, onExit, existingGateways, unassig }; const handleJwtClients = (clients: string) => { - // Parse comma-separated values + setJwtClients(clients); + setJwtSubStep(3); + }; + + const handleJwtScopes = (scopes: string) => { + setJwtScopes(scopes); + setJwtSubStep(4); + }; + + const handleJwtAgentClientId = (clientId: string) => { + setJwtAgentClientId(clientId); + setJwtSubStep(5); + }; + + const handleJwtAgentClientSecret = (clientSecret: string) => { const audienceList = jwtAudience .split(',') .map(s => s.trim()) .filter(Boolean); - const clientsList = clients + const clientsList = jwtClients + .split(',') + .map(s => s.trim()) + .filter(Boolean); + const scopesList = jwtScopes .split(',') .map(s => s.trim()) .filter(Boolean); @@ -99,9 +121,10 @@ export function AddGatewayScreen({ onComplete, onExit, existingGateways, unassig discoveryUrl: jwtDiscoveryUrl, allowedAudience: audienceList, allowedClients: clientsList, + ...(scopesList.length > 0 ? { allowedScopes: scopesList } : {}), + ...(jwtAgentClientId ? { agentClientId: jwtAgentClientId, agentClientSecret: clientSecret } : {}), }); - // Reset sub-step counter only - preserve values for potential back navigation setJwtSubStep(0); }; @@ -160,6 +183,9 @@ export function AddGatewayScreen({ onComplete, onExit, existingGateways, unassig onDiscoveryUrl={handleJwtDiscoveryUrl} onAudience={handleJwtAudience} onClients={handleJwtClients} + onScopes={handleJwtScopes} + onAgentClientId={handleJwtAgentClientId} + onAgentClientSecret={handleJwtAgentClientSecret} onCancel={handleJwtCancel} /> )} @@ -187,6 +213,12 @@ export function AddGatewayScreen({ onComplete, onExit, existingGateways, unassig { label: 'Discovery URL', value: wizard.config.jwtConfig.discoveryUrl }, { label: 'Allowed Audience', value: wizard.config.jwtConfig.allowedAudience.join(', ') }, { label: 'Allowed Clients', value: wizard.config.jwtConfig.allowedClients.join(', ') }, + ...(wizard.config.jwtConfig.allowedScopes?.length + ? [{ label: 'Allowed Scopes', value: wizard.config.jwtConfig.allowedScopes.join(', ') }] + : []), + ...(wizard.config.jwtConfig.agentClientId + ? [{ label: 'Agent Credential', value: `${wizard.config.name}-agent-oauth` }] + : []), ] : []), { @@ -209,6 +241,9 @@ interface JwtConfigInputProps { onDiscoveryUrl: (url: string) => void; onAudience: (audience: string) => void; onClients: (clients: string) => void; + onScopes: (scopes: string) => void; + onAgentClientId: (clientId: string) => void; + onAgentClientSecret: (clientSecret: string) => void; onCancel: () => void; } @@ -227,16 +262,28 @@ function validateCommaSeparatedList(value: string, fieldName: string): true | st return true; } -function JwtConfigInput({ subStep, onDiscoveryUrl, onAudience, onClients, onCancel }: JwtConfigInputProps) { +function JwtConfigInput({ + subStep, + onDiscoveryUrl, + onAudience, + onClients, + onScopes, + onAgentClientId, + onAgentClientSecret, + onCancel, +}: JwtConfigInputProps) { + const totalSteps = 6; return ( Configure Custom JWT Authorizer - Step {subStep + 1} of 3 + + Step {subStep + 1} of {totalSteps} + {subStep === 0 && ( { @@ -271,6 +318,33 @@ function JwtConfigInput({ subStep, onDiscoveryUrl, onAudience, onClients, onCanc customValidation={value => validateCommaSeparatedList(value, 'client')} /> )} + {subStep === 3 && ( + + )} + {subStep === 4 && ( + + )} + {subStep === 5 && ( + value.trim().length > 0 || 'Client secret is required'} + revealChars={4} + /> + )} ); diff --git a/src/cli/tui/screens/mcp/types.ts b/src/cli/tui/screens/mcp/types.ts index fcf7d593..f24aeed5 100644 --- a/src/cli/tui/screens/mcp/types.ts +++ b/src/cli/tui/screens/mcp/types.ts @@ -16,6 +16,9 @@ export interface AddGatewayConfig { discoveryUrl: string; allowedAudience: string[]; allowedClients: string[]; + allowedScopes?: string[]; + agentClientId?: string; + agentClientSecret?: string; }; /** Selected unassigned targets to include in this gateway */ selectedTargets?: string[]; diff --git a/src/cli/tui/screens/mcp/useAddGatewayWizard.ts b/src/cli/tui/screens/mcp/useAddGatewayWizard.ts index 2bd24b75..90265bca 100644 --- a/src/cli/tui/screens/mcp/useAddGatewayWizard.ts +++ b/src/cli/tui/screens/mcp/useAddGatewayWizard.ts @@ -68,7 +68,14 @@ export function useAddGatewayWizard(unassignedTargetsCount = 0) { }, []); const setJwtConfig = useCallback( - (jwtConfig: { discoveryUrl: string; allowedAudience: string[]; allowedClients: string[] }) => { + (jwtConfig: { + discoveryUrl: string; + allowedAudience: string[]; + allowedClients: string[]; + allowedScopes?: string[]; + agentClientId?: string; + agentClientSecret?: string; + }) => { setConfig(c => ({ ...c, jwtConfig, diff --git a/src/schema/schemas/agentcore-project.ts b/src/schema/schemas/agentcore-project.ts index 13f8241f..fda34160 100644 --- a/src/schema/schemas/agentcore-project.ts +++ b/src/schema/schemas/agentcore-project.ts @@ -101,6 +101,8 @@ export const OAuthCredentialSchema = z.object({ vendor: z.string().default('CustomOauth2'), /** Whether this credential was auto-created by the CLI (e.g., for CUSTOM_JWT inbound auth) */ managed: z.boolean().optional(), + /** Whether this credential is used for inbound or outbound auth */ + usage: z.enum(['inbound', 'outbound']).optional(), }); export type OAuthCredential = z.infer;