Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/assets/__tests__/__snapshots__/assets.snapshot.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,23 @@ logger = logging.getLogger(__name__)
import httpx
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
from bedrock_agentcore.identity import requires_access_token
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
@requires_access_token(
provider_name="{{credentialProviderName}}",
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
auth_flow="M2M",
)
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
return access_token

{{/if}}
{{/each}}

def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
"""Returns MCP Toolsets for all configured gateways."""
Expand All @@ -1740,6 +1757,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
url=url,
httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)
)))
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else None
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers)))
{{else}}
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url)))
{{/if}}
Expand Down Expand Up @@ -2012,6 +2033,23 @@ logger = logging.getLogger(__name__)
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
from bedrock_agentcore.identity import requires_access_token
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
@requires_access_token(
provider_name="{{credentialProviderName}}",
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
auth_flow="M2M",
)
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
return access_token

{{/if}}
{{/each}}

def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
"""Returns an MCP Client connected to all configured gateways."""
Expand All @@ -2023,6 +2061,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
session = create_aws_session()
auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name)
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth}
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else None
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers}
{{else}}
servers["{{name}}"] = {"transport": "streamable_http", "url": url}
{{/if}}
Expand Down Expand Up @@ -2438,6 +2480,23 @@ logger = logging.getLogger(__name__)
import httpx
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
from bedrock_agentcore.identity import requires_access_token
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
@requires_access_token(
provider_name="{{credentialProviderName}}",
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
auth_flow="M2M",
)
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
return access_token

{{/if}}
{{/each}}

def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
"""Returns MCP servers for all configured gateways."""
Expand All @@ -2452,6 +2511,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
name="{{name}}",
params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)}
))
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else {}
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers}))
{{else}}
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url}))
{{/if}}
Expand Down Expand Up @@ -2749,7 +2812,23 @@ logger = logging.getLogger(__name__)
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
from bedrock_agentcore.identity import requires_access_token
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
@requires_access_token(
provider_name="{{credentialProviderName}}",
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
auth_flow="M2M",
)
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
return access_token

{{/if}}
{{/each}}
{{#each gatewayProviders}}
def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
"""Returns an MCP Client connected to the {{name}} gateway."""
Expand All @@ -2759,6 +2838,10 @@ def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
return None
{{#if (eq authType "AWS_IAM")}}
return MCPClient(lambda: aws_iam_streamablehttp_client(url, aws_service="bedrock-agentcore", aws_region=os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))))
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else {}
return MCPClient(lambda: streamablehttp_client(url, headers=headers))
{{else}}
return MCPClient(lambda: streamablehttp_client(url))
{{/if}}
Expand Down
21 changes: 21 additions & 0 deletions src/assets/python/googleadk/base/mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@
import httpx
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
from bedrock_agentcore.identity import requires_access_token
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
@requires_access_token(
provider_name="{{credentialProviderName}}",
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
auth_flow="M2M",
)
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
return access_token

{{/if}}
{{/each}}

def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
"""Returns MCP Toolsets for all configured gateways."""
Expand All @@ -24,6 +41,10 @@ def get_all_gateway_mcp_toolsets() -> list[MCPToolset]:
url=url,
httpx_client_factory=lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)
)))
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else None
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url, headers=headers)))
{{else}}
toolsets.append(MCPToolset(connection_params=StreamableHTTPConnectionParams(url=url)))
{{/if}}
Expand Down
21 changes: 21 additions & 0 deletions src/assets/python/langchain_langgraph/base/mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,23 @@
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
from bedrock_agentcore.identity import requires_access_token
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
@requires_access_token(
provider_name="{{credentialProviderName}}",
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
auth_flow="M2M",
)
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
return access_token

{{/if}}
{{/each}}

def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
"""Returns an MCP Client connected to all configured gateways."""
Expand All @@ -19,6 +36,10 @@ def get_all_gateway_mcp_client() -> MultiServerMCPClient | None:
session = create_aws_session()
auth = SigV4HTTPXAuth(session.get_credentials(), "bedrock-agentcore", session.region_name)
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "auth": auth}
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else None
servers["{{name}}"] = {"transport": "streamable_http", "url": url, "headers": headers}
{{else}}
servers["{{name}}"] = {"transport": "streamable_http", "url": url}
{{/if}}
Expand Down
21 changes: 21 additions & 0 deletions src/assets/python/openaiagents/base/mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,23 @@
import httpx
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth, create_aws_session
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
from bedrock_agentcore.identity import requires_access_token
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
@requires_access_token(
provider_name="{{credentialProviderName}}",
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
auth_flow="M2M",
)
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
return access_token

{{/if}}
{{/each}}

