Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 63 additions & 16 deletions python/packages/claude/agent_framework_claude/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
PermissionMode,
SandboxSettings,
SdkBeta,
SdkPluginConfig,
SettingSource,
)
from claude_agent_sdk.types import ThinkingConfig


logger = logging.getLogger("agent_framework.claude")
Expand Down Expand Up @@ -118,9 +121,6 @@ class ClaudeAgentOptions(TypedDict, total=False):
fallback_model: str
"""Fallback model if primary fails."""

max_thinking_tokens: int
"""Maximum tokens for thinking blocks."""

allowed_tools: list[str]
"""Allowlist of tools. If set, Claude can ONLY use tools in this list."""

Expand Down Expand Up @@ -163,6 +163,18 @@ class ClaudeAgentOptions(TypedDict, total=False):
betas: list[SdkBeta]
"""Beta features to enable."""

plugins: list[SdkPluginConfig]
"""Plugin configurations for custom commands and capabilities."""

setting_sources: list[SettingSource]
"""Which Claude settings files to load ("user", "project", "local")."""

thinking: ThinkingConfig
"""Extended thinking configuration (adaptive, enabled, or disabled)."""

effort: Literal["low", "medium", "high", "max"]
"""Effort level for thinking depth."""


OptionsT = TypeVar(
"OptionsT",
Expand Down Expand Up @@ -213,7 +225,11 @@ def __init__(
description: str | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | str | Sequence[ToolTypes | Callable[..., Any] | str] | None = None,
tools: ToolTypes
| Callable[..., Any]
| str
| Sequence[ToolTypes | Callable[..., Any] | str]
| None = None,
default_options: OptionsT | MutableMapping[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
Expand Down Expand Up @@ -289,7 +305,11 @@ def __init__(

def _normalize_tools(
self,
tools: ToolTypes | Callable[..., Any] | str | Sequence[ToolTypes | Callable[..., Any] | str] | None,
tools: ToolTypes
| Callable[..., Any]
| str
| Sequence[ToolTypes | Callable[..., Any] | str]
| None,
) -> None:
"""Separate built-in tools (strings) from custom tools.

Expand Down Expand Up @@ -358,7 +378,9 @@ async def _ensure_session(self, session_id: str | None = None) -> None:
session_id: The session ID to use, or None for a new session.
"""
needs_new_client = (
not self._started or self._client is None or (session_id and session_id != self._current_session_id)
not self._started
or self._client is None
or (session_id and session_id != self._current_session_id)
)

if needs_new_client:
Expand All @@ -381,7 +403,9 @@ async def _ensure_session(self, session_id: str | None = None) -> None:
self._client = None
raise AgentException(f"Failed to start Claude SDK client: {ex}") from ex

def _prepare_client_options(self, resume_session_id: str | None = None) -> SDKOptions:
def _prepare_client_options(
self, resume_session_id: str | None = None
) -> SDKOptions:
"""Prepare SDK options for client initialization.

Args:
Expand Down Expand Up @@ -421,7 +445,9 @@ def _prepare_client_options(self, resume_session_id: str | None = None) -> SDKOp

# Prepare custom tools (FunctionTool instances)
custom_tools_server, custom_tool_names = (
self._prepare_tools(self._custom_tools) if self._custom_tools else (None, [])
self._prepare_tools(self._custom_tools)
if self._custom_tools
else (None, [])
)

# MCP servers - merge user-provided servers with custom tools server
Expand Down Expand Up @@ -468,9 +494,13 @@ def _prepare_tools(
if not sdk_tools:
return None, []

return create_sdk_mcp_server(name=TOOLS_MCP_SERVER_NAME, tools=sdk_tools), tool_names
return create_sdk_mcp_server(
name=TOOLS_MCP_SERVER_NAME, tools=sdk_tools
), tool_names

def _function_tool_to_sdk_mcp_tool(self, func_tool: FunctionTool) -> SdkMcpTool[Any]:
def _function_tool_to_sdk_mcp_tool(
self, func_tool: FunctionTool
) -> SdkMcpTool[Any]:
"""Convert a FunctionTool to an SDK MCP tool.

Args:
Expand All @@ -493,7 +523,9 @@ async def handler(args: dict[str, Any]) -> dict[str, Any]:
return {"content": [{"type": "text", "text": f"Error: {e}"}]}

# Get JSON schema from pydantic model
schema: dict[str, Any] = func_tool.input_model.model_json_schema() if func_tool.input_model else {}
schema: dict[str, Any] = (
func_tool.input_model.model_json_schema() if func_tool.input_model else {}
)
input_schema: dict[str, Any] = {
"type": "object",
"properties": schema.get("properties", {}),
Expand Down Expand Up @@ -554,7 +586,9 @@ def default_options(self) -> dict[str, Any]:
opts["instructions"] = system_prompt
return opts

def _finalize_response(self, updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]:
def _finalize_response(
self, updates: Sequence[AgentResponseUpdate]
) -> AgentResponse[Any]:
"""Build AgentResponse and propagate structured_output as value.

Args:
Expand Down Expand Up @@ -593,7 +627,10 @@ def run(
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
) -> (
Awaitable[AgentResponse[Any]]
| ResponseStream[AgentResponseUpdate, AgentResponse[Any]]
):
"""Run the agent with the given messages.

Args:
Expand Down Expand Up @@ -659,15 +696,23 @@ async def _get_stream(
if text:
yield AgentResponseUpdate(
role="assistant",
contents=[Content.from_text(text=text, raw_representation=message)],
contents=[
Content.from_text(
text=text, raw_representation=message
)
],
raw_representation=message,
)
elif delta_type == "thinking_delta":
thinking = delta.get("thinking", "")
if thinking:
yield AgentResponseUpdate(
role="assistant",
contents=[Content.from_text_reasoning(text=thinking, raw_representation=message)],
contents=[
Content.from_text_reasoning(
text=thinking, raw_representation=message
)
],
raw_representation=message,
)
elif isinstance(message, AssistantMessage):
Expand All @@ -684,7 +729,9 @@ async def _get_stream(
"server_error": "Claude API server error",
"unknown": "Unknown error from Claude API",
}
error_msg = error_messages.get(message.error, f"Claude API error: {message.error}")
error_msg = error_messages.get(
message.error, f"Claude API error: {message.error}"
)
# Extract any error details from content blocks
if message.content:
for block in message.content:
Expand Down