def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
"""Returns MCP servers for all configured gateways."""
Expand All @@ -23,6 +40,10 @@ def get_all_gateway_mcp_servers() -> list[MCPServerStreamableHttp]:
name="{{name}}",
params={"url": url, "httpx_client_factory": lambda **kwargs: httpx.AsyncClient(auth=auth, **kwargs)}
))
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else {}
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url, "headers": headers}))
{{else}}
servers.append(MCPServerStreamableHttp(name="{{name}}", params={"url": url}))
{{/if}}
Expand Down
20 changes: 20 additions & 0 deletions src/assets/python/strands/base/mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,23 @@
{{#if (includes gatewayAuthTypes "AWS_IAM")}}
from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client
{{/if}}
{{#if (includes gatewayAuthTypes "CUSTOM_JWT")}}
from bedrock_agentcore.identity import requires_access_token
{{/if}}

{{#each gatewayProviders}}
{{#if (eq authType "CUSTOM_JWT")}}
@requires_access_token(
provider_name="{{credentialProviderName}}",
scopes=[{{#if scopes}}"{{scopes}}"{{/if}}],
auth_flow="M2M",
)
def _get_bearer_token_{{snakeCase name}}(*, access_token: str):
"""Obtain OAuth access token via AgentCore Identity for {{name}}."""
return access_token

{{/if}}
{{/each}}
{{#each gatewayProviders}}
def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
"""Returns an MCP Client connected to the {{name}} gateway."""
Expand All @@ -19,6 +35,10 @@ def get_{{snakeCase name}}_mcp_client() -> MCPClient | None:
return None
{{#if (eq authType "AWS_IAM")}}
return MCPClient(lambda: aws_iam_streamablehttp_client(url, aws_service="bedrock-agentcore", aws_region=os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))))
{{else if (eq authType "CUSTOM_JWT")}}
token = _get_bearer_token_{{snakeCase name}}()
headers = {"Authorization": f"Bearer {token}"} if token else {}
return MCPClient(lambda: streamablehttp_client(url, headers=headers))
{{else}}
return MCPClient(lambda: streamablehttp_client(url))
{{/if}}
Expand Down
41 changes: 41 additions & 0 deletions src/cli/commands/add/__tests__/validate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,47 @@ describe('validate', () => {
expect(validateAddGatewayOptions(validGatewayOptionsNone)).toEqual({ valid: true });
expect(validateAddGatewayOptions(validGatewayOptionsJwt)).toEqual({ valid: true });
});

// AC15: agentClientId and agentClientSecret must be provided together
it('returns error when agentClientId provided without agentClientSecret', () => {
const result = validateAddGatewayOptions({
...validGatewayOptionsJwt,
agentClientId: 'my-client-id',
});
expect(result.valid).toBe(false);
expect(result.error).toBe('Both --agent-client-id and --agent-client-secret must be provided together');
});

it('returns error when agentClientSecret provided without agentClientId', () => {
const result = validateAddGatewayOptions({
...validGatewayOptionsJwt,
agentClientSecret: 'my-secret',
});
expect(result.valid).toBe(false);
expect(result.error).toBe('Both --agent-client-id and --agent-client-secret must be provided together');
});

// AC16: agent credentials only valid with CUSTOM_JWT
it('returns error when agent credentials used with non-CUSTOM_JWT authorizer', () => {
const result = validateAddGatewayOptions({
...validGatewayOptionsNone,
agentClientId: 'my-client-id',
agentClientSecret: 'my-secret',
});
expect(result.valid).toBe(false);
expect(result.error).toBe('Agent OAuth credentials are only valid with CUSTOM_JWT authorizer');
});

// AC17: valid CUSTOM_JWT with agent credentials passes
it('passes for CUSTOM_JWT with agent credentials', () => {
const result = validateAddGatewayOptions({
...validGatewayOptionsJwt,
agentClientId: 'my-client-id',
agentClientSecret: 'my-secret',
allowedScopes: 'scope1,scope2',
});
expect(result.valid).toBe(true);
});
});

describe('validateAddGatewayTargetOptions', () => {
Expand Down
11 changes: 11 additions & 0 deletions src/cli/commands/add/actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ export interface ValidatedAddGatewayOptions {
discoveryUrl?: string;
allowedAudience?: string;
allowedClients?: string;
allowedScopes?: string;
agentClientId?: string;
agentClientSecret?: string;
agents?: string;
}

Expand Down Expand Up @@ -267,6 +270,14 @@ function buildGatewayConfig(options: ValidatedAddGatewayOptions): AddGatewayConf
.allowedClients!.split(',')
.map(s => s.trim())
.filter(Boolean),
allowedScopes: options.allowedScopes
? options.allowedScopes
.split(',')
.map(s => s.trim())
.filter(Boolean)
: undefined,
agentClientId: options.agentClientId,
agentClientSecret: options.agentClientSecret,
};
}

Expand Down
6 changes: 6 additions & 0 deletions src/cli/commands/add/command.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ async function handleAddGatewayCLI(options: AddGatewayOptions): Promise<void> {
discoveryUrl: options.discoveryUrl,
allowedAudience: options.allowedAudience,
allowedClients: options.allowedClients,
allowedScopes: options.allowedScopes,
agentClientId: options.agentClientId,
agentClientSecret: options.agentClientSecret,
agents: options.agents,
});

Expand Down Expand Up @@ -272,6 +275,9 @@ export function registerAdd(program: Command) {
.option('--discovery-url <url>', 'OIDC discovery URL (required for CUSTOM_JWT)')
.option('--allowed-audience <values>', 'Comma-separated allowed audience values (required for CUSTOM_JWT)')
.option('--allowed-clients <values>', 'Comma-separated allowed client IDs (required for CUSTOM_JWT)')
.option('--allowed-scopes <scopes>', 'Comma-separated allowed scopes (optional for CUSTOM_JWT)')
.option('--agent-client-id <id>', 'Agent OAuth client ID for Bearer token auth (CUSTOM_JWT)')
.option('--agent-client-secret <secret>', 'Agent OAuth client secret (CUSTOM_JWT)')
.option('--json', 'Output as JSON')
.action(async options => {
requireProject();
Expand Down
3 changes: 3 additions & 0 deletions src/cli/commands/add/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ export interface AddGatewayOptions {
discoveryUrl?: string;
allowedAudience?: string;
allowedClients?: string;
allowedScopes?: string;
agentClientId?: string;
agentClientSecret?: string;
agents?: string;
json?: boolean;
}
Expand Down
Loading
Loading