From bdfa3dd7c11604be1053e6dab482636f2760da96 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 3 Mar 2026 17:46:23 +0100 Subject: [PATCH 01/15] Fix Python pyright package scoping and typing remediation Implements issue #4407 by removing the root pyright include, adding package-level pyright includes, and resolving pyright/mypy typing issues across Python packages. Also cleans unnecessary casts and applies line-level, rule-specific ignores where external libraries are too dynamic. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- python/CODING_STANDARD.md | 15 + .../a2a/agent_framework_a2a/_agent.py | 33 +- python/packages/a2a/pyproject.toml | 1 + python/packages/ag-ui/pyproject.toml | 1 + .../agent_framework_anthropic/_chat_client.py | 46 +-- .../_context_provider.py | 22 +- .../packages/azure-ai-search/pyproject.toml | 1 + .../agent_framework_azure_ai/_chat_client.py | 110 ++++--- .../agent_framework_azure_ai/_client.py | 92 ++++-- .../_project_provider.py | 4 +- .../agent_framework_azure_ai/_shared.py | 28 +- python/packages/azure-ai/pyproject.toml | 1 + .../_history_provider.py | 16 +- python/packages/azure-cosmos/pyproject.toml | 1 + .../agent_framework_azurefunctions/_app.py | 91 ++++-- .../_context.py | 6 +- .../_serialization.py | 14 +- .../_workflow.py | 31 +- python/packages/azurefunctions/pyproject.toml | 1 + .../agent_framework_bedrock/_chat_client.py | 159 +++++++--- .../_embedding_client.py | 82 +++-- python/packages/bedrock/pyproject.toml | 1 + python/packages/chatkit/pyproject.toml | 1 + .../claude/agent_framework_claude/_agent.py | 26 +- python/packages/claude/pyproject.toml | 1 + .../agent_framework_copilotstudio/_agent.py | 20 +- python/packages/copilotstudio/pyproject.toml | 1 + .../packages/core/agent_framework/_agents.py | 56 ++-- .../packages/core/agent_framework/_clients.py | 7 +- .../core/agent_framework/_middleware.py | 56 ++-- .../core/agent_framework/_serialization.py | 20 +- .../core/agent_framework/_sessions.py | 11 +- .../core/agent_framework/_settings.py | 4 +- .../packages/core/agent_framework/_skills.py | 5 +- .../packages/core/agent_framework/_tools.py | 282 +++++++++++------- .../packages/core/agent_framework/_types.py | 251 ++++++++++------ .../_workflows/_agent_executor.py | 15 +- .../_workflows/_function_executor.py | 2 +- .../_workflows/_typing_utils.py | 26 +- .../azure/_assistants_client.py | 33 +- .../agent_framework/azure/_chat_client.py | 47 +-- .../azure/_embedding_client.py | 19 +- .../azure/_responses_client.py | 26 +- .../core/agent_framework/azure/_shared.py | 3 + .../agent_framework/declarative/__init__.pyi | 2 - .../core/agent_framework/observability.py | 124 +++++--- .../openai/_assistant_provider.py | 43 +-- .../openai/_assistants_client.py | 38 ++- .../agent_framework/openai/_chat_client.py | 51 +++- .../openai/_embedding_client.py | 23 +- .../openai/_responses_client.py | 14 +- .../core/agent_framework/openai/_shared.py | 8 +- python/packages/core/pyproject.toml | 1 + .../agent_framework_declarative/_loader.py | 5 +- .../_workflows/_declarative_base.py | 49 +-- .../_workflows/_declarative_builder.py | 7 +- .../_workflows/_executors_agents.py | 2 +- .../_workflows/_executors_basic.py | 87 ++++-- .../_workflows/_executors_tools.py | 32 +- .../_workflows/_powerfx_functions.py | 27 +- .../_workflows/_state.py | 11 +- .../devui/agent_framework_devui/__init__.py | 9 +- .../agent_framework_devui/_conversations.py | 32 +- .../agent_framework_devui/_deployment.py | 19 +- .../devui/agent_framework_devui/_discovery.py | 80 ++--- .../devui/agent_framework_devui/_executor.py | 183 ++++++++---- .../devui/agent_framework_devui/_mapper.py | 133 +++++---- .../_openai/_executor.py | 60 ++-- .../devui/agent_framework_devui/_server.py | 121 ++++++-- .../devui/agent_framework_devui/_session.py | 65 ++-- .../devui/agent_framework_devui/_utils.py | 46 +-- .../models/_discovery_models.py | 6 +- .../agent_framework_durabletask/_entities.py | 4 +- .../_response_utils.py | 4 +- python/packages/durabletask/pyproject.toml | 1 + .../_foundry_local_client.py | 8 +- python/packages/foundry_local/pyproject.toml | 1 + .../agent_framework_github_copilot/_agent.py | 24 +- python/packages/github_copilot/pyproject.toml | 1 + .../lab/gaia/agent_framework_lab_gaia/gaia.py | 116 ++++--- python/packages/lab/pyproject.toml | 1 + .../_message_utils.py | 2 +- .../agent_framework_lab_tau2/_tau2_utils.py | 77 +++-- .../tau2/agent_framework_lab_tau2/runner.py | 13 +- .../agent_framework_mem0/_context_provider.py | 2 +- python/packages/mem0/pyproject.toml | 1 + .../agent_framework_ollama/_chat_client.py | 2 +- .../_embedding_client.py | 17 +- python/packages/ollama/pyproject.toml | 1 + .../_handoff.py | 111 +++---- python/packages/orchestrations/pyproject.toml | 1 + .../agent_framework_purview/_client.py | 68 +++-- .../agent_framework_purview/_middleware.py | 4 +- .../agent_framework_purview/_models.py | 78 +++-- .../agent_framework_purview/_processor.py | 9 +- python/packages/purview/pyproject.toml | 1 + .../_context_provider.py | 23 +- .../_history_provider.py | 2 +- python/packages/redis/pyproject.toml | 1 + python/pyproject.toml | 2 +- 100 files changed, 2278 insertions(+), 1243 deletions(-) diff --git a/python/CODING_STANDARD.md b/python/CODING_STANDARD.md index 21d87e5b8c..3eb84e7ba0 100644 --- a/python/CODING_STANDARD.md +++ b/python/CODING_STANDARD.md @@ -79,6 +79,21 @@ def process_config(config: MutableMapping[str, Any]) -> None: ... ``` +### Pyright Ignore and Cast Policy + +Use typing as a helper first and suppressions as a last resort: + +- **Prefer explicit typing before suppression**: Start with clearer type annotations, helper types, overloads, + protocols, or refactoring dynamic code into typed helpers. +- **Line-level pyright ignores only**: If suppression is still required, use a line-level rule-specific ignore + (`# pyright: ignore[reportGeneralTypeIssues]`), never file-level or global suppression for this workflow. +- **Private usage boundary**: Accessing private members across `agent_framework*` packages can be acceptable for this + codebase, but private member usage for non-Agent Framework dependencies should remain flagged. +- **Avoid redundant casts**: Do not add `cast(...)` if the type already matches; casts should be reserved for + unavoidable narrowing where the runtime contract is known. +- **Uncertainty handoff**: If you are still unsure after best-effort typing, leave a targeted TODO note + (`TODO(): ...`) that explains what reviewer guidance is needed. + ## Function Parameter Guidelines To make the code easier to use and maintain: diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 2eec8a41db..e6b0f49a14 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -7,7 +7,7 @@ import re import uuid from collections.abc import AsyncIterable, Awaitable, Sequence -from typing import Any, Final, Literal, overload +from typing import Any, Final, Literal, TypeAlias, overload import httpx from a2a.client import Client, ClientConfig, ClientFactory, minimal_agent_card @@ -21,7 +21,9 @@ Task, TaskIdParams, TaskQueryParams, + TaskArtifactUpdateEvent, TaskState, + TaskStatusUpdateEvent, TextPart, TransportProtocol, ) @@ -70,6 +72,9 @@ class A2AContinuationToken(ContinuationToken): TaskState.auth_required, ] +A2AClientEvent: TypeAlias = tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] +A2AStreamItem: TypeAlias = A2AMessage | A2AClientEvent + def _get_uri_data(uri: str) -> str: match = URI_PATTERN.match(uri) @@ -260,7 +265,9 @@ def run( When stream=True: A ResponseStream of AgentResponseUpdate items. """ if continuation_token is not None: - a2a_stream: AsyncIterable[Any] = self.client.resubscribe(TaskIdParams(id=continuation_token["task_id"])) + a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe( + TaskIdParams(id=continuation_token["task_id"]) + ) else: normalized_messages = normalize_messages(messages) a2a_message = self._prepare_message_for_a2a(normalized_messages[-1]) @@ -276,7 +283,7 @@ def run( async def _map_a2a_stream( self, - a2a_stream: AsyncIterable[Any], + a2a_stream: AsyncIterable[A2AStreamItem], *, background: bool = False, ) -> AsyncIterable[AgentResponseUpdate]: @@ -300,14 +307,10 @@ async def _map_a2a_stream( response_id=str(getattr(item, "message_id", uuid.uuid4())), raw_representation=item, ) - elif isinstance(item, tuple) and len(item) == 2: # ClientEvent = (Task, UpdateEvent) - task, _update_event = item - if isinstance(task, Task): - for update in self._updates_from_task(task, background=background): - yield update else: - msg = f"Only Message and Task responses are supported from A2A agents. Received: {type(item)}" - raise NotImplementedError(msg) + task, _update_event = item + for update in self._updates_from_task(task, background=background): + yield update # ------------------------------------------------------------------ # Task helpers @@ -396,6 +399,8 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: for content in message.contents: match content.type: case "text": + if content.text is None: + raise ValueError("Text content requires a non-null text value") parts.append( A2APart( root=TextPart( @@ -414,6 +419,8 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: ) ) case "uri": + if content.uri is None: + raise ValueError("URI content requires a non-null uri value") parts.append( A2APart( root=FilePart( @@ -426,11 +433,13 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: ) ) case "data": + if content.uri is None: + raise ValueError("Data content requires a non-null uri value") parts.append( A2APart( root=FilePart( file=FileWithBytes( - bytes=_get_uri_data(content.uri), # type: ignore[arg-type] + bytes=_get_uri_data(content.uri), mime_type=content.media_type, ), metadata=content.additional_properties, @@ -438,6 +447,8 @@ def _prepare_message_for_a2a(self, message: Message) -> A2AMessage: ) ) case "hosted_file": + if content.file_id is None: + raise ValueError("Hosted file content requires a non-null file_id value") parts.append( A2APart( root=FilePart( diff --git a/python/packages/a2a/pyproject.toml b/python/packages/a2a/pyproject.toml index b537b0a30d..43e0df726b 100644 --- a/python/packages/a2a/pyproject.toml +++ b/python/packages/a2a/pyproject.toml @@ -61,6 +61,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_a2a"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/ag-ui/pyproject.toml b/python/packages/ag-ui/pyproject.toml index 74d9fcbd2e..cc5c081c44 100644 --- a/python/packages/ag-ui/pyproject.toml +++ b/python/packages/ag-ui/pyproject.toml @@ -64,6 +64,7 @@ warn_unused_configs = true disallow_untyped_defs = false [tool.pyright] +include = ["agent_framework_ag_ui"] exclude = ["tests", "tests/ag_ui", "examples"] typeCheckingMode = "basic" diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index 8ec2943181..cbe7c51c28 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -5,7 +5,7 @@ import logging import sys from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence -from typing import Any, ClassVar, Final, Generic, Literal, TypedDict +from typing import Any, ClassVar, Final, Generic, Literal, TypedDict, cast from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, @@ -302,15 +302,18 @@ class MyOptions(AnthropicChatOptions, total=False): env_file_encoding=env_file_encoding, ) + api_key_secret = anthropic_settings.get("api_key") + model_id_setting = anthropic_settings.get("chat_model_id") + if anthropic_client is None: - if not anthropic_settings["api_key"]: + if api_key_secret is None: raise ValueError( "Anthropic API key is required. Set via 'api_key' parameter " "or 'ANTHROPIC_API_KEY' environment variable." ) anthropic_client = AsyncAnthropic( - api_key=anthropic_settings["api_key"].get_secret_value(), + api_key=api_key_secret.get_secret_value(), default_headers={"User-Agent": AGENT_FRAMEWORK_USER_AGENT}, ) @@ -324,7 +327,7 @@ class MyOptions(AnthropicChatOptions, total=False): # Initialize instance variables self.anthropic_client = anthropic_client self.additional_beta_flags = additional_beta_flags or [] - self.model_id = anthropic_settings["chat_model_id"] + self.model_id = model_id_setting # streaming requires tracking the last function call ID, name, and content type self._last_call_id_name: tuple[str, str] | None = None self._last_call_content_type: str | None = None @@ -785,19 +788,28 @@ def _prepare_tools_for_anthropic(self, options: Mapping[str, Any]) -> dict[str, "description": tool.description, "input_schema": tool.parameters(), }) - elif isinstance(tool, MutableMapping) and tool.get("type") == "mcp": - # MCP servers must be routed to separate mcp_servers parameter - server_def: dict[str, Any] = { - "type": "url", - "name": tool.get("server_label", ""), - "url": tool.get("server_url", ""), - } - if allowed_tools := tool.get("allowed_tools"): - server_def["tool_configuration"] = {"allowed_tools": list(allowed_tools)} - headers = tool.get("headers") - if isinstance(headers, dict) and (auth := headers.get("authorization")): - server_def["authorization_token"] = auth - mcp_server_list.append(server_def) + elif isinstance(tool, MutableMapping): + tool_data = cast(MutableMapping[str, Any], tool) + if tool_data.get("type") == "mcp": + # MCP servers must be routed to separate mcp_servers parameter + server_def: dict[str, Any] = { + "type": "url", + "name": tool_data.get("server_label", ""), + "url": tool_data.get("server_url", ""), + } + allowed_tools = tool_data.get("allowed_tools") + if isinstance(allowed_tools, Sequence) and not isinstance(allowed_tools, str): + server_def["tool_configuration"] = { + "allowed_tools": [str(item) for item in allowed_tools] # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] + } + headers = tool_data.get("headers") + if isinstance(headers, Mapping): + if isinstance(auth := headers.get("authorization"), str): # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + server_def["authorization_token"] = auth + mcp_server_list.append(server_def) + else: + # Pass through all other tools (dicts, SDK types) unchanged + tool_list.append(tool) else: # Pass through all other tools (dicts, SDK types) unchanged tool_list.append(tool) diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py index ff245817b7..3e6f48a572 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -11,7 +11,7 @@ import logging import sys from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Annotation, Content, Message, SupportsGetEmbeddings from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext @@ -456,9 +456,9 @@ async def _semantic_search(self, query: str) -> list[Message]: elif self.embedding_function: if isinstance(self.embedding_function, SupportsGetEmbeddings): embeddings = await self.embedding_function.get_embeddings([query]) # type: ignore[reportUnknownVariableType] - query_vector: list[float] = embeddings[0].vector # type: ignore[reportUnknownVariableType] + query_vector = self._normalize_query_vector(embeddings[0].vector) # type: ignore[reportUnknownVariableType] else: - query_vector = await self.embedding_function(query) + query_vector = self._normalize_query_vector(await self.embedding_function(query)) vector_queries = [VectorizedQuery(vector=query_vector, k=vector_k, fields=self.vector_field_name)] search_params: dict[str, Any] = {"search_text": query, "top": self.top_k} @@ -603,6 +603,20 @@ async def _agentic_search(self, messages: list[Message]) -> list[Message]: return self._parse_messages_from_kb_response(retrieval_result) + @staticmethod + def _normalize_query_vector(vector: object) -> list[float]: + """Normalize query vector values to floats for Azure Search vector query.""" + if not isinstance(vector, list): + raise TypeError("embedding_function must return list[float]") + + vector_values = cast(list[object], vector) + normalized: list[float] = [] + for value in vector_values: + if not isinstance(value, int | float): + raise TypeError("embedding_function must return list[float]") + normalized.append(float(value)) + return normalized + @staticmethod def _prepare_messages_for_kb_search(messages: list[Message]) -> list[KnowledgeBaseMessage]: """Convert framework Messages to KnowledgeBaseMessages for agentic retrieval. @@ -632,6 +646,8 @@ def _prepare_messages_for_kb_search(messages: list[Message]) -> list[KnowledgeBa image=KnowledgeBaseMessageImageContentImage(url=content.uri), ) ) + case _: + pass elif msg.text: kb_content.append(KnowledgeBaseMessageTextContent(text=msg.text)) if kb_content: diff --git a/python/packages/azure-ai-search/pyproject.toml b/python/packages/azure-ai-search/pyproject.toml index a4bdc5e978..6af0688f3f 100644 --- a/python/packages/azure-ai-search/pyproject.toml +++ b/python/packages/azure-ai-search/pyproject.toml @@ -62,6 +62,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_azure_ai_search"] exclude = ['tests'] [tool.mypy] diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 2c0498b1e4..ffad93eaa5 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -9,7 +9,7 @@ import re import sys from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence -from typing import Any, ClassVar, Generic, TypedDict +from typing import Any, ClassVar, Generic, TypedDict, cast from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, @@ -997,39 +997,40 @@ async def _process_stream( role="assistant", ) case RunStepDeltaChunk(): # type: ignore - if ( - event_data.delta.step_details is not None - and event_data.delta.step_details.type == "tool_calls" - and event_data.delta.step_details.tool_calls is not None # type: ignore[attr-defined] - ): - for tool_call in event_data.delta.step_details.tool_calls: # type: ignore[attr-defined] - if tool_call.type == "code_interpreter" and isinstance( - tool_call.code_interpreter, - RunStepDeltaCodeInterpreterDetailItemObject, - ): - code_contents: list[Content] = [] - if tool_call.code_interpreter.input is not None: - logger.debug(f"Code Interpreter Input: {tool_call.code_interpreter.input}") - if tool_call.code_interpreter.outputs is not None: - for output in tool_call.code_interpreter.outputs: - if isinstance(output, RunStepDeltaCodeInterpreterLogOutput) and output.logs: - code_contents.append(Content.from_text(text=output.logs)) - if ( - isinstance(output, RunStepDeltaCodeInterpreterImageOutput) - and output.image is not None - and output.image.file_id is not None - ): - code_contents.append( - Content.from_hosted_file(file_id=output.image.file_id) - ) - yield ChatResponseUpdate( - role="assistant", - contents=code_contents, - conversation_id=thread_id, - message_id=response_id, - raw_representation=tool_call.code_interpreter, - response_id=response_id, - ) + step_details: Any = event_data.delta.step_details + if step_details is not None and getattr(step_details, "type", None) == "tool_calls": + tool_calls = getattr(step_details, "tool_calls", None) + if isinstance(tool_calls, list): + for tool_call in cast(list[object], tool_calls): + tool_type = getattr(tool_call, "type", None) + code_interpreter = getattr(tool_call, "code_interpreter", None) + if tool_type == "code_interpreter" and isinstance( + code_interpreter, + RunStepDeltaCodeInterpreterDetailItemObject, + ): + code_contents: list[Content] = [] + if code_interpreter.input is not None: + logger.debug(f"Code Interpreter Input: {code_interpreter.input}") + if code_interpreter.outputs is not None: + for output in code_interpreter.outputs: + if isinstance(output, RunStepDeltaCodeInterpreterLogOutput) and output.logs: + code_contents.append(Content.from_text(text=output.logs)) + if ( + isinstance(output, RunStepDeltaCodeInterpreterImageOutput) + and output.image is not None + and output.image.file_id is not None + ): + code_contents.append( + Content.from_hosted_file(file_id=output.image.file_id) + ) + yield ChatResponseUpdate( + role="assistant", + contents=code_contents, + conversation_id=thread_id, + message_id=response_id, + raw_representation=code_interpreter, + response_id=response_id, + ) case _: # ThreadMessage or string # possible event_types for ThreadMessage: # AgentStreamEvent.THREAD_MESSAGE_CREATED @@ -1056,17 +1057,15 @@ def _capture_azure_search_tool_calls( ) -> None: """Capture Azure AI Search tool call data from completed steps.""" try: - if ( - hasattr(step_data, "step_details") - and hasattr(step_data.step_details, "tool_calls") - and step_data.step_details.tool_calls - ): - for tool_call in step_data.step_details.tool_calls: - if hasattr(tool_call, "type") and tool_call.type == "azure_ai_search": + step_details: Any = getattr(step_data, "step_details", None) + tool_calls = getattr(step_details, "tool_calls", None) if step_details is not None else None + if isinstance(tool_calls, list): + for tool_call in cast(list[object], tool_calls): + if getattr(tool_call, "type", None) == "azure_ai_search": # Store the complete tool call as a dictionary tool_call_dict = { "id": getattr(tool_call, "id", None), - "type": tool_call.type, + "type": getattr(tool_call, "type", None), "azure_ai_search": getattr(tool_call, "azure_ai_search", None), } azure_search_tool_calls.append(tool_call_dict) @@ -1226,13 +1225,15 @@ def _prepare_tool_choice_mode( return AgentsToolChoiceOptionMode.NONE if tool_choice == "auto": return AgentsToolChoiceOptionMode.AUTO - if isinstance(tool_choice, Mapping) and tool_choice.get("mode") == "required": - req_fn = tool_choice.get("required_function_name") - if req_fn: - return AgentsNamedToolChoice( - type=AgentsNamedToolChoiceType.FUNCTION, - function=FunctionName(name=str(req_fn)), - ) + if isinstance(tool_choice, Mapping): + tool_choice_mapping = cast(Mapping[str, Any], tool_choice) + if tool_choice_mapping.get("mode") == "required": + req_fn = tool_choice_mapping.get("required_function_name") + if req_fn: + return AgentsNamedToolChoice( + type=AgentsNamedToolChoiceType.FUNCTION, + function=FunctionName(name=str(req_fn)), + ) return None async def _prepare_tool_definitions_and_resources( @@ -1369,15 +1370,12 @@ async def _prepare_tools_for_azure_ai( # SDK Tool wrappers (McpTool, FileSearchTool, BingGroundingTool, etc.) tool_definitions.extend(tool.definitions) # Handle tool resources (MCP resources handled separately by _prepare_mcp_resources) - if ( - run_options is not None - and hasattr(tool, "resources") - and tool.resources - and "mcp" not in tool.resources - ): + resources = getattr(tool, "resources", None) + if run_options is not None and isinstance(resources, Mapping) and resources and "mcp" not in resources: if "tool_resources" not in run_options: run_options["tool_resources"] = {} - run_options["tool_resources"].update(tool.resources) + tool_resources = cast(MutableMapping[str, Any], run_options["tool_resources"]) + tool_resources.update(dict(cast(Mapping[str, Any], resources))) else: # Pass through ToolDefinition, dict, and other types unchanged tool_definitions.append(tool) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 61c4a09e94..4b31c99b61 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -2,6 +2,7 @@ from __future__ import annotations +import importlib import json import logging import re @@ -26,6 +27,8 @@ Message, MiddlewareTypes, ResponseStream, + Role, + RoleLiteral, TextSpanRegion, ) from agent_framework._settings import load_settings @@ -70,6 +73,15 @@ logger = logging.getLogger("agent_framework.azure") +AzureMonitorConfigurator = Callable[..., Any] + + +def _normalize_chat_role(role: Role | str | None) -> RoleLiteral | Role | None: + if role in {"system", "user", "assistant", "tool"}: + return cast(RoleLiteral, role) + return None + + class AzureAIProjectAgentOptions(OpenAIResponsesOptions, total=False): """Azure AI Project Agent options.""" @@ -304,13 +316,18 @@ async def configure_azure_monitor( # Import Azure Monitor with proper error handling try: - from azure.monitor.opentelemetry import configure_azure_monitor + monitor_module = importlib.import_module("azure.monitor.opentelemetry") except ImportError as exc: raise ImportError( "azure-monitor-opentelemetry is required for Azure Monitor integration. " "Install it with: pip install azure-monitor-opentelemetry" ) from exc + configure_azure_monitor_attr = getattr(monitor_module, "configure_azure_monitor", None) + if not callable(configure_azure_monitor_attr): + raise ImportError("azure-monitor-opentelemetry does not expose configure_azure_monitor") + configure_azure_monitor: AzureMonitorConfigurator = configure_azure_monitor_attr + from agent_framework.observability import create_metric_views, create_resource, enable_instrumentation # Create resource if not provided in kwargs @@ -433,31 +450,41 @@ def _extract_tool_names(self, tools: Any) -> set[str]: """Extract comparable tool names from runtime tool payloads.""" if not isinstance(tools, Sequence) or isinstance(tools, str | bytes): return set() - return {self._get_tool_name(tool) for tool in tools} + tool_names: set[str] = set() + for tool_item in cast(Sequence[object], tools): + tool_names.add(self._get_tool_name(tool_item)) + return tool_names def _get_tool_name(self, tool: Any) -> str: """Get a stable name for a tool for runtime comparison.""" if isinstance(tool, FunctionTool): return tool.name + if isinstance(tool, Mapping): - tool_type = tool.get("type") + tool_mapping = cast(Mapping[str, Any], tool) + tool_type = tool_mapping.get("type") if tool_type == "function": - if isinstance(function_data := tool.get("function"), Mapping) and function_data.get("name"): - return str(function_data["name"]) - if tool.get("name"): - return str(tool["name"]) - if tool.get("name"): - return str(tool["name"]) - if tool.get("server_label"): - return f"mcp:{tool['server_label']}" + function_data = tool_mapping.get("function") + if isinstance(function_data, Mapping): + function_mapping = cast(Mapping[str, Any], function_data) + if function_name := function_mapping.get("name"): + return str(function_name) + if tool_name := tool_mapping.get("name"): + return str(tool_name) + if tool_name := tool_mapping.get("name"): + return str(tool_name) + if server_label := tool_mapping.get("server_label"): + return f"mcp:{server_label}" if tool_type: return str(tool_type) - if getattr(tool, "name", None): - return str(tool.name) - if getattr(tool, "server_label", None): - return f"mcp:{tool.server_label}" - if getattr(tool, "type", None): - return str(tool.type) + return type(cast(Any, tool)).__name__ + + if name_value := getattr(tool, "name", None): + return str(name_value) + if server_label_value := getattr(tool, "server_label", None): + return f"mcp:{server_label_value}" + if tool_type_value := getattr(tool, "type", None): + return str(tool_type_value) return type(tool).__name__ def _get_structured_output_signature(self, chat_options: Mapping[str, Any] | None) -> str | None: @@ -545,14 +572,14 @@ async def _prepare_options( return run_options @override - def _check_model_presence(self, run_options: dict[str, Any]) -> None: + def _check_model_presence(self, options: dict[str, Any]) -> None: # Skip model check for application endpoints - model is pre-configured on server if self._is_application_endpoint: return - if not run_options.get("model"): + if not options.get("model"): if not self.model_id: raise ValueError("model_deployment_name must be a non-empty string") - run_options["model"] = self.model_id + options["model"] = self.model_id def _transform_input_for_azure_ai(self, input_items: list[dict[str, Any]]) -> list[dict[str, Any]]: """Transform input items to match Azure AI Projects expected schema. @@ -576,10 +603,10 @@ def _transform_input_for_azure_ai(self, input_items: list[dict[str, Any]]) -> li # Add 'annotations' only to output_text content items (assistant messages) # User messages (input_text) do NOT support annotations in Azure AI if "content" in new_item and isinstance(new_item["content"], list): - new_content: list[dict[str, Any] | Any] = [] - for content_item in new_item["content"]: - if isinstance(content_item, dict): - new_content_item: dict[str, Any] = dict(content_item) + new_content: list[Any] = [] + for content_item in cast(list[object], new_item["content"]): + if isinstance(content_item, Mapping): + new_content_item: dict[str, Any] = dict(cast(Mapping[str, Any], content_item)) # Only add annotations to output_text (assistant content) if new_content_item.get("type") == "output_text" and "annotations" not in new_content_item: new_content_item["annotations"] = [] @@ -721,9 +748,18 @@ def _extract_azure_search_urls(self, output_items: Any) -> list[str]: # Streaming "added" events send output as an empty list; skip. continue if output is not None: - urls = output.get("get_urls") if isinstance(output, dict) else output.get_urls - if urls and isinstance(urls, list): - get_urls.extend(urls) + urls: Any + if isinstance(output, Mapping): + output_mapping = cast(Mapping[str, Any], output) + urls = output_mapping.get("get_urls") + else: + urls = getattr(output, "get_urls", None) + if isinstance(urls, list): + string_urls: list[str] = [] + for url_item in cast(list[object], urls): + if isinstance(url_item, str): + string_urls.append(url_item) + get_urls.extend(string_urls) return get_urls def _get_search_doc_url(self, citation_title: str | None, get_urls: list[str]) -> str | None: @@ -878,7 +914,7 @@ def _enrich_update(update: ChatResponseUpdate) -> ChatResponseUpdate: contents=contents_list, conversation_id=update.conversation_id, response_id=update.response_id, - role=update.role, + role=_normalize_chat_role(update.role), model_id=update.model_id, continuation_token=update.continuation_token, additional_properties=update.additional_properties, diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py index d6b922db91..8274ab473f 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py @@ -223,8 +223,8 @@ async def create_agent( for tool in normalized_tools: if isinstance(tool, MCPTool): mcp_tools.append(tool) - elif isinstance(tool, (FunctionTool, MutableMapping)): - non_mcp_tools.append(tool) + elif isinstance(tool, FunctionTool) or isinstance(tool, MutableMapping): + non_mcp_tools.append(tool) # type: ignore[reportUnknownArgumentType] # Connect MCP tools and discover their functions BEFORE creating the agent # This is required because Azure AI Responses API doesn't accept tools at request time diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py index 6f7d39c3be..4630280bb5 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py @@ -102,8 +102,9 @@ def _extract_project_connection_id(additional_properties: dict[str, Any] | None) # Check for connection.name structure (declarative/YAML usage) if "connection" in additional_properties: conn = additional_properties["connection"] - if isinstance(conn, dict): - name = conn.get("name") + if isinstance(conn, Mapping): + conn_mapping = cast(Mapping[str, Any], conn) + name = conn_mapping.get("name") if isinstance(name, str): return name @@ -191,7 +192,9 @@ def to_azure_ai_agent_tools( ): if "tool_resources" not in run_options: run_options["tool_resources"] = {} - run_options["tool_resources"].update(tool.resources) + tool_resources = cast(MutableMapping[str, Any], run_options["tool_resources"]) + if isinstance(tool.resources, Mapping): + tool_resources.update(dict(cast(Mapping[str, Any], tool.resources))) elif isinstance(tool, (dict, MutableMapping)): # Handle dict-based tools - pass through directly tool_dict = tool if isinstance(tool, dict) else dict(tool) @@ -422,9 +425,16 @@ def to_azure_ai_tools( elif isinstance(tool, Tool): # Pass through SDK Tool types directly (CodeInterpreterTool, FileSearchTool, etc.) azure_tools.append(tool) + elif isinstance(tool, MutableMapping): + # Convert mutable mappings into plain dicts for stable typing. + tool_dict: dict[str, Any] = dict(tool) + if tool_dict.get("type") == "mcp": + azure_tools.append(_prepare_mcp_tool_dict_for_azure_ai(tool_dict)) + else: + azure_tools.append(tool_dict) else: - # Pass through dict-based tools directly - azure_tools.append(dict(tool) if isinstance(tool, MutableMapping) else tool) # type: ignore[arg-type] + # Pass through any other supported tool objects unchanged. + azure_tools.append(tool) return azure_tools @@ -446,7 +456,13 @@ def _prepare_mcp_tool_dict_for_azure_ai(tool_dict: dict[str, Any]) -> MCPTool: mcp["server_description"] = description # Check for project_connection_id - if project_connection_id := tool_dict.get("project_connection_id"): + additional_properties = tool_dict.get("additional_properties") + extracted_project_connection_id = ( + _extract_project_connection_id(dict(cast(Mapping[str, Any], additional_properties))) + if isinstance(additional_properties, Mapping) + else None + ) + if project_connection_id := tool_dict.get("project_connection_id") or extracted_project_connection_id: mcp["project_connection_id"] = project_connection_id elif headers := tool_dict.get("headers"): mcp["headers"] = headers diff --git a/python/packages/azure-ai/pyproject.toml b/python/packages/azure-ai/pyproject.toml index bdc898af8c..9dca3ea0f0 100644 --- a/python/packages/azure-ai/pyproject.toml +++ b/python/packages/azure-ai/pyproject.toml @@ -61,6 +61,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_azure_ai"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py index 5b802bde9f..3b27823332 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py @@ -8,7 +8,7 @@ import time import uuid from collections.abc import Sequence -from typing import Any, ClassVar, TypedDict +from typing import Any, ClassVar, TypeGuard, TypedDict, cast from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message from agent_framework._sessions import BaseHistoryProvider @@ -20,6 +20,18 @@ logger = logging.getLogger(__name__) +def _is_str_key_dict(value: object) -> TypeGuard[dict[str, Any]]: + if not isinstance(value, dict): + return False + + candidate_dict = cast(dict[object, Any], value) + for key_obj in candidate_dict: + if not isinstance(key_obj, str): + return False + + return True + + class AzureCosmosHistorySettings(TypedDict, total=False): """Settings for CosmosHistoryProvider resolved from args and environment.""" @@ -146,7 +158,7 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess messages: list[Message] = [] async for item in items: message_payload = item.get("message") - if isinstance(message_payload, dict): + if _is_str_key_dict(message_payload): messages.append(Message.from_dict(message_payload)) return messages diff --git a/python/packages/azure-cosmos/pyproject.toml b/python/packages/azure-cosmos/pyproject.toml index d053465fb1..c05d3cd939 100644 --- a/python/packages/azure-cosmos/pyproject.toml +++ b/python/packages/azure-cosmos/pyproject.toml @@ -61,6 +61,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_azure_cosmos"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index c7d8552b24..3cb8d688af 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -274,10 +274,27 @@ def executor_activity(inputData: str) -> str: """ from agent_framework._workflows._state import State - data = json.loads(inputData) - message_data = data["message"] - shared_state_snapshot = data.get("shared_state_snapshot", {}) - source_executor_ids = data.get("source_executor_ids", [SOURCE_ORCHESTRATOR]) + data_obj = json.loads(inputData) + if not isinstance(data_obj, dict): + raise ValueError("Activity inputData must decode to a JSON object") + data = cast(dict[str, Any], data_obj) + + message_data = data.get("message") + shared_state_raw = data.get("shared_state_snapshot", {}) + source_executor_ids_raw = data.get("source_executor_ids", [SOURCE_ORCHESTRATOR]) + + shared_state_snapshot: dict[str, Any] + if isinstance(shared_state_raw, dict): + shared_state_snapshot = cast(dict[str, Any], shared_state_raw) + else: + shared_state_snapshot = {} + + source_executor_ids: list[str] + if isinstance(source_executor_ids_raw, list): + source_executor_ids_values = cast(list[object], source_executor_ids_raw) + source_executor_ids = [str(source_executor_id) for source_executor_id in source_executor_ids_values] + else: + source_executor_ids = [SOURCE_ORCHESTRATOR] if not self.workflow: raise RuntimeError("Workflow not initialized in AgentFunctionApp") @@ -299,15 +316,20 @@ async def run() -> dict[str, Any]: shared_state = State() # Deserialize shared state values to reconstruct dataclasses/Pydantic models - deserialized_state = {k: deserialize_value(v) for k, v in (shared_state_snapshot or {}).items()} - original_snapshot = dict(deserialized_state) + deserialized_state: dict[str, Any] = { + str(k): deserialize_value(v) for k, v in shared_state_snapshot.items() + } + original_snapshot: dict[str, Any] = dict(deserialized_state) shared_state.import_state(deserialized_state) if is_hitl_response: # Handle HITL response by calling the executor's @response_handler + if not isinstance(message_data, dict): + raise ValueError("HITL message payload must be a JSON object") + await execute_hitl_response_handler( executor=executor, - hitl_message=message_data, + hitl_message=cast(dict[str, Any], message_data), shared_state=shared_state, runner_context=runner_context, ) @@ -323,11 +345,11 @@ async def run() -> dict[str, Any]: # Commit pending state changes and export shared_state.commit() current_state = shared_state.export_state() - original_keys = set(original_snapshot.keys()) - current_keys = set(current_state.keys()) + original_keys: set[str] = set(original_snapshot.keys()) + current_keys: set[str] = set(current_state.keys()) # Deleted = was in original, not in current - deletes = original_keys - current_keys + deletes: set[str] = original_keys - current_keys # Updates = keys in current that are new or have different values updates = { @@ -348,7 +370,7 @@ async def run() -> dict[str, Any]: pending_request_info_events = await runner_context.get_pending_request_info_events() # Serialize pending request info events for orchestrator - serialized_pending_requests = [] + serialized_pending_requests: list[dict[str, Any]] = [] for _request_id, event in pending_request_info_events.items(): serialized_pending_requests.append({ "request_id": event.request_id, @@ -361,7 +383,7 @@ async def run() -> dict[str, Any]: }) # Serialize messages for JSON compatibility - serialized_sent_messages = [] + serialized_sent_messages: list[dict[str, Any]] = [] for _source_id, msg_list in sent_messages.items(): for msg in msg_list: serialized_sent_messages.append({ @@ -441,6 +463,9 @@ async def get_workflow_status( ) -> func.HttpResponse: """HTTP endpoint to get workflow status.""" instance_id = req.route_params.get("instanceId") + if not instance_id: + return self._build_error_response("Instance ID is required", status_code=400) + status = await client.get_status(instance_id) if not status: @@ -457,20 +482,29 @@ async def get_workflow_status( } # Add pending HITL requests info if available - custom_status = status.custom_status or {} - if isinstance(custom_status, dict) and custom_status.get("pending_requests"): - base_url = self._build_base_url(req.url) - pending_requests = [] - for req_id, req_data in custom_status["pending_requests"].items(): - pending_requests.append({ - "requestId": req_id, - "sourceExecutor": req_data.get("source_executor_id"), - "requestData": req_data.get("data"), - "requestType": req_data.get("request_type"), - "responseType": req_data.get("response_type"), - "respondUrl": f"{base_url}/api/workflow/respond/{instance_id}/{req_id}", - }) - response["pendingHumanInputRequests"] = pending_requests + custom_status = status.custom_status + if isinstance(custom_status, dict): + custom_status_typed = cast(dict[str, Any], custom_status) + pending_requests_raw = custom_status_typed.get("pending_requests") + if isinstance(pending_requests_raw, dict): + base_url = self._build_base_url(req.url) + pending_requests: list[dict[str, Any]] = [] + pending_requests_dict = cast(dict[str, Any], pending_requests_raw) + for req_id_raw, req_data_raw in pending_requests_dict.items(): + if not isinstance(req_data_raw, dict): + continue + + req_id = str(req_id_raw) + req_data = cast(dict[str, Any], req_data_raw) + pending_requests.append({ + "requestId": req_id, + "sourceExecutor": req_data.get("source_executor_id"), + "requestData": req_data.get("data"), + "requestType": req_data.get("request_type"), + "responseType": req_data.get("response_type"), + "respondUrl": f"{base_url}/api/workflow/respond/{instance_id}/{req_id}", + }) + response["pendingHumanInputRequests"] = pending_requests return func.HttpResponse( json.dumps(response, default=str), @@ -515,6 +549,11 @@ async def send_hitl_response(req: func.HttpRequest, client: df.DurableOrchestrat mimetype="application/json", ) + # Ensure route handlers are registered (prevents unused function warnings) + _ = start_workflow_orchestration + _ = get_workflow_status + _ = send_hitl_response + def _build_status_url(self, request_url: str, instance_id: str) -> str: """Build the status URL for a workflow instance.""" base_url = self._build_base_url(request_url) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py index a45dcf81fc..561e05bee4 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py @@ -44,10 +44,10 @@ def __init__(self) -> None: # region Messaging - async def send_message(self, message: WorkflowMessage) -> None: + async def send_message(self, WorkflowMessage: WorkflowMessage) -> None: """Capture a message sent by an executor.""" - self._messages.setdefault(message.source_id, []) - self._messages[message.source_id].append(message) + self._messages.setdefault(WorkflowMessage.source_id, []) + self._messages[WorkflowMessage.source_id].append(WorkflowMessage) async def drain_messages(self) -> dict[str, list[WorkflowMessage]]: """Drain and return all captured messages.""" diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py index 94263fa4ef..ad62ba7a06 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py @@ -13,7 +13,7 @@ - serialize_value / deserialize_value: convenience aliases for encode/decode - reconstruct_to_type: for HITL responses where external data (without type markers) needs to be reconstructed to a known type -- _resolve_type: resolves 'module:class' type keys to Python types +- resolve_type: resolves 'module:class' type keys to Python types """ from __future__ import annotations @@ -21,14 +21,14 @@ import importlib import logging from dataclasses import is_dataclass -from typing import Any +from typing import Any, Callable, cast from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value logger = logging.getLogger(__name__) -def _resolve_type(type_key: str) -> type | None: +def resolve_type(type_key: str) -> type | None: """Resolve a 'module:class' type key to its Python type. Args: @@ -123,9 +123,11 @@ def reconstruct_to_type(value: Any, target_type: type) -> Any: return decoded # Try Pydantic model validation (for unmarked dicts, e.g., external HITL data) - if hasattr(target_type, "model_validate"): + model_validate = getattr(target_type, "model_validate", None) + if callable(model_validate): try: - return target_type.model_validate(value) + model_validate_fn = cast(Callable[[Any], Any], model_validate) + return model_validate_fn(value) except Exception: logger.debug("Could not validate Pydantic model %s", target_type) @@ -136,4 +138,4 @@ def reconstruct_to_type(value: Any, target_type: type) -> Any: except Exception: logger.debug("Could not construct dataclass %s", target_type) - return value + return cast(Any, value) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index a0e0f04185..0c24752905 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -26,7 +26,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import Any +from typing import Any, cast from agent_framework import ( AgentExecutor, @@ -44,12 +44,13 @@ SingleEdgeGroup, SwitchCaseEdgeGroup, ) +from agent_framework._workflows._state import State from agent_framework_durabletask import AgentSessionId, DurableAgentSession, DurableAIAgent from azure.durable_functions import DurableOrchestrationContext from ._context import CapturingRunnerContext from ._orchestration import AzureFunctionsAgentExecutor -from ._serialization import _resolve_type, deserialize_value, reconstruct_to_type, serialize_value +from ._serialization import deserialize_value, reconstruct_to_type, resolve_type, serialize_value logger = logging.getLogger(__name__) @@ -148,7 +149,7 @@ def _evaluate_edge_condition_sync(edge: Edge, message: Any) -> bool: True if the edge should be traversed, False otherwise """ # Access the internal condition directly since should_route is async - condition = edge._condition + condition = edge._condition # pyright: ignore[reportPrivateUsage] if condition is None: return True result = condition(message) @@ -322,7 +323,8 @@ def _prepare_activity_task( activity_input_json = json.dumps(activity_input) # Use the prefixed activity name that matches the registered function activity_name = f"dafx-{executor_id}" - return context.call_activity(activity_name, activity_input_json) + orchestration_context: Any = context + return orchestration_context.call_activity(activity_name, activity_input_json) # ============================================================================ @@ -349,8 +351,11 @@ def _process_agent_response( structured_response = None if agent_response and agent_response.value is not None: - if hasattr(agent_response.value, "model_dump"): - structured_response = agent_response.value.model_dump() + model_dump = getattr(agent_response.value, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict): + structured_response = cast(dict[str, Any], dumped) elif isinstance(agent_response.value, dict): structured_response = agent_response.value @@ -726,7 +731,7 @@ def run_workflow_orchestrator( if winner == approval_task: # Cancel the timeout - timeout_task.cancel() + timeout_task.cancel() # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] # Get the response raw_response = approval_task.result @@ -756,7 +761,7 @@ def run_workflow_orchestrator( ) else: # Timeout occurred — cancel the dangling external event listener - approval_task.cancel() + approval_task.cancel() # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] logger.warning("HITL request %s timed out after %s hours", request_id, hitl_timeout_hours) raise TimeoutError( f"Human-in-the-loop request '{request_id}' timed out after {hitl_timeout_hours} hours." @@ -864,7 +869,9 @@ def _extract_message_content(message: Any) -> str: # Extract text from the last message in the request message_content = message.messages[-1].text or "" elif isinstance(message, dict): - logger.warning("Unexpected dict message in _extract_message_content. Keys: %s", list(message.keys())) + message_dict = cast(dict[str, Any], message) + key_names = list(message_dict.keys()) + logger.warning("Unexpected dict message in _extract_message_content. Keys: %s", key_names) elif isinstance(message, str): message_content = message @@ -879,7 +886,7 @@ def _extract_message_content(message: Any) -> str: async def execute_hitl_response_handler( executor: Any, hitl_message: dict[str, Any], - shared_state: Any, + shared_state: State, runner_context: CapturingRunnerContext, ) -> None: """Execute a HITL response handler on an executor. @@ -910,7 +917,7 @@ async def execute_hitl_response_handler( response = _deserialize_hitl_response(response_data, response_type_str) # Find the matching response handler - handler = executor._find_response_handler(original_request, response) + handler = executor._find_response_handler(original_request, response) # pyright: ignore[reportPrivateUsage] if handler is None: logger.warning( @@ -965,7 +972,7 @@ def _deserialize_hitl_response(response_data: Any, response_type_str: str | None # Try to deserialize using the type hint if response_type_str: - response_type = _resolve_type(response_type_str) + response_type = resolve_type(response_type_str) if response_type: logger.debug("Found response type %s, attempting reconstruction", response_type) result = reconstruct_to_type(response_data, response_type) diff --git a/python/packages/azurefunctions/pyproject.toml b/python/packages/azurefunctions/pyproject.toml index 82fe4f32b5..c55bd86785 100644 --- a/python/packages/azurefunctions/pyproject.toml +++ b/python/packages/azurefunctions/pyproject.toml @@ -67,6 +67,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_azurefunctions"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index b0d87fe8cc..c85a7ad836 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -8,7 +8,7 @@ import sys from collections import deque from collections.abc import AsyncIterable, Awaitable, Mapping, MutableMapping, Sequence -from typing import Any, ClassVar, Generic, Literal, TypedDict +from typing import Any, ClassVar, Generic, Literal, Protocol, TypeGuard, TypedDict, cast from uuid import uuid4 from agent_framework import ( @@ -214,6 +214,10 @@ class BedrockSettings(TypedDict, total=False): session_token: SecretString | None +class BedrockRuntimeClient(Protocol): + def converse(self, **kwargs: Any) -> Mapping[str, object]: ... + + class BedrockChatClient( ChatMiddlewareLayer[BedrockChatOptionsT], FunctionInvocationLayer[BedrockChatOptionsT], @@ -224,6 +228,37 @@ class BedrockChatClient( """Async chat client for Amazon Bedrock's Converse API with middleware, telemetry, and function invocation.""" OTEL_PROVIDER_NAME: ClassVar[str] = "aws.bedrock" # type: ignore[reportIncompatibleVariableOverride, misc] + _bedrock_client: BedrockRuntimeClient + + @staticmethod + def _is_runtime_client(value: object) -> TypeGuard[BedrockRuntimeClient]: + converse = getattr(value, "converse", None) + return callable(converse) + + @staticmethod + def _get_str(value: object) -> str | None: + return value if isinstance(value, str) else None + + @staticmethod + def _get_dict(value: object) -> dict[str, Any] | None: + if not isinstance(value, dict): + return None + return cast(dict[str, Any], value) + + @staticmethod + def _is_nonstring_sequence(value: object) -> TypeGuard[Sequence[object]]: + return isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)) + + @staticmethod + def _get_content_blocks(value: object) -> list[dict[str, Any]]: + if not BedrockChatClient._is_nonstring_sequence(value): + return [] + blocks: list[dict[str, Any]] = [] + for item in value: + block = BedrockChatClient._get_dict(item) + if block is not None: + blocks.append(block) + return blocks def __init__( self, @@ -288,36 +323,55 @@ class MyOptions(BedrockChatOptions, total=False): env_file_path=env_file_path, env_file_encoding=env_file_encoding, ) - if not settings.get("region"): - settings["region"] = DEFAULT_REGION + region = settings.get("region") or DEFAULT_REGION + chat_model_id = settings.get("chat_model_id") if client is None: session = boto3_session or self._create_session(settings) - client = session.client( + client_factory = getattr(session, "client", None) + if not callable(client_factory): + raise TypeError("Boto3 session does not provide a callable client factory.") + created_client: object = client_factory( "bedrock-runtime", - region_name=settings["region"], + region_name=region, config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT), ) + if not self._is_runtime_client(created_client): + raise TypeError("Boto3 session did not create a compatible Bedrock runtime client.") + runtime_client = created_client + elif not self._is_runtime_client(client): + raise TypeError("Provided client must expose a callable 'converse' method.") + else: + runtime_client = client super().__init__( middleware=middleware, function_invocation_configuration=function_invocation_configuration, **kwargs, ) - self._bedrock_client = client - self.model_id = settings["chat_model_id"] - self.region = settings["region"] + self._bedrock_client = runtime_client + self.model_id = chat_model_id + self.region = region @staticmethod def _create_session(settings: BedrockSettings) -> Boto3Session: session_kwargs: dict[str, Any] = {"region_name": settings.get("region") or DEFAULT_REGION} - if settings.get("access_key") and settings.get("secret_key"): - session_kwargs["aws_access_key_id"] = settings["access_key"].get_secret_value() # type: ignore[union-attr] - session_kwargs["aws_secret_access_key"] = settings["secret_key"].get_secret_value() # type: ignore[union-attr] - if settings.get("session_token"): - session_kwargs["aws_session_token"] = settings["session_token"].get_secret_value() # type: ignore[union-attr] + access_key = settings.get("access_key") + secret_key = settings.get("secret_key") + session_token = settings.get("session_token") + if access_key is not None and secret_key is not None: + session_kwargs["aws_access_key_id"] = access_key.get_secret_value() + session_kwargs["aws_secret_access_key"] = secret_key.get_secret_value() + if session_token is not None: + session_kwargs["aws_session_token"] = session_token.get_secret_value() return Boto3Session(**session_kwargs) + def _invoke_converse(self, request: Mapping[str, Any]) -> dict[str, Any]: + response = self._bedrock_client.converse(**request) + if not isinstance(response, Mapping): + raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.") + return dict(response) + @override def _inner_get_response( self, @@ -332,16 +386,18 @@ def _inner_get_response( if stream: # Streaming mode - simulate streaming by yielding a single update async def _stream() -> AsyncIterable[ChatResponseUpdate]: - response = await asyncio.to_thread(self._bedrock_client.converse, **request) + response = await asyncio.to_thread(self._invoke_converse, request) parsed_response = self._process_converse_response(response) contents = list(parsed_response.messages[0].contents if parsed_response.messages else []) if parsed_response.usage_details: contents.append(Content.from_usage(usage_details=parsed_response.usage_details)) # type: ignore[arg-type] + raw_finish_reason = self._get_str(parsed_response.finish_reason) + finish_reason = self._map_finish_reason(raw_finish_reason) yield ChatResponseUpdate( response_id=parsed_response.response_id, contents=contents, model_id=parsed_response.model_id, - finish_reason=parsed_response.finish_reason, + finish_reason=finish_reason, raw_representation=parsed_response.raw_representation, ) @@ -349,7 +405,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: # Non-streaming mode async def _get_response() -> ChatResponse: - raw_response = await asyncio.to_thread(self._bedrock_client.converse, **request) + raw_response = await asyncio.to_thread(self._invoke_converse, request) return self._process_converse_response(raw_response) return _get_response() @@ -529,25 +585,25 @@ def _convert_content_to_bedrock_block(self, content: Content) -> dict[str, Any] def _convert_tool_result_to_blocks(self, result: Any) -> list[dict[str, Any]]: prepared_result = result if isinstance(result, str) else FunctionTool.parse_result(result) try: - parsed_result = json.loads(prepared_result) + parsed_result: object = json.loads(prepared_result) except json.JSONDecodeError: return [{"text": prepared_result}] return self._convert_prepared_tool_result_to_blocks(parsed_result) - def _convert_prepared_tool_result_to_blocks(self, value: Any) -> list[dict[str, Any]]: - if isinstance(value, list): + def _convert_prepared_tool_result_to_blocks(self, value: object) -> list[dict[str, Any]]: + if self._is_nonstring_sequence(value): blocks: list[dict[str, Any]] = [] for item in value: blocks.extend(self._convert_prepared_tool_result_to_blocks(item)) return blocks or [{"text": ""}] return [self._normalize_tool_result_value(value)] - def _normalize_tool_result_value(self, value: Any) -> dict[str, Any]: + def _normalize_tool_result_value(self, value: object) -> dict[str, Any]: if isinstance(value, dict): return {"json": value} - if isinstance(value, (list, tuple)): - return {"json": list(value)} + if self._is_nonstring_sequence(value): + return {"json": [item for item in value]} if isinstance(value, str): return {"text": value} if isinstance(value, (int, float, bool)) or value is None: @@ -586,15 +642,18 @@ def _generate_tool_call_id() -> str: return f"tool-call-{uuid4().hex}" def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse: - output = response.get("output", {}) - message = output.get("message", {}) - content_blocks = message.get("content", []) or [] + output = self._get_dict(response.get("output")) or {} + message = self._get_dict(output.get("message")) or {} + content_blocks = self._get_content_blocks(message.get("content")) contents = self._parse_message_contents(content_blocks) chat_message = Message(role="assistant", contents=contents, raw_representation=message) - usage_details = self._parse_usage(response.get("usage") or output.get("usage")) - finish_reason = self._map_finish_reason(output.get("completionReason") or response.get("stopReason")) - response_id = response.get("responseId") or message.get("id") - model_id = response.get("modelId") or output.get("modelId") or self.model_id + usage_source = self._get_dict(response.get("usage")) or self._get_dict(output.get("usage")) + usage_details = self._parse_usage(usage_source) + finish_reason = self._map_finish_reason( + self._get_str(output.get("completionReason")) or self._get_str(response.get("stopReason")) + ) + response_id = self._get_str(response.get("responseId")) or self._get_str(message.get("id")) + model_id = self._get_str(response.get("modelId")) or self._get_str(output.get("modelId")) or self.model_id return ChatResponse( response_id=response_id, messages=[chat_message], @@ -616,7 +675,7 @@ def _parse_usage(self, usage: dict[str, Any] | None) -> UsageDetails | None: details["total_token_count"] = total_tokens return details - def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, Any]]) -> list[Any]: + def _parse_message_contents(self, content_blocks: Sequence[dict[str, Any]]) -> list[Any]: contents: list[Any] = [] for block in content_blocks: if text_value := block.get("text"): @@ -625,32 +684,32 @@ def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, A if (json_value := block.get("json")) is not None: contents.append(Content.from_text(text=json.dumps(json_value), raw_representation=block)) continue - tool_use = block.get("toolUse") - if isinstance(tool_use, MutableMapping): - tool_name = tool_use.get("name") + tool_use = self._get_dict(block.get("toolUse")) + if tool_use is not None: + tool_name = self._get_str(tool_use.get("name")) if not tool_name: raise ChatClientInvalidResponseException( "Bedrock response missing required tool name in toolUse block." ) contents.append( Content.from_function_call( - call_id=tool_use.get("toolUseId") or self._generate_tool_call_id(), + call_id=self._get_str(tool_use.get("toolUseId")) or self._generate_tool_call_id(), name=tool_name, arguments=tool_use.get("input"), raw_representation=block, ) ) continue - tool_result = block.get("toolResult") - if isinstance(tool_result, MutableMapping): - status = (tool_result.get("status") or "success").lower() + tool_result = self._get_dict(block.get("toolResult")) + if tool_result is not None: + status = (self._get_str(tool_result.get("status")) or "success").lower() exception = None if status not in {"success", "ok"}: exception = RuntimeError(f"Bedrock tool result status: {status}") result_value = self._convert_bedrock_tool_result_to_value(tool_result.get("content")) contents.append( Content.from_function_result( - call_id=tool_result.get("toolUseId") or self._generate_tool_call_id(), + call_id=self._get_str(tool_result.get("toolUseId")) or self._generate_tool_call_id(), result=result_value, exception=str(exception) if exception else None, # type: ignore[arg-type] raw_representation=block, @@ -673,24 +732,26 @@ def service_url(self) -> str: """ return f"https://bedrock-runtime.{self.region}.amazonaws.com" - def _convert_bedrock_tool_result_to_value(self, content: Any) -> Any: + def _convert_bedrock_tool_result_to_value(self, content: object) -> object: if not content: return None - if isinstance(content, Sequence) and not isinstance(content, (str, bytes, bytearray)): - values: list[Any] = [] + if self._is_nonstring_sequence(content): + values: list[object] = [] for item in content: - if isinstance(item, MutableMapping): - if (text_value := item.get("text")) is not None: + item_dict = self._get_dict(item) + if item_dict is not None: + if (text_value := self._get_str(item_dict.get("text"))) is not None: values.append(text_value) continue - if "json" in item: - values.append(item["json"]) + if "json" in item_dict: + values.append(item_dict["json"]) continue values.append(item) return values[0] if len(values) == 1 else values - if isinstance(content, MutableMapping): - if (text_value := content.get("text")) is not None: + content_dict = self._get_dict(content) + if content_dict is not None: + if (text_value := self._get_str(content_dict.get("text"))) is not None: return text_value - if "json" in content: - return content["json"] + if "json" in content_dict: + return content_dict["json"] return content diff --git a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py index 30be74eed9..5aac1dcc74 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py @@ -7,7 +7,7 @@ import logging import sys from collections.abc import Sequence -from typing import Any, ClassVar, Generic, TypedDict +from typing import Any, ClassVar, Generic, Protocol, TypedDict, cast from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, @@ -30,6 +30,33 @@ from typing_extensions import TypeVar # type: ignore # pragma: no cover + + +class BedrockRuntimeMeta(Protocol): + endpoint_url: str + + +class BedrockResponseBody(Protocol): + def read(self) -> bytes | bytearray | str: ... + + +class BedrockInvokeModelResponse(TypedDict): + body: BedrockResponseBody + + +class BedrockRuntimeClient(Protocol): + meta: BedrockRuntimeMeta + + def invoke_model( + self, + *, + modelId: str, + contentType: str, + accept: str, + body: str, + ) -> BedrockInvokeModelResponse: ... + + logger = logging.getLogger("agent_framework.bedrock") DEFAULT_REGION = "us-east-1" @@ -101,7 +128,7 @@ def __init__( access_key: str | None = None, secret_key: str | None = None, session_token: str | None = None, - client: BaseClient | None = None, + client: BaseClient | BedrockRuntimeClient | None = None, boto3_session: Boto3Session | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, @@ -128,19 +155,24 @@ def __init__( if region := settings.get("region"): session_kwargs["region_name"] = region if (access_key := settings.get("access_key")) and (secret_key := settings.get("secret_key")): - session_kwargs["aws_access_key_id"] = access_key.get_secret_value() # type: ignore[union-attr] - session_kwargs["aws_secret_access_key"] = secret_key.get_secret_value() # type: ignore[union-attr] + session_kwargs["aws_access_key_id"] = access_key.get_secret_value() + session_kwargs["aws_secret_access_key"] = secret_key.get_secret_value() if session_token := settings.get("session_token"): - session_kwargs["aws_session_token"] = session_token.get_secret_value() # type: ignore[union-attr] + session_kwargs["aws_session_token"] = session_token.get_secret_value() boto3_session = Boto3Session(**session_kwargs) - client = boto3_session.client( - "bedrock-runtime", - region_name=boto3_session.region_name or resolved_region, - config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT), + region_name = cast(str | None, getattr(boto3_session, "region_name", None)) + client_factory = cast(Any, boto3_session.client) # pyright: ignore[reportUnknownMemberType] + client = cast( + BedrockRuntimeClient, + client_factory( + "bedrock-runtime", + region_name=region_name if isinstance(region_name, str) else resolved_region, + config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT), + ), ) - self._bedrock_client = client - self.model_id = settings["embedding_model_id"] # type: ignore[assignment] + self._bedrock_client: BedrockRuntimeClient = cast(BedrockRuntimeClient, client) + self.model_id: str = settings["embedding_model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] self.region = resolved_region super().__init__(**kwargs) @@ -169,8 +201,9 @@ async def get_embeddings( Raises: ValueError: If model_id is not provided or values is empty. """ + resolved_options = cast(EmbeddingGenerationOptions | None, options) if not values: - return GeneratedEmbeddings([], options=options) + return GeneratedEmbeddings([], options=resolved_options) opts: dict[str, Any] = dict(options) if options else {} model = opts.get("model_id") or self.model_id @@ -190,7 +223,7 @@ async def get_embeddings( if total_input_tokens > 0: usage_dict = {"input_token_count": total_input_tokens} - return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict) + return GeneratedEmbeddings(embeddings, options=resolved_options, usage=usage_dict) async def _generate_embedding_for_text( self, @@ -212,10 +245,25 @@ async def _generate_embedding_for_text( body=json.dumps(body), ) - response_body = json.loads(response["body"].read()) + response_body_raw = response["body"] + response_payload = response_body_raw.read() + payload_text = response_payload.decode() if isinstance(response_payload, (bytes, bytearray)) else response_payload + response_body_raw_map: object = json.loads(payload_text) + if not isinstance(response_body_raw_map, dict): + raise ValueError("Bedrock embedding response body must be a JSON object") + response_body = cast(dict[str, Any], response_body_raw_map) + embedding_values = response_body.get("embedding") + if not isinstance(embedding_values, list): + raise ValueError("Bedrock embedding response missing 'embedding' list") + vector: list[float] = [] + for value in cast(list[object], embedding_values): + if isinstance(value, (int, float, str)): + vector.append(float(value)) + continue + raise ValueError("Bedrock embedding response contains non-numeric embedding value") embedding = Embedding( - vector=response_body["embedding"], - dimensions=len(response_body["embedding"]), + vector=vector, + dimensions=len(vector), model_id=model, ) input_tokens = int(response_body.get("inputTextTokenCount", 0)) @@ -269,7 +317,7 @@ def __init__( access_key: str | None = None, secret_key: str | None = None, session_token: str | None = None, - client: BaseClient | None = None, + client: BaseClient | BedrockRuntimeClient | None = None, boto3_session: Boto3Session | None = None, otel_provider_name: str | None = None, env_file_path: str | None = None, diff --git a/python/packages/bedrock/pyproject.toml b/python/packages/bedrock/pyproject.toml index 5cff0f4c69..8fec38093f 100644 --- a/python/packages/bedrock/pyproject.toml +++ b/python/packages/bedrock/pyproject.toml @@ -60,6 +60,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_bedrock"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/chatkit/pyproject.toml b/python/packages/chatkit/pyproject.toml index b4ecd81dff..91ba8347b8 100644 --- a/python/packages/chatkit/pyproject.toml +++ b/python/packages/chatkit/pyproject.toml @@ -61,6 +61,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_chatkit"] exclude = ['tests', 'chatkit-python', 'openai-chatkit-advanced-samples'] [tool.mypy] diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index d764419214..e903627637 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -319,21 +319,29 @@ def _normalize_tools( if tools is None: return - # Normalize to sequence + non_builtin_tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None if isinstance(tools, str): - tools_list: Sequence[Any] = [tools] - elif isinstance(tools, Sequence): - tools_list = list(tools) + self._builtin_tools.append(tools) + return + if isinstance(tools, Sequence) and not isinstance(tools, MutableMapping): + sequence_tools: list[ToolTypes | Callable[..., Any]] = [] + for tool in tools: # pyright: ignore[reportUnknownVariableType] + if isinstance(tool, str): + self._builtin_tools.append(tool) + else: + sequence_tools.append(tool) # pyright: ignore[reportUnknownArgumentType] + non_builtin_tools = sequence_tools else: - tools_list = [tools] + non_builtin_tools = tools + + if not non_builtin_tools: + return - for tool in tools_list: + for tool in normalize_tools(non_builtin_tools): if isinstance(tool, str): self._builtin_tools.append(tool) else: - # Use normalize_tools for custom tools - normalized = normalize_tools(tool) - self._custom_tools.extend(normalized) + self._custom_tools.append(tool) async def __aenter__(self) -> RawClaudeAgent[OptionsT]: """Start the agent when entering async context.""" diff --git a/python/packages/claude/pyproject.toml b/python/packages/claude/pyproject.toml index a3b009dcd5..74c9a6f358 100644 --- a/python/packages/claude/pyproject.toml +++ b/python/packages/claude/pyproject.toml @@ -61,6 +61,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_claude"] exclude = ['tests'] [tool.mypy] diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py index 91a07b58ff..edacb614a5 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py @@ -133,43 +133,47 @@ def __init__( env_file_path=env_file_path, env_file_encoding=env_file_encoding, ) + resolved_environment_id = copilot_studio_settings.get("environmentid") + resolved_agent_identifier = copilot_studio_settings.get("schemaname") + resolved_client_id = copilot_studio_settings.get("agentappid") + resolved_tenant_id = copilot_studio_settings.get("tenantid") if not settings: - if not copilot_studio_settings["environmentid"]: + if not resolved_environment_id: raise ValueError( "Copilot Studio environment ID is required. Set via 'environment_id' parameter " "or 'COPILOTSTUDIOAGENT__ENVIRONMENTID' environment variable." ) - if not copilot_studio_settings["schemaname"]: + if not resolved_agent_identifier: raise ValueError( "Copilot Studio agent identifier/schema name is required. Set via 'agent_identifier' parameter " "or 'COPILOTSTUDIOAGENT__SCHEMANAME' environment variable." ) settings = ConnectionSettings( - environment_id=copilot_studio_settings["environmentid"], - agent_identifier=copilot_studio_settings["schemaname"], + environment_id=resolved_environment_id, + agent_identifier=resolved_agent_identifier, cloud=cloud, copilot_agent_type=agent_type, custom_power_platform_cloud=custom_power_platform_cloud, ) if not token: - if not copilot_studio_settings["agentappid"]: + if not resolved_client_id: raise ValueError( "Copilot Studio client ID is required. Set via 'client_id' parameter " "or 'COPILOTSTUDIOAGENT__AGENTAPPID' environment variable." ) - if not copilot_studio_settings["tenantid"]: + if not resolved_tenant_id: raise ValueError( "Copilot Studio tenant ID is required. Set via 'tenant_id' parameter " "or 'COPILOTSTUDIOAGENT__TENANTID' environment variable." ) token = acquire_token( - client_id=copilot_studio_settings["agentappid"], - tenant_id=copilot_studio_settings["tenantid"], + client_id=resolved_client_id, + tenant_id=resolved_tenant_id, username=username, token_cache=token_cache, scopes=scopes, diff --git a/python/packages/copilotstudio/pyproject.toml b/python/packages/copilotstudio/pyproject.toml index 02fa708f20..df6531b623 100644 --- a/python/packages/copilotstudio/pyproject.toml +++ b/python/packages/copilotstudio/pyproject.toml @@ -61,6 +61,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_copilotstudio"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index a0c998757c..878edc1312 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -83,10 +83,13 @@ def _get_tool_name(tool: Any) -> str | None: """Extract a tool's name from either an object with a .name attribute or a dict tool definition.""" - if isinstance(tool, dict): - func = tool.get("function") - if isinstance(func, dict): - return func.get("name") + if isinstance(tool, Mapping): + tool_mapping = cast(Mapping[str, Any], tool) + func = tool_mapping.get("function") + if isinstance(func, Mapping): + func_mapping = cast(Mapping[str, Any], func) + name = func_mapping.get("name") + return name if isinstance(name, str) else None return None return getattr(tool, "name", None) @@ -770,10 +773,9 @@ def _update_agent_name_and_description(self) -> None: should check if there is already an agent name defined, and if not set it to this value. """ - if hasattr(self.client, "_update_agent_name_and_description") and callable( - self.client._update_agent_name_and_description - ): # type: ignore[reportAttributeAccessIssue, attr-defined] - self.client._update_agent_name_and_description(self.name, self.description) # type: ignore[reportAttributeAccessIssue, attr-defined] + update_fn = getattr(self.client, "_update_agent_name_and_description", None) + if callable(update_fn): + update_fn(self.name, self.description) @overload def run( @@ -860,11 +862,14 @@ async def _run_non_streaming() -> AgentResponse[Any]: options=options, kwargs=kwargs, ) - response = await self.client.get_response( # type: ignore[call-overload] - messages=ctx["session_messages"], - stream=False, - options=ctx["chat_options"], - **ctx["filtered_kwargs"], + response = cast( + ChatResponse[Any] | None, + await self.client.get_response( # type: ignore[call-overload] + messages=ctx["session_messages"], + stream=False, + options=cast(Any, ctx["chat_options"]), + **ctx["filtered_kwargs"], + ), ) if not response: @@ -930,7 +935,7 @@ async def _post_hook(response: AgentResponse) -> None: ) await self._run_after_providers(session=ctx["session"], context=session_context) - async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse]: + async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ctx_holder["ctx"] = await self._prepare_run_context( messages=messages, session=session, @@ -942,7 +947,7 @@ async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse]: return self.client.get_response( # type: ignore[call-overload, no-any-return] messages=ctx["session_messages"], stream=True, - options=ctx["chat_options"], + options=cast(Any, ctx["chat_options"]), **ctx["filtered_kwargs"], ) @@ -965,13 +970,16 @@ def _finalizer(updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]: rf = ( ctx.get("chat_options", {}).get("response_format") if ctx - else (options.get("response_format") if options else None) + else (options.get("response_format") if options else None) # type: ignore[union-attr] ) return self._finalize_response_updates(updates, response_format=rf) + stream_response = cast( + ResponseStream[ChatResponseUpdate, ChatResponse[Any]], + cast(Any, ResponseStream).from_awaitable(_get_stream()), + ) return ( - ResponseStream - .from_awaitable(_get_stream()) + stream_response .map( transform=partial( map_chat_to_agent_update, @@ -988,10 +996,13 @@ def _finalize_response_updates( updates: Sequence[AgentResponseUpdate], *, response_format: Any | None = None, - ) -> AgentResponse: + ) -> AgentResponse[Any]: """Finalize response updates into a single AgentResponse.""" output_format_type = response_format if isinstance(response_format, type) else None - return AgentResponse.from_updates(updates, output_format_type=output_format_type) + return AgentResponse.from_updates( # pyright: ignore[reportUnknownVariableType] + updates, + output_format_type=output_format_type, + ) @staticmethod def _extract_conversation_id_from_streaming_response(response: AgentResponse[Any]) -> str | None: @@ -1000,10 +1011,11 @@ def _extract_conversation_id_from_streaming_response(response: AgentResponse[Any if raw is None: return None - raw_items: list[Any] = raw if isinstance(raw, list) else [raw] + raw_items: list[Any] = list(cast(Any, raw)) if isinstance(raw, list) else [raw] for item in reversed(raw_items): if isinstance(item, Mapping): - value = item.get("conversation_id") + mapped_item = cast(Mapping[str, Any], item) + value = mapped_item.get("conversation_id") if isinstance(value, str) and value: return value continue diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 278657a154..bc842bbfbd 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -317,10 +317,13 @@ def _finalize_response_updates( updates: Sequence[ChatResponseUpdate], *, response_format: Any | None = None, - ) -> ChatResponse: + ) -> ChatResponse[Any]: """Finalize response updates into a single ChatResponse.""" output_format_type = response_format if isinstance(response_format, type) else None - return ChatResponse.from_updates(updates, output_format_type=output_format_type) + return ChatResponse.from_updates( # pyright: ignore[reportUnknownVariableType] + updates, + output_format_type=output_format_type, + ) def _build_response_stream( self, diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 1f0f9e3338..ceece1a410 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence from enum import Enum -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast, overload from ._clients import SupportsChatGetResponse from ._types import ( @@ -170,9 +170,9 @@ def __init__( self.session = session self.options = options self.stream = stream - self.metadata = metadata if metadata is not None else {} + self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result - self.kwargs = kwargs if kwargs is not None else {} + self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} self.stream_transform_hooks = list(stream_transform_hooks or []) self.stream_result_hooks = list(stream_result_hooks or []) self.stream_cleanup_hooks = list(stream_cleanup_hooks or []) @@ -231,9 +231,9 @@ def __init__( """ self.function = function self.arguments = arguments - self.metadata = metadata if metadata is not None else {} + self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result - self.kwargs = kwargs if kwargs is not None else {} + self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} class ChatContext: @@ -314,9 +314,9 @@ def __init__( self.messages = messages self.options = options self.stream = stream - self.metadata = metadata if metadata is not None else {} + self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result - self.kwargs = kwargs if kwargs is not None else {} + self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} self.stream_transform_hooks = list(stream_transform_hooks or []) self.stream_result_hooks = list(stream_result_hooks or []) self.stream_cleanup_hooks = list(stream_cleanup_hooks or []) @@ -754,9 +754,11 @@ def create_next_handler(index: int) -> Callable[[], Awaitable[None]]: if index >= len(self._middleware): async def final_wrapper() -> None: - context.result = final_handler(context) # type: ignore[assignment] - if inspect.isawaitable(context.result): - context.result = await context.result + result = final_handler(context) + if inspect.isawaitable(result): + context.result = await cast(Awaitable[AgentResponse], result) + else: + context.result = result return final_wrapper @@ -893,12 +895,17 @@ async def execute( The chat response after processing through all middleware. """ if not self._middleware: - context.result = final_handler(context) # type: ignore[assignment] - if isinstance(context.result, Awaitable): - context.result = await context.result - if context.stream and not isinstance(context.result, ResponseStream): + result = final_handler(context) + if inspect.isawaitable(result): + resolved_result: ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] = await cast( + Awaitable[ChatResponse], result + ) + else: + resolved_result = result + context.result = resolved_result + if context.stream and not isinstance(resolved_result, ResponseStream): raise ValueError("Streaming agent middleware requires a ResponseStream result.") - return context.result + return resolved_result def create_next_handler(index: int) -> Callable[[], Awaitable[None]]: if index >= len(self._middleware): @@ -1038,7 +1045,11 @@ async def _execute_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse]: # If result is ChatResponse (shouldn't happen for streaming), raise error raise ValueError("Expected ResponseStream for streaming, got ChatResponse") - return ResponseStream.from_awaitable(_execute_stream()) + stream_result = cast( + ResponseStream[ChatResponseUpdate, ChatResponse[Any]], + cast(Any, ResponseStream).from_awaitable(_execute_stream()), + ) + return stream_result # For non-streaming, return the coroutine directly return _execute() # type: ignore[return-value] @@ -1120,7 +1131,12 @@ def run( ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """MiddlewareTypes-enabled unified run method.""" # Re-categorize self.middleware at runtime to support dynamic changes - base_middleware = getattr(self, "middleware", None) or [] + base_middleware_attr = getattr(self, "middleware", None) + base_middleware: Sequence[MiddlewareTypes] = ( + cast(Sequence[MiddlewareTypes], base_middleware_attr) + if isinstance(base_middleware_attr, Sequence) + else [] + ) base_middleware_list = categorize_middleware(base_middleware) run_middleware_list = categorize_middleware(middleware) pipeline = AgentMiddlewarePipeline(*base_middleware_list["agent"], *run_middleware_list["agent"]) @@ -1166,7 +1182,11 @@ async def _execute_stream() -> ResponseStream[AgentResponseUpdate, AgentResponse # If result is AgentResponse (shouldn't happen for streaming), convert to stream raise ValueError("Expected ResponseStream for streaming, got AgentResponse") - return ResponseStream.from_awaitable(_execute_stream()) + stream_result = cast( + ResponseStream[AgentResponseUpdate, AgentResponse[Any]], + cast(Any, ResponseStream).from_awaitable(_execute_stream()), + ) + return stream_result # For non-streaming, return the coroutine directly return _execute() # type: ignore[return-value] diff --git a/python/packages/core/agent_framework/_serialization.py b/python/packages/core/agent_framework/_serialization.py index 7934477298..8dffdc0ce6 100644 --- a/python/packages/core/agent_framework/_serialization.py +++ b/python/packages/core/agent_framework/_serialization.py @@ -303,7 +303,7 @@ def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) # Handle lists containing SerializationProtocol objects if isinstance(value, list): value_as_list: list[Any] = [] - for item in value: + for item in value: # pyright: ignore[reportUnknownVariableType] if isinstance(item, SerializationProtocol): value_as_list.append(item.to_dict(exclude=exclude, exclude_none=exclude_none)) continue @@ -311,7 +311,7 @@ def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) value_as_list.append(item) continue logger.debug( - f"Skipping non-serializable item in list attribute '{key}' of type {type(item).__name__}" + f"Skipping non-serializable item in list attribute '{key}' of type {type(item).__name__}" # pyright: ignore[reportUnknownArgumentType] ) result[key] = value_as_list continue @@ -320,21 +320,22 @@ def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) from datetime import date, datetime, time serialized_dict: dict[str, Any] = {} - for k, v in value.items(): + for raw_key, v in value.items(): # pyright: ignore[reportUnknownVariableType] + dict_key = str(raw_key) # pyright: ignore[reportUnknownArgumentType] if isinstance(v, SerializationProtocol): - serialized_dict[k] = v.to_dict(exclude=exclude, exclude_none=exclude_none) + serialized_dict[dict_key] = v.to_dict(exclude=exclude, exclude_none=exclude_none) continue # Convert datetime objects to strings if isinstance(v, (datetime, date, time)): - serialized_dict[k] = str(v) + serialized_dict[dict_key] = str(v) continue # Check if the value is JSON serializable if is_serializable(v): - serialized_dict[k] = v + serialized_dict[dict_key] = v continue logger.debug( - f"Skipping non-serializable value for key '{k}' in dict attribute '{key}' " - f"of type {type(v).__name__}" + f"Skipping non-serializable value for key '{dict_key}' in dict attribute '{key}' " + f"of type {type(v).__name__}" # pyright: ignore[reportUnknownArgumentType] ) result[key] = serialized_dict continue @@ -505,7 +506,8 @@ async def get_current_weather(location: Annotated[str, "The city name"]) -> str: # Only apply if the instance matches if kwargs.get(field) == name and isinstance(dep_value, dict): # Apply instance-specific dependencies - for param_name, param_value in dep_value.items(): + for raw_param_name, param_value in dep_value.items(): # pyright: ignore[reportUnknownVariableType] + param_name = str(raw_param_name) # pyright: ignore[reportUnknownArgumentType] if param_name not in cls.INJECTABLE: logger.debug( f"Dependency '{param_name}' for type '{type_id}' is not in INJECTABLE set. " diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index aba90bc6e5..26016f68cc 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -16,7 +16,7 @@ import uuid from abc import abstractmethod from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast from ._types import AgentResponse, Message @@ -92,7 +92,7 @@ def _deserialize_value(value: Any) -> Any: from pydantic import BaseModel if issubclass(cls, BaseModel): - data = {k: v for k, v in value.items() if k != "type"} + data: dict[str, Any] = {str(k): v for k, v in value.items() if k != "type"} # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] return cls.model_validate(data) except ImportError: pass @@ -229,8 +229,11 @@ def extend_tools(self, source_id: str, tools: Sequence[Any]) -> None: tools: The tools to add. """ for tool in tools: - if hasattr(tool, "additional_properties") and isinstance(tool.additional_properties, dict): - tool.additional_properties["context_source"] = source_id + if hasattr(tool, "additional_properties"): + additional_properties_obj = getattr(tool, "additional_properties") + if isinstance(additional_properties_obj, dict): + additional_properties = cast(dict[str, Any], additional_properties_obj) + additional_properties["context_source"] = source_id self.tools.extend(tools) def get_messages( diff --git a/python/packages/core/agent_framework/_settings.py b/python/packages/core/agent_framework/_settings.py index e2b6af428c..4eecf3434d 100644 --- a/python/packages/core/agent_framework/_settings.py +++ b/python/packages/core/agent_framework/_settings.py @@ -215,9 +215,7 @@ def load_settings( raise FileNotFoundError(env_file_path) raw_dotenv_values = dotenv_values(dotenv_path=env_file_path, encoding=encoding) - loaded_dotenv_values = { - key: value for key, value in raw_dotenv_values.items() if key is not None and value is not None - } + loaded_dotenv_values = {key: value for key, value in raw_dotenv_values.items() if value is not None} # Filter out None overrides so defaults / env vars are preserved overrides = {k: v for k, v in overrides.items() if v is not None} diff --git a/python/packages/core/agent_framework/_skills.py b/python/packages/core/agent_framework/_skills.py index 9e11ecbe96..49695c89e6 100644 --- a/python/packages/core/agent_framework/_skills.py +++ b/python/packages/core/agent_framework/_skills.py @@ -151,6 +151,7 @@ class Skill: content="Use this skill for DB tasks.", ) + @skill.resource def get_schema() -> str: return "CREATE TABLE ..." @@ -972,9 +973,7 @@ def _load_skills( if skills: for code_skill in skills: - error = _validate_skill_metadata( - code_skill.name, code_skill.description, "code skill" - ) + error = _validate_skill_metadata(code_skill.name, code_skill.description, "code skill") if error: logger.warning(error) continue diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 303699572c..91e2ea1c75 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -26,8 +26,10 @@ Generic, Literal, TypeAlias, + TypeGuard, TypedDict, Union, + cast, get_args, get_origin, overload, @@ -77,6 +79,14 @@ logger = logging.getLogger("agent_framework") + +def _is_str_key_mapping(value: object) -> TypeGuard[Mapping[str, Any]]: + if not isinstance(value, Mapping): + return False + keys = cast(Mapping[object, object], value).keys() + return all(isinstance(key, str) for key in keys) + + DEFAULT_MAX_ITERATIONS: Final[int] = 40 DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3 SHELL_TOOL_KIND_VALUE: Final[str] = "shell" @@ -84,7 +94,7 @@ # region Helpers -def _parse_inputs( +def _parse_inputs( # pyright: ignore[reportUnusedFunction] inputs: Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None, ) -> list[Content]: """Parse the inputs for a tool, ensuring they are of type Content. @@ -352,7 +362,8 @@ def __str__(self) -> str: def declaration_only(self) -> bool: """Indicate whether the function is declaration only (i.e., has no implementation).""" # Check for explicit _declaration_only attribute first (used in tests) - if hasattr(self, "_declaration_only") and self._declaration_only: + declaration_flag = getattr(self, "_declaration_only", False) + if isinstance(declaration_flag, bool) and declaration_flag: return True return self.func is None @@ -430,10 +441,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: ) self.invocation_count += 1 try: + func = self.func + if func is None: + raise ToolException(f"Function '{self.name}' has no implementation.") # If we have a bound instance, call the function with self if self._instance is not None: - return self.func(self._instance, *args, **kwargs) - return self.func(*args, **kwargs) # type:ignore[misc] + return func(self._instance, *args, **kwargs) + return func(*args, **kwargs) except Exception: self.invocation_exception_count += 1 raise @@ -600,9 +614,11 @@ def _make_dumpable(value: Any) -> Any: from ._types import Content if isinstance(value, list): - return [FunctionTool._make_dumpable(item) for item in value] + list_value = cast(list[object], value) + return [FunctionTool._make_dumpable(item) for item in list_value] if isinstance(value, dict): - return {k: FunctionTool._make_dumpable(v) for k, v in value.items()} + dict_value = cast(dict[object, object], value) + return {key: FunctionTool._make_dumpable(item) for key, item in dict_value.items()} if isinstance(value, Content): return value.to_dict(exclude={"raw_representation", "additional_properties"}) if isinstance(value, BaseModel): @@ -661,7 +677,7 @@ def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) return as_dict -ToolTypes: TypeAlias = FunctionTool | MCPTool | Mapping[str, Any] | Any +ToolTypes: TypeAlias = FunctionTool | MCPTool | Mapping[str, Any] | object def normalize_tools( @@ -679,17 +695,24 @@ def normalize_tools( if not tools: return [] - tool_items = ( - list(tools) - if isinstance(tools, Sequence) and not isinstance(tools, (str, bytes, bytearray, Mapping)) - else [tools] - ) + tool_items: list[object] + if isinstance(tools, Sequence) and not isinstance(tools, (str, bytes, bytearray, Mapping)): + sequence_tools = cast(Sequence[object], tools) + tool_items = list(sequence_tools) + else: + tool_items = [tools] from ._mcp import MCPTool normalized: list[ToolTypes] = [] for tool_item in tool_items: # check known types, these are also callable, so we need to do that first - if isinstance(tool_item, (FunctionTool, Mapping, MCPTool)): + if isinstance(tool_item, FunctionTool): + normalized.append(tool_item) + continue + if _is_str_key_mapping(tool_item): + normalized.append(tool_item) + continue + if isinstance(tool_item, MCPTool): normalized.append(tool_item) continue if callable(tool_item): @@ -699,7 +722,7 @@ def normalize_tools( return normalized -def _tools_to_dict( +def _tools_to_dict( # pyright: ignore[reportUnusedFunction] tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None, ) -> list[str | dict[str, Any]] | None: """Parse the tools to a dict. @@ -722,7 +745,7 @@ def _tools_to_dict( if isinstance(tool_item, SerializationMixin): results.append(tool_item.to_dict()) continue - if isinstance(tool_item, Mapping): + if _is_str_key_mapping(tool_item): results.append(dict(tool_item)) continue logger.warning("Can't parse tool.") @@ -802,7 +825,7 @@ def _validate_arguments_against_schema( raise TypeError(f"Missing required argument(s) for '{tool_name}': {', '.join(sorted(missing_fields))}") properties_raw = schema.get("properties") - properties = properties_raw if isinstance(properties_raw, Mapping) else {} + properties: Mapping[str, Any] = properties_raw if _is_str_key_mapping(properties_raw) else {} if schema.get("additionalProperties") is False: unexpected_fields = sorted(field for field in parsed_arguments if field not in properties) @@ -810,9 +833,10 @@ def _validate_arguments_against_schema( raise TypeError(f"Unexpected argument(s) for '{tool_name}': {', '.join(unexpected_fields)}") for field_name, field_value in parsed_arguments.items(): - field_schema = properties.get(field_name) - if not isinstance(field_schema, Mapping): + field_schema_raw = properties.get(field_name) + if not _is_str_key_mapping(field_schema_raw): continue + field_schema = field_schema_raw enum_values = field_schema.get("enum") if isinstance(enum_values, list) and enum_values and field_value not in enum_values: @@ -829,8 +853,10 @@ def _validate_arguments_against_schema( ) continue - if isinstance(schema_type, list): - allowed_types = [item for item in schema_type if isinstance(item, str)] + schema_type_obj: object = schema_type + if isinstance(schema_type_obj, list): + schema_type_list = cast(list[object], schema_type_obj) + allowed_types: list[str] = [item for item in schema_type_list if isinstance(item, str)] if allowed_types and not any(_matches_json_schema_type(field_value, item) for item in allowed_types): raise TypeError( f"Invalid type for '{field_name}' in '{tool_name}': expected one of " @@ -865,15 +891,21 @@ def _build_pydantic_model_from_json_schema( Returns: The dynamically created Pydantic model class. """ - properties = schema.get("properties") - required = schema.get("required", []) - definitions = schema.get("$defs", {}) + properties_raw = schema.get("properties") + properties = properties_raw if _is_str_key_mapping(properties_raw) else None + required_raw = schema.get("required", []) + required_obj: object = required_raw + required: list[str] = [ + item for item in cast(list[object], required_obj) if isinstance(item, str) + ] if isinstance(required_obj, list) else [] + defs_raw = schema.get("$defs", {}) + definitions: Mapping[str, Any] = defs_raw if _is_str_key_mapping(defs_raw) else {} # Check if 'properties' is missing or not a dictionary if not properties: return create_model(f"{model_name}_input") - def _resolve_literal_type(prop_details: dict[str, Any]) -> type | None: + def _resolve_literal_type(prop_details: Mapping[str, Any]) -> type | None: """Check if property should be a Literal type (const or enum). Args: @@ -887,14 +919,15 @@ def _resolve_literal_type(prop_details: dict[str, Any]) -> type | None: return Literal[prop_details["const"]] # type: ignore # enum → Literal["a", "b", ...] - if "enum" in prop_details and isinstance(prop_details["enum"], list): - enum_values = prop_details["enum"] + enum_raw: object = prop_details.get("enum") + if isinstance(enum_raw, list): + enum_values = cast(list[object], enum_raw) if enum_values: return Literal[tuple(enum_values)] # type: ignore return None - def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: + def _resolve_type(prop_details: Mapping[str, Any], parent_name: str = "") -> type: """Resolve JSON Schema type to Python type, handling $ref, nested objects, and typed arrays. Args: @@ -906,13 +939,23 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: """ # Handle oneOf + discriminator (polymorphic objects) if "oneOf" in prop_details and "discriminator" in prop_details: - discriminator = prop_details["discriminator"] - disc_field = discriminator.get("propertyName") - - variants = [] - for variant in prop_details["oneOf"]: + discriminator_raw = prop_details["discriminator"] + discriminator: Mapping[str, Any] = discriminator_raw if _is_str_key_mapping(discriminator_raw) else {} + disc_field_raw = discriminator.get("propertyName") + disc_field = disc_field_raw if isinstance(disc_field_raw, str) else None + + variants: list[type] = [] + one_of_raw = prop_details["oneOf"] + one_of: list[object] = cast(list[object], one_of_raw) if isinstance(one_of_raw, list) else [] + for variant_raw in one_of: + if not _is_str_key_mapping(variant_raw): + continue + variant = variant_raw if "$ref" in variant: - ref = variant["$ref"] + ref_raw = variant["$ref"] + if not isinstance(ref_raw, str): + continue + ref = ref_raw if ref.startswith("#/$defs/"): def_name = ref.split("/")[-1] resolved = definitions.get(def_name) @@ -954,7 +997,7 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: case "array": # Handle typed arrays items_schema = prop_details.get("items") - if items_schema and isinstance(items_schema, dict): + if _is_str_key_mapping(items_schema): # Recursively resolve the item type item_type = _resolve_type(items_schema, f"{parent_name}_item") # Return list[ItemType] instead of bare list @@ -963,21 +1006,29 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: return list case "object": # Handle nested objects by creating a nested Pydantic model - nested_properties = prop_details.get("properties") - nested_required = prop_details.get("required", []) - - if nested_properties and isinstance(nested_properties, dict): + nested_properties_raw = prop_details.get("properties") + nested_properties = nested_properties_raw if _is_str_key_mapping(nested_properties_raw) else None + nested_required_raw = prop_details.get("required", []) + nested_required_obj: object = nested_required_raw + nested_required: set[str] = { + item for item in cast(list[object], nested_required_obj) if isinstance(item, str) + } if isinstance(nested_required_obj, list) else set() + + if nested_properties: # Create the name for the nested model nested_model_name = f"{parent_name}_nested" if parent_name else "NestedModel" # Recursively build field definitions for the nested model nested_field_definitions: dict[str, Any] = {} - for nested_prop_name, nested_prop_details in nested_properties.items(): - nested_prop_details = ( - json.loads(nested_prop_details) - if isinstance(nested_prop_details, str) - else nested_prop_details + for nested_prop_name, nested_prop_details_raw in nested_properties.items(): + nested_prop_details_candidate = ( + json.loads(nested_prop_details_raw) + if isinstance(nested_prop_details_raw, str) + else nested_prop_details_raw ) + if not _is_str_key_mapping(nested_prop_details_candidate): + continue + nested_prop_details = nested_prop_details_candidate # Check for Literal types first (const/enum) literal_type = _resolve_literal_type(nested_prop_details) @@ -1021,8 +1072,12 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: return str # default field_definitions: dict[str, Any] = {} - for prop_name, prop_details in properties.items(): - prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details + + for prop_name, prop_details_raw in properties.items(): + prop_details_candidate = json.loads(prop_details_raw) if isinstance(prop_details_raw, str) else prop_details_raw + if not _is_str_key_mapping(prop_details_candidate): + continue + prop_details = prop_details_candidate # Check for Literal types first (const/enum) literal_type = _resolve_literal_type(prop_details) @@ -1054,7 +1109,9 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type: return create_model(f"{model_name}_input", **field_definitions) -def _create_model_from_json_schema(tool_name: str, schema_json: Mapping[str, Any]) -> type[BaseModel]: +def _create_model_from_json_schema( # pyright: ignore[reportUnusedFunction] + tool_name: str, schema_json: Mapping[str, Any] +) -> type[BaseModel]: """Creates a Pydantic model from a given JSON Schema. Args: @@ -1348,8 +1405,6 @@ def normalize_function_invocation_configuration( raise ValueError("max_function_calls must be at least 1 or None.") if normalized["max_consecutive_errors_per_request"] < 0: raise ValueError("max_consecutive_errors_per_request must be 0 or more.") - if normalized["additional_tools"] is None: - normalized["additional_tools"] = [] return normalized @@ -1424,7 +1479,7 @@ async def _auto_invoke_function( if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"} } try: - if not tool._schema_supplied and tool.input_model is not None: + if not cast(bool, getattr(tool, "_schema_supplied", False)) and tool.input_model is not None: args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True) else: args = dict(parsed_args) @@ -1435,7 +1490,7 @@ async def _auto_invoke_function( ) except (TypeError, ValidationError) as exc: message = "Error: Argument parsing failed." - if config["include_detailed_errors"]: + if config.get("include_detailed_errors", False): message = f"{message} Exception: {exc}" return Content.from_function_result( call_id=function_call_content.call_id, # type: ignore[arg-type] @@ -1459,7 +1514,7 @@ async def _auto_invoke_function( ) except Exception as exc: message = "Error: Function failed." - if config["include_detailed_errors"]: + if config.get("include_detailed_errors", False): message = f"{message} Exception: {exc}" return Content.from_function_result( call_id=function_call_content.call_id, # type: ignore[arg-type] @@ -1505,7 +1560,7 @@ async def final_function_handler(context_obj: Any) -> Any: raise except Exception as exc: message = "Error: Function failed." - if config["include_detailed_errors"]: + if config.get("include_detailed_errors", False): message = f"{message} Exception: {exc}" return Content.from_function_result( call_id=function_call_content.call_id, # type: ignore[arg-type] @@ -1560,7 +1615,8 @@ async def _try_execute_function_calls( approval_tools, ) declaration_only = [tool_name for tool_name, tool in tool_map.items() if tool.declaration_only] - additional_tool_names = [tool.name for tool in config["additional_tools"]] if config["additional_tools"] else [] + configured_additional_tools = config.get("additional_tools") or [] + additional_tool_names = [tool.name for tool in configured_additional_tools] # check if any are calling functions that need approval # if so, we return approval request for all approval_needed = False @@ -1581,7 +1637,9 @@ async def _try_execute_function_calls( declaration_only_flag = True break if ( - config["terminate_on_unknown_calls"] and fcc.type == "function_call" and fcc.name not in tool_map # type: ignore[attr-defined] + config.get("terminate_on_unknown_calls", False) + and fcc.type == "function_call" + and fcc.name not in tool_map # type: ignore[attr-defined] ): raise KeyError(f'Error: Requested function "{fcc.name}" not found.') # type: ignore[attr-defined] if approval_needed: @@ -1598,7 +1656,7 @@ async def _try_execute_function_calls( if declaration_only_flag: # return the declaration only tools to the user, since we cannot execute them. # Mark as user_input_request so AgentExecutor emits request_info events and pauses the workflow. - declaration_only_calls = [] + declaration_only_calls: list[Content] = [] for fcc in function_calls: if fcc.type == "function_call": fcc.user_input_request = True @@ -1696,16 +1754,16 @@ def _update_conversation_id( async def _ensure_response_stream( - stream_like: ResponseStream[Any, Any] | Awaitable[ResponseStream[Any, Any]], -) -> ResponseStream[Any, Any]: + stream_like: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] + | Awaitable[ResponseStream[ChatResponseUpdate, ChatResponse[Any]]], +) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: from ._types import ResponseStream stream = await stream_like if isinstance(stream_like, Awaitable) else stream_like if not isinstance(stream, ResponseStream): raise ValueError("Streaming function invocation requires a ResponseStream result.") - if getattr(stream, "_stream", None) is None: - await stream - return stream + await stream + return cast(ResponseStream[ChatResponseUpdate, ChatResponse[Any]], stream) def _extract_tools( @@ -1776,7 +1834,7 @@ def _replace_approval_contents_with_results( } # Track approval requests that should be removed (duplicates) - contents_to_remove = [] + contents_to_remove: list[int] = [] for content_idx, content in enumerate(msg.contents): if content.type == "function_approval_request": @@ -2097,7 +2155,9 @@ def get_response( function_middleware_pipeline = FunctionMiddlewarePipeline( *(self.function_middleware), *(function_middleware or []) ) - max_errors: int = self.function_invocation_configuration["max_consecutive_errors_per_request"] # type: ignore[assignment] + max_errors = self.function_invocation_configuration.get( + "max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST + ) additional_function_arguments: dict[str, Any] = {} if options and (additional_opts := options.get("additional_function_arguments")): # type: ignore[attr-defined] additional_function_arguments = additional_opts # type: ignore @@ -2122,7 +2182,7 @@ def get_response( if not stream: - async def _get_response() -> ChatResponse: + async def _get_response() -> ChatResponse[Any]: nonlocal mutable_options nonlocal filtered_kwargs errors_in_a_row: int = 0 @@ -2130,13 +2190,11 @@ async def _get_response() -> ChatResponse: max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls") prepped_messages = list(messages) fcc_messages: list[Message] = [] - response: ChatResponse | None = None + response: ChatResponse[Any] | None = None - for attempt_idx in range( - self.function_invocation_configuration["max_iterations"] - if self.function_invocation_configuration["enabled"] - else 0 - ): + loop_enabled = self.function_invocation_configuration.get("enabled", True) + max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS) + for attempt_idx in range(max_iterations if loop_enabled else 0): approval_result = await _process_function_requests( response=None, prepped_messages=prepped_messages, @@ -2147,17 +2205,20 @@ async def _get_response() -> ChatResponse: max_errors=max_errors, execute_function_calls=execute_function_calls, ) - if approval_result["action"] == "stop": + if approval_result.get("action") == "stop": response = ChatResponse(messages=prepped_messages) break - errors_in_a_row = approval_result["errors_in_a_row"] + errors_in_a_row = approval_result.get("errors_in_a_row", errors_in_a_row) total_function_calls += approval_result.get("function_call_count", 0) - response = await super_get_response( - messages=prepped_messages, - stream=False, - options=mutable_options, - **filtered_kwargs, + response = cast( + ChatResponse[Any], + await super_get_response( + messages=prepped_messages, + stream=False, + options=mutable_options, + **filtered_kwargs, + ), ) if response.conversation_id is not None: @@ -2174,10 +2235,10 @@ async def _get_response() -> ChatResponse: max_errors=max_errors, execute_function_calls=execute_function_calls, ) - if result["action"] == "return": + if result.get("action") == "return": return response total_function_calls += result.get("function_call_count", 0) - if result["action"] == "stop": + if result.get("action") == "stop": # Error threshold reached: force a final non-tool turn so # function_call_output items are submitted before exit. mutable_options["tool_choice"] = "none" @@ -2190,7 +2251,7 @@ async def _get_response() -> ChatResponse: max_function_calls, ) mutable_options["tool_choice"] = "none" - errors_in_a_row = result["errors_in_a_row"] + errors_in_a_row = result.get("errors_in_a_row", errors_in_a_row) # When tool_choice is 'required', reset tool_choice after one iteration to avoid infinite loops if mutable_options.get("tool_choice") == "required" or ( @@ -2213,17 +2274,20 @@ async def _get_response() -> ChatResponse: # Make a final model call with tool_choice="none" so the model # produces a plain text answer instead of leaving orphaned # function_call items without matching results. - if response is not None and self.function_invocation_configuration["enabled"]: + if response is not None and self.function_invocation_configuration.get("enabled", True): logger.info( "Maximum iterations reached (%d). Requesting final response without tools.", - self.function_invocation_configuration["max_iterations"], + self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS), ) mutable_options["tool_choice"] = "none" - response = await super_get_response( - messages=prepped_messages, - stream=False, - options=mutable_options, - **filtered_kwargs, + response = cast( + ChatResponse[Any], + await super_get_response( + messages=prepped_messages, + stream=False, + options=mutable_options, + **filtered_kwargs, + ), ) if fcc_messages: for msg in reversed(fcc_messages): @@ -2233,7 +2297,7 @@ async def _get_response() -> ChatResponse: return _get_response() response_format = mutable_options.get("response_format") if mutable_options else None - output_format_type = response_format if isinstance(response_format, type) else None + output_format_type: type[BaseModel] | None = response_format if isinstance(response_format, type) else None stream_result_hooks: list[Callable[[ChatResponse], Any]] = [] async def _stream() -> AsyncIterable[ChatResponseUpdate]: @@ -2245,13 +2309,11 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: max_function_calls: int | None = self.function_invocation_configuration.get("max_function_calls") prepped_messages = list(messages) fcc_messages: list[Message] = [] - response: ChatResponse | None = None + response: ChatResponse[Any] | None = None - for attempt_idx in range( - self.function_invocation_configuration["max_iterations"] - if self.function_invocation_configuration["enabled"] - else 0 - ): + loop_enabled = self.function_invocation_configuration.get("enabled", True) + max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS) + for attempt_idx in range(max_iterations if loop_enabled else 0): approval_result = await _process_function_requests( response=None, prepped_messages=prepped_messages, @@ -2262,20 +2324,23 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: max_errors=max_errors, execute_function_calls=execute_function_calls, ) - errors_in_a_row = approval_result["errors_in_a_row"] + errors_in_a_row = approval_result.get("errors_in_a_row", errors_in_a_row) total_function_calls += approval_result.get("function_call_count", 0) - if approval_result["action"] == "stop": + if approval_result.get("action") == "stop": mutable_options["tool_choice"] = "none" return - inner_stream = await _ensure_response_stream( + stream_like = cast( + ResponseStream[ChatResponseUpdate, ChatResponse[Any]] + | Awaitable[ResponseStream[ChatResponseUpdate, ChatResponse[Any]]], super_get_response( messages=prepped_messages, stream=True, options=mutable_options, **filtered_kwargs, - ) + ), ) + inner_stream = await _ensure_response_stream(stream_like) # Collect result hooks from the inner stream to run later stream_result_hooks[:] = _get_result_hooks_from_stream(inner_stream) @@ -2308,18 +2373,18 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: max_errors=max_errors, execute_function_calls=execute_function_calls, ) - errors_in_a_row = result["errors_in_a_row"] + errors_in_a_row = result.get("errors_in_a_row", errors_in_a_row) total_function_calls += result.get("function_call_count", 0) - if role := result["update_role"]: + if role := result.get("update_role"): yield ChatResponseUpdate( - contents=result["function_call_results"] or [], + contents=result.get("function_call_results") or [], role=role, ) - if result["action"] == "stop": + if result.get("action") == "stop": # Error threshold reached: submit collected function_call_output # items once more with tools disabled. mutable_options["tool_choice"] = "none" - elif result["action"] != "continue": + elif result.get("action") != "continue": return elif max_function_calls is not None and total_function_calls >= max_function_calls: # Best-effort limit: checked after each batch of parallel calls completes, @@ -2352,26 +2417,29 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: # Make a final model call with tool_choice="none" so the model # produces a plain text answer instead of leaving orphaned # function_call items without matching results. - if response is not None and self.function_invocation_configuration["enabled"]: + if response is not None and self.function_invocation_configuration.get("enabled", True): logger.info( "Maximum iterations reached (%d). Requesting final response without tools.", - self.function_invocation_configuration["max_iterations"], + self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS), ) mutable_options["tool_choice"] = "none" - inner_stream = await _ensure_response_stream( + stream_like = cast( + ResponseStream[ChatResponseUpdate, ChatResponse[Any]] + | Awaitable[ResponseStream[ChatResponseUpdate, ChatResponse[Any]]], super_get_response( messages=prepped_messages, stream=True, options=mutable_options, **filtered_kwargs, - ) + ), ) + inner_stream = await _ensure_response_stream(stream_like) async for update in inner_stream: yield update # Finalize the inner stream to trigger its hooks await inner_stream.get_final_response() - def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse: + def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse[Any]: # Note: stream_result_hooks are already run via inner stream's get_final_response() # We don't need to run them again here return ChatResponse.from_updates(updates, output_format_type=output_format_type) diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index ee0e813d27..e0797f64a2 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -4,6 +4,7 @@ import base64 import json +import inspect import logging import re import sys @@ -17,10 +18,11 @@ Mapping, MutableMapping, Sequence, + Sized, ) from copy import deepcopy from datetime import datetime -from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, NewType, cast, overload +from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, NewType, TypeGuard, cast, overload from pydantic import BaseModel @@ -194,7 +196,7 @@ def _get_data_bytes_as_str(content: Content) -> str | None: return data # type: ignore[return-value, no-any-return] -def _get_data_bytes(content: Content) -> bytes | None: +def _get_data_bytes(content: Content) -> bytes | None: # pyright: ignore[reportUnusedFunction] """Extract and decode binary data from data URI. Args: @@ -219,6 +221,20 @@ def _get_data_bytes(content: Content) -> bytes | None: KNOWN_URI_SCHEMAS: Final[set[str]] = {"http", "https", "ftp", "ftps", "file", "s3", "gs", "azure", "blob"} +def _is_legacy_value_mapping(value: object) -> TypeGuard[Mapping[str, str]]: + if not isinstance(value, Mapping): + return False + mapping = cast(Mapping[object, object], value) + return isinstance(mapping.get("value"), str) + + +def _is_str_key_mapping(value: object) -> TypeGuard[Mapping[str, Any]]: + if not isinstance(value, Mapping): + return False + mapping = cast(Mapping[object, object], value) + return all(isinstance(key, str) for key in mapping.keys()) + + def _validate_uri(uri: str, media_type: str | None) -> dict[str, Any]: """Validate URI format and return validation result. @@ -270,9 +286,18 @@ def _serialize_value(value: Any, exclude_none: bool) -> Any: if isinstance(value, Content): return value.to_dict(exclude_none=exclude_none) if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): - return [_serialize_value(item, exclude_none) for item in value] + serialized_items: list[Any] = [] + for item in cast(Iterable[Any], value): + item_any: Any = item + serialized_items.append(_serialize_value(item_any, exclude_none)) + return serialized_items if isinstance(value, Mapping): - return {k: _serialize_value(v, exclude_none) for k, v in value.items()} + serialized_mapping: dict[Any, Any] = {} + for key, map_value in cast(Mapping[object, object], value).items(): + key_any: Any = key + map_value_any: Any = map_value + serialized_mapping[key_any] = _serialize_value(map_value_any, exclude_none) + return serialized_mapping if hasattr(value, "to_dict"): return value.to_dict() # type: ignore[call-arg] return value @@ -1264,18 +1289,23 @@ def from_dict(cls: type[ContentT], data: Mapping[str, Any]) -> ContentT: return cls.from_data(remaining["data"], remaining["media_type"]) # Handle nested Content objects (e.g., function_call in function_approval_request) - if "function_call" in remaining and isinstance(remaining["function_call"], dict): - remaining["function_call"] = cls.from_dict(remaining["function_call"]) + function_call_raw = remaining.get("function_call") + if _is_str_key_mapping(function_call_raw): + remaining["function_call"] = cls.from_dict(function_call_raw) # Handle list of Content objects (e.g., inputs in code_interpreter_tool_call) - if "inputs" in remaining and isinstance(remaining["inputs"], list): + input_items_obj: object = remaining.get("inputs") + if isinstance(input_items_obj, list): + input_items: list[Any] = list(cast(Iterable[Any], input_items_obj)) remaining["inputs"] = [ - cls.from_dict(item) if isinstance(item, dict) else item for item in remaining["inputs"] + cls.from_dict(item) if _is_str_key_mapping(item) else item for item in input_items ] - if "outputs" in remaining and isinstance(remaining["outputs"], list): + output_items_obj: object = remaining.get("outputs") + if isinstance(output_items_obj, list): + output_items: list[Any] = list(cast(Iterable[Any], output_items_obj)) remaining["outputs"] = [ - cls.from_dict(item) if isinstance(item, dict) else item for item in remaining["outputs"] + cls.from_dict(item) if _is_str_key_mapping(item) else item for item in output_items ] return cls( @@ -1307,22 +1337,26 @@ def __add__(self, other: Content) -> Content: def _add_text_content(self, other: Content) -> Content: """Add two TextContent instances.""" # Merge raw representations + raw_representation: Any if self.raw_representation is None: raw_representation = other.raw_representation elif other.raw_representation is None: raw_representation = self.raw_representation else: - raw_representation = ( - self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] - ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) + self_raw_repr: object = self.raw_representation + other_raw_repr: object = other.raw_representation + self_raw: list[object] = cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] + other_raw: list[object] = cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + raw_representation = self_raw + other_raw # Merge annotations + annotations: Sequence[Annotation] | None if self.annotations is None: annotations = other.annotations elif other.annotations is None: annotations = self.annotations else: - annotations = self.annotations + other.annotations # type: ignore[operator] + annotations = [*self.annotations, *other.annotations] return Content( "text", @@ -1343,17 +1377,20 @@ def _add_text_reasoning_content(self, other: Content) -> Content: elif other.raw_representation is None: raw_representation = self.raw_representation else: - raw_representation = ( - self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] - ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) + self_raw_repr: object = self.raw_representation + other_raw_repr: object = other.raw_representation + self_raw: list[object] = cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] + other_raw: list[object] = cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + raw_representation = self_raw + other_raw # Merge annotations + annotations: Sequence[Annotation] | None if self.annotations is None: annotations = other.annotations elif other.annotations is None: annotations = self.annotations else: - annotations = self.annotations + other.annotations # type: ignore[operator] + annotations = [*self.annotations, *other.annotations] # Concatenate text, handling None values self_text = self.text or "" # type: ignore[attr-defined] @@ -1402,9 +1439,11 @@ def _add_function_call_content(self, other: Content) -> Content: elif other.raw_representation is None: raw_representation = self.raw_representation else: - raw_representation = ( - self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] - ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) + self_raw_repr: object = self.raw_representation + other_raw_repr: object = other.raw_representation + self_raw: list[object] = cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] + other_raw: list[object] = cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + raw_representation = self_raw + other_raw return Content( "function_call", @@ -1437,14 +1476,17 @@ def _add_usage_content(self, other: Content) -> Content: combined_details[key] = other_val # Merge raw representations + raw_representation: Any if self.raw_representation is None: raw_representation = other.raw_representation elif other.raw_representation is None: raw_representation = self.raw_representation else: - raw_representation = ( - self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] - ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) + self_raw_repr: object = self.raw_representation + other_raw_repr: object = other.raw_representation + self_raw: list[object] = cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] + other_raw: list[object] = cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + raw_representation = self_raw + other_raw return Content( "usage", @@ -1666,7 +1708,7 @@ def __init__( raw_representation: Optional raw representation of the chat message. """ # Handle role conversion from legacy dict format - if isinstance(role, dict) and "value" in role: + if _is_legacy_value_mapping(role): role = role["value"] # Handle contents conversion @@ -1836,14 +1878,14 @@ def _process_update(response: ChatResponse | AgentResponse, update: ChatResponse if update.created_at is not None: response.created_at = update.created_at if update.additional_properties is not None: - if response.additional_properties is None: - response.additional_properties = {} response.additional_properties.update(update.additional_properties) if response.raw_representation is None: response.raw_representation = [] if not isinstance(response.raw_representation, list): response.raw_representation = [response.raw_representation] - response.raw_representation.append(update.raw_representation) + raw_representation_value = cast(Any, getattr(response, "raw_representation", None)) + raw_representation_list = cast(list[Any], raw_representation_value) + raw_representation_list.append(update.raw_representation) if isinstance(response, ChatResponse) and isinstance(update, ChatResponseUpdate): if update.conversation_id is not None: response.conversation_id = update.conversation_id @@ -2027,8 +2069,8 @@ def __init__( self.model_id = model_id self.created_at = created_at # Handle legacy dict format for finish_reason - if isinstance(finish_reason, dict) and "value" in finish_reason: - finish_reason = finish_reason["value"] + if _is_legacy_value_mapping(finish_reason): + finish_reason = cast(FinishReasonLiteral | FinishReason, finish_reason["value"]) self.finish_reason = finish_reason self.usage_details = usage_details self._value: ResponseModelT | None = value @@ -2621,7 +2663,7 @@ def __init__( self.contents = processed_contents # Handle legacy dict format for role - if isinstance(role, dict) and "value" in role: + if _is_legacy_value_mapping(role): role = role["value"] self.role: str | None = role @@ -2672,6 +2714,12 @@ def map_chat_to_agent_update(update: ChatResponseUpdate, agent_name: str | None) OuterFinalT = TypeVar("OuterFinalT") +async def _await_if_needed(value: _T | Awaitable[_T]) -> _T: + if inspect.isawaitable(value): + return await cast(Awaitable[_T], value) + return value + + class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]): """Async stream wrapper that supports iteration and deferred finalization.""" @@ -2757,11 +2805,11 @@ def map( ... AgentResponse.from_updates, ... ) """ - stream: ResponseStream[Any, Any] = ResponseStream(self, finalizer=finalizer) + stream: ResponseStream[OuterUpdateT, OuterFinalT] = ResponseStream(self, finalizer=finalizer) stream._inner_stream_source = self stream._wrap_inner = True stream._map_update = transform - return stream # type: ignore[return-value] + return stream def with_finalizer( self, @@ -2785,10 +2833,10 @@ def with_finalizer( Example: >>> stream.with_finalizer(AgentResponse.from_updates) """ - stream: ResponseStream[Any, Any] = ResponseStream(self, finalizer=finalizer) + stream: ResponseStream[UpdateT, OuterFinalT] = ResponseStream(self, finalizer=finalizer) stream._inner_stream_source = self stream._wrap_inner = True - return stream # type: ignore[return-value] + return stream @classmethod def from_awaitable( @@ -2813,24 +2861,26 @@ def from_awaitable( >>> async def get_stream() -> ResponseStream[Update, Response]: ... >>> stream = ResponseStream.from_awaitable(get_stream()) """ - stream: ResponseStream[Any, Any] = cls(awaitable) # type: ignore[arg-type] - stream._inner_stream_source = awaitable # type: ignore[assignment] + stream: ResponseStream[UpdateT, FinalT] = cls(cast(Awaitable[AsyncIterable[UpdateT]], awaitable)) + stream._inner_stream_source = awaitable stream._wrap_inner = True - return stream # type: ignore[return-value] + return stream async def _get_stream(self) -> AsyncIterable[UpdateT]: if self._stream is None: if hasattr(self._stream_source, "__aiter__"): - self._stream = self._stream_source # type: ignore[assignment] + self._stream = cast(AsyncIterable[UpdateT], self._stream_source) else: if not iscoroutine(self._stream_source): - self._stream = self._stream_source # type: ignore[assignment] + self._stream = cast(AsyncIterable[UpdateT], self._stream_source) else: - self._stream = await self._stream_source # type: ignore[assignment] - if isinstance(self._stream, ResponseStream) and self._wrap_inner: - self._inner_stream = self._stream - return self._stream - return self._stream # type: ignore[return-value] + self._stream = await self._stream_source + stream_obj = cast(Any, self._stream) + if isinstance(stream_obj, ResponseStream) and self._wrap_inner: + inner_stream: Any = cast(Any, stream_obj) + self._inner_stream = inner_stream + return cast(AsyncIterable[UpdateT], inner_stream) + return cast(AsyncIterable[UpdateT], cast(Any, self._stream)) def __aiter__(self) -> ResponseStream[UpdateT, FinalT]: return self @@ -2840,7 +2890,7 @@ async def __anext__(self) -> UpdateT: stream = await self._get_stream() self._iterator = stream.__aiter__() try: - update = await self._iterator.__anext__() + update: UpdateT = await self._iterator.__anext__() except StopAsyncIteration: self._consumed = True await self._run_cleanup_hooks() @@ -2851,16 +2901,20 @@ async def __anext__(self) -> UpdateT: if self._map_update is not None: mapped = self._map_update(update) if isinstance(mapped, Awaitable): - update = await mapped + mapped_any: Any = cast(Any, await mapped) + update = cast(UpdateT, mapped_any) else: - update = mapped # type: ignore[assignment] + mapped_any = mapped + update = cast(UpdateT, mapped_any) self._updates.append(update) for hook in self._transform_hooks: hooked = hook(update) if isinstance(hooked, Awaitable): - update = await hooked + hooked_any: Any = cast(Any, await hooked) + update = cast(UpdateT, hooked_any) elif hooked is not None: - update = hooked # type: ignore[assignment] + hooked_any = cast(Any, hooked) + update = cast(UpdateT, hooked_any) return update def __await__(self) -> Any: @@ -2903,61 +2957,62 @@ async def get_final_response(self) -> FinalT: # First, finalize the inner stream and run its result hooks # This ensures inner post-processing (e.g., context provider notifications) runs - if self._inner_stream._finalizer is not None: - inner_result: Any = self._inner_stream._finalizer(self._inner_stream._updates) - if isinstance(inner_result, Awaitable): - inner_result = await inner_result + inner_stream = self._inner_stream + inner_result: Any + if inner_stream._finalizer is not None: + inner_finalizer = inner_stream._finalizer + inner_result = await _await_if_needed(inner_finalizer(inner_stream._updates)) else: - inner_result = self._inner_stream._updates + inner_result = list(inner_stream._updates) + # Run inner stream's result hooks - for hook in self._inner_stream._result_hooks: - hooked = hook(inner_result) - if isinstance(hooked, Awaitable): - hooked = await hooked - if hooked is not None: - inner_result = hooked - self._inner_stream._final_result = inner_result - self._inner_stream._finalized = True + inner_hooks = cast(list[Callable[[Any], Any | Awaitable[Any] | None]], inner_stream._result_hooks) + for hook in inner_hooks: + hooked_result = await _await_if_needed(hook(inner_result)) + if hooked_result is not None: + inner_result = hooked_result + inner_stream._final_result = inner_result + inner_stream._finalized = True # Now finalize the outer stream with its own finalizer # If outer has no finalizer, use inner's result (preserves from_awaitable behavior) + outer_result: Any if self._finalizer is not None: - result: Any = self._finalizer(self._updates) - if isinstance(result, Awaitable): - result = await result + outer_result = await _await_if_needed(self._finalizer(self._updates)) else: # No outer finalizer - use inner's finalized result - result = inner_result + outer_result = inner_result + # Apply outer's result_hooks - for hook in self._result_hooks: - hooked = hook(result) - if isinstance(hooked, Awaitable): - hooked = await hooked - if hooked is not None: - result = hooked - self._final_result = result + outer_hooks = cast(list[Callable[[Any], Any | Awaitable[Any] | None]], self._result_hooks) + for hook in outer_hooks: + outer_hooked_result = await _await_if_needed(hook(outer_result)) + if outer_hooked_result is not None: + outer_result = outer_hooked_result + self._final_result = cast(FinalT, outer_result) self._finalized = True - return self._final_result # type: ignore[return-value] + return cast(FinalT, self._final_result) + if not self._finalized: if not self._consumed: async for _ in self: pass + # Use finalizer if configured, otherwise return collected updates + result: Any if self._finalizer is not None: - result = self._finalizer(self._updates) - if isinstance(result, Awaitable): - result = await result + result = await _await_if_needed(self._finalizer(self._updates)) else: - result = self._updates - for hook in self._result_hooks: - hooked = hook(result) - if isinstance(hooked, Awaitable): - hooked = await hooked - if hooked is not None: - result = hooked - self._final_result = result + result = list(self._updates) + + final_hooks = cast(list[Callable[[Any], Any | Awaitable[Any] | None]], self._result_hooks) + for hook in final_hooks: + final_hook_result = await _await_if_needed(hook(result)) + if final_hook_result is not None: + result = final_hook_result + self._final_result = cast(FinalT, result) self._finalized = True - return self._final_result # type: ignore[return-value] + return cast(FinalT, self._final_result) def with_transform_hook( self, @@ -3302,9 +3357,11 @@ def merge_chat_options( # Copy base values (shallow copy for simple values, dict copy for dicts) for key, value in base.items(): if isinstance(value, dict): - result[key] = dict(value) + dict_value = cast(Mapping[Any, Any], value) + result[key] = dict(dict_value) elif isinstance(value, list): - result[key] = list(value) + list_value: list[Any] = list(cast(Iterable[Any], value)) + result[key] = list(list_value) else: result[key] = value @@ -3325,20 +3382,22 @@ def merge_chat_options( base_tools = result.get("tools") if base_tools and value: # Add tools that aren't already present - merged_tools = list(base_tools) - for tool in value if isinstance(value, list) else [value]: + base_tool_values: list[Any] = list(cast(Iterable[Any], base_tools)) if isinstance(base_tools, list) else [base_tools] + merged_tools = list(base_tool_values) + tool_values: list[Any] = list(cast(Iterable[Any], value)) if isinstance(value, list) else [value] + for tool in tool_values: if tool not in merged_tools: merged_tools.append(tool) result["tools"] = merged_tools elif value: - result["tools"] = list(value) if isinstance(value, list) else [value] + result["tools"] = value if isinstance(value, list) else [value] elif key in ("logit_bias", "metadata", "additional_properties"): # Merge dicts base_dict = result.get(key) - if base_dict and isinstance(value, dict): + if base_dict and isinstance(base_dict, dict) and isinstance(value, dict): result[key] = {**base_dict, **value} elif value: - result[key] = dict(value) if isinstance(value, dict) else value + result[key] = dict(cast(Mapping[Any, Any], value)) if isinstance(value, dict) else value elif key == "tool_choice": # tool_choice from override takes precedence result["tool_choice"] = value if value else result.get("tool_choice") @@ -3424,8 +3483,8 @@ def dimensions(self) -> int | None: """ if self._dimensions is not None: return self._dimensions - if isinstance(self.vector, (list, tuple, bytes)): - return len(self.vector) + if isinstance(self.vector, Sized): + return len(cast(Sized, self.vector)) return None diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 3d8024a35e..ac2ebcf56f 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -450,9 +450,9 @@ def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, A options: dict[str, Any] = {} if options_from_workflow is not None: if isinstance(options_from_workflow, Mapping): - for key, value in options_from_workflow.items(): - if isinstance(key, str): - options[key] = value + options_from_workflow_map = cast(Mapping[str, Any], options_from_workflow) + for key, value in options_from_workflow_map.items(): + options[key] = value else: logger.warning( "Ignoring non-mapping workflow 'options' kwarg of type %s for AgentExecutor %s.", @@ -461,16 +461,17 @@ def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, A ) existing_additional_args = options.get("additional_function_arguments") + additional_args: dict[str, Any] if isinstance(existing_additional_args, Mapping): - additional_args = {key: value for key, value in existing_additional_args.items() if isinstance(key, str)} + existing_additional_args_map = cast(Mapping[str, Any], existing_additional_args) + additional_args = {key: value for key, value in existing_additional_args_map.items()} else: additional_args = {} if workflow_additional_args is not None: if isinstance(workflow_additional_args, Mapping): - additional_args.update({ - key: value for key, value in workflow_additional_args.items() if isinstance(key, str) - }) + workflow_additional_args_map = cast(Mapping[str, Any], workflow_additional_args) + additional_args.update({key: value for key, value in workflow_additional_args_map.items()}) else: logger.warning( "Ignoring non-mapping workflow 'additional_function_arguments' kwarg of type %s for AgentExecutor %s.", # noqa: E501 diff --git a/python/packages/core/agent_framework/_workflows/_function_executor.py b/python/packages/core/agent_framework/_workflows/_function_executor.py index a27e250690..326145b6c4 100644 --- a/python/packages/core/agent_framework/_workflows/_function_executor.py +++ b/python/packages/core/agent_framework/_workflows/_function_executor.py @@ -119,7 +119,7 @@ def __init__( # Determine if function has WorkflowContext parameter self._has_context = ctx_annotation is not None # Determine if the function is an async function - self._is_async = asyncio.iscoroutinefunction(func) + self._is_async = inspect.iscoroutinefunction(func) # Initialize parent WITHOUT calling _discover_handlers yet # We'll manually set up the attributes first diff --git a/python/packages/core/agent_framework/_workflows/_typing_utils.py b/python/packages/core/agent_framework/_workflows/_typing_utils.py index 41ed071f0a..07b6d15bca 100644 --- a/python/packages/core/agent_framework/_workflows/_typing_utils.py +++ b/python/packages/core/agent_framework/_workflows/_typing_utils.py @@ -193,36 +193,40 @@ def try_coerce_to_type(data: Any, target_type: type | UnionType | Any) -> Any: Returns: The coerced value, or the original value if coercion fails. """ + original_data = data + # If already the right type, return as-is if is_instance_of(data, target_type): return data # Can't coerce to non-concrete targets (Union, generic, etc.) if not isinstance(target_type, type): - return data + return original_data + + target_cls: type[Any] = target_type # int -> float (JSON integers for float fields) - if isinstance(data, int) and target_type is float: + if isinstance(data, int) and target_cls is float: return float(data) - # dict -> dataclass + # dict -> dataclass or pydantic model if isinstance(data, dict): from dataclasses import is_dataclass - if is_dataclass(target_type): + if is_dataclass(target_cls): try: - return target_type(**data) + return target_cls(**data) except (TypeError, ValueError): - return data + return original_data - # dict -> Pydantic model - if hasattr(target_type, "model_validate"): + model_validate = getattr(target_cls, "model_validate", None) + if callable(model_validate): try: - return target_type.model_validate(data) + return model_validate(data) except Exception: - return data + return original_data - return data + return original_data def serialize_type(t: type) -> str: diff --git a/python/packages/core/agent_framework/azure/_assistants_client.py b/python/packages/core/agent_framework/azure/_assistants_client.py index 015a1dcc82..aae89d562d 100644 --- a/python/packages/core/agent_framework/azure/_assistants_client.py +++ b/python/packages/core/agent_framework/azure/_assistants_client.py @@ -12,7 +12,7 @@ from ..openai import OpenAIAssistantsClient from ..openai._assistants_client import OpenAIAssistantsOptions from ._entra_id_authentication import AzureCredentialTypes, AzureTokenProvider, resolve_credential_to_token_provider -from ._shared import AzureOpenAISettings, _apply_azure_defaults +from ._shared import AzureOpenAISettings, _apply_azure_defaults # pyright: ignore[reportPrivateUsage] if sys.version_info >= (3, 13): from typing import TypeVar # type: ignore # pragma: no cover @@ -145,43 +145,46 @@ class MyOptions(AzureOpenAIAssistantsOptions, total=False): ) _apply_azure_defaults(azure_openai_settings, default_api_version=self.DEFAULT_AZURE_API_VERSION) - if not azure_openai_settings["chat_deployment_name"]: + chat_deployment_name = azure_openai_settings.get("chat_deployment_name") + if not chat_deployment_name: raise ValueError( "Azure OpenAI deployment name is required. Set via 'deployment_name' parameter " "or 'AZURE_OPENAI_CHAT_DEPLOYMENT_NAME' environment variable." ) + api_key_secret = azure_openai_settings.get("api_key") + token_scope = azure_openai_settings.get("token_endpoint") + # Resolve credential to token provider ad_token_provider = None - if not async_client and not azure_openai_settings["api_key"] and credential: - ad_token_provider = resolve_credential_to_token_provider( - credential, azure_openai_settings["token_endpoint"] - ) + if not async_client and not api_key_secret and credential: + ad_token_provider = resolve_credential_to_token_provider(credential, token_scope) - if not async_client and not azure_openai_settings["api_key"] and not ad_token_provider: + if not async_client and not api_key_secret and not ad_token_provider: raise ValueError("Please provide either api_key, credential, or a client.") # Create Azure client if not provided if not async_client: client_params: dict[str, Any] = { - "api_version": azure_openai_settings["api_version"], "default_headers": default_headers, } + if resolved_api_version := azure_openai_settings.get("api_version"): + client_params["api_version"] = resolved_api_version - if azure_openai_settings["api_key"]: - client_params["api_key"] = azure_openai_settings["api_key"].get_secret_value() + if api_key_secret: + client_params["api_key"] = api_key_secret.get_secret_value() elif ad_token_provider: client_params["azure_ad_token_provider"] = ad_token_provider - if azure_openai_settings["base_url"]: - client_params["base_url"] = str(azure_openai_settings["base_url"]) - elif azure_openai_settings["endpoint"]: - client_params["azure_endpoint"] = str(azure_openai_settings["endpoint"]) + if resolved_base_url := azure_openai_settings.get("base_url"): + client_params["base_url"] = str(resolved_base_url) + elif resolved_endpoint := azure_openai_settings.get("endpoint"): + client_params["azure_endpoint"] = str(resolved_endpoint) async_client = AsyncAzureOpenAI(**client_params) super().__init__( - model_id=azure_openai_settings["chat_deployment_name"], + model_id=chat_deployment_name, assistant_id=assistant_id, assistant_name=assistant_name, assistant_description=assistant_description, diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index b4bd3659ed..b57abd6faf 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -6,7 +6,7 @@ import logging import sys from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Generic +from typing import TYPE_CHECKING, Any, Generic, cast from openai.lib.azure import AsyncAzureOpenAI from openai.types.chat.chat_completion import Choice @@ -31,7 +31,7 @@ from ._shared import ( AzureOpenAIConfigMixin, AzureOpenAISettings, - _apply_azure_defaults, + _apply_azure_defaults, # pyright: ignore[reportPrivateUsage] ) if sys.version_info >= (3, 13): @@ -260,19 +260,26 @@ class MyOptions(AzureOpenAIChatOptions, total=False): ) _apply_azure_defaults(azure_openai_settings) - if not azure_openai_settings["chat_deployment_name"]: + chat_deployment_name = azure_openai_settings.get("chat_deployment_name") + if not chat_deployment_name: raise ValueError( "Azure OpenAI deployment name is required. Set via 'deployment_name' parameter " "or 'AZURE_OPENAI_CHAT_DEPLOYMENT_NAME' environment variable." ) + endpoint_value = azure_openai_settings.get("endpoint") + base_url_value = azure_openai_settings.get("base_url") + api_version_value = cast(str, azure_openai_settings.get("api_version")) + api_key_value = azure_openai_settings.get("api_key") + token_endpoint_value = azure_openai_settings.get("token_endpoint") + super().__init__( - deployment_name=azure_openai_settings["chat_deployment_name"], - endpoint=azure_openai_settings["endpoint"], - base_url=azure_openai_settings["base_url"], - api_version=azure_openai_settings["api_version"], # type: ignore - api_key=azure_openai_settings["api_key"].get_secret_value() if azure_openai_settings["api_key"] else None, - token_endpoint=azure_openai_settings["token_endpoint"], + deployment_name=chat_deployment_name, + endpoint=endpoint_value, + base_url=base_url_value, + api_version=api_version_value, + api_key=api_key_value.get_secret_value() if api_key_value else None, + token_endpoint=token_endpoint_value, credential=credential, default_headers=default_headers, client=async_client, @@ -302,24 +309,29 @@ def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> Content | Non if not message.model_extra or "context" not in message.model_extra: return text_content - context: dict[str, Any] | str = message.context # type: ignore[assignment, union-attr] - if isinstance(context, str): + context_raw: object = cast(object, message.context) # type: ignore[union-attr] + if isinstance(context_raw, str): try: - context = json.loads(context) + context_raw = json.loads(context_raw) except json.JSONDecodeError: logger.warning("Context is not a valid JSON string, ignoring context.") return text_content - if not isinstance(context, dict): + if not isinstance(context_raw, dict): logger.warning("Context is not a valid dictionary, ignoring context.") return text_content + context = cast(dict[str, Any], context_raw) # `all_retrieved_documents` is currently not used, but can be retrieved # through the raw_representation in the text content. if intent := context.get("intent"): text_content.additional_properties = {"intent": intent} - if citations := context.get("citations"): - text_content.annotations = [] - for citation in citations: - text_content.annotations.append( + citations = context.get("citations") + if isinstance(citations, list) and citations: + annotations: list[Annotation] = [] + for citation_raw in cast(list[object], citations): + if not isinstance(citation_raw, dict): + continue + citation = cast(dict[str, Any], citation_raw) + annotations.append( Annotation( type="citation", title=citation.get("title", ""), @@ -331,4 +343,5 @@ def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> Content | Non raw_representation=citation, ) ) + text_content.annotations = annotations return text_content diff --git a/python/packages/core/agent_framework/azure/_embedding_client.py b/python/packages/core/agent_framework/azure/_embedding_client.py index 13455e78a4..7003a4611f 100644 --- a/python/packages/core/agent_framework/azure/_embedding_client.py +++ b/python/packages/core/agent_framework/azure/_embedding_client.py @@ -17,7 +17,7 @@ from ._shared import ( AzureOpenAIConfigMixin, AzureOpenAISettings, - _apply_azure_defaults, + _apply_azure_defaults, # pyright: ignore[reportPrivateUsage] ) if sys.version_info >= (3, 13): @@ -118,19 +118,22 @@ def __init__( ) _apply_azure_defaults(azure_openai_settings) - if not azure_openai_settings.get("embedding_deployment_name"): + embedding_deployment_name = azure_openai_settings.get("embedding_deployment_name") + if not embedding_deployment_name: raise ValueError( "Azure OpenAI embedding deployment name is required. Set via 'deployment_name' parameter " "or 'AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME' environment variable." ) + api_key_secret = azure_openai_settings.get("api_key") + super().__init__( - deployment_name=azure_openai_settings["embedding_deployment_name"], # type: ignore[arg-type] - endpoint=azure_openai_settings["endpoint"], - base_url=azure_openai_settings["base_url"], - api_version=azure_openai_settings["api_version"], # type: ignore - api_key=azure_openai_settings["api_key"].get_secret_value() if azure_openai_settings["api_key"] else None, - token_endpoint=azure_openai_settings["token_endpoint"], + deployment_name=embedding_deployment_name, + endpoint=azure_openai_settings.get("endpoint"), + base_url=azure_openai_settings.get("base_url"), + api_version=azure_openai_settings.get("api_version") or "", + api_key=api_key_secret.get_secret_value() if api_key_secret else None, + token_endpoint=azure_openai_settings.get("token_endpoint"), credential=credential, default_headers=default_headers, client=async_client, diff --git a/python/packages/core/agent_framework/azure/_responses_client.py b/python/packages/core/agent_framework/azure/_responses_client.py index 2debbd7b21..a420108ce0 100644 --- a/python/packages/core/agent_framework/azure/_responses_client.py +++ b/python/packages/core/agent_framework/azure/_responses_client.py @@ -20,7 +20,7 @@ from ._shared import ( AzureOpenAIConfigMixin, AzureOpenAISettings, - _apply_azure_defaults, + _apply_azure_defaults, # pyright: ignore[reportPrivateUsage] ) if sys.version_info >= (3, 13): @@ -207,27 +207,31 @@ class MyOptions(AzureOpenAIResponsesOptions, total=False): # TODO(peterychang): This is a temporary hack to ensure that the base_url is set correctly # while this feature is in preview. # But we should only do this if we're on azure. Private deployments may not need this. + endpoint_value = azure_openai_settings.get("endpoint") if ( not azure_openai_settings.get("base_url") - and azure_openai_settings.get("endpoint") - and (hostname := urlparse(str(azure_openai_settings["endpoint"])).hostname) + and endpoint_value + and (hostname := urlparse(str(endpoint_value)).hostname) and hostname.endswith(".openai.azure.com") ): - azure_openai_settings["base_url"] = urljoin(str(azure_openai_settings["endpoint"]), "/openai/v1/") + azure_openai_settings["base_url"] = urljoin(str(endpoint_value), "/openai/v1/") - if not azure_openai_settings["responses_deployment_name"]: + responses_deployment_name = azure_openai_settings.get("responses_deployment_name") + if not responses_deployment_name: raise ValueError( "Azure OpenAI deployment name is required. Set via 'deployment_name' parameter " "or 'AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME' environment variable." ) + api_key_secret = azure_openai_settings.get("api_key") + super().__init__( - deployment_name=azure_openai_settings["responses_deployment_name"], - endpoint=azure_openai_settings["endpoint"], - base_url=azure_openai_settings["base_url"], - api_version=azure_openai_settings["api_version"], # type: ignore - api_key=azure_openai_settings["api_key"].get_secret_value() if azure_openai_settings["api_key"] else None, - token_endpoint=azure_openai_settings["token_endpoint"], + deployment_name=responses_deployment_name, + endpoint=azure_openai_settings.get("endpoint"), + base_url=azure_openai_settings.get("base_url"), + api_version=azure_openai_settings.get("api_version") or "", + api_key=api_key_secret.get_secret_value() if api_key_secret else None, + token_endpoint=azure_openai_settings.get("token_endpoint"), credential=credential, default_headers=default_headers, client=async_client, diff --git a/python/packages/core/agent_framework/azure/_shared.py b/python/packages/core/agent_framework/azure/_shared.py index dce116a242..5e06fbbe74 100644 --- a/python/packages/core/agent_framework/azure/_shared.py +++ b/python/packages/core/agent_framework/azure/_shared.py @@ -123,6 +123,9 @@ def _apply_azure_defaults( settings["token_endpoint"] = default_token_endpoint +_AZURE_DEFAULTS_APPLIER = _apply_azure_defaults + + class AzureOpenAIConfigMixin(OpenAIBase): """Internal class for configuring a connection to an Azure OpenAI service.""" diff --git a/python/packages/core/agent_framework/declarative/__init__.pyi b/python/packages/core/agent_framework/declarative/__init__.pyi index 214bb132ab..92da0da682 100644 --- a/python/packages/core/agent_framework/declarative/__init__.pyi +++ b/python/packages/core/agent_framework/declarative/__init__.pyi @@ -4,7 +4,6 @@ from agent_framework_declarative import ( AgentExternalInputRequest, AgentExternalInputResponse, AgentFactory, - AgentInvocationError, DeclarativeLoaderError, DeclarativeWorkflowError, ExternalInputRequest, @@ -19,7 +18,6 @@ __all__ = [ "AgentExternalInputRequest", "AgentExternalInputResponse", "AgentFactory", - "AgentInvocationError", "DeclarativeLoaderError", "DeclarativeWorkflowError", "ExternalInputRequest", diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 9a60053068..0e7cc2a8ee 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -22,7 +22,7 @@ from collections.abc import Awaitable, Callable, Generator, Mapping, Sequence from enum import Enum from time import perf_counter, time_ns -from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, TypedDict, overload +from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, TypedDict, cast, overload from dotenv import load_dotenv from opentelemetry import metrics, trace @@ -894,7 +894,6 @@ def get_meter( return metrics.get_meter(name=name, version=version, schema_url=schema_url) -global OBSERVABILITY_SETTINGS OBSERVABILITY_SETTINGS: ObservabilitySettings = ObservabilitySettings() @@ -1053,7 +1052,15 @@ def configure_otel_providers( if vs_code_extension_port is not None: settings_kwargs["vs_code_extension_port"] = vs_code_extension_port - OBSERVABILITY_SETTINGS = ObservabilitySettings(**settings_kwargs) + updated_settings = ObservabilitySettings(**settings_kwargs) + OBSERVABILITY_SETTINGS.enable_instrumentation = updated_settings.enable_instrumentation + OBSERVABILITY_SETTINGS.enable_sensitive_data = updated_settings.enable_sensitive_data + OBSERVABILITY_SETTINGS.enable_console_exporters = updated_settings.enable_console_exporters + OBSERVABILITY_SETTINGS.vs_code_extension_port = updated_settings.vs_code_extension_port + OBSERVABILITY_SETTINGS.env_file_path = updated_settings.env_file_path + OBSERVABILITY_SETTINGS.env_file_encoding = updated_settings.env_file_encoding + OBSERVABILITY_SETTINGS._resource = updated_settings._resource # pyright: ignore[reportPrivateUsage] + OBSERVABILITY_SETTINGS._executed_setup = False # pyright: ignore[reportPrivateUsage] else: # Update the observability settings with the provided values OBSERVABILITY_SETTINGS.enable_instrumentation = True @@ -1147,7 +1154,10 @@ def get_response( ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: """Trace chat responses with OpenTelemetry spans and metrics.""" global OBSERVABILITY_SETTINGS - super_get_response = super().get_response # type: ignore[misc] + super_get_response = cast( + "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", + super().get_response, # type: ignore[misc] + ) if not OBSERVABILITY_SETTINGS.ENABLED: return super_get_response(messages=messages, stream=stream, options=options, **kwargs) # type: ignore[no-any-return] @@ -1168,11 +1178,16 @@ def get_response( if stream: from ._types import ResponseStream - stream_result = super_get_response(messages=messages, stream=True, options=opts, **kwargs) + stream_result: object = super_get_response(messages=messages, stream=True, options=opts, **kwargs) if isinstance(stream_result, ResponseStream): - result_stream = stream_result + result_stream: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] = stream_result # pyright: ignore[reportUnknownVariableType] elif isinstance(stream_result, Awaitable): - result_stream = ResponseStream.from_awaitable(stream_result) + result_stream = cast( + ResponseStream[ChatResponseUpdate, ChatResponse[Any]], + ResponseStream.from_awaitable( # pyright: ignore[reportUnknownMemberType] + cast(Awaitable[ResponseStream[ChatResponseUpdate, ChatResponse[Any]]], stream_result) + ), + ) else: raise RuntimeError("Streaming telemetry requires a ResponseStream result.") @@ -1209,7 +1224,7 @@ async def _finalize_stream() -> None: from ._types import ChatResponse try: - response = await result_stream.get_final_response() + response: ChatResponse[Any] = await result_stream.get_final_response() duration = duration_state.get("duration") response_attributes = _get_response_attributes(attributes, response) _capture_response( @@ -1238,7 +1253,9 @@ async def _finalize_stream() -> None: # Register a weak reference callback to close the span if stream is garbage collected # without being consumed. This ensures spans don't leak if users don't consume streams. - wrapped_stream = result_stream.with_cleanup_hook(_record_duration).with_cleanup_hook(_finalize_stream) + wrapped_stream: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] = ( + result_stream.with_cleanup_hook(_record_duration).with_cleanup_hook(_finalize_stream) + ) weakref.finalize(wrapped_stream, _close_span) return wrapped_stream @@ -1253,7 +1270,12 @@ async def _get_response() -> ChatResponse: ) start_time_stamp = perf_counter() try: - response = await super_get_response(messages=messages, stream=False, options=opts, **kwargs) + response: ChatResponse[Any] = await super_get_response( + messages=messages, + stream=False, + options=opts, + **kwargs, + ) except Exception as exception: capture_exception(span=span, exception=exception, timestamp=time_ns()) raise @@ -1267,11 +1289,15 @@ async def _get_response() -> ChatResponse: duration=duration, ) if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages: + finish_reason = cast( + "FinishReason | None", + response.finish_reason if response.finish_reason in FINISH_REASON_MAP else None, + ) _capture_messages( span=span, provider_name=provider_name, messages=response.messages, - finish_reason=response.finish_reason, + finish_reason=finish_reason, output=True, ) return response # type: ignore[return-value,no-any-return] @@ -1305,7 +1331,10 @@ async def get_embeddings( ) -> GeneratedEmbeddings[EmbeddingT]: """Trace embedding generation with OpenTelemetry spans and metrics.""" global OBSERVABILITY_SETTINGS - super_get_embeddings = super().get_embeddings # type: ignore[misc] + super_get_embeddings = cast( + "Callable[..., Awaitable[GeneratedEmbeddings[EmbeddingT]]]", + super().get_embeddings, # type: ignore[misc] + ) if not OBSERVABILITY_SETTINGS.ENABLED: return await super_get_embeddings(values, options=options) # type: ignore[no-any-return] @@ -1325,14 +1354,16 @@ async def get_embeddings( with _get_span(attributes=attributes, span_name_attribute=OtelAttr.REQUEST_MODEL) as span: start_time_stamp = perf_counter() try: - result = await super_get_embeddings(values, options=options) + result: GeneratedEmbeddings[EmbeddingT] = await super_get_embeddings(values, options=options) except Exception as exception: capture_exception(span=span, exception=exception, timestamp=time_ns()) raise duration = perf_counter() - start_time_stamp response_attributes: dict[str, Any] = {**attributes} - if result.usage and "prompt_tokens" in result.usage: - response_attributes[OtelAttr.INPUT_TOKENS] = result.usage["prompt_tokens"] + usage = cast(Mapping[str, Any], result.usage) if result.usage else None + prompt_tokens = usage.get("prompt_tokens") if usage is not None else None + if prompt_tokens is not None: + response_attributes[OtelAttr.INPUT_TOKENS] = prompt_tokens _capture_response( span=span, attributes=response_attributes, @@ -1391,7 +1422,12 @@ def run( ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Trace agent runs with OpenTelemetry spans and metrics.""" global OBSERVABILITY_SETTINGS - super_run = super().run # type: ignore[misc] + from ._types import ResponseStream, merge_chat_options + + super_run = cast( + "Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]", + super().run, # type: ignore[misc] + ) provider_name = str(self.otel_provider_name) capture_usage = bool(getattr(self, "_otel_capture_usage", True)) @@ -1403,8 +1439,6 @@ def run( **kwargs, ) - from ._types import ResponseStream, merge_chat_options - default_options = getattr(self, "default_options", {}) options = kwargs.get("options") merged_options: dict[str, Any] = merge_chat_options(default_options, options or {}) @@ -1420,16 +1454,21 @@ def run( ) if stream: - run_result = super_run( + run_result: object = super_run( messages=messages, stream=True, session=session, **kwargs, ) if isinstance(run_result, ResponseStream): - result_stream = run_result + result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType] elif isinstance(run_result, Awaitable): - result_stream = ResponseStream.from_awaitable(run_result) + result_stream = cast( + ResponseStream[AgentResponseUpdate, AgentResponse[Any]], + ResponseStream.from_awaitable( # pyright: ignore[reportUnknownMemberType] + cast(Awaitable[ResponseStream[AgentResponseUpdate, AgentResponse[Any]]], run_result) + ), + ) else: raise RuntimeError("Streaming telemetry requires a ResponseStream result.") @@ -1466,7 +1505,7 @@ async def _finalize_stream() -> None: from ._types import AgentResponse try: - response = await result_stream.get_final_response() + response: AgentResponse[Any] = await result_stream.get_final_response() duration = duration_state.get("duration") response_attributes = _get_response_attributes( attributes, @@ -1492,7 +1531,9 @@ async def _finalize_stream() -> None: # Register a weak reference callback to close the span if stream is garbage collected # without being consumed. This ensures spans don't leak if users don't consume streams. - wrapped_stream = result_stream.with_cleanup_hook(_record_duration).with_cleanup_hook(_finalize_stream) + wrapped_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = ( + result_stream.with_cleanup_hook(_record_duration).with_cleanup_hook(_finalize_stream) + ) weakref.finalize(wrapped_stream, _close_span) return wrapped_stream @@ -1507,7 +1548,7 @@ async def _run() -> AgentResponse: ) start_time_stamp = perf_counter() try: - response = await super_run( + response: AgentResponse[Any] = await super_run( messages=messages, stream=False, session=session, @@ -1598,12 +1639,18 @@ def _get_span( yield current_span -def _get_instructions_from_options(options: Any) -> str | None: +def _get_instructions_from_options(options: Any) -> str | list[str] | None: """Extract instructions from options dict.""" if options is None: return None - if isinstance(options, dict): - return options.get("instructions") + if isinstance(options, Mapping): + instructions = cast(Mapping[str, Any], options).get("instructions") + if isinstance(instructions, str): + return instructions + if isinstance(instructions, list): + if all(isinstance(item, str) for item in instructions): # pyright: ignore[reportUnknownVariableType] + return cast("list[str]", instructions) + return None return None @@ -1662,8 +1709,11 @@ def _get_span_attributes(**kwargs: Any) -> dict[str, Any]: """Get the span attributes from a kwargs dictionary.""" attributes: dict[str, Any] = {} options = kwargs.get("all_options", kwargs.get("options")) - if options is not None and not isinstance(options, dict): - options = None + options_mapping: Mapping[str, Any] | None + if isinstance(options, Mapping): + options_mapping = cast(Mapping[str, Any], options) + else: + options_mapping = None for source_keys, (otel_key, transform_func, check_options, default_value) in OTEL_ATTR_MAP.items(): # Normalize to tuple of keys @@ -1671,8 +1721,8 @@ def _get_span_attributes(**kwargs: Any) -> dict[str, Any]: value = None for key in keys: - if check_options and options is not None: - value = options.get(key) + if check_options and options_mapping is not None: + value = options_mapping.get(key) if value is None: value = kwargs.get(key) if value is not None: @@ -1743,7 +1793,7 @@ def _to_otel_message(message: Message) -> dict[str, Any]: def _to_otel_part(content: Content) -> dict[str, Any] | None: """Create a otel representation of a Content.""" - from ._types import _get_data_bytes_as_str + from ._types import _get_data_bytes_as_str # pyright: ignore[reportPrivateUsage] match content.type: case "text": @@ -1798,10 +1848,12 @@ def _get_response_attributes( if model_id := getattr(response, "model_id", None): attributes[OtelAttr.RESPONSE_MODEL] = model_id if capture_usage and (usage := response.usage_details): - if usage.get("input_token_count"): - attributes[OtelAttr.INPUT_TOKENS] = usage["input_token_count"] - if usage.get("output_token_count"): - attributes[OtelAttr.OUTPUT_TOKENS] = usage["output_token_count"] + input_tokens = usage.get("input_token_count") + if input_tokens: + attributes[OtelAttr.INPUT_TOKENS] = input_tokens + output_tokens = usage.get("output_token_count") + if output_tokens: + attributes[OtelAttr.OUTPUT_TOKENS] = output_tokens return attributes diff --git a/python/packages/core/agent_framework/openai/_assistant_provider.py b/python/packages/core/agent_framework/openai/_assistant_provider.py index ecf27db316..9746725128 100644 --- a/python/packages/core/agent_framework/openai/_assistant_provider.py +++ b/python/packages/core/agent_framework/openai/_assistant_provider.py @@ -3,7 +3,7 @@ from __future__ import annotations import sys -from collections.abc import Awaitable, Callable, MutableMapping, Sequence +from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence from typing import TYPE_CHECKING, Any, Generic, cast from openai import AsyncOpenAI @@ -149,24 +149,25 @@ def __init__( env_file_encoding=env_file_encoding, ) - if not settings["api_key"]: + api_key_setting = settings.get("api_key") + if not api_key_setting: raise ValueError( "OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable." ) # Get API key value - api_key_value: str | Callable[[], str | Awaitable[str]] | None - if isinstance(settings["api_key"], SecretString): - api_key_value = settings["api_key"].get_secret_value() + api_key_value: str | Callable[[], str | Awaitable[str]] + if isinstance(api_key_setting, SecretString): + api_key_value = api_key_setting.get_secret_value() else: - api_key_value = settings["api_key"] + api_key_value = api_key_setting # Create client client_args: dict[str, Any] = {"api_key": api_key_value} - if settings["org_id"]: - client_args["organization"] = settings["org_id"] - if settings["base_url"]: - client_args["base_url"] = settings["base_url"] + if org_id_value := settings.get("org_id"): + client_args["organization"] = org_id_value + if base_url_value := settings.get("base_url"): + client_args["base_url"] = base_url_value self._client = AsyncOpenAI(**client_args) @@ -250,7 +251,9 @@ async def create_agent( """ # Normalize tools normalized_tools = normalize_tools(tools) - assistant_tools = [tool for tool in normalized_tools if isinstance(tool, (FunctionTool, MutableMapping))] + assistant_tools: list[FunctionTool | MutableMapping[str, Any]] = [ + tool for tool in normalized_tools if isinstance(tool, (FunctionTool, MutableMapping)) + ] api_tools = to_assistant_tools(assistant_tools) if assistant_tools else [] # Extract response_format from default_options if present @@ -287,7 +290,7 @@ async def create_agent( if not self._client: raise RuntimeError("OpenAI client is not initialized.") - assistant = await self._client.beta.assistants.create(**create_params) + assistant = await self._client.beta.assistants.create(**create_params) # type: ignore[reportDeprecated] # Create Agent - pass default_options which contains response_format return self._create_chat_agent_from_assistant( @@ -353,7 +356,7 @@ async def get_agent( if not self._client: raise RuntimeError("OpenAI client is not initialized.") - assistant = await self._client.beta.assistants.retrieve(assistant_id) + assistant = await self._client.beta.assistants.retrieve(assistant_id) # type: ignore[reportDeprecated] # Use as_agent to wrap it return self.as_agent( @@ -466,12 +469,14 @@ def _validate_function_tools( for tool in normalized: if isinstance(tool, FunctionTool): provided_functions.add(tool.name) - elif isinstance(tool, MutableMapping) and "function" in tool: - func_spec = tool.get("function", {}) - if isinstance(func_spec, dict): - func_dict = cast(dict[str, Any], func_spec) - if "name" in func_dict: - provided_functions.add(str(func_dict["name"])) + elif isinstance(tool, Mapping): + typed_tool = cast(Mapping[str, Any], tool) + raw_func_spec = typed_tool.get("function") + if isinstance(raw_func_spec, Mapping): + typed_func_spec = cast(Mapping[str, Any], raw_func_spec) + raw_name = typed_func_spec.get("name") + if isinstance(raw_name, str) and raw_name: + provided_functions.add(raw_name) # Check for missing functions missing = required_functions - provided_functions diff --git a/python/packages/core/agent_framework/openai/_assistants_client.py b/python/packages/core/agent_framework/openai/_assistants_client.py index 17b801a36a..b90935a33f 100644 --- a/python/packages/core/agent_framework/openai/_assistants_client.py +++ b/python/packages/core/agent_framework/openai/_assistants_client.py @@ -360,23 +360,26 @@ class MyOptions(OpenAIAssistantsOptions, total=False): env_file_encoding=env_file_encoding, ) - if not async_client and not openai_settings["api_key"]: + api_key_value = openai_settings.get("api_key") + if not async_client and not api_key_value: raise ValueError( "OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable." ) - if not openai_settings["chat_model_id"]: + + chat_model_id = openai_settings.get("chat_model_id") + if not chat_model_id: raise ValueError( "OpenAI model ID is required. " "Set via 'model_id' parameter or 'OPENAI_CHAT_MODEL_ID' environment variable." ) super().__init__( - model_id=openai_settings["chat_model_id"], - api_key=self._get_api_key(openai_settings["api_key"]), - org_id=openai_settings["org_id"], + model_id=chat_model_id, + api_key=self._get_api_key(api_key_value), + org_id=openai_settings.get("org_id"), default_headers=default_headers, client=async_client, - base_url=openai_settings["base_url"], + base_url=openai_settings.get("base_url"), middleware=middleware, function_invocation_configuration=function_invocation_configuration, ) @@ -403,7 +406,7 @@ async def close(self) -> None: """Clean up any assistants we created.""" if self._should_delete_assistant and self.assistant_id is not None: client = await self._ensure_client() - await client.beta.assistants.delete(self.assistant_id) + await client.beta.assistants.delete(self.assistant_id) # type: ignore[reportDeprecated] object.__setattr__(self, "assistant_id", None) object.__setattr__(self, "_should_delete_assistant", False) @@ -466,7 +469,7 @@ async def _get_assistant_id_or_create(self) -> str: raise ValueError("Parameter 'model_id' is required for assistant creation.") client = await self._ensure_client() - created_assistant = await client.beta.assistants.create( + created_assistant = await client.beta.assistants.create( # type: ignore[reportDeprecated] model=self.model_id, description=self.assistant_description, name=self.assistant_name, @@ -568,7 +571,8 @@ async def _process_stream_events(self, stream: Any, thread_id: str) -> AsyncIter if isinstance(delta_block, TextDeltaBlock) and delta_block.text and delta_block.text.value: text_content = Content.from_text(delta_block.text.value) if delta_block.text.annotations: - text_content.annotations = [] + annotations: list[Annotation] = [] + text_content.annotations = annotations for annotation in delta_block.text.annotations: if isinstance(annotation, FileCitationDeltaAnnotation): ann: Annotation = Annotation( @@ -589,7 +593,7 @@ async def _process_stream_events(self, stream: Any, thread_id: str) -> AsyncIter end_index=annotation.end_index, ) ] - text_content.annotations.append(ann) + annotations.append(ann) elif isinstance(annotation, FilePathDeltaAnnotation): ann = Annotation( type="citation", @@ -609,7 +613,7 @@ async def _process_stream_events(self, stream: Any, thread_id: str) -> AsyncIter end_index=annotation.end_index, ) ] - text_content.annotations.append(ann) + annotations.append(ann) yield ChatResponseUpdate( role=role, # type: ignore[arg-type] contents=[text_content], @@ -628,7 +632,8 @@ async def _process_stream_events(self, stream: Any, thread_id: str) -> AsyncIter continue text_content = Content.from_text(block.text.value) if block.text.annotations: - text_content.annotations = [] + completed_annotations: list[Annotation] = [] + text_content.annotations = completed_annotations for completed_annotation in block.text.annotations: if isinstance(completed_annotation, FileCitationAnnotation): props: dict[str, Any] = { @@ -823,15 +828,16 @@ def _prepare_options( tool_definitions.append(tool.to_json_schema_spec()) # type: ignore[reportUnknownArgumentType] elif isinstance(tool, MutableMapping): # Pass through dict-based tools directly (from static factory methods) - tool_definitions.append(tool) + tool_definitions.append(cast(MutableMapping[str, Any], tool)) if len(tool_definitions) > 0: run_options["tools"] = tool_definitions if tool_mode is not None: - if (mode := tool_mode["mode"]) == "required" and ( - func_name := tool_mode.get("required_function_name") - ) is not None: + mode = tool_mode.get("mode") + if mode is None: + raise ValueError("tool_choice mode is required") + if mode == "required" and (func_name := tool_mode.get("required_function_name")) is not None: run_options["tool_choice"] = { "type": "function", "function": {"name": func_name}, diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index 0c3d346129..0214c8df20 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -15,7 +15,7 @@ ) from datetime import datetime, timezone from itertools import chain -from typing import Any, Generic, Literal +from typing import Any, Generic, Literal, cast from openai import AsyncOpenAI, BadRequestError from openai.lib._parsing._completions import type_to_response_format_param @@ -301,11 +301,16 @@ def _prepare_tools_for_openai( for tool in normalize_tools(tools): if isinstance(tool, FunctionTool): chat_tools.append(tool.to_json_schema_spec()) - elif isinstance(tool, MutableMapping) and tool.get("type") == "web_search": - # Web search is handled via web_search_options, not tools array - web_search_options = {k: v for k, v in tool.items() if k != "type"} + elif isinstance(tool, MutableMapping): + typed_tool = cast(MutableMapping[str, Any], tool) + if typed_tool.get("type") == "web_search": + # Web search is handled via web_search_options, not tools array + web_search_options = {k: v for k, v in typed_tool.items() if k != "type"} + else: + # Pass through all other dict-based tools unchanged + chat_tools.append(typed_tool) else: - # Pass through all other tools (dicts, SDK types) unchanged + # Pass through all other tools (SDK types) unchanged chat_tools.append(tool) result: dict[str, Any] = {} if chat_tools: @@ -608,10 +613,21 @@ def _prepare_message_for_openai(self, message: Message) -> list[dict[str, Any]]: # See https://github.com/microsoft/agent-framework/issues/4084 for msg in all_messages: msg_content: Any = msg.get("content") - if isinstance(msg_content, list) and all( - isinstance(c, dict) and c.get("type") == "text" for c in msg_content - ): - msg["content"] = "\n".join(c.get("text", "") for c in msg_content) + if isinstance(msg_content, list): + typed_msg_content = cast(list[object], msg_content) + text_items: list[Mapping[str, Any]] = [] + for item in typed_msg_content: + if not isinstance(item, Mapping): + break + text_item = cast(Mapping[str, Any], item) + if text_item.get("type") != "text": + break + text_items.append(text_item) + else: + msg["content"] = "\n".join( + text_item.get("text", "") if isinstance(text_item.get("text", ""), str) else "" + for text_item in text_items + ) return all_messages @@ -775,21 +791,26 @@ class MyOptions(OpenAIChatOptions, total=False): env_file_encoding=env_file_encoding, ) - if not async_client and not openai_settings["api_key"]: + api_key_value = openai_settings.get("api_key") + if not async_client and not api_key_value: raise ValueError( "OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable." ) - if not openai_settings["chat_model_id"]: + + chat_model_id = openai_settings.get("chat_model_id") + if not chat_model_id: raise ValueError( "OpenAI model ID is required. " "Set via 'model_id' parameter or 'OPENAI_CHAT_MODEL_ID' environment variable." ) + base_url_value = openai_settings.get("base_url") + super().__init__( - model_id=openai_settings["chat_model_id"], - api_key=self._get_api_key(openai_settings["api_key"]), - base_url=openai_settings["base_url"] if openai_settings["base_url"] else None, - org_id=openai_settings["org_id"], + model_id=chat_model_id, + api_key=self._get_api_key(api_key_value), + base_url=base_url_value if base_url_value else None, + org_id=openai_settings.get("org_id"), default_headers=default_headers, client=async_client, instruction_role=instruction_role, diff --git a/python/packages/core/agent_framework/openai/_embedding_client.py b/python/packages/core/agent_framework/openai/_embedding_client.py index fb479c181c..e730bf62d3 100644 --- a/python/packages/core/agent_framework/openai/_embedding_client.py +++ b/python/packages/core/agent_framework/openai/_embedding_client.py @@ -6,7 +6,7 @@ import struct import sys from collections.abc import Awaitable, Callable, Mapping, Sequence -from typing import Any, Generic, Literal, TypedDict +from typing import Any, Generic, Literal, TypedDict, cast from openai import AsyncOpenAI @@ -81,7 +81,7 @@ async def get_embeddings( ValueError: If model_id is not provided or values is empty. """ if not values: - return GeneratedEmbeddings([], options=options) + return cast(GeneratedEmbeddings[list[float]], GeneratedEmbeddings([], options=options)) opts: dict[str, Any] = dict(options) if options else {} model = opts.get("model_id") or self.model_id @@ -123,7 +123,7 @@ async def get_embeddings( "total_token_count": response.usage.total_tokens, } - return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict) + return cast(GeneratedEmbeddings[list[float]], GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)) class OpenAIEmbeddingClient( @@ -193,21 +193,26 @@ def __init__( env_file_encoding=env_file_encoding, ) - if not async_client and not openai_settings["api_key"]: + api_key_value = openai_settings.get("api_key") + if not async_client and not api_key_value: raise ValueError( "OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable." ) - if not openai_settings["embedding_model_id"]: + + embedding_model_id = openai_settings.get("embedding_model_id") + if not embedding_model_id: raise ValueError( "OpenAI embedding model ID is required. " "Set via 'model_id' parameter or 'OPENAI_EMBEDDING_MODEL_ID' environment variable." ) + base_url_value = openai_settings.get("base_url") + super().__init__( - model_id=openai_settings["embedding_model_id"], - api_key=self._get_api_key(openai_settings["api_key"]), - base_url=openai_settings["base_url"] if openai_settings["base_url"] else None, - org_id=openai_settings["org_id"], + model_id=embedding_model_id, + api_key=self._get_api_key(api_key_value), + base_url=base_url_value if base_url_value else None, + org_id=openai_settings.get("org_id"), default_headers=default_headers, client=async_client, otel_provider_name=otel_provider_name, diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index f11b60b767..b4fb1cbe1c 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -2293,24 +2293,26 @@ class MyOptions(OpenAIResponsesOptions, total=False): env_file_encoding=env_file_encoding, ) - if not async_client and not openai_settings["api_key"]: + api_key_setting = openai_settings.get("api_key") + if not async_client and not api_key_setting: raise ValueError( "OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable." ) - if not openai_settings["responses_model_id"]: + responses_model_id = openai_settings.get("responses_model_id") + if not responses_model_id: raise ValueError( "OpenAI model ID is required. " "Set via 'model_id' parameter or 'OPENAI_RESPONSES_MODEL_ID' environment variable." ) super().__init__( - model_id=openai_settings["responses_model_id"], - api_key=self._get_api_key(openai_settings["api_key"]), - org_id=openai_settings["org_id"], + model_id=responses_model_id, + api_key=self._get_api_key(api_key_setting), + org_id=openai_settings.get("org_id"), default_headers=default_headers, client=async_client, instruction_role=instruction_role, - base_url=openai_settings["base_url"], + base_url=openai_settings.get("base_url"), middleware=middleware, function_invocation_configuration=function_invocation_configuration, **kwargs, diff --git a/python/packages/core/agent_framework/openai/_shared.py b/python/packages/core/agent_framework/openai/_shared.py index 67f0e91818..9817b7fb11 100644 --- a/python/packages/core/agent_framework/openai/_shared.py +++ b/python/packages/core/agent_framework/openai/_shared.py @@ -6,7 +6,7 @@ import sys from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence from copy import copy -from typing import Any, ClassVar, Union +from typing import Any, ClassVar, Union, cast import openai from openai import ( @@ -332,8 +332,10 @@ def from_assistant_tools( for tool in assistant_tools: if hasattr(tool, "type"): tool_type = tool.type - elif isinstance(tool, dict): - tool_type = tool.get("type") + elif isinstance(tool, Mapping): + typed_tool = cast(Mapping[str, Any], tool) + tool_type_value: Any = typed_tool.get("type") + tool_type = tool_type_value if isinstance(tool_type_value, str) else None else: tool_type = None diff --git a/python/packages/core/pyproject.toml b/python/packages/core/pyproject.toml index 5a0b3d8c2d..c9708a084b 100644 --- a/python/packages/core/pyproject.toml +++ b/python/packages/core/pyproject.toml @@ -109,6 +109,7 @@ include = ["tests/workflow"] [tool.mypy] plugins = ['pydantic.mypy'] strict = true +incremental = false python_version = "3.10" ignore_missing_imports = true disallow_untyped_defs = true diff --git a/python/packages/declarative/agent_framework_declarative/_loader.py b/python/packages/declarative/agent_framework_declarative/_loader.py index 79bedb657d..ebbc61280d 100644 --- a/python/packages/declarative/agent_framework_declarative/_loader.py +++ b/python/packages/declarative/agent_framework_declarative/_loader.py @@ -34,7 +34,7 @@ RemoteConnection, Tool, WebSearchTool, - _safe_mode_context, + _safe_mode_context, # pyright: ignore[reportPrivateUsage] agent_schema_dispatch, ) @@ -598,6 +598,9 @@ async def _create_agent_with_provider(self, prompt_agent: PromptAgent, mapping: case ApiKeyConnection(): if prompt_agent.model.connection.endpoint: provider_kwargs["project_endpoint"] = prompt_agent.model.connection.endpoint + case ReferenceConnection(): + # Reference connections are resolved by concrete providers when supported. + pass # Create the provider and use it to create the agent provider = provider_class(**provider_kwargs) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index 01a68e6a8e..87cdb7dac7 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -121,10 +121,12 @@ def _make_powerfx_safe(value: Any) -> Any: return value if isinstance(value, dict): - return {k: _make_powerfx_safe(v) for k, v in value.items()} + value_dict = cast(Mapping[Any, Any], value) + return {str(k): _make_powerfx_safe(v) for k, v in value_dict.items()} if isinstance(value, list): - return [_make_powerfx_safe(item) for item in value] + value_list = cast(list[Any], value) # type: ignore[redundant-cast] + return [_make_powerfx_safe(item) for item in value_list] # Try to convert objects with __dict__ or dataclass-style attributes if hasattr(value, "__dict__"): @@ -385,7 +387,9 @@ def eval(self, expression: str) -> Any: engine = Engine() symbols = self._to_powerfx_symbols() try: - from System.Globalization import CultureInfo + from System.Globalization import ( + CultureInfo, # pyright: ignore[reportMissingImports, reportUnknownVariableType] + ) original_culture = CultureInfo.CurrentCulture original_ui_culture = CultureInfo.CurrentUICulture @@ -424,7 +428,7 @@ def _eval_custom_function(self, formula: str) -> Any | None: args_str = match.group(1) # Parse comma-separated arguments (handling nested parentheses) args = self._parse_function_args(args_str) - evaluated_args = [] + evaluated_args: list[str] = [] for arg in args: arg = arg.strip() if arg.startswith('"') and arg.endswith('"'): @@ -576,37 +580,44 @@ def _eval_and_replace_message_text(self, inner_expr: str) -> str: """ messages: Any = self.eval(f"={inner_expr}") if isinstance(messages, list) and messages: - last_msg: Any = messages[-1] + message_list = cast(list[Any], messages) # type: ignore[redundant-cast] + last_msg: Any = message_list[-1] if isinstance(last_msg, dict): + last_msg_dict = cast(dict[str, Any], last_msg) # Try "text" key first (simple dict format) - if "text" in last_msg: - return str(last_msg["text"]) + if "text" in last_msg_dict: + return str(last_msg_dict["text"]) # Try extracting from "contents" (Message dict format) # Message.text concatenates text from all TextContent items - contents = last_msg.get("contents", []) - if isinstance(contents, list): - text_parts = [] + contents_obj = last_msg_dict.get("contents", []) + if isinstance(contents_obj, list): + contents = cast(list[Any], contents_obj) # type: ignore[redundant-cast] + text_parts: list[str] = [] for content in contents: if isinstance(content, dict): + content_dict = cast(dict[str, Any], content) # TextContent has a "text" key - if content.get("type") == "text" or "text" in content: - text_parts.append(str(content.get("text", ""))) - elif hasattr(content, "text"): - text_parts.append(str(getattr(content, "text", ""))) + if content_dict.get("type") == "text" or "text" in content_dict: + text_parts.append(str(content_dict.get("text", ""))) + else: + content_obj: object = content + if hasattr(content_obj, "text"): + text_parts.append(str(getattr(content_obj, "text", ""))) if text_parts: return " ".join(text_parts) return "" - if hasattr(last_msg, "text"): - return str(getattr(last_msg, "text", "")) + last_msg_obj: object = last_msg + if hasattr(last_msg_obj, "text"): + return str(getattr(last_msg_obj, "text", "")) return "" def _parse_function_args(self, args_str: str) -> list[str]: """Parse comma-separated function arguments, handling nested parentheses and strings.""" - args = [] - current = [] + args: list[str] = [] + current: list[str] = [] depth = 0 in_string = False - string_char = None + string_char: str | None = None for char in args_str: if char in ('"', "'") and not in_string: diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py index 65e129d921..6843c5bd92 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_builder.py @@ -14,7 +14,7 @@ from __future__ import annotations import logging -from typing import Any +from typing import Any, cast from agent_framework import ( Workflow, @@ -983,8 +983,9 @@ def _get_branch_exit(self, branch_entry: Any) -> Any | None: last_executor = chain[-1] # Skip terminators — they handle their own control flow - action_def = getattr(last_executor, "_action_def", {}) - if isinstance(action_def, dict) and action_def.get("kind", "") in TERMINATOR_ACTIONS: + action_def_obj = getattr(last_executor, "_action_def", {}) + action_def = cast(dict[str, Any], action_def_obj) if isinstance(action_def_obj, dict) else {} + if action_def.get("kind", "") in TERMINATOR_ACTIONS: return None # Check if last executor is a structure with branch_exits diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py index c2fded5fb8..20345bb750 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py @@ -188,7 +188,7 @@ def _validate_conversation_history(messages: list[Message], agent_name: str) -> tool_result_ids: set[str] = set() for i, msg in enumerate(messages): - if not hasattr(msg, "contents") or msg.contents is None: + if not hasattr(msg, "contents"): continue for content in msg.contents: if content.type == "function_call" and content.call_id: diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py index 4643cfd34b..1e9b4a8bc9 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py @@ -7,7 +7,8 @@ """ import uuid -from typing import Any +from collections.abc import Mapping +from typing import Any, cast from agent_framework import ( WorkflowContext, @@ -28,9 +29,13 @@ def _get_variable_path(action_def: dict[str, Any], key: str = "variable") -> str variable = action_def.get(key) if isinstance(variable, str): return variable - if isinstance(variable, dict): - return variable.get("path") - return action_def.get("path") + if isinstance(variable, Mapping): + variable_map = cast(Mapping[str, Any], variable) + path = variable_map.get("path") + return path if isinstance(path, str) else None + + fallback_path = action_def.get("path") + return fallback_path if isinstance(fallback_path, str) else None class SetValueExecutor(DeclarativeActionExecutor): @@ -150,16 +155,23 @@ async def handle_action( """Handle the SetMultipleVariables action.""" state = await self._ensure_state_initialized(ctx, trigger) - assignments = self._action_def.get("assignments", []) - for assignment in assignments: + assignments_obj = self._action_def.get("assignments", []) + assignments = cast(list[Any], assignments_obj) if isinstance(assignments_obj, list) else [] # type: ignore[redundant-cast] + for assignment_obj in assignments: + if not isinstance(assignment_obj, Mapping): + continue + assignment = cast(Mapping[str, Any], assignment_obj) variable = assignment.get("variable") path: str | None if isinstance(variable, str): path = variable - elif isinstance(variable, dict): - path = variable.get("path") + elif isinstance(variable, Mapping): + variable_map = cast(Mapping[str, Any], variable) + path_value = variable_map.get("path") + path = path_value if isinstance(path_value, str) else None else: - path = assignment.get("path") + fallback_path = assignment.get("path") + path = fallback_path if isinstance(fallback_path, str) else None value = assignment.get("value") if path: evaluated_value = state.eval_if_expression(value) @@ -249,7 +261,11 @@ async def handle_action( activity = self._action_def.get("activity", "") # Activity can be a string directly or a dict with a "text" field - text = activity.get("text", "") if isinstance(activity, dict) else activity + if isinstance(activity, Mapping): + activity_map = cast(Mapping[str, Any], activity) + text: Any = activity_map.get("text", "") + else: + text = activity if isinstance(text, str): # First evaluate any =expression syntax @@ -336,11 +352,14 @@ async def handle_action( if table_path: # Get current table value - current_table = state.get(table_path) - if current_table is None: + current_table_value = state.get(table_path) + current_table: list[Any] + if current_table_value is None: current_table = [] - elif not isinstance(current_table, list): - current_table = [current_table] + elif isinstance(current_table_value, list): + current_table = list(cast(list[Any], current_table_value)) # type: ignore[redundant-cast] + else: + current_table = [current_table_value] if operation == "add" or operation == "insert": evaluated_value = state.eval_if_expression(value) @@ -413,11 +432,14 @@ async def handle_action( if table_path: # Get current table value - current_table = state.get(table_path) - if current_table is None: + current_table_value = state.get(table_path) + current_table: list[Any] + if current_table_value is None: current_table = [] - elif not isinstance(current_table, list): - current_table = [current_table] + elif isinstance(current_table_value, list): + current_table = list(cast(list[Any], current_table_value)) # type: ignore[redundant-cast] + else: + current_table = [current_table_value] if operation == "add": evaluated_item = state.eval_if_expression(item) @@ -433,9 +455,12 @@ async def handle_action( evaluated_item = state.eval_if_expression(item) if key_field and isinstance(evaluated_item, dict): # Remove by key match - key_value = evaluated_item.get(key_field) + evaluated_item_dict = cast(dict[str, Any], evaluated_item) + key_value = evaluated_item_dict.get(key_field) current_table = [ - r for r in current_table if not (isinstance(r, dict) and r.get(key_field) == key_value) + r + for r in current_table + if not (isinstance(r, dict) and cast(dict[str, Any], r).get(key_field) == key_value) ] elif evaluated_item in current_table: current_table.remove(evaluated_item) @@ -451,11 +476,12 @@ async def handle_action( elif operation == "addorupdate": evaluated_item = state.eval_if_expression(item) if key_field and isinstance(evaluated_item, dict): - key_value = evaluated_item.get(key_field) + evaluated_item_dict = cast(dict[str, Any], evaluated_item) + key_value = evaluated_item_dict.get(key_field) # Find existing item with same key found_idx = -1 for i, r in enumerate(current_table): - if isinstance(r, dict) and r.get(key_field) == key_value: + if isinstance(r, dict) and cast(dict[str, Any], r).get(key_field) == key_value: found_idx = i break if found_idx >= 0: @@ -476,9 +502,10 @@ async def handle_action( if 0 <= idx < len(current_table): current_table[idx] = evaluated_item elif key_field and isinstance(evaluated_item, dict): - key_value = evaluated_item.get(key_field) + evaluated_item_dict = cast(dict[str, Any], evaluated_item) + key_value = evaluated_item_dict.get(key_field) for i, r in enumerate(current_table): - if isinstance(r, dict) and r.get(key_field) == key_value: + if isinstance(r, dict) and cast(dict[str, Any], r).get(key_field) == key_value: current_table[i] = evaluated_item break @@ -568,11 +595,13 @@ def _convert_to_type(self, value: Any, target_type: str) -> Any: if value is None: return {} if isinstance(value, dict): - return value + return cast(dict[str, Any], value) if isinstance(value, str): try: parsed = json.loads(value) - return parsed if isinstance(parsed, dict) else {"value": parsed} + if isinstance(parsed, dict): + return cast(dict[str, Any], parsed) + return {"value": parsed} except json.JSONDecodeError: return {"value": value} return {"value": value} @@ -581,11 +610,13 @@ def _convert_to_type(self, value: Any, target_type: str) -> Any: if value is None: return [] if isinstance(value, list): - return value + return cast(list[Any], value) # type: ignore[redundant-cast] if isinstance(value, str): try: parsed = json.loads(value) - return parsed if isinstance(parsed, list) else [parsed] + if isinstance(parsed, list): + return cast(list[Any], parsed) # type: ignore[redundant-cast] + return [parsed] except json.JSONDecodeError: return [value] return [value] diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py index 829d48103f..6ef171fce5 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py @@ -17,7 +17,8 @@ from abc import abstractmethod from dataclasses import dataclass, field from inspect import isawaitable -from typing import Any +from collections.abc import Mapping +from typing import Any, cast from agent_framework import ( Content, @@ -111,6 +112,12 @@ class ToolApprovalState: # ============================================================================ + + +def _empty_messages() -> list[Message]: + return [] + + @dataclass class ToolInvocationResult: """Result from a tool invocation. @@ -127,7 +134,7 @@ class ToolInvocationResult: success: bool result: Any = None error: str | None = None - messages: list[Message] = field(default_factory=list) + messages: list[Message] = field(default_factory=_empty_messages) rejected: bool = False rejection_reason: str | None = None @@ -267,20 +274,20 @@ def _get_output_config(self) -> tuple[str | None, str | None, bool]: Returns: Tuple of (messages_var, result_var, auto_send) """ - output_config = self._action_def.get("output", {}) + output_config_obj = self._action_def.get("output", {}) - if not isinstance(output_config, dict): + if not isinstance(output_config_obj, Mapping): return None, None, True - messages_var = output_config.get("messages") - result_var = output_config.get("result") + output_config = cast(Mapping[str, Any], output_config_obj) + messages_var_obj = output_config.get("messages") + result_var_obj = output_config.get("result") auto_send = bool(output_config.get("autoSend", True)) - return ( - str(messages_var) if messages_var else None, - str(result_var) if result_var else None, - auto_send, - ) + messages_var = messages_var_obj if isinstance(messages_var_obj, str) else None + result_var = result_var_obj if isinstance(result_var_obj, str) else None + + return (messages_var, result_var, auto_send) def _store_result( self, @@ -494,7 +501,8 @@ async def handle_action( type(arguments_def).__name__, ) elif isinstance(arguments_def, dict): - for key, value in arguments_def.items(): + arguments_map = cast(dict[str, Any], arguments_def) + for key, value in arguments_map.items(): arguments[key] = state.eval_if_expression(value) # Check if approval is required diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py b/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py index df66ef59fd..499e577c96 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py @@ -51,7 +51,8 @@ def message_text(messages: Any) -> str: if isinstance(messages, list): # List of messages - concatenate all text texts: list[str] = [] - for msg in messages: + message_list = cast(list[Any], messages) # type: ignore[redundant-cast] + for msg in message_list: if isinstance(msg, str): texts.append(msg) elif isinstance(msg, dict): @@ -61,14 +62,16 @@ def message_text(messages: Any) -> str: texts.append(msg_content) elif msg_content: texts.append(str(msg_content)) - elif hasattr(msg, "content"): - msg_obj_content: Any = msg.content - if isinstance(msg_obj_content, str): - texts.append(msg_obj_content) - elif hasattr(msg_obj_content, "text"): - texts.append(str(msg_obj_content.text)) - elif msg_obj_content: - texts.append(str(msg_obj_content)) + else: + msg_obj: object = msg + if hasattr(msg_obj, "content"): + msg_obj_content: Any = getattr(msg_obj, "content") + if isinstance(msg_obj_content, str): + texts.append(msg_obj_content) + elif hasattr(msg_obj_content, "text"): + texts.append(str(getattr(msg_obj_content, "text"))) + elif msg_obj_content: + texts.append(str(msg_obj_content)) return " ".join(texts) # Try to get text attribute @@ -192,9 +195,11 @@ def is_blank(value: Any) -> bool: if isinstance(value, str) and not value.strip(): return True if isinstance(value, list): - return len(value) == 0 + value_list = cast(list[Any], value) # type: ignore[redundant-cast] + return len(value_list) == 0 if isinstance(value, dict): - return len(value) == 0 + value_dict = cast(dict[Any, Any], value) # type: ignore[redundant-cast] + return len(value_dict) == 0 return False diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_state.py b/python/packages/declarative/agent_framework_declarative/_workflows/_state.py index 7417fa26fe..fb0abc1086 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_state.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_state.py @@ -284,8 +284,9 @@ def append(self, path: str, value: Any) -> None: if existing is None: self.set(path, [value]) elif isinstance(existing, list): - existing.append(value) - self.set(path, existing) + existing_list = cast(list[Any], existing) # type: ignore[redundant-cast] + existing_list.append(value) + self.set(path, existing_list) else: raise ValueError(f"Cannot append to non-list at path '{path}'") @@ -614,9 +615,11 @@ def eval_if_expression(self, value: Any) -> Any: if isinstance(value, str): return self.eval(value) if isinstance(value, dict): - return {str(k): self.eval_if_expression(v) for k, v in value.items()} + value_dict = cast(dict[Any, Any], value) # type: ignore[redundant-cast] + return {str(k): self.eval_if_expression(v) for k, v in value_dict.items()} if isinstance(value, list): - return [self.eval_if_expression(item) for item in value] + value_list = cast(list[Any], value) # type: ignore[redundant-cast] + return [self.eval_if_expression(item) for item in value_list] return value def reset_local(self) -> None: diff --git a/python/packages/devui/agent_framework_devui/__init__.py b/python/packages/devui/agent_framework_devui/__init__.py index f703e85a63..4b1130506e 100644 --- a/python/packages/devui/agent_framework_devui/__init__.py +++ b/python/packages/devui/agent_framework_devui/__init__.py @@ -73,7 +73,7 @@ def register_cleanup(entity: Any, *hooks: Callable[[], Any]) -> None: ) -def _get_registered_cleanup_hooks(entity: Any) -> list[Callable[[], Any]]: +def get_registered_cleanup_hooks(entity: Any) -> list[Callable[[], Any]]: """Get cleanup hooks registered for an entity (internal use). Args: @@ -86,6 +86,10 @@ def _get_registered_cleanup_hooks(entity: Any) -> list[Callable[[], Any]]: return _cleanup_registry.get(entity_id, []) +# Backward-compatible private alias +_get_registered_cleanup_hooks = get_registered_cleanup_hooks + + def serve( entities: list[Any] | None = None, entities_dir: str | None = None, @@ -193,7 +197,7 @@ def serve( if entities: logger.info(f"Registering {len(entities)} in-memory entities") # Store entities for later registration during server startup - server._pending_entities = entities + server.set_pending_entities(entities) app = server.get_app() @@ -263,5 +267,6 @@ def main() -> None: "ResponseStreamEvent", "main", "register_cleanup", + "get_registered_cleanup_hooks", "serve", ] diff --git a/python/packages/devui/agent_framework_devui/_conversations.py b/python/packages/devui/agent_framework_devui/_conversations.py index f0e91e0d87..ba2fa21586 100644 --- a/python/packages/devui/agent_framework_devui/_conversations.py +++ b/python/packages/devui/agent_framework_devui/_conversations.py @@ -300,12 +300,15 @@ async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> stored_messages: list[Message] = conv_data["messages"] # Convert items to Messages and add to storage - chat_messages = [] + chat_messages: list[Message] = [] for item in items: # Simple conversion - assume text content for now role = item.get("role", "user") - content = item.get("content", []) - text = content[0].get("text", "") if content else "" + content_obj = item.get("content", []) + content = cast(list[dict[str, Any]], content_obj) if isinstance(content_obj, list) else [] + first_content = content[0] if content and isinstance(content[0], dict) else {} + text_obj = first_content.get("text", "") + text = text_obj if isinstance(text_obj, str) else str(text_obj) chat_msg = Message(role=role, text=text) # type: ignore[arg-type] chat_messages.append(chat_msg) @@ -319,15 +322,17 @@ async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> item_id = f"item_{uuid.uuid4().hex}" # Extract role - handle both string and enum - role_str = msg.role if hasattr(msg.role, "value") else str(msg.role) + msg_role_obj: object = getattr(msg, "role", "user") + role_str = str(getattr(msg_role_obj, "value", msg_role_obj)) role = cast(MessageRole, role_str) # Safe: Agent Framework roles match OpenAI roles # Convert Message contents to OpenAI TextContent format - message_content = [] - for content_item in msg.contents: - if content_item.type == "text": + message_content: list[TextContent] = [] + for content_item in cast(list[Any], msg.contents): + if getattr(content_item, "type", None) == "text": # Extract text from TextContent object - text_value = getattr(content_item, "text", "") + text_value_obj = getattr(content_item, "text", "") + text_value = text_value_obj if isinstance(text_value_obj, str) else str(text_value_obj) message_content.append(TextContent(type="text", text=text_value)) # Create Message object (concrete type from ConversationItem union) @@ -335,7 +340,7 @@ async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> id=item_id, type="message", # Required discriminator for union role=role, - content=message_content, + content=cast(Any, message_content), status="completed", # Required field ) conv_items.append(message) @@ -383,8 +388,8 @@ async def list_items( # A single Message may produce multiple ConversationItems # (e.g., a message with both text and a function call) message_contents: list[TextContent | ResponseInputImage | ResponseInputFile] = [] - function_calls = [] - function_results = [] + function_calls: list[ResponseFunctionToolCallItem] = [] + function_results: list[ResponseFunctionToolCallOutputItem] = [] for content in msg.contents: content_type = getattr(content, "type", None) @@ -628,7 +633,7 @@ def get_traces(self, conversation_id: str) -> list[dict[str, Any]]: async def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) -> list[Conversation]: """Filter conversations by metadata (e.g., agent_id).""" - results = [] + results: list[Conversation] = [] for conv_data in self._conversations.values(): conv_meta = conv_data.get("metadata", {}).copy() # Copy to avoid mutating original @@ -704,7 +709,8 @@ def get_checkpoint_storage(self, conversation_id: str) -> InMemoryCheckpointStor ValueError: If conversation not found """ # Access internal conversations dict (we know it's InMemoryConversationStore) - conv_data = self._store._conversations.get(conversation_id) + conversations_dict = cast(dict[str, dict[str, Any]], getattr(self._store, "_conversations", {})) + conv_data = conversations_dict.get(conversation_id) if not conv_data: raise ValueError(f"Conversation {conversation_id} not found") diff --git a/python/packages/devui/agent_framework_devui/_deployment.py b/python/packages/devui/agent_framework_devui/_deployment.py index db2de27ecf..70d785b0bb 100644 --- a/python/packages/devui/agent_framework_devui/_deployment.py +++ b/python/packages/devui/agent_framework_devui/_deployment.py @@ -7,6 +7,7 @@ import re import secrets import uuid +from typing import cast from collections.abc import AsyncGenerator from datetime import datetime, timezone from pathlib import Path @@ -175,7 +176,7 @@ async def _validate_prerequisites(self) -> None: # Check required resource providers are registered required_providers = ["Microsoft.App", "Microsoft.ContainerRegistry", "Microsoft.OperationalInsights"] - unregistered_providers = [] + unregistered_providers: list[str] = [] # Get list of registered providers provider_check = await asyncio.create_subprocess_exec( @@ -195,7 +196,12 @@ async def _validate_prerequisites(self) -> None: import json try: - registered = json.loads(stdout.decode()) + registered_raw = json.loads(stdout.decode()) + registered: list[str] = [] + if isinstance(registered_raw, list): + for item_obj in cast(list[object], registered_raw): + if isinstance(item_obj, str): + registered.append(item_obj) for provider in required_providers: if provider not in registered: unregistered_providers.append(provider) @@ -385,7 +391,7 @@ async def _deploy_to_azure( ) # Stream output line by line - output_lines = [] + output_lines: list[str] = [] try: if not process.stdout: raise ValueError("Failed to capture process output") @@ -473,8 +479,11 @@ async def _deploy_to_azure( for url in urls: # Strip common trailing punctuation to ensure clean URL parsing url_clean = url.rstrip(".,;:!?'\")}]") - host = urlparse(url_clean).hostname - if host and (host == "azurecontainerapps.io" or host.endswith(".azurecontainerapps.io")): + parsed_url = urlparse(str(url_clean)) + host = parsed_url.hostname + if isinstance(host, str) and ( + host == "azurecontainerapps.io" or host.endswith(".azurecontainerapps.io") + ): await event_queue.put( DeploymentEvent(type="deploy.progress", message="Deployment URL generated!") ) diff --git a/python/packages/devui/agent_framework_devui/_discovery.py b/python/packages/devui/agent_framework_devui/_discovery.py index a5fada1ba9..5aad165571 100644 --- a/python/packages/devui/agent_framework_devui/_discovery.py +++ b/python/packages/devui/agent_framework_devui/_discovery.py @@ -11,7 +11,7 @@ import sys import uuid from pathlib import Path -from typing import Any +from typing import Any, cast from dotenv import load_dotenv @@ -141,9 +141,9 @@ async def load_entity(self, entity_id: str, checkpoint_manager: Any = None) -> A self._loaded_objects[entity_id] = entity_obj # Check module-level registry for cleanup hooks - from . import _get_registered_cleanup_hooks + from . import get_registered_cleanup_hooks - registered_hooks = _get_registered_cleanup_hooks(entity_obj) + registered_hooks = get_registered_cleanup_hooks(entity_obj) if registered_hooks: if entity_id not in self._cleanup_hooks: self._cleanup_hooks[entity_id] = [] @@ -299,9 +299,9 @@ def register_entity(self, entity_id: str, entity_info: EntityInfo, entity_object self._loaded_objects[entity_id] = entity_object # Check module-level registry for cleanup hooks - from . import _get_registered_cleanup_hooks + from . import get_registered_cleanup_hooks - registered_hooks = _get_registered_cleanup_hooks(entity_object) + registered_hooks = get_registered_cleanup_hooks(entity_object) if registered_hooks: if entity_id not in self._cleanup_hooks: self._cleanup_hooks[entity_id] = [] @@ -379,6 +379,8 @@ async def create_entity_info_from_object( deployment_supported = True deployment_reason = "Ready for deployment (pending path verification)" + class_name = type(entity_object).__name__ + # Create EntityInfo with Agent Framework specifics return EntityInfo( id=entity_id, @@ -400,9 +402,7 @@ async def create_entity_info_from_object( deployment_reason=deployment_reason, metadata={ "source": "agent_framework_object", - "class_name": entity_object.__class__.__name__ - if hasattr(entity_object, "__class__") - else str(type(entity_object)), + "class_name": class_name, }, ) @@ -854,7 +854,7 @@ async def _register_entity_from_object( "module_path": module_path, "entity_type": obj_type, "source": source, - "class_name": obj.__class__.__name__ if hasattr(obj, "__class__") else str(type(obj)), + "class_name": type(obj).__name__, }, ) @@ -874,47 +874,59 @@ async def _extract_tools_from_object(self, obj: Any, obj_type: str) -> list[str] Returns: List of tool/executor names """ - tools = [] + tools: list[str] = [] try: if obj_type == "agent": - # For agents, check default_options.get("tools") chat_options = getattr(obj, "default_options", None) - chat_options_tools = None - if chat_options: - chat_options_tools = chat_options.get("tools") - - if chat_options_tools: - for tool in chat_options_tools: - if hasattr(tool, "__name__"): - tools.append(tool.__name__) - elif hasattr(tool, "name"): - tools.append(tool.name) + chat_options_tools: object | None = None + if isinstance(chat_options, dict): + chat_options_dict = cast(dict[str, Any], chat_options) + chat_options_tools = chat_options_dict.get("tools") + + if chat_options_tools is not None: + tool_iterable: list[object] = cast(list[object], chat_options_tools) if isinstance(chat_options_tools, list) else [chat_options_tools] + for tool_obj in tool_iterable: + tool_name = getattr(tool_obj, "__name__", None) + if isinstance(tool_name, str): + tools.append(tool_name) + continue + + named_tool = getattr(tool_obj, "name", None) + if isinstance(named_tool, str): + tools.append(named_tool) else: - tools.append(str(tool)) + tools.append(str(tool_obj)) else: - # Fallback to direct tools attribute agent_tools = getattr(obj, "tools", None) - if agent_tools: - for tool in agent_tools: - if hasattr(tool, "__name__"): - tools.append(tool.__name__) - elif hasattr(tool, "name"): - tools.append(tool.name) + if isinstance(agent_tools, list): + for tool_obj in cast(list[object], agent_tools): + tool_name = getattr(tool_obj, "__name__", None) + if isinstance(tool_name, str): + tools.append(tool_name) + continue + + named_tool = getattr(tool_obj, "name", None) + if isinstance(named_tool, str): + tools.append(named_tool) else: - tools.append(str(tool)) + tools.append(str(tool_obj)) elif obj_type == "workflow": - # For workflows, extract executor names if hasattr(obj, "get_executors_list"): executor_objects = obj.get_executors_list() - tools = [getattr(ex, "id", str(ex)) for ex in executor_objects] + if isinstance(executor_objects, list): + for executor_obj in cast(list[object], executor_objects): + tools.append(str(getattr(executor_obj, "id", executor_obj))) elif hasattr(obj, "executors"): executors = obj.executors if isinstance(executors, list): - tools = [getattr(ex, "id", str(ex)) for ex in executors] + for executor_obj in cast(list[object], executors): + tools.append(str(getattr(executor_obj, "id", executor_obj))) elif isinstance(executors, dict): - tools = list(executors.keys()) + executors_dict = cast(dict[str, Any], executors) + for key_obj in executors_dict.keys(): + tools.append(str(key_obj)) except Exception as e: logger.debug(f"Error extracting tools from {obj_type} {type(obj)}: {e}") diff --git a/python/packages/devui/agent_framework_devui/_executor.py b/python/packages/devui/agent_framework_devui/_executor.py index 1b1b77162a..516ef6d8de 100644 --- a/python/packages/devui/agent_framework_devui/_executor.py +++ b/python/packages/devui/agent_framework_devui/_executor.py @@ -7,7 +7,7 @@ import json import logging from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, cast from agent_framework import Content, SupportsAgentRun, Workflow @@ -24,7 +24,8 @@ def _get_event_type(event: Any) -> str | None: """Safely get the type of an event, handling both objects and dicts.""" if isinstance(event, dict): - return event.get("type") + event_type = cast(dict[str, Any], event).get("type") + return event_type if isinstance(event_type, str) else None return getattr(event, "type", None) @@ -71,7 +72,8 @@ def _setup_instrumentation_provider(self) -> None: from opentelemetry.sdk.trace import TracerProvider # Only set up if no provider exists yet - if not hasattr(trace, "_TRACER_PROVIDER") or trace._TRACER_PROVIDER is None: + current_provider = trace.get_tracer_provider() + if current_provider.__class__.__name__ == "ProxyTracerProvider": resource = Resource.create({ "service.name": "agent-framework-server", "service.version": "1.0.0", @@ -94,21 +96,29 @@ def _setup_agent_framework_instrumentation(self) -> None: # Configure if instrumentation is enabled (via enable_instrumentation() or env var) if OBSERVABILITY_SETTINGS.ENABLED: - # Only configure providers if not already executed - if not OBSERVABILITY_SETTINGS._executed_setup: - # Call configure_otel_providers to set up exporters. - # If OTEL_EXPORTER_OTLP_ENDPOINT is set, exporters will be created automatically. - # If not set, no exporters are created (no console spam), but DevUI's - # TracerProvider from _setup_instrumentation_provider() remains active for local capture. - configure_otel_providers(enable_sensitive_data=OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED) - logger.info("Enabled Agent Framework observability") - else: - logger.debug("Agent Framework observability already configured") + # Call configure_otel_providers to set up exporters. + # If OTEL_EXPORTER_OTLP_ENDPOINT is set, exporters will be created automatically. + # If not set, no exporters are created (no console spam), but DevUI's + # TracerProvider from _setup_instrumentation_provider() remains active for local capture. + configure_otel_providers(enable_sensitive_data=OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED) + logger.info("Enabled Agent Framework observability") else: logger.debug("Instrumentation not enabled, skipping observability setup") except Exception as e: logger.warning(f"Failed to enable Agent Framework observability: {e}") + def _get_request_conversation_id(self, request: AgentFrameworkRequest) -> str | None: + """Read conversation id using public request fields.""" + if isinstance(request.conversation, str): + return request.conversation + + if isinstance(request.conversation, dict): + conversation_id = request.conversation.get("id") + if isinstance(conversation_id, str): + return conversation_id + + return None + async def _ensure_mcp_connections(self, agent: Any) -> None: """Ensure MCP tool connections are healthy before agent execution. @@ -317,7 +327,7 @@ async def _execute_agent( # Get session from conversation parameter (OpenAI standard!) session = None - conversation_id = request._get_conversation_id() + conversation_id = self._get_request_conversation_id(request) if conversation_id: session = self.conversation_store.get_session(conversation_id) if session: @@ -344,7 +354,7 @@ async def _execute_agent( if session: run_kwargs["session"] = session - stream = agent.run(user_message, **run_kwargs) + stream = cast(Any, agent.run(user_message, **run_kwargs)) async for update in stream: for trace_event in trace_collector.get_pending_events(): yield trace_event @@ -388,7 +398,7 @@ async def _execute_workflow( entity_id = request.get_entity_id() or "unknown" # Get or create session conversation for checkpoint storage - conversation_id = request._get_conversation_id() + conversation_id = self._get_request_conversation_id(request) if not conversation_id: # Create default session if not provided import time @@ -463,11 +473,14 @@ async def _execute_workflow( logger.info(f"Resuming workflow with HIL responses for {len(hil_responses)} request(s)") # Unwrap primitive responses if they're wrapped in {response: value} format - unwrapped_responses = {} + unwrapped_responses: dict[str, Any] = {} for request_id, response_value in hil_responses.items(): - if isinstance(response_value, dict) and "response" in response_value: - response_value = response_value["response"] - unwrapped_responses[request_id] = response_value + normalized_response: Any = response_value + if isinstance(response_value, dict): + response_dict = cast(dict[str, Any], response_value) + if "response" in response_dict: + normalized_response = response_dict["response"] + unwrapped_responses[request_id] = normalized_response hil_responses = unwrapped_responses @@ -568,7 +581,8 @@ def _convert_input_to_chat_message(self, input_data: Any) -> Any: # Handle OpenAI ResponseInputParam (List[ResponseInputItemParam]) if isinstance(input_data, list): - return self._convert_openai_input_to_chat_message(input_data, Message, Role) + input_items: Any = cast(Any, input_data) + return self._convert_openai_input_to_chat_message(input_items, Message, Role) # Fallback for other formats return self._extract_user_message_fallback(input_data) @@ -593,27 +607,31 @@ def _convert_openai_input_to_chat_message(self, input_items: list[Any], Message: for item in input_items: # Handle dict format (from JSON) if isinstance(item, dict): - item_type = item.get("type") + item_dict = cast(dict[str, Any], item) + item_type = item_dict.get("type") if item_type == "message": # Extract content from OpenAI message - message_content = item.get("content", []) + message_content = item_dict.get("content", []) # Handle both string content and list content if isinstance(message_content, str): contents.append(Content.from_text(text=message_content)) elif isinstance(message_content, list): - for content_item in message_content: + message_content_items: Any = cast(Any, message_content) + for content_item in message_content_items: # Handle dict content items if isinstance(content_item, dict): - content_type = content_item.get("type") + content_dict = cast(dict[str, Any], content_item) + content_type = content_dict.get("type") if content_type == "input_text": - text = content_item.get("text", "") - contents.append(Content.from_text(text=text)) + text = content_dict.get("text", "") + if isinstance(text, str): + contents.append(Content.from_text(text=text)) elif content_type == "input_image": - image_url = content_item.get("image_url", "") - if image_url: + image_url = content_dict.get("image_url", "") + if isinstance(image_url, str) and image_url: # Extract media type from data URI if possible # Parse media type from data URL, fallback to image/png if image_url.startswith("data:"): @@ -631,9 +649,12 @@ def _convert_openai_input_to_chat_message(self, input_items: list[Any], Message: elif content_type == "input_file": # Handle file input - file_data = content_item.get("file_data") - file_url = content_item.get("file_url") - filename = content_item.get("filename", "") + file_data = content_dict.get("file_data") + file_url = content_dict.get("file_url") + filename = content_dict.get("filename", "") + + if not isinstance(filename, str): + filename = "" # Determine media type from filename media_type = "application/octet-stream" # default @@ -656,8 +677,8 @@ def _convert_openai_input_to_chat_message(self, input_items: list[Any], Message: # Use file_data or file_url # Include filename in additional_properties for OpenAI/Azure file handling - additional_props = {"filename": filename} if filename else None - if file_data: + additional_props: dict[str, Any] | None = {"filename": filename} if filename else None + if isinstance(file_data, str) and file_data: # Assume file_data is base64, create data URI data_uri = f"data:{media_type};base64,{file_data}" contents.append( @@ -667,7 +688,7 @@ def _convert_openai_input_to_chat_message(self, input_items: list[Any], Message: additional_properties=additional_props, ) ) - elif file_url: + elif isinstance(file_url, str) and file_url: contents.append( Content.from_uri( uri=file_url, @@ -679,15 +700,35 @@ def _convert_openai_input_to_chat_message(self, input_items: list[Any], Message: elif content_type == "function_approval_response": # Handle function approval response (DevUI extension) try: - request_id = content_item.get("request_id", "") - approved = content_item.get("approved", False) - function_call_data = content_item.get("function_call", {}) + request_id = content_dict.get("request_id", "") + approved = content_dict.get("approved", False) + function_call_data = content_dict.get("function_call", {}) + + if not isinstance(request_id, str): + request_id = "" + if not isinstance(approved, bool): + approved = False + if not isinstance(function_call_data, dict): + function_call_data = {} + + function_call_data_dict = cast(dict[str, Any], function_call_data) + + function_call_id = function_call_data_dict.get("id", "") + function_call_name = function_call_data_dict.get("name", "") + function_call_args = function_call_data_dict.get("arguments", {}) + + if not isinstance(function_call_id, str): + function_call_id = "" + if not isinstance(function_call_name, str): + function_call_name = "" + if not isinstance(function_call_args, dict): + function_call_args = {} # Create FunctionCallContent from the function_call data function_call = Content.from_function_call( - call_id=function_call_data.get("id", ""), - name=function_call_data.get("name", ""), - arguments=function_call_data.get("arguments", {}), + call_id=function_call_id, + name=function_call_name, + arguments=cast(dict[str, Any], function_call_args), ) # Create FunctionApprovalResponseContent with correct signature @@ -739,12 +780,14 @@ def _extract_user_message_fallback(self, input_data: Any) -> str: if isinstance(input_data, str): return input_data if isinstance(input_data, dict): + typed_input_data = cast(dict[str, Any], input_data) # Try common field names for field in ["message", "text", "input", "content", "query"]: - if field in input_data: - return str(input_data[field]) + if field in typed_input_data: + value = typed_input_data[field] + return value if isinstance(value, str) else str(value) # Fallback to JSON string - return json.dumps(input_data) + return json.dumps(typed_input_data) return str(input_data) def _is_openai_multimodal_format(self, input_data: Any) -> bool: @@ -758,8 +801,12 @@ def _is_openai_multimodal_format(self, input_data: Any) -> bool: """ if not isinstance(input_data, list) or not input_data: return False - first_item = input_data[0] - return isinstance(first_item, dict) and first_item.get("type") == "message" + input_data_items: Any = cast(Any, input_data) + first_item = input_data_items[0] + if not isinstance(first_item, dict): + return False + first_type = cast(dict[str, Any], first_item).get("type") + return isinstance(first_type, str) and first_type == "message" async def _parse_workflow_input(self, workflow: Any, raw_input: Any) -> Any: """Parse input based on workflow's expected input type. @@ -775,7 +822,7 @@ async def _parse_workflow_input(self, workflow: Any, raw_input: Any) -> Any: # Handle JSON string input (from frontend api.ts JSON.stringify) if isinstance(raw_input, str): try: - parsed = json.loads(raw_input) + parsed: Any = json.loads(raw_input) raw_input = parsed except (json.JSONDecodeError, TypeError): # Plain text string, continue with string handling @@ -789,14 +836,14 @@ async def _parse_workflow_input(self, workflow: Any, raw_input: Any) -> Any: # Handle structured input (dict) if isinstance(raw_input, dict): - return self._parse_structured_workflow_input(workflow, raw_input) + return self._parse_structured_workflow_input(workflow, cast(dict[str, Any], raw_input)) # Handle string input return self._parse_raw_workflow_input(workflow, str(raw_input)) except Exception as e: logger.warning(f"Error parsing workflow input: {e}") - return raw_input + return cast(Any, raw_input) def _get_start_executor_message_types(self, workflow: Any) -> tuple[Any | None, list[Any]]: """Return start executor and its declared input types.""" @@ -823,7 +870,8 @@ def _get_start_executor_message_types(self, workflow: Any) -> tuple[Any | None, try: handlers = start_executor._handlers if isinstance(handlers, dict): - message_types = list(handlers.keys()) + handlers_dict: Any = cast(Any, handlers) + message_types = list(handlers_dict.keys()) except Exception as exc: # pragma: no cover - defensive logging path logger.debug(f"Failed to read executor handlers: {exc}") @@ -847,7 +895,8 @@ def _extract_workflow_hil_responses(self, input_data: Any) -> dict[str, Any] | N parsed = json.loads(input_data) # Only use parsed value if it's a list (ResponseInputParam format expected for HIL) if isinstance(parsed, list): - input_data = parsed + parsed_list: Any = cast(Any, parsed) + input_data = parsed_list else: # Parsed to dict, string, or primitive - not HIL response format return None @@ -864,19 +913,32 @@ def _extract_workflow_hil_responses(self, input_data: Any) -> dict[str, Any] | N if not isinstance(input_data, list): return None - for item in input_data: - if isinstance(item, dict) and item.get("type") == "message": - message_content = item.get("content", []) + input_items: Any = cast(Any, input_data) + for item in input_items: + if isinstance(item, dict): + item_dict = cast(dict[str, Any], item) + if item_dict.get("type") != "message": + continue + message_content = item_dict.get("content", []) if isinstance(message_content, list): - for content_item in message_content: + message_content_items: Any = cast(Any, message_content) + for content_item in message_content_items: if isinstance(content_item, dict): - content_type = content_item.get("type") + content_dict = cast(dict[str, Any], content_item) + content_type = content_dict.get("type") if content_type == "workflow_hil_response": # Extract responses dict - # dict.get() returns Any, so we explicitly type it - responses: dict[str, Any] = content_item.get("responses", {}) # type: ignore[assignment] + responses_raw = content_dict.get("responses", {}) + if not isinstance(responses_raw, dict): + continue + + responses_dict: Any = cast(Any, responses_raw) + responses = { + str(response_key): response_value + for response_key, response_value in responses_dict.items() + } logger.info(f"Found workflow HIL responses: {list(responses.keys())}") return responses @@ -1000,11 +1062,12 @@ def _enrich_request_info_event_with_response_schema(self, event: Any, workflow: return # Find the source executor in the workflow - if not hasattr(workflow, "executors") or not isinstance(workflow.executors, dict): + executors = getattr(workflow, "executors", None) + if not isinstance(executors, dict): logger.debug("Workflow doesn't have executors dict") return - source_executor = workflow.executors.get(source_executor_id) + source_executor = cast(dict[str, Any], executors).get(source_executor_id) if not source_executor: logger.debug(f"Could not find executor '{source_executor_id}' in workflow") return diff --git a/python/packages/devui/agent_framework_devui/_mapper.py b/python/packages/devui/agent_framework_devui/_mapper.py index bcb99634cb..59728dd03c 100644 --- a/python/packages/devui/agent_framework_devui/_mapper.py +++ b/python/packages/devui/agent_framework_devui/_mapper.py @@ -11,7 +11,7 @@ from collections import OrderedDict from collections.abc import Sequence from datetime import datetime -from typing import Any, Union +from typing import Any, Union, cast from uuid import uuid4 from agent_framework import Content, Message @@ -61,6 +61,21 @@ ] +def _to_str_dict(value: Any) -> dict[str, Any] | None: + """Convert arbitrary dict-like payload to a string-keyed dictionary.""" + if not isinstance(value, dict): + return None + source: dict[str, Any] = cast(dict[str, Any], value) + result: dict[str, Any] = {} + for key_obj, val_obj in source.items(): + result[str(key_obj)] = val_obj + return result + + +def _stringify_name(value: Any) -> str: + return value if isinstance(value, str) else str(value) + + def _serialize_content_recursive(value: Any) -> Any: """Recursively serialize Agent Framework Content objects to JSON-compatible values. @@ -88,16 +103,21 @@ def _serialize_content_recursive(value: Any) -> Any: # Handle dictionaries - recursively process values if isinstance(value, dict): - return {key: _serialize_content_recursive(val) for key, val in value.items()} + value_dict = cast(dict[str, Any], value) + return {str(key): _serialize_content_recursive(val) for key, val in value_dict.items()} # Handle lists and tuples - recursively process elements if isinstance(value, (list, tuple)): - serialized = [_serialize_content_recursive(item) for item in value] + sequence_items: Any = cast(Any, value) + serialized: list[Any] = [_serialize_content_recursive(item) for item in sequence_items] # For single-item lists containing text Content, extract just the text # This handles the MCP case where result = [Content.from_text(text="Hello")] # and we want output = "Hello" not output = '[{"type": "text", "text": "Hello"}]' - if len(serialized) == 1 and isinstance(serialized[0], dict) and serialized[0].get("type") == "text": - return serialized[0].get("text", "") + if len(serialized) == 1: + first_item = _to_str_dict(serialized[0]) + if first_item and first_item.get("type") == "text": + text_value = first_item.get("text", "") + return text_value if isinstance(text_value, str) else str(text_value) return serialized # For other objects with model_dump(), try that @@ -156,8 +176,10 @@ async def convert_event(self, raw_event: Any, request: AgentFrameworkRequest) -> context = self._get_or_create_context(request) # Handle error events - if isinstance(raw_event, dict) and raw_event.get("type") == "error": - return [await self._create_error_event(raw_event.get("message", "Unknown error"), context)] + raw_event_dict = _to_str_dict(raw_event) + if raw_event_dict and raw_event_dict.get("type") == "error": + message = raw_event_dict.get("message", "Unknown error") + return [await self._create_error_event(_stringify_name(message), context)] # Handle ResponseTraceEvent objects from our trace collector from .models import ResponseTraceEvent @@ -185,15 +207,12 @@ async def convert_event(self, raw_event: Any, request: AgentFrameworkRequest) -> # Handle WorkflowEvent with type='output' or 'data' wrapping AgentResponseUpdate # This must be checked BEFORE generic WorkflowEvent check # Note: AgentExecutor uses type='output' for streaming updates - if ( - isinstance(raw_event, WorkflowEvent) - and raw_event.type in ("output", "data") - and raw_event.data - and isinstance(raw_event.data, AgentResponseUpdate) - ): - # Preserve executor_id in context for proper output routing - context["current_executor_id"] = raw_event.executor_id - return await self._convert_agent_update(raw_event.data, context) + if isinstance(raw_event, WorkflowEvent) and raw_event.type in ("output", "data"): + event_data = getattr(cast(Any, raw_event), "data", None) + if isinstance(event_data, AgentResponseUpdate): + # Preserve executor_id in context for proper output routing + context["current_executor_id"] = getattr(cast(Any, raw_event), "executor_id", None) + return await self._convert_agent_update(event_data, context) # Handle complete agent response (AgentResponse) - for non-streaming agent execution if isinstance(raw_event, AgentResponse): @@ -210,10 +229,11 @@ async def convert_event(self, raw_event: Any, request: AgentFrameworkRequest) -> except ImportError as e: logger.warning(f"Could not import Agent Framework types: {e}") # Fallback to attribute-based detection - if hasattr(raw_event, "contents"): - return await self._convert_agent_update(raw_event, context) - if hasattr(raw_event, "__class__") and "Event" in raw_event.__class__.__name__: - return await self._convert_workflow_event(raw_event, context) + candidate_event = cast(Any, raw_event) + if hasattr(candidate_event, "contents"): + return await self._convert_agent_update(candidate_event, context) + if "Event" in type(candidate_event).__name__: + return await self._convert_workflow_event(candidate_event, context) # Unknown event type return [await self._create_unknown_event(raw_event, context)] @@ -256,32 +276,36 @@ async def aggregate_to_response(self, events: Sequence[Any], request: AgentFrame item = getattr(event, "item", None) if item: # Handle both object and dict formats - item_type = item.get("type") if isinstance(item, dict) else getattr(item, "type", None) + item_dict = _to_str_dict(item) + item_type = item_dict.get("type") if item_dict is not None else getattr(item, "type", None) # Track function calls to accumulate their arguments if item_type == "function_call": # Handle both object and dict formats - if isinstance(item, dict): - call_id = item.get("call_id") or item.get("id") - if call_id: + item_dict = _to_str_dict(item) + if item_dict is not None: + call_id_value = item_dict.get("call_id") or item_dict.get("id") + if call_id_value: + call_id = str(call_id_value) function_calls[call_id] = { - "id": item.get("id", call_id), + "id": str(item_dict.get("id", call_id)), "call_id": call_id, - "name": item.get("name", ""), - "arguments": item.get("arguments", ""), + "name": _stringify_name(item_dict.get("name", "")), + "arguments": _stringify_name(item_dict.get("arguments", "")), "type": "function_call", - "status": item.get("status", "completed"), + "status": _stringify_name(item_dict.get("status", "completed")), } else: - call_id = getattr(item, "call_id", None) or getattr(item, "id", None) - if call_id: + call_id_value = getattr(item, "call_id", None) or getattr(item, "id", None) + if call_id_value: + call_id = str(call_id_value) function_calls[call_id] = { - "id": getattr(item, "id", call_id), + "id": str(getattr(item, "id", call_id)), "call_id": call_id, - "name": getattr(item, "name", ""), - "arguments": getattr(item, "arguments", ""), + "name": _stringify_name(getattr(item, "name", "")), + "arguments": _stringify_name(getattr(item, "arguments", "")), "type": "function_call", - "status": getattr(item, "status", "completed"), + "status": _stringify_name(getattr(item, "status", "completed")), } # Other output items (message, etc.) - track for later @@ -299,8 +323,9 @@ async def aggregate_to_response(self, events: Sequence[Any], request: AgentFrame # Handle function result complete events elif event_type == "response.function_result.complete": - call_id = getattr(event, "call_id", None) - if call_id: + call_id_value = getattr(event, "call_id", None) + if call_id_value: + call_id = str(call_id_value) function_results[call_id] = { "type": "function_call_output", "call_id": call_id, @@ -322,7 +347,7 @@ async def aggregate_to_response(self, events: Sequence[Any], request: AgentFrame # Build final text message from accumulated deltas # Combine all text parts (usually there's just one message) - all_text_parts = [] + all_text_parts: list[str] = [] for _item_id, parts in text_parts_by_message.items(): all_text_parts.extend(parts) @@ -493,14 +518,14 @@ def _serialize_value(self, value: Any) -> Any: return value.value # Handle lists/tuples/sets - recursively serialize elements - if isinstance(value, (list, tuple)): - return [self._serialize_value(item) for item in value] - if isinstance(value, set): - return [self._serialize_value(item) for item in value] + if isinstance(value, (list, tuple, set)): + value_items: Any = cast(Any, value) + return [self._serialize_value(item) for item in value_items] # Handle dicts - recursively serialize values if isinstance(value, dict): - return {k: self._serialize_value(v) for k, v in value.items()} + value_dict = cast(dict[str, Any], value) + return {str(k): self._serialize_value(v) for k, v in value_dict.items()} # Handle SerializationMixin (like Message) - call to_dict() if hasattr(value, "to_dict") and callable(getattr(value, "to_dict", None)): @@ -551,14 +576,15 @@ def _serialize_request_data(self, request_data: Any) -> dict[str, Any]: # Handle dict first (most common) if isinstance(request_data, dict): - return {k: self._serialize_value(v) for k, v in request_data.items()} + request_dict = cast(dict[str, Any], request_data) + return {str(k): self._serialize_value(v) for k, v in request_dict.items()} # Handle dataclasses with nested SerializationMixin objects # We can't use asdict() directly because it doesn't handle Message if is_dataclass(request_data) and not isinstance(request_data, type): try: # Manually serialize each field to handle nested SerializationMixin - result = {} + result: dict[str, Any] = {} for field in fields(request_data): field_value = getattr(request_data, field.name) result[field.name] = self._serialize_value(field_value) @@ -900,8 +926,9 @@ async def _convert_workflow_event(self, event: Any, context: dict[str, Any]) -> text = str(output_data) elif isinstance(output_data, list): # Handle list of Message objects (from Magentic yield_output([final_answer])) - text_parts = [] - for item in output_data: + text_parts: list[str] = [] + output_items_list: Any = cast(Any, output_data) + for item in output_items_list: if isinstance(item, Message): item_text = getattr(item, "text", None) if item_text: @@ -912,17 +939,17 @@ async def _convert_workflow_event(self, event: Any, context: dict[str, Any]) -> text_parts.append(item) else: try: - text_parts.append(json.dumps(item, indent=2)) + text_parts.append(json.dumps(self._serialize_value(item), indent=2)) except (TypeError, ValueError): text_parts.append(str(item)) - text = "\n".join(text_parts) if text_parts else str(output_data) + text = "\n".join(text_parts) if text_parts else str(cast(Any, output_data)) elif isinstance(output_data, str): # String output text = output_data else: # Object/dict → JSON string try: - text = json.dumps(output_data, indent=2) + text = json.dumps(self._serialize_value(output_data), indent=2) except (TypeError, ValueError): # Fallback to string representation if not JSON serializable text = str(output_data) @@ -1420,10 +1447,10 @@ async def _map_usage_content(self, content: Any, context: dict[str, Any]) -> Non None - no event emitted (usage goes in final Response.usage) """ # Extract usage from UsageContent.usage_details (UsageDetails object) - details = content.usage_details or {} - total_tokens = details.get("total_token_count", 0) - prompt_tokens = details.get("input_token_count", 0) - completion_tokens = details.get("output_token_count", 0) + details = _to_str_dict(getattr(content, "usage_details", None)) or {} + total_tokens = int(details.get("total_token_count", 0) or 0) + prompt_tokens = int(details.get("input_token_count", 0) or 0) + completion_tokens = int(details.get("output_token_count", 0) or 0) # Accumulate for final Response.usage request_id = context.get("request_id", "default") diff --git a/python/packages/devui/agent_framework_devui/_openai/_executor.py b/python/packages/devui/agent_framework_devui/_openai/_executor.py index 986d2d3a84..c62e7acc98 100644 --- a/python/packages/devui/agent_framework_devui/_openai/_executor.py +++ b/python/packages/devui/agent_framework_devui/_openai/_executor.py @@ -11,7 +11,7 @@ import logging import os from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, cast from openai import APIStatusError, AsyncOpenAI, AsyncStream, AuthenticationError, PermissionDeniedError, RateLimitError from openai.types.responses import Response, ResponseStreamEvent @@ -22,6 +22,28 @@ logger = logging.getLogger(__name__) +def _extract_error_details(body: object) -> tuple[str | None, str | None, str | None]: + """Extract typed OpenAI error fields from error body payload.""" + if not isinstance(body, dict): + return None, None, None + + body_dict = cast(dict[str, object], body) + error_obj = body_dict.get("error") + if not isinstance(error_obj, dict): + return None, None, None + + error_dict = cast(dict[str, object], error_obj) + message = error_dict.get("message") + error_type = error_dict.get("type") + code = error_dict.get("code") + + return ( + message if isinstance(message, str) else None, + error_type if isinstance(error_type, str) else None, + code if isinstance(code, str) else None, + ) + + class OpenAIExecutor: """Executor for OpenAI Responses API - mirrors AgentFrameworkExecutor interface. @@ -138,68 +160,64 @@ async def execute_streaming(self, request: AgentFrameworkRequest) -> AsyncGenera except AuthenticationError as e: # 401 - Invalid API key or authentication issue logger.error(f"OpenAI authentication error: {e}", exc_info=True) - error_body = e.body if hasattr(e, "body") else {} - error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {} + message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None) yield { "type": "response.failed", "response": { "id": f"resp_{os.urandom(16).hex()}", "status": "failed", "error": { - "message": error_data.get("message", str(e)), - "type": error_data.get("type", "authentication_error"), - "code": error_data.get("code", "invalid_api_key"), + "message": message or str(e), + "type": error_type or "authentication_error", + "code": code or "invalid_api_key", }, }, } except PermissionDeniedError as e: # 403 - Permission denied logger.error(f"OpenAI permission denied: {e}", exc_info=True) - error_body = e.body if hasattr(e, "body") else {} - error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {} + message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None) yield { "type": "response.failed", "response": { "id": f"resp_{os.urandom(16).hex()}", "status": "failed", "error": { - "message": error_data.get("message", str(e)), - "type": error_data.get("type", "permission_denied"), - "code": error_data.get("code", "insufficient_permissions"), + "message": message or str(e), + "type": error_type or "permission_denied", + "code": code or "insufficient_permissions", }, }, } except RateLimitError as e: # 429 - Rate limit exceeded logger.error(f"OpenAI rate limit exceeded: {e}", exc_info=True) - error_body = e.body if hasattr(e, "body") else {} - error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {} + message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None) yield { "type": "response.failed", "response": { "id": f"resp_{os.urandom(16).hex()}", "status": "failed", "error": { - "message": error_data.get("message", str(e)), - "type": error_data.get("type", "rate_limit_error"), - "code": error_data.get("code", "rate_limit_exceeded"), + "message": message or str(e), + "type": error_type or "rate_limit_error", + "code": code or "rate_limit_exceeded", }, }, } except APIStatusError as e: # Other OpenAI API errors logger.error(f"OpenAI API error: {e}", exc_info=True) - error_body = e.body if hasattr(e, "body") else {} - error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {} + message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None) yield { "type": "response.failed", "response": { "id": f"resp_{os.urandom(16).hex()}", "status": "failed", "error": { - "message": error_data.get("message", str(e)), - "type": error_data.get("type", "api_error"), - "code": error_data.get("code", "unknown_error"), + "message": message or str(e), + "type": error_type or "api_error", + "code": code or "unknown_error", }, }, } diff --git a/python/packages/devui/agent_framework_devui/_server.py b/python/packages/devui/agent_framework_devui/_server.py index e7994d3d3b..f8d862b731 100644 --- a/python/packages/devui/agent_framework_devui/_server.py +++ b/python/packages/devui/agent_framework_devui/_server.py @@ -31,6 +31,28 @@ logger = logging.getLogger(__name__) + +def _extract_error_details(body: object) -> tuple[str | None, str | None, str | None]: + """Extract typed OpenAI-style error payload fields.""" + if not isinstance(body, dict): + return None, None, None + + body_dict = cast(dict[str, object], body) + error_obj = body_dict.get("error") + if not isinstance(error_obj, dict): + return None, None, None + + error_dict = cast(dict[str, object], error_obj) + message = error_dict.get("message") + error_type = error_dict.get("type") + code = error_dict.get("code") + + return ( + message if isinstance(message, str) else None, + error_type if isinstance(error_type, str) else None, + code if isinstance(code, str) else None, + ) + # Get package version try: __version__ = importlib.metadata.version("agent-framework-devui") @@ -83,6 +105,10 @@ def __init__( self._pending_entities: list[Any] | None = None self._running_tasks: dict[str, asyncio.Task[Any]] = {} # Track running response tasks for cancellation + def set_pending_entities(self, entities: list[Any]) -> None: + """Set in-memory entities to register on startup.""" + self._pending_entities = entities + def _is_dev_mode(self) -> bool: """Check if running in developer mode. @@ -378,6 +404,8 @@ async def auth_middleware(request: Request, call_next: Callable[[Request], Await # Token valid, proceed return await call_next(request) + _ = auth_middleware + self._register_routes(app) self._mount_ui(app) @@ -452,7 +480,7 @@ async def get_entity_info(entity_id: str) -> EntityInfo: if entity_info.type == "workflow" and entity_obj: # Entity object already loaded by load_entity() above # Get workflow structure - workflow_dump = None + workflow_dump: dict[str, Any] | str | None = None if hasattr(entity_obj, "to_dict") and callable(getattr(entity_obj, "to_dict", None)): try: workflow_dump = entity_obj.to_dict() # type: ignore[attr-defined] @@ -475,7 +503,11 @@ async def get_entity_info(entity_id: str) -> EntityInfo: except Exception: workflow_dump = raw_dump else: - workflow_dump = parsed_dump if isinstance(parsed_dump, dict) else raw_dump + if isinstance(parsed_dump, dict): + parsed_dump_dict = cast(dict[str, Any], parsed_dump) + workflow_dump = {str(k): v for k, v in parsed_dump_dict.items()} + else: + workflow_dump = raw_dump else: workflow_dump = raw_dump elif hasattr(entity_obj, "__dict__"): @@ -838,34 +870,31 @@ async def create_conversation(raw_request: Request) -> dict[str, Any] | JSONResp except AuthenticationError as e: # 401 - Invalid API key or authentication issue logger.error(f"OpenAI authentication error creating conversation: {e}") - error_body = e.body if hasattr(e, "body") else {} - error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {} + message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None) error = OpenAIError.create( - message=error_data.get("message", str(e)), - type=error_data.get("type", "authentication_error"), - code=error_data.get("code", "invalid_api_key"), + message=message or str(e), + type=error_type or "authentication_error", + code=code or "invalid_api_key", ) return JSONResponse(status_code=401, content=error.to_dict()) except PermissionDeniedError as e: # 403 - Permission denied logger.error(f"OpenAI permission denied creating conversation: {e}") - error_body = e.body if hasattr(e, "body") else {} - error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {} + message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None) error = OpenAIError.create( - message=error_data.get("message", str(e)), - type=error_data.get("type", "permission_denied"), - code=error_data.get("code", "insufficient_permissions"), + message=message or str(e), + type=error_type or "permission_denied", + code=code or "insufficient_permissions", ) return JSONResponse(status_code=403, content=error.to_dict()) except APIStatusError as e: # Other OpenAI API errors (rate limit, etc.) logger.error(f"OpenAI API error creating conversation: {e}") - error_body = e.body if hasattr(e, "body") else {} - error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {} + message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None) error = OpenAIError.create( - message=error_data.get("message", str(e)), - type=error_data.get("type", "api_error"), - code=error_data.get("code", "unknown_error"), + message=message or str(e), + type=error_type or "api_error", + code=code or "unknown_error", ) return JSONResponse( status_code=e.status_code if hasattr(e, "status_code") else 500, content=error.to_dict() @@ -902,7 +931,7 @@ async def list_conversations( executor = await self._ensure_executor() # Build filter criteria - filters = {} + filters: dict[str, str] = {} if agent_id: filters["agent_id"] = agent_id if entity_id: @@ -997,15 +1026,16 @@ async def list_conversation_items( conversation_id, limit=limit, after=after, order=order ) # Handle both Pydantic models and dicts (some stores return raw dicts) - serialized_items = [] + serialized_items: list[dict[str, Any]] = [] for item in items: if hasattr(item, "model_dump"): serialized_items.append(item.model_dump()) elif isinstance(item, dict): - serialized_items.append(item) + item_dict = cast(dict[str, Any], item) + serialized_items.append({str(k): v for k, v in item_dict.items()}) else: logger.warning(f"Unexpected item type: {type(item)}, converting to dict") - serialized_items.append(dict(item)) + serialized_items.append({str(k): v for k, v in dict(item).items()}) # Get stored traces for context inspection (DevUI extension) traces = executor.conversation_store.get_traces(conversation_id) @@ -1038,9 +1068,14 @@ async def retrieve_conversation_item(conversation_id: str, item_id: str) -> dict if not item: raise HTTPException(status_code=404, detail="Item not found") # Handle both Pydantic models and dicts - result: dict[str, Any] = ( - item.model_dump() if hasattr(item, "model_dump") else cast(dict[str, Any], item) - ) + result: dict[str, Any] + if hasattr(item, "model_dump"): + result = item.model_dump() + elif isinstance(item, dict): + item_dict = cast(dict[str, Any], item) + result = {str(k): v for k, v in item_dict.items()} + else: + result = {"value": item} return result except HTTPException: raise @@ -1085,16 +1120,42 @@ async def delete_conversation_item(conversation_id: str, item_id: str) -> dict[s # Checkpoints are exposed as conversation items with type="checkpoint" # ============================================================================ + _registered_route_handlers = ( + health_check, + get_meta, + discover_entities, + get_entity_info, + reload_entity, + create_deployment, + list_deployments, + get_deployment, + delete_deployment, + deploy_entity, + create_response, + cancel_response, + create_conversation, + list_conversations, + retrieve_conversation, + update_conversation, + delete_conversation, + create_conversation_items, + list_conversation_items, + retrieve_conversation_item, + delete_conversation_item, + ) + _ = _registered_route_handlers + async def _stream_execution( self, executor: AgentFrameworkExecutor, request: AgentFrameworkRequest ) -> AsyncGenerator[str]: """Stream execution directly through executor.""" try: # Collect events for final response.completed event - events = [] + events: list[Any] = [] # Get conversation_id for trace storage - conversation_id = request._get_conversation_id() + conversation_getter = getattr(request, "_get_conversation_id", None) + conversation_id = conversation_getter() if callable(conversation_getter) else None # Stream all events async for event in executor.execute_streaming(request): @@ -1105,7 +1166,8 @@ async def _stream_execution( try: trace_data = event.data if hasattr(event, "data") else None if trace_data: - executor.conversation_store.add_trace(conversation_id, trace_data) + if isinstance(conversation_id, str): + executor.conversation_store.add_trace(conversation_id, trace_data) except Exception as e: logger.debug(f"Failed to store trace event: {e}") @@ -1136,8 +1198,9 @@ async def _stream_execution( # We need to increment from that last_seq = 0 for event in reversed(events): - if hasattr(event, "sequence_number") and event.sequence_number is not None: - last_seq = event.sequence_number + sequence_number = getattr(event, "sequence_number", None) + if isinstance(sequence_number, int): + last_seq = sequence_number break completed_event = ResponseCompletedEvent( diff --git a/python/packages/devui/agent_framework_devui/_session.py b/python/packages/devui/agent_framework_devui/_session.py index 5cabeee072..0f53e09e3d 100644 --- a/python/packages/devui/agent_framework_devui/_session.py +++ b/python/packages/devui/agent_framework_devui/_session.py @@ -5,13 +5,37 @@ import logging import uuid from datetime import datetime -from typing import Any +from typing import Any, TypedDict, cast + +from typing_extensions import NotRequired logger = logging.getLogger(__name__) -# Type aliases for better readability -SessionData = dict[str, Any] -RequestRecord = dict[str, Any] + +class RequestRecord(TypedDict): + """Tracked execution request data.""" + + id: str + timestamp: datetime + entity_id: str + executor: str + input: Any + model_id: str + stream: bool + execution_time: NotRequired[float] + status: NotRequired[str] + + +class SessionData(TypedDict): + """Stored session state.""" + + id: str + created_at: datetime + requests: list[RequestRecord] + context: dict[str, Any] + active: bool + + SessionSummary = dict[str, Any] @@ -95,7 +119,7 @@ def add_request_record( "stream": True, } session["requests"].append(request_record) - return str(request_record["id"]) + return request_record["id"] def update_request_record(self, session_id: str, request_id: str, updates: dict[str, Any]) -> None: """Update a request record in a session. @@ -111,7 +135,8 @@ def update_request_record(self, session_id: str, request_id: str, updates: dict[ for request in session["requests"]: if request["id"] == request_id: - request.update(updates) + request_data = cast(dict[str, Any], request) + request_data.update(updates) break def get_session_history(self, session_id: str) -> SessionSummary | None: @@ -138,7 +163,7 @@ def get_session_history(self, session_id: str) -> SessionSummary | None: "timestamp": req["timestamp"].isoformat(), "entity_id": req["entity_id"], "executor": req["executor"], - "model": req["model"], + "model": req["model_id"], "input_length": len(str(req["input"])) if req["input"] else 0, "execution_time": req.get("execution_time"), "status": req.get("status", "unknown"), @@ -153,20 +178,22 @@ def get_active_sessions(self) -> list[SessionSummary]: Returns: List of active session summaries """ - active_sessions = [] + active_sessions: list[SessionSummary] = [] for session_id, session in self.sessions.items(): if session["active"]: - active_sessions.append({ - "session_id": session_id, - "created_at": session["created_at"].isoformat(), - "request_count": len(session["requests"]), - "last_activity": ( - session["requests"][-1]["timestamp"].isoformat() - if session["requests"] - else session["created_at"].isoformat() - ), - }) + active_sessions.append( + { + "session_id": session_id, + "created_at": session["created_at"].isoformat(), + "request_count": len(session["requests"]), + "last_activity": ( + session["requests"][-1]["timestamp"].isoformat() + if session["requests"] + else session["created_at"].isoformat() + ), + } + ) return active_sessions @@ -178,7 +205,7 @@ def cleanup_old_sessions(self, max_age_hours: int = 24) -> None: """ cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600) - sessions_to_remove = [] + sessions_to_remove: list[str] = [] for session_id, session in self.sessions.items(): if session["created_at"].timestamp() < cutoff_time: sessions_to_remove.append(session_id) diff --git a/python/packages/devui/agent_framework_devui/_utils.py b/python/packages/devui/agent_framework_devui/_utils.py index 66886b8ea7..bdf76f1ec5 100644 --- a/python/packages/devui/agent_framework_devui/_utils.py +++ b/python/packages/devui/agent_framework_devui/_utils.py @@ -7,12 +7,20 @@ import logging from dataclasses import fields, is_dataclass from types import UnionType -from typing import Any, Union, get_args, get_origin, get_type_hints +from typing import Any, Union, cast, get_args, get_origin, get_type_hints from agent_framework import Message logger = logging.getLogger(__name__) + +def _string_key_dict(value: object) -> dict[str, Any] | None: + if not isinstance(value, dict): + return None + + source: dict[str, Any] = cast(dict[str, Any], value) + return {str(k): v for k, v in source.items()} + # ============================================================================ # Agent Metadata Extraction # ============================================================================ @@ -39,18 +47,21 @@ def extract_agent_metadata(entity_object: Any) -> dict[str, Any]: # Try to get instructions if hasattr(entity_object, "default_options"): chat_opts = entity_object.default_options - if isinstance(chat_opts, dict): - if "instructions" in chat_opts: - metadata["instructions"] = chat_opts.get("instructions") + chat_opts_dict = _string_key_dict(chat_opts) + if chat_opts_dict is not None: + if "instructions" in chat_opts_dict: + metadata["instructions"] = chat_opts_dict.get("instructions") elif hasattr(chat_opts, "instructions"): metadata["instructions"] = chat_opts.instructions # Try to get model - check both default_options and client if hasattr(entity_object, "default_options"): chat_opts = entity_object.default_options - if isinstance(chat_opts, dict): - if chat_opts.get("model_id"): - metadata["model"] = chat_opts.get("model_id") + chat_opts_dict = _string_key_dict(chat_opts) + if chat_opts_dict is not None: + model_id = chat_opts_dict.get("model_id") + if model_id: + metadata["model"] = model_id elif hasattr(chat_opts, "model_id") and chat_opts.model_id: metadata["model"] = chat_opts.model_id if metadata["model"] is None and hasattr(entity_object, "client") and hasattr(entity_object.client, "model_id"): @@ -111,8 +122,9 @@ def extract_executor_message_types(executor: Any) -> list[Any]: if not message_types and hasattr(executor, "_handlers"): try: handlers = executor._handlers - if isinstance(handlers, dict): - message_types = list(handlers.keys()) + handlers_dict = _string_key_dict(handlers) + if handlers_dict is not None: + message_types = list(handlers_dict.keys()) except Exception as exc: # pragma: no cover - defensive logging path logger.debug(f"Failed to read executor handlers: {exc}") @@ -366,11 +378,10 @@ async def handler(self, original_request: RequestType, response: ResponseType, c _, second_param_type = param_items[1] if len(param_items) > 1 else (None, None) # Check if first param matches request_type - first_matches_request = first_param_type == request_type or ( - hasattr(first_param_type, "__name__") - and hasattr(request_type, "__name__") - and first_param_type.__name__ == request_type.__name__ - ) + first_matches_request = first_param_type == request_type + if not first_matches_request and isinstance(first_param_type, type): + request_type_name = request_type.__name__ + first_matches_request = first_param_type.__name__ == request_type_name # Verify we have a matching request type and valid response type (must be a type class) if first_matches_request and second_param_type is not None and isinstance(second_param_type, type): @@ -432,7 +443,7 @@ def generate_input_schema(input_type: type) -> dict[str, Any]: return generate_schema_from_dataclass(input_type) # 5. Fallback to string - type_name = getattr(input_type, "__name__", str(input_type)) + type_name = input_type.__name__ if isinstance(input_type, type) else str(cast(Any, input_type)) return {"type": "string", "description": f"Input type: {type_name}"} @@ -466,8 +477,9 @@ def parse_input_for_type(input_data: Any, target_type: type) -> Any: return _parse_string_input(input_data, target_type) # Handle dict input - if isinstance(input_data, dict): - return _parse_dict_input(input_data, target_type) + parsed_dict = _string_key_dict(input_data) + if parsed_dict is not None: + return _parse_dict_input(parsed_dict, target_type) # Fallback: return original return input_data diff --git a/python/packages/devui/agent_framework_devui/models/_discovery_models.py b/python/packages/devui/agent_framework_devui/models/_discovery_models.py index ff217a48d2..e3fcccb5f9 100644 --- a/python/packages/devui/agent_framework_devui/models/_discovery_models.py +++ b/python/packages/devui/agent_framework_devui/models/_discovery_models.py @@ -8,6 +8,10 @@ from pydantic import BaseModel, Field, field_validator +def _default_entities() -> list["EntityInfo"]: + return [] + + class EnvVarRequirement(BaseModel): """Environment variable requirement for an entity.""" @@ -57,7 +61,7 @@ class EntityInfo(BaseModel): class DiscoveryResponse(BaseModel): """Response model for entity discovery.""" - entities: list[EntityInfo] = Field(default_factory=list) + entities: list[EntityInfo] = Field(default_factory=_default_entities) # ============================================================================ diff --git a/python/packages/durabletask/agent_framework_durabletask/_entities.py b/python/packages/durabletask/agent_framework_durabletask/_entities.py index 650e1b8013..460b6b0429 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_entities.py +++ b/python/packages/durabletask/agent_framework_durabletask/_entities.py @@ -206,9 +206,7 @@ async def _invoke_agent( request_message=request_message, ) - run_callable = getattr(self.agent, "run", None) - if run_callable is None or not callable(run_callable): - raise AttributeError("Agent does not implement run() method") + run_callable = self.agent.run # Try streaming first with run(stream=True) try: diff --git a/python/packages/durabletask/agent_framework_durabletask/_response_utils.py b/python/packages/durabletask/agent_framework_durabletask/_response_utils.py index fe371b592f..2d0ee84d3e 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_response_utils.py +++ b/python/packages/durabletask/agent_framework_durabletask/_response_utils.py @@ -58,8 +58,8 @@ def ensure_response_format( """ if response_format is not None: # Set the response format on the response so .value knows how to parse - response._response_format = response_format - response._value_parsed = False # Reset to allow re-parsing with new format + response._response_format = response_format # pyright: ignore[reportPrivateUsage] + response._value_parsed = False # pyright: ignore[reportPrivateUsage] # Reset to allow re-parsing with new format # Access response.value to trigger parsing (may raise ValidationError) # Validate that parsing succeeded diff --git a/python/packages/durabletask/pyproject.toml b/python/packages/durabletask/pyproject.toml index 95a00929a2..923bfc9b1d 100644 --- a/python/packages/durabletask/pyproject.toml +++ b/python/packages/durabletask/pyproject.toml @@ -73,6 +73,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_durabletask"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py index 9bccc60309..4709307299 100644 --- a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py +++ b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py @@ -248,17 +248,19 @@ class MyOptions(FoundryLocalChatOptions, total=False): env_file_path=env_file_path, env_file_encoding=env_file_encoding, ) + model_id_setting: str = settings["model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] + manager = FoundryLocalManager(bootstrap=bootstrap, timeout=timeout) model_info = manager.get_model_info( - alias_or_model_id=settings["model_id"], + alias_or_model_id=model_id_setting, device=device, ) if model_info is None: message = ( - f"Model with ID or alias '{settings['model_id']}:{device.value}' not found in Foundry Local." + f"Model with ID or alias '{model_id_setting}:{device.value}' not found in Foundry Local." if device else ( - f"Model with ID or alias '{settings['model_id']}' for your current device " + f"Model with ID or alias '{model_id_setting}' for your current device " "not found in Foundry Local." ) ) diff --git a/python/packages/foundry_local/pyproject.toml b/python/packages/foundry_local/pyproject.toml index dd2af572f2..a04a21d82b 100644 --- a/python/packages/foundry_local/pyproject.toml +++ b/python/packages/foundry_local/pyproject.toml @@ -59,6 +59,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_foundry_local"] exclude = ['tests'] [tool.mypy] diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 053e0d3de0..86ecf737f8 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -7,7 +7,7 @@ import logging import sys from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence -from typing import Any, ClassVar, Generic, Literal, TypedDict, overload +from typing import Any, ClassVar, Generic, Literal, TypedDict, cast, overload from agent_framework import ( AgentMiddlewareTypes, @@ -32,6 +32,7 @@ MCPServerConfig, PermissionRequest, PermissionRequestResult, + MessageOptions, ResumeSessionConfig, SessionConfig, SystemMessageConfig, @@ -266,10 +267,13 @@ async def start(self) -> None: if self._client is None: client_options: CopilotClientOptions = {} - if self._settings["cli_path"]: - client_options["cli_path"] = self._settings["cli_path"] - if self._settings["log_level"]: - client_options["log_level"] = self._settings["log_level"] # type: ignore[typeddict-item] + cli_path = self._settings.get("cli_path") + if cli_path: + client_options["cli_path"] = cli_path + + log_level = self._settings.get("log_level") + if log_level: + client_options["log_level"] = log_level # type: ignore[typeddict-item] self._client = CopilotClient(client_options if client_options else None) @@ -372,14 +376,15 @@ async def _run_impl( session = self.create_session() opts: dict[str, Any] = dict(options) if options else {} - timeout = opts.pop("timeout", None) or self._settings["timeout"] or DEFAULT_TIMEOUT_SECONDS + timeout = opts.pop("timeout", None) or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts) input_messages = normalize_messages(messages) prompt = "\n".join([message.text for message in input_messages]) + message_options = cast(MessageOptions, {"prompt": prompt}) try: - response_event = await copilot_session.send_and_wait({"prompt": prompt}, timeout=timeout) + response_event = await copilot_session.send_and_wait(message_options, timeout=timeout) except Exception as ex: raise AgentException(f"GitHub Copilot request failed: {ex}") from ex @@ -439,6 +444,7 @@ async def _stream_updates( copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts) input_messages = normalize_messages(messages) prompt = "\n".join([message.text for message in input_messages]) + message_options = cast(MessageOptions, {"prompt": prompt}) queue: asyncio.Queue[AgentResponseUpdate | Exception | None] = asyncio.Queue() @@ -462,7 +468,7 @@ def event_handler(event: SessionEvent) -> None: unsubscribe = copilot_session.on(event_handler) try: - await copilot_session.send({"prompt": prompt}) + await copilot_session.send(message_options) while (item := await queue.get()) is not None: if isinstance(item, Exception): @@ -597,7 +603,7 @@ async def _create_session( opts = runtime_options or {} config: SessionConfig = {"streaming": streaming} - model = opts.get("model") or self._settings["model"] + model = opts.get("model") or self._settings.get("model") if model: config["model"] = model # type: ignore[typeddict-item] diff --git a/python/packages/github_copilot/pyproject.toml b/python/packages/github_copilot/pyproject.toml index 1a60ff4298..940fbf5fa7 100644 --- a/python/packages/github_copilot/pyproject.toml +++ b/python/packages/github_copilot/pyproject.toml @@ -61,6 +61,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_github_copilot"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py index 08619b84bc..ec7a90640d 100644 --- a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py +++ b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py @@ -13,7 +13,7 @@ from collections.abc import Iterable from datetime import datetime from pathlib import Path -from typing import Any +from typing import Any, cast from opentelemetry.trace import NoOpTracer, SpanKind, get_tracer from tqdm import tqdm @@ -163,7 +163,7 @@ def _normalize_str(s: str, remove_punct: bool = True) -> str: return no_spaces.lower() -def gaia_scorer(model_answer: str, ground_truth: str) -> bool: +def gaia_scorer(model_answer: str | None, ground_truth: str) -> bool: """Official GAIA scoring function. Args: @@ -193,7 +193,7 @@ def is_float(x: Any) -> bool: ma_elems = _split_string(model_answer) if len(gt_elems) != len(ma_elems): return False - comparisons = [] + comparisons: list[bool] = [] for ma, gt in zip(ma_elems, gt_elems, strict=False): if is_float(gt): comparisons.append(abs(_normalize_number_str(ma) - float(gt)) < 1e-6) @@ -204,18 +204,39 @@ def is_float(x: Any) -> bool: return _normalize_str(model_answer) == _normalize_str(ground_truth) +def _coerce_record(raw: object) -> dict[str, Any] | None: + if isinstance(raw, dict): + raw_dict = cast(dict[object, Any], raw) + if all(isinstance(key, str) for key in raw_dict): + return cast(dict[str, Any], raw_dict) + return None + + +def _parse_level(level: object) -> int | None: + if isinstance(level, int): + return level + if isinstance(level, str) and level.isdigit(): + return int(level) + return None + + def _read_jsonl(path: Path) -> Iterable[dict[str, Any]]: """Read JSONL file and yield parsed records.""" with path.open("rb") as f: for line in f: if not line.strip(): continue + parsed: object try: import orjson - yield orjson.loads(line) + parsed = orjson.loads(line) except Exception: - yield json.loads(line) + parsed = json.loads(line) + + record = _coerce_record(parsed) + if record is not None: + yield record def _load_gaia_local(repo_dir: Path, wanted_levels: list[int] | None = None, max_n: int | None = None) -> list[Task]: @@ -232,41 +253,41 @@ def _load_gaia_local(repo_dir: Path, wanted_levels: list[int] | None = None, max try: import pyarrow.parquet as pq - table = pq.read_table(p) - for row in table.to_pylist(): + pq_any = cast(Any, pq) + table: Any = pq_any.read_table(p) + rows = cast(list[object], table.to_pylist()) + for row in rows: + record = _coerce_record(row) + if record is None: + continue + # Robustly extract fields used across variants - q = row.get("Question") or row.get("question") or row.get("query") or row.get("prompt") - ans = row.get("Final answer") or row.get("answer") or row.get("final_answer") + q_obj = record.get("Question") or record.get("question") or record.get("query") or record.get("prompt") + ans = record.get("Final answer") or record.get("answer") or record.get("final_answer") + if not isinstance(q_obj, str): + continue + q = q_obj + qid = str( - row.get("task_id") - or row.get("question_id") - or row.get("id") - or row.get("uuid") + record.get("task_id") + or record.get("question_id") + or record.get("id") + or record.get("uuid") or f"{p.stem}:{len(tasks)}" ) - lvl = row.get("Level") or row.get("level") - - # Convert level to int if it's a string - def _parse_level(lvl: Any) -> int | None: - """Parse level value to integer if possible.""" - if isinstance(lvl, int): - return lvl - if isinstance(lvl, str) and lvl.isdigit(): - return int(lvl) - return None - - lvl = _parse_level(lvl) - fname = row.get("file_name") or row.get("filename") or None + lvl = _parse_level(record.get("Level") or record.get("level")) + fname_obj = record.get("file_name") or record.get("filename") + fname = fname_obj if isinstance(fname_obj, str) else None # Only evaluate examples with public answers (dev/validation split) # Skip if no question, no answer, or answer is placeholder like "?" - if not q or ans is None or str(ans).strip() in ["?", ""]: + if ans is None or str(ans).strip() in ["?", ""]: continue if wanted_levels and (lvl not in wanted_levels): continue - tasks.append(Task(task_id=qid, question=q, answer=str(ans), level=lvl, file_name=fname, metadata=row)) + tasks.append(Task(task_id=qid, question=q, answer=str(ans), level=lvl, file_name=fname, metadata=record)) except ImportError: print("Warning: pyarrow not installed. Install with: pip install pyarrow") continue @@ -279,8 +300,12 @@ def _parse_level(lvl: Any) -> int | None: for p in repo_dir.rglob("metadata.jsonl"): for rec in _read_jsonl(p): # Robustly extract fields used across variants - q = rec.get("Question") or rec.get("question") or rec.get("query") or rec.get("prompt") + q_obj = rec.get("Question") or rec.get("question") or rec.get("query") or rec.get("prompt") ans = rec.get("Final answer") or rec.get("answer") or rec.get("final_answer") + if not isinstance(q_obj, str): + continue + q = q_obj + qid = str( rec.get("task_id") or rec.get("question_id") @@ -288,15 +313,13 @@ def _parse_level(lvl: Any) -> int | None: or rec.get("uuid") or f"{p.stem}:{len(tasks)}" ) - lvl = rec.get("Level") or rec.get("level") - # Convert level to int if it's a string - if isinstance(lvl, str) and lvl.isdigit(): - lvl = int(lvl) - fname = rec.get("file_name") or rec.get("filename") or None + lvl = _parse_level(rec.get("Level") or rec.get("level")) + fname_obj = rec.get("file_name") or rec.get("filename") + fname = fname_obj if isinstance(fname_obj, str) else None # Only evaluate examples with public answers (dev/validation split) # Skip if no question, no answer, or answer is placeholder like "?" - if not q or ans is None or str(ans).strip() in ["?", ""]: + if ans is None or str(ans).strip() in ["?", ""]: continue if wanted_levels and (lvl not in wanted_levels): @@ -366,9 +389,10 @@ def _ensure_data(self) -> Path: "with access to gaia-benchmark/GAIA." ) - from huggingface_hub import snapshot_download + import huggingface_hub - local_dir = snapshot_download( # type: ignore + hf_hub = cast(Any, huggingface_hub) + local_dir = hf_hub.snapshot_download( repo_id="gaia-benchmark/GAIA", repo_type="dataset", revision="682dd723ee1e1697e00360edccf2366dc8418dd9", @@ -376,6 +400,8 @@ def _ensure_data(self) -> Path: local_dir=str(self.data_dir), force_download=False, ) + if not isinstance(local_dir, str): + raise TypeError("snapshot_download returned unexpected non-string path") return Path(local_dir) async def _run_single_task( @@ -522,7 +548,7 @@ async def run( # Run tasks semaphore = asyncio.Semaphore(parallel) - results = [] + results: list[TaskResult] = [] tasks_coroutines = [self._run_single_task(task, task_runner, semaphore, timeout) for task in tasks] @@ -561,7 +587,7 @@ def _save_results(self, results: list[TaskResult], output_path: str) -> None: with open(output_path, "w", encoding="utf-8") as f: for result in results: # Convert messages to serializable format - serializable_messages = [] + serializable_messages: list[dict[str, Any] | str] = [] if result.prediction.messages: for msg in result.prediction.messages: if hasattr(msg, "model_dump"): @@ -569,7 +595,7 @@ def _save_results(self, results: list[TaskResult], output_path: str) -> None: serializable_messages.append(msg.model_dump()) elif hasattr(msg, "__dict__"): # Regular object with attributes - serializable_messages.append(vars(msg)) + serializable_messages.append(cast(dict[str, Any], getattr(msg, "__dict__", {}))) else: # Fallback to string representation serializable_messages.append(str(msg)) @@ -614,16 +640,20 @@ def viewer_main() -> None: args = parser.parse_args() # Load results - results = [] + results: list[dict[str, Any]] = [] with open(args.results_file, encoding="utf-8") as f: for line in f: if line.strip(): try: import orjson - results.append(orjson.loads(line)) + parsed: object = orjson.loads(line) except ImportError: - results.append(json.loads(line)) + parsed = json.loads(line) + + record = _coerce_record(parsed) + if record is not None: + results.append(record) # Apply filters if args.level is not None: diff --git a/python/packages/lab/pyproject.toml b/python/packages/lab/pyproject.toml index 03d2ed9e55..d64c0b0593 100644 --- a/python/packages/lab/pyproject.toml +++ b/python/packages/lab/pyproject.toml @@ -122,6 +122,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["gaia/agent_framework_lab_gaia", "lightning/agent_framework_lab_lightning", "tau2/agent_framework_lab_tau2"] exclude = ['gaia/tests', 'lightning/tests', 'tau2/tests', 'namespace', '**/samples'] [tool.mypy] diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py index bd8d521e28..bb617e3ad9 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py @@ -23,7 +23,7 @@ def filter_out_function_calls(messages: list[Content]) -> list[Content]: """Remove function call content from message contents.""" return [content for content in messages if content.type != "function_call"] - flipped_messages = [] + flipped_messages: list[Message] = [] for msg in messages: role_value = _get_role_value(msg.role) if role_value == "assistant": diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py index 75c0676cb6..9847a460bc 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py @@ -3,7 +3,7 @@ import json from collections.abc import Mapping from copy import deepcopy -from typing import Any +from typing import Any, TypeGuard, cast import numpy as np from agent_framework._tools import FunctionTool @@ -27,6 +27,27 @@ _original_set_state = Environment.set_state +def _to_str(value: object, default: str = "") -> str: + if isinstance(value, str): + return value + if value is None: + return default + return str(value) + + + + +def _is_any_list(value: Any) -> TypeGuard[list[Any]]: + return isinstance(value, list) + + +def _is_any_mapping(value: Any) -> TypeGuard[Mapping[Any, Any]]: + return isinstance(value, Mapping) + + +def _is_any_sequence(value: Any) -> TypeGuard[list[Any] | tuple[Any, ...] | set[Any]]: + return isinstance(value, (list, tuple, set)) + def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool: """Convert a tau2 Tool to a FunctionTool for agent framework compatibility. @@ -41,7 +62,7 @@ def wrapped_func(**kwargs: Any) -> Any: return FunctionTool( name=tau2_tool.name, - description=tau2_tool._get_description(), + description=tau2_tool._get_description(), # pyright: ignore[reportPrivateUsage] func=wrapped_func, input_model=tau2_tool.params, ) @@ -53,27 +74,26 @@ def convert_agent_framework_messages_to_tau2_messages(messages: list[Message]) - Handles role mapping, text extraction, function calls, and function results. Function results are converted to separate ToolMessage instances. """ - tau2_messages = [] + tau2_messages: list[Tau2Message] = [] for msg in messages: role_str = str(msg.role) # Extract text content from all text-type contents - text_content = None text_contents = [c for c in msg.contents if hasattr(c, "text") and hasattr(c, "type") and c.type == "text"] - if text_contents: - text_content = " ".join(c.text for c in text_contents) # type: ignore[misc] + content_parts: list[str] = [_to_str(getattr(c, "text", "")) for c in text_contents] + content_value = " ".join(content_parts) # Extract function calls and convert to ToolCall objects function_calls = [c for c in msg.contents if hasattr(c, "type") and c.type == "function_call"] - tool_calls = None + tool_calls: list[ToolCall] | None = None if function_calls: tool_calls = [] for fc in function_calls: arguments = fc.parse_arguments() or {} tool_call = ToolCall( - id=fc.call_id, - name=fc.name, + id=_to_str(fc.call_id), + name=_to_str(fc.name), arguments=arguments, requestor="assistant" if role_str == "assistant" else "user", ) @@ -84,11 +104,11 @@ def convert_agent_framework_messages_to_tau2_messages(messages: list[Message]) - # Create main message based on role if role_str == "system": - tau2_messages.append(SystemMessage(role="system", content=text_content)) + tau2_messages.append(SystemMessage(role="system", content=content_value)) elif role_str == "user": - tau2_messages.append(UserMessage(role="user", content=text_content, tool_calls=tool_calls)) + tau2_messages.append(UserMessage(role="user", content=content_value, tool_calls=tool_calls)) elif role_str == "assistant": - tau2_messages.append(AssistantMessage(role="assistant", content=text_content, tool_calls=tool_calls)) + tau2_messages.append(AssistantMessage(role="assistant", content=content_value, tool_calls=tool_calls)) elif role_str == "tool": # Tool messages are handled as function results below pass @@ -98,7 +118,7 @@ def convert_agent_framework_messages_to_tau2_messages(messages: list[Message]) - dumpable_content = _dump_function_result(fr.result) content = dumpable_content if isinstance(dumpable_content, str) else json.dumps(dumpable_content) tool_msg = ToolMessage( - id=fr.call_id, + id=_to_str(fr.call_id), role="tool", content=content, requestor="assistant", # Most tool calls originate from assistant @@ -126,12 +146,10 @@ def set_state( if self.solo_mode and any(isinstance(message, UserMessage) for message in message_history): raise ValueError("User messages are not allowed in solo mode") - def get_actions_from_messages( - messages: list[Tau2Message], - ) -> list[tuple[ToolCall, ToolMessage]]: + def get_actions_from_messages(messages: list[Tau2Message]) -> list[tuple[ToolCall, ToolMessage]]: """Get the actions from the messages.""" messages = deepcopy(messages)[::-1] - actions = [] + actions: list[tuple[ToolCall, ToolMessage]] = [] while messages: message = messages.pop() if isinstance(message, ToolMessage): @@ -153,10 +171,13 @@ def get_actions_from_messages( return actions if initialization_data is not None: - if initialization_data.agent_data is not None: - self.tools.update_db(initialization_data.agent_data) - if initialization_data.user_data is not None: - self.user_tools.update_db(initialization_data.user_data) + agent_data = cast(object, getattr(initialization_data, "agent_data", None)) + if isinstance(agent_data, dict): + self.tools.update_db(cast(dict[str, Any], agent_data)) + + user_data = cast(object, getattr(initialization_data, "user_data", None)) + if isinstance(user_data, dict): + self.user_tools.update_db(cast(dict[str, Any], user_data)) if initialization_actions is not None: for action in initialization_actions: @@ -188,10 +209,11 @@ def unpatch_env_set_state() -> None: def _dump_function_result(result: Any) -> Any: if isinstance(result, BaseModel): return result.model_dump_json() - if isinstance(result, list): + if _is_any_list(result): return [_dump_function_result(item) for item in result] if isinstance(result, dict): - return {k: _dump_function_result(v) for k, v in result.items()} + result_dict = cast(dict[str, Any], result) + return {k: _dump_function_result(v) for k, v in result_dict.items()} if result is None: return None return result @@ -208,11 +230,11 @@ def _to_native(obj: Any) -> Any: return _to_native(obj.item()) # 3) Dict-like -> dict - if isinstance(obj, Mapping): + if _is_any_mapping(obj): return {_to_native(k): _to_native(v) for k, v in obj.items()} # 4) Lists/Tuples/Sets -> list - if isinstance(obj, (list, tuple, set)): + if _is_any_sequence(obj): return [_to_native(x) for x in obj] # 5) Anything else: leave as-is @@ -227,9 +249,10 @@ def _recursive_json_deserialize(obj: Any) -> Any: return _recursive_json_deserialize(deserialized) except (json.JSONDecodeError, TypeError): return obj - elif isinstance(obj, list): + elif _is_any_list(obj): return [_recursive_json_deserialize(item) for item in obj] elif isinstance(obj, dict): - return {k: _recursive_json_deserialize(v) for k, v in obj.items()} + typed_obj = cast(dict[str, Any], obj) + return {k: _recursive_json_deserialize(v) for k, v in typed_obj.items()} else: return obj diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py index 78a9496444..2a29e5b544 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py @@ -3,7 +3,7 @@ from __future__ import annotations import uuid -from typing import Any +from typing import Any, cast from agent_framework import ( Agent, @@ -38,6 +38,15 @@ __all__ = ["ASSISTANT_AGENT_ID", "ORCHESTRATOR_ID", "USER_SIMULATOR_ID", "TaskRunner"] + +def _get_openai_schema(tool: Any) -> dict[str, Any]: + schema = getattr(tool, "openai_schema", None) + if isinstance(schema, dict): + schema_dict = cast(dict[object, Any], schema) + if all(isinstance(key, str) for key in schema_dict): + return cast(dict[str, Any], schema_dict) + raise TypeError(f"Tool {tool} does not expose a dict openai_schema") + # Agent instructions matching tau2's LLMAgent ASSISTANT_AGENT_INSTRUCTION = """ You are a customer service agent that helps the user according to the provided below. @@ -205,7 +214,7 @@ def assistant_agent(self, assistant_chat_client: SupportsChatGetResponse) -> Age context_providers=[ SlidingWindowHistoryProvider( system_message=assistant_system_prompt, - tool_definitions=[tool.openai_schema for tool in tools], + tool_definitions=[_get_openai_schema(tool) for tool in tools], max_tokens=self.assistant_window_size, ) ], diff --git a/python/packages/mem0/agent_framework_mem0/_context_provider.py b/python/packages/mem0/agent_framework_mem0/_context_provider.py index 26ebca2d11..36b878e411 100644 --- a/python/packages/mem0/agent_framework_mem0/_context_provider.py +++ b/python/packages/mem0/agent_framework_mem0/_context_provider.py @@ -88,7 +88,7 @@ async def __aenter__(self) -> Self: async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: """Async context manager exit.""" if self._should_close_client and self.mem0_client and isinstance(self.mem0_client, AbstractAsyncContextManager): - await self.mem0_client.__aexit__(exc_type, exc_val, exc_tb) + await self.mem0_client.__aexit__(exc_type, exc_val, exc_tb) # pyright: ignore[reportUnknownMemberType] # -- Hooks pattern --------------------------------------------------------- diff --git a/python/packages/mem0/pyproject.toml b/python/packages/mem0/pyproject.toml index dc20e77fb6..0d04fe802f 100644 --- a/python/packages/mem0/pyproject.toml +++ b/python/packages/mem0/pyproject.toml @@ -61,6 +61,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_mem0"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index cc7fc0c9a7..796c17107e 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -329,7 +329,7 @@ def __init__( env_file_path=env_file_path, ) - self.model_id = ollama_settings["model_id"] + self.model_id = ollama_settings["model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] # we can just pass in None for the host, the default is set by the Ollama package. self.client = client or AsyncClient(host=ollama_settings.get("host")) # Save Host URL for serialization with to_dict() diff --git a/python/packages/ollama/agent_framework_ollama/_embedding_client.py b/python/packages/ollama/agent_framework_ollama/_embedding_client.py index 4fcf75b465..f5063159ef 100644 --- a/python/packages/ollama/agent_framework_ollama/_embedding_client.py +++ b/python/packages/ollama/agent_framework_ollama/_embedding_client.py @@ -107,7 +107,7 @@ def __init__( env_file_encoding=env_file_encoding, ) - self.model_id = ollama_settings["embedding_model_id"] + self.model_id = ollama_settings["embedding_model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] self.client = client or AsyncClient(host=ollama_settings.get("host")) self.host = str(self.client._client.base_url) # pyright: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType] super().__init__(**kwargs) @@ -134,10 +134,19 @@ async def get_embeddings( Raises: ValueError: If model_id is not provided or values is empty. """ + opts: dict[str, Any] = dict(options) if options else {} + if not values: - return GeneratedEmbeddings([], options=options) + return GeneratedEmbeddings([], options=None) + + response_options: EmbeddingGenerationOptions | None = None + if options: + response_options = {} + if (model_id := opts.get("model_id")) is not None: + response_options["model_id"] = model_id + if (dimensions := opts.get("dimensions")) is not None: + response_options["dimensions"] = dimensions - opts: dict[str, Any] = dict(options) if options else {} model = opts.get("model_id") or self.model_id if not model: raise ValueError("model_id is required") @@ -166,7 +175,7 @@ async def get_embeddings( if prompt_eval_count is not None: usage_dict = {"input_token_count": prompt_eval_count} - return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict) + return GeneratedEmbeddings(embeddings, options=response_options, usage=usage_dict) class OllamaEmbeddingClient( diff --git a/python/packages/ollama/pyproject.toml b/python/packages/ollama/pyproject.toml index c8bd9052ad..57cbcd3b96 100644 --- a/python/packages/ollama/pyproject.toml +++ b/python/packages/ollama/pyproject.toml @@ -62,6 +62,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_ollama"] exclude = ['tests'] [tool.mypy] diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index 5d6e84ef05..e117b30aa9 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -33,25 +33,24 @@ import json import logging import sys -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass -from typing import Any, cast +from typing import Any -from agent_framework import Agent, SupportsAgentRun +from agent_framework import Agent, ChatOptions, SupportsAgentRun from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware from agent_framework._sessions import AgentSession from agent_framework._tools import FunctionTool, tool -from agent_framework._types import AgentResponse, AgentResponseUpdate, Content, Message -from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse +from agent_framework._types import AgentResponse, Content, Message +from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest from agent_framework._workflows._agent_utils import resolve_agent_id from agent_framework._workflows._checkpoint import CheckpointStorage from agent_framework._workflows._events import WorkflowEvent from agent_framework._workflows._request_info_mixin import response_handler +from agent_framework._workflows._typing_utils import is_chat_agent from agent_framework._workflows._workflow import Workflow from agent_framework._workflows._workflow_builder import WorkflowBuilder from agent_framework._workflows._workflow_context import WorkflowContext -from typing_extensions import Never - from ._base_group_chat_orchestrator import TerminationCondition from ._orchestrator_helpers import clean_conversation_for_handoff @@ -254,7 +253,7 @@ def _prepare_agent_with_handoffs( """ # Clone the agent to avoid mutating the original - cloned_agent = self._clone_chat_agent(agent) # type: ignore + cloned_agent = self._clone_chat_agent(agent) # Add handoff tools to the cloned agent self._apply_auto_tools(cloned_agent, handoffs) # Add middleware to handle handoff tool invocations @@ -347,7 +346,7 @@ def _persist_missing_approved_function_results( ) ) - def _clone_chat_agent(self, agent: Agent) -> Agent: + def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]: """Produce a deep copy of the Agent while preserving runtime configuration.""" options = agent.default_options middleware = list(agent.middleware or []) @@ -365,28 +364,43 @@ def _clone_chat_agent(self, agent: Agent) -> Agent: metadata = options.get("metadata") # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. - cloned_options: dict[str, Any] = { + cloned_options: ChatOptions[None] = { "allow_multiple_tool_calls": False, # Handoff workflows already manage full conversation context explicitly # across executors. Keep provider-side conversation storage disabled to # avoid stale tool-call state (Responses API previous_response chains). "store": False, - "frequency_penalty": options.get("frequency_penalty"), - "instructions": options.get("instructions"), - "logit_bias": dict(logit_bias) if logit_bias else None, - "max_tokens": options.get("max_tokens"), - "metadata": dict(metadata) if metadata else None, - "model_id": options.get("model_id"), - "presence_penalty": options.get("presence_penalty"), - "response_format": options.get("response_format"), - "seed": options.get("seed"), - "stop": options.get("stop"), - "temperature": options.get("temperature"), - "tool_choice": options.get("tool_choice"), - "tools": all_tools if all_tools else None, - "top_p": options.get("top_p"), - "user": options.get("user"), } + if (frequency_penalty := options.get("frequency_penalty")) is not None: + cloned_options["frequency_penalty"] = frequency_penalty + if (instructions := options.get("instructions")) is not None: + cloned_options["instructions"] = instructions + if logit_bias: + cloned_options["logit_bias"] = dict(logit_bias) + if (max_tokens := options.get("max_tokens")) is not None: + cloned_options["max_tokens"] = max_tokens + if metadata: + cloned_options["metadata"] = dict(metadata) + if (model_id := options.get("model_id")) is not None: + cloned_options["model_id"] = model_id + if (presence_penalty := options.get("presence_penalty")) is not None: + cloned_options["presence_penalty"] = presence_penalty + if (response_format := options.get("response_format")) is not None: + cloned_options["response_format"] = response_format + if (seed := options.get("seed")) is not None: + cloned_options["seed"] = seed + if (stop := options.get("stop")) is not None: + cloned_options["stop"] = stop + if (temperature := options.get("temperature")) is not None: + cloned_options["temperature"] = temperature + if (tool_choice := options.get("tool_choice")) is not None: + cloned_options["tool_choice"] = tool_choice + if all_tools: + cloned_options["tools"] = all_tools + if (top_p := options.get("top_p")) is not None: + cloned_options["top_p"] = top_p + if (user := options.get("user")) is not None: + cloned_options["user"] = user return Agent( client=agent.client, @@ -395,7 +409,7 @@ def _clone_chat_agent(self, agent: Agent) -> Agent: description=agent.description, context_providers=agent.context_providers, middleware=middleware, - default_options=cloned_options, # type: ignore[arg-type] + default_options=cloned_options, ) def _apply_auto_tools(self, agent: Agent, targets: Sequence[HandoffConfiguration]) -> None: @@ -446,7 +460,7 @@ def _handoff_tool() -> None: @override async def _run_agent_and_emit( - self, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate] + self, ctx: WorkflowContext[Any, Any] ) -> None: """Override to support handoff.""" incoming_messages = list(self._cache) @@ -469,7 +483,7 @@ async def _run_agent_and_emit( # Broadcast the initial cache to all other agents. Subsequent runs won't # need this since responses are broadcast after each agent run and user input. if self._is_start_agent and not self._full_conversation: - await self._broadcast_messages(cleaned_incoming_messages, cast(WorkflowContext[AgentExecutorRequest], ctx)) + await self._broadcast_messages(cleaned_incoming_messages, ctx) # Persist only cleaned chat history between turns to avoid replaying stale tool calls. self._full_conversation.extend(cleaned_incoming_messages) @@ -482,30 +496,27 @@ async def _run_agent_and_emit( # Handoff workflows are orchestrator-stateful and provider-stateless by design. # If an existing session still has a service conversation id, clear it to avoid # replaying stale unresolved tool calls across resumed turns. - if ( - cast(Agent, self._agent).default_options.get("store") is False - and self._session.service_session_id is not None - ): + if is_chat_agent(self._agent) and self._agent.default_options.get("store") is False and self._session.service_session_id is not None: self._session.service_session_id = None # Check termination condition before running the agent - if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)): + if await self._check_terminate_and_yield(ctx): return # Run the agent if ctx.is_streaming(): # Streaming mode: emit incremental updates - response = await self._run_agent_streaming(cast(WorkflowContext[Never, AgentResponseUpdate], ctx)) + response = await self._run_agent_streaming(ctx) else: # Non-streaming mode: use run() and emit single event - response = await self._run_agent(cast(WorkflowContext[Never, AgentResponse], ctx)) + response = await self._run_agent(ctx) # Clear the cache after running the agent self._cache.clear() # A function approval request is issued by the base AgentExecutor if response is None: - if cast(Agent, self._agent).default_options.get("store") is False: + if is_chat_agent(self._agent) and self._agent.default_options.get("store") is False: self._persist_pending_approval_function_calls() # Agent did not complete (e.g., waiting for user input); do not emit response logger.debug("AgentExecutor %s: Agent did not complete, awaiting user input", self.id) @@ -525,7 +536,7 @@ async def _run_agent_and_emit( ) # Broadcast only the cleaned response to other agents (without function_calls/results) - await self._broadcast_messages(cleaned_response, cast(WorkflowContext[AgentExecutorRequest], ctx)) + await self._broadcast_messages(cleaned_response, ctx) # Check if a handoff was requested if handoff_target := self._is_handoff_requested(response): @@ -535,7 +546,7 @@ async def _run_agent_and_emit( f"target '{handoff_target}'. Valid targets are: {', '.join(self._handoff_targets)}" ) - await cast(WorkflowContext[AgentExecutorRequest], ctx).send_message( + await ctx.send_message( AgentExecutorRequest(messages=[], should_respond=True), target_id=handoff_target, ) @@ -548,7 +559,7 @@ async def _run_agent_and_emit( # Re-evaluate termination after appending and broadcasting this response. # Without this check, workflows that become terminal due to the latest assistant # message would still emit request_info and require an unnecessary extra resume. - if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[Message]], ctx)): + if await self._check_terminate_and_yield(ctx): return # Handle case where no handoff was requested @@ -570,7 +581,7 @@ async def handle_response( self, original_request: HandoffAgentUserRequest, response: list[Message], - ctx: WorkflowContext[AgentExecutorResponse, AgentResponse], + ctx: WorkflowContext[Any, Any], ) -> None: """Handle user response for a request that is issued after agent runs. @@ -588,22 +599,20 @@ async def handle_response( If the response is empty, it indicates termination of the handoff workflow. """ if not response: - await cast(WorkflowContext[Never, list[Message]], ctx).yield_output(self._full_conversation) + await ctx.yield_output(self._full_conversation) return # Broadcast the user response to all other agents - await self._broadcast_messages(response, cast(WorkflowContext[AgentExecutorRequest], ctx)) + await self._broadcast_messages(response, ctx) # Append the user response messages to the cache self._cache.extend(response) - await self._run_agent_and_emit( - cast(WorkflowContext[AgentExecutorResponse, AgentResponse | AgentResponseUpdate], ctx) - ) + await self._run_agent_and_emit(ctx) async def _broadcast_messages( self, messages: list[Message], - ctx: WorkflowContext[AgentExecutorRequest], + ctx: WorkflowContext[Any, Any], ) -> None: """Broadcast the workflow cache to the agent before running.""" agent_executor_request = AgentExecutorRequest( @@ -628,15 +637,15 @@ def _is_handoff_requested(self, response: AgentResponse) -> str | None: if content.type == "function_result": payload = content.result parsed_payload: dict[str, Any] | None = None - if isinstance(payload, dict): - parsed_payload = payload + if isinstance(payload, Mapping): + parsed_payload = {key: value for key, value in payload.items() if isinstance(key, str)} # pyright: ignore[reportUnknownVariableType] elif isinstance(payload, str): try: maybe_payload = json.loads(payload) except json.JSONDecodeError: maybe_payload = None - if isinstance(maybe_payload, dict): - parsed_payload = maybe_payload + if isinstance(maybe_payload, Mapping): + parsed_payload = {key: value for key, value in maybe_payload.items() if isinstance(key, str)} # pyright: ignore[reportUnknownVariableType] if parsed_payload: handoff_target = parsed_payload.get(HANDOFF_FUNCTION_RESULT_KEY) @@ -647,7 +656,7 @@ def _is_handoff_requested(self, response: AgentResponse) -> str | None: return None - async def _check_terminate_and_yield(self, ctx: WorkflowContext[Never, list[Message]]) -> bool: + async def _check_terminate_and_yield(self, ctx: WorkflowContext[Any, Any]) -> bool: """Check termination conditions and yield completion if met. Args: diff --git a/python/packages/orchestrations/pyproject.toml b/python/packages/orchestrations/pyproject.toml index c670842715..52ac5424fb 100644 --- a/python/packages/orchestrations/pyproject.toml +++ b/python/packages/orchestrations/pyproject.toml @@ -58,6 +58,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_orchestrations"] exclude = ['tests'] [tool.mypy] diff --git a/python/packages/purview/agent_framework_purview/_client.py b/python/packages/purview/agent_framework_purview/_client.py index a1f404849b..7dedf465ee 100644 --- a/python/packages/purview/agent_framework_purview/_client.py +++ b/python/packages/purview/agent_framework_purview/_client.py @@ -6,7 +6,7 @@ import inspect import json import logging -from typing import Any, cast +from typing import Any, Literal, TypeVar, overload from uuid import uuid4 import httpx @@ -36,6 +36,8 @@ logger = logging.getLogger("agent_framework.purview") +ResponseT = TypeVar("ResponseT") + class PurviewClient: """Async client for calling Graph Purview endpoints. @@ -98,7 +100,7 @@ async def process_content(self, request: ProcessContentRequest) -> ProcessConten with get_tracer().start_as_current_span("purview.process_content"): token = await self._get_token(tenant_id=request.tenant_id) url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/processContent" - headers = {} + headers: dict[str, str] = {} # Add If-None-Match header if scope_identifier is present if hasattr(request, "scope_identifier") and request.scope_identifier: headers["If-None-Match"] = request.scope_identifier @@ -106,21 +108,23 @@ async def process_content(self, request: ProcessContentRequest) -> ProcessConten if hasattr(request, "process_inline") and request.process_inline: headers["Prefer"] = "evaluateInline" - response = await self._post( + response: ProcessContentResponse | tuple[ProcessContentResponse, httpx.Headers] = await self._post( url, request, ProcessContentResponse, token, headers=headers, return_response=True ) if isinstance(response, tuple) and len(response) == 2: response_obj, _ = response - return cast(ProcessContentResponse, response_obj) + return response_obj - return cast(ProcessContentResponse, response) + return response async def get_protection_scopes(self, request: ProtectionScopesRequest) -> ProtectionScopesResponse: with get_tracer().start_as_current_span("purview.get_protection_scopes"): token = await self._get_token() url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/protectionScopes/compute" - response = await self._post(url, request, ProtectionScopesResponse, token, return_response=True) + response: ProtectionScopesResponse | tuple[ProtectionScopesResponse, httpx.Headers] = await self._post( + url, request, ProtectionScopesResponse, token, return_response=True + ) # Extract etag from response headers if isinstance(response, tuple) and len(response) == 2: @@ -128,25 +132,48 @@ async def get_protection_scopes(self, request: ProtectionScopesRequest) -> Prote if "etag" in headers: etag_value = headers["etag"].strip('"') response_obj.scope_identifier = etag_value - return cast(ProtectionScopesResponse, response_obj) + return response_obj - return cast(ProtectionScopesResponse, response) + return response async def send_content_activities(self, request: ContentActivitiesRequest) -> ContentActivitiesResponse: with get_tracer().start_as_current_span("purview.send_content_activities"): token = await self._get_token() url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/activities/contentActivities" - return cast(ContentActivitiesResponse, await self._post(url, request, ContentActivitiesResponse, token)) + return await self._post(url, request, ContentActivitiesResponse, token) + + + @overload + async def _post( + self, + url: str, + model: Any, + response_type: type[ResponseT], + token: str, + headers: dict[str, str] | None = None, + return_response: Literal[False] = False, + ) -> ResponseT: ... + + @overload + async def _post( + self, + url: str, + model: Any, + response_type: type[ResponseT], + token: str, + headers: dict[str, str] | None = None, + return_response: Literal[True] = True, + ) -> tuple[ResponseT, httpx.Headers]: ... async def _post( self, url: str, model: Any, - response_type: type[Any], + response_type: type[ResponseT], token: str, headers: dict[str, str] | None = None, return_response: bool = False, - ) -> Any: + ) -> ResponseT | tuple[ResponseT, httpx.Headers]: if hasattr(model, "correlation_id") and not model.correlation_id: model.correlation_id = str(uuid4()) @@ -174,7 +201,7 @@ async def _post( raise PurviewAuthenticationError(f"Auth failure {resp.status_code}: {resp.text}") if resp.status_code == 402: if self._settings.get("ignore_payment_required", False): - return response_type() # type: ignore[call-arg, no-any-return] + return response_type() # type: ignore[call-arg] raise PurviewPaymentRequiredError(f"Payment required {resp.status_code}: {resp.text}") if resp.status_code == 429: raise PurviewRateLimitError(f"Rate limited {resp.status_code}: {resp.text}") @@ -187,18 +214,21 @@ async def _post( try: # Prefer pydantic-style model_validate if present, else fall back to constructor. - if hasattr(response_type, "model_validate"): - response_obj = response_type.model_validate(data) # type: ignore[no-any-return] + model_validate = getattr(response_type, "model_validate", None) + if callable(model_validate): + response_obj = model_validate(data) else: - response_obj = response_type(**data) # type: ignore[call-arg, no-any-return] + response_obj = response_type(**data) # type: ignore[call-arg] # Extract correlation_id from response headers if response object supports it if "client-request-id" in resp.headers and hasattr(response_obj, "correlation_id"): - response_obj.correlation_id = resp.headers["client-request-id"] - logger.info(f"Purview response from {url} with correlation_id: {response_obj.correlation_id}") + response_correlation_id = resp.headers["client-request-id"] + setattr(response_obj, "correlation_id", response_correlation_id) + logger.info(f"Purview response from {url} with correlation_id: {response_correlation_id}") + typed_response_obj = response_obj if isinstance(response_obj, response_type) else response_type(**data) if return_response: - return (response_obj, resp.headers) - return response_obj + return (typed_response_obj, resp.headers) + return typed_response_obj except Exception as ex: raise PurviewServiceError(f"Failed to deserialize Purview response: {ex}") from ex diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py index 55619d0a39..c0e89a04a5 100644 --- a/python/packages/purview/agent_framework_purview/_middleware.py +++ b/python/packages/purview/agent_framework_purview/_middleware.py @@ -67,6 +67,7 @@ async def process( call_next: Callable[[], Awaitable[None]], ) -> None: # type: ignore[override] resolved_user_id: str | None = None + session_id: str | None = None try: # Pre (prompt) check session_id = self._get_agent_session_id(context) @@ -107,7 +108,7 @@ async def process( should_block_response, _ = await self._processor.process_messages( context.result.messages, # type: ignore[union-attr] Activity.DOWNLOAD_TEXT, - session_id=session_id, + session_id=session_id_response, user_id=resolved_user_id, ) if should_block_response: @@ -173,6 +174,7 @@ async def process( call_next: Callable[[], Awaitable[None]], ) -> None: # type: ignore[override] resolved_user_id: str | None = None + session_id: str | None = None try: session_id = context.options.get("conversation_id") if context.options else None should_block_prompt, resolved_user_id = await self._processor.process_messages( diff --git a/python/packages/purview/agent_framework_purview/_models.py b/python/packages/purview/agent_framework_purview/_models.py index ad6cc5b331..503871deef 100644 --- a/python/packages/purview/agent_framework_purview/_models.py +++ b/python/packages/purview/agent_framework_purview/_models.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from collections.abc import Mapping, MutableMapping, Sequence +from collections.abc import Iterable, Mapping, MutableMapping, Sequence from datetime import datetime from enum import Enum, Flag, auto from typing import Any, ClassVar, TypeVar, cast @@ -60,6 +60,23 @@ def __int__(self) -> int: # pragma: no cover ] +def _as_object_list(value: object) -> list[object] | None: + if not isinstance(value, (list, tuple, set)): + return None + return list(cast(Iterable[object], value)) + + +def _as_str_dict(value: object) -> dict[str, str]: + if not isinstance(value, dict): + return {} + + aliases: dict[str, str] = {} + for raw_key, raw_value in cast(dict[object, object], value).items(): + if isinstance(raw_key, str) and isinstance(raw_value, str): + aliases[raw_key] = raw_value + return aliases + + def deserialize_flag( value: object, mapping: Mapping[str, FlagT], enum_cls: type[FlagT] ) -> FlagT | None: # pragma: no cover @@ -82,8 +99,11 @@ def deserialize_flag( if not raw: return enum_cls(0) parts.extend([p.strip() for p in raw.split(",") if p.strip()]) - elif isinstance(value, (list, tuple, set)): - for item in value: + else: + iterable_items = _as_object_list(value) + if iterable_items is None: + return None + for item in iterable_items: if isinstance(item, str): parts.extend([p.strip() for p in item.split(",") if p.strip()]) elif isinstance(item, enum_cls): @@ -93,8 +113,6 @@ def deserialize_flag( flag_value |= enum_cls(item) except Exception: logger.warning(f"Failed to convert int {item} to {enum_cls.__name__}") - else: - return None for part in parts: member = mapping.get(part) @@ -196,10 +214,10 @@ def __init__(self, **kwargs: Any) -> None: # Collect all aliases from parent classes too all_aliases: dict[str, str] = {} for cls in type(self).__mro__: - if hasattr(cls, "_ALIASES") and isinstance(cls._ALIASES, dict): - for internal, external in cls._ALIASES.items(): - if external not in all_aliases: - all_aliases[external] = internal + aliases_obj = _as_str_dict(getattr(cls, "_ALIASES", None)) + for internal, external in aliases_obj.items(): + if external not in all_aliases: + all_aliases[external] = internal # Normalize all aliased keys in kwargs for external, internal in all_aliases.items(): @@ -248,11 +266,11 @@ def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) # Collect all aliases from class hierarchy all_aliases: dict[str, str] = {} for cls in type(self).__mro__: - if hasattr(cls, "_ALIASES") and isinstance(cls._ALIASES, dict): - # Parent aliases first (will be overridden by child if same key) - for internal, external in cls._ALIASES.items(): - if internal not in all_aliases: - all_aliases[internal] = external + aliases_obj = _as_str_dict(getattr(cls, "_ALIASES", None)) + # Parent aliases first (will be overridden by child if same key) + for internal, external in aliases_obj.items(): + if internal not in all_aliases: + all_aliases[internal] = external if not all_aliases: return base @@ -836,17 +854,15 @@ def __init__( # Convert to objects converted_policy_actions: list[DlpActionInfo] | None = None if policy_actions is not None: - converted_policy_actions = cast( - list[DlpActionInfo], - [p if isinstance(p, DlpActionInfo) else DlpActionInfo(**p) for p in policy_actions], - ) + converted_policy_actions = [ + p if isinstance(p, DlpActionInfo) else DlpActionInfo(**p) for p in policy_actions + ] converted_processing_errors: list[ProcessingError] | None = None if processing_errors is not None: - converted_processing_errors = cast( - list[ProcessingError], - [pe if isinstance(pe, ProcessingError) else ProcessingError(**pe) for pe in processing_errors], - ) + converted_processing_errors = [ + pe if isinstance(pe, ProcessingError) else ProcessingError(**pe) for pe in processing_errors + ] super().__init__(**kwargs) self.id = id @@ -885,17 +901,15 @@ def __init__( # Convert nested objects converted_locations: list[PolicyLocation] | None = None if locations is not None: - converted_locations = cast( - list[PolicyLocation], - [loc if isinstance(loc, PolicyLocation) else PolicyLocation(**loc) for loc in locations], - ) + converted_locations = [ + loc if isinstance(loc, PolicyLocation) else PolicyLocation(**loc) for loc in locations + ] converted_policy_actions: list[DlpActionInfo] | None = None if policy_actions is not None: - converted_policy_actions = cast( - list[DlpActionInfo], - [p if isinstance(p, DlpActionInfo) else DlpActionInfo(**p) for p in policy_actions], - ) + converted_policy_actions = [ + p if isinstance(p, DlpActionInfo) else DlpActionInfo(**p) for p in policy_actions + ] # Call parent without explicit params with aliases super().__init__(**kwargs) @@ -947,9 +961,7 @@ def __init__( converted_scopes: list[PolicyScope] | None = None if scopes is not None: - converted_scopes = cast( - list[PolicyScope], [s if isinstance(s, PolicyScope) else PolicyScope(**s) for s in scopes] - ) + converted_scopes = [s if isinstance(s, PolicyScope) else PolicyScope(**s) for s in scopes] # Don't pass parameters that have aliases - let parent normalize them super().__init__(**kwargs) diff --git a/python/packages/purview/agent_framework_purview/_processor.py b/python/packages/purview/agent_framework_purview/_processor.py index e911fae7a5..c81785fbad 100644 --- a/python/packages/purview/agent_framework_purview/_processor.py +++ b/python/packages/purview/agent_framework_purview/_processor.py @@ -177,13 +177,14 @@ async def _map_messages( else: raise ValueError("App location not provided or inferable") + app_name = self._settings.get("app_name") or "Unknown" protected_app = ProtectedAppMetadata( - name=self._settings["app_name"], + name=app_name, version=self._settings.get("app_version", "Unknown"), application_location=policy_location, ) integrated_app = IntegratedAppMetadata( - name=self._settings["app_name"], version=self._settings.get("app_version", "Unknown") + name=app_name, version=self._settings.get("app_version", "Unknown") ) device_meta = DeviceMetadata( operating_system_specifications=OperatingSystemSpecifications( @@ -234,9 +235,9 @@ async def _process_with_scopes(self, pc_request: ProcessContentRequest) -> Proce if cached_ps_resp is not None and isinstance(cached_ps_resp, ProtectionScopesResponse): ps_resp = cached_ps_resp else: + ttl = self._settings.get("cache_ttl_seconds") + ttl_seconds = ttl if ttl is not None else 14400 try: - ttl = self._settings.get("cache_ttl_seconds") - ttl_seconds = ttl if ttl is not None else 14400 ps_resp = await self._client.get_protection_scopes(ps_req) await self._cache.set(cache_key, ps_resp, ttl_seconds=ttl_seconds) except PurviewPaymentRequiredError as ex: diff --git a/python/packages/purview/pyproject.toml b/python/packages/purview/pyproject.toml index aed447580a..cb7819a36e 100644 --- a/python/packages/purview/pyproject.toml +++ b/python/packages/purview/pyproject.toml @@ -60,6 +60,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_purview"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/redis/agent_framework_redis/_context_provider.py b/python/packages/redis/agent_framework_redis/_context_provider.py index 75886d25c3..82fad0b8b4 100644 --- a/python/packages/redis/agent_framework_redis/_context_provider.py +++ b/python/packages/redis/agent_framework_redis/_context_provider.py @@ -12,7 +12,7 @@ import sys from functools import reduce from operator import and_ -from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal import numpy as np from agent_framework import Message @@ -107,9 +107,10 @@ def __init__( self._token_escaper: TokenEscaper = TokenEscaper() self._index_initialized: bool = False self._schema_dict: dict[str, Any] | None = None - self.redis_index = redis_index or AsyncSearchIndex.from_dict( + index = redis_index or AsyncSearchIndex.from_dict( # pyright: ignore[reportUnknownMemberType] self.schema_dict, redis_url=self.redis_url, validate_on_load=True ) + self.redis_index: Any = index # -- Hooks pattern --------------------------------------------------------- @@ -189,7 +190,7 @@ def schema_dict(self) -> dict[str, Any]: def _build_filter_from_dict(self, filters: dict[str, str | None]) -> Any | None: """Builds a combined filter expression from simple equality tags.""" - parts = [Tag(k) == v for k, v in filters.items() if v] + parts: list[FilterExpression] = [Tag(k) == v for k, v in filters.items() if v] return reduce(and_, parts) if parts else None def _build_schema_dict( @@ -278,7 +279,9 @@ def _schema_signature(schema: dict[str, Any]) -> dict[str, Any]: sig["fields"][name] = {"type": ftype} return sig - existing_index = await AsyncSearchIndex.from_existing(self.index_name, redis_url=self.redis_url) + existing_index: Any = await AsyncSearchIndex.from_existing( # pyright: ignore[reportUnknownMemberType] + self.index_name, redis_url=self.redis_url + ) existing_schema = existing_index.schema.to_dict() current_schema = self.schema_dict existing_sig = _schema_signature(existing_schema) @@ -319,7 +322,9 @@ async def _add( if self.redis_vectorizer and self.vector_field_name: text_list = [d["content"] for d in prepared] - embeddings = await self.redis_vectorizer.aembed_many(text_list, batch_size=len(text_list)) + embeddings = await self.redis_vectorizer.aembed_many( # pyright: ignore[reportUnknownMemberType] + text_list, batch_size=len(text_list) + ) for i, d in enumerate(prepared): vec = np.asarray(embeddings[i], dtype=np.float32).tobytes() field_name: str = self.vector_field_name @@ -365,7 +370,7 @@ async def _redis_search( try: if self.redis_vectorizer and self.vector_field_name: - vector = await self.redis_vectorizer.aembed(q) + vector = await self.redis_vectorizer.aembed(q) # pyright: ignore[reportUnknownMemberType] query = HybridQuery( text=q, text_field_name="content", @@ -374,13 +379,13 @@ async def _redis_search( text_scorer=text_scorer, filter_expression=combined_filter, linear_alpha=linear_alpha, - dtype=self.redis_vectorizer.dtype, + dtype=self.redis_vectorizer.dtype, # pyright: ignore[reportUnknownMemberType] num_results=num_results, return_fields=return_fields, stopwords=None, ) hybrid_results = await self.redis_index.query(query) - return cast(list[dict[str, Any]], hybrid_results) + return hybrid_results # type: ignore[no-any-return] query = TextQuery( text=q, text_field_name="content", @@ -391,7 +396,7 @@ async def _redis_search( stopwords=None, ) text_results = await self.redis_index.query(query) - return cast(list[dict[str, Any]], text_results) + return text_results # type: ignore[no-any-return] except Exception as exc: # pragma: no cover raise IntegrationInvalidRequestException(f"Redis text search failed: {exc}") from exc diff --git a/python/packages/redis/agent_framework_redis/_history_provider.py b/python/packages/redis/agent_framework_redis/_history_provider.py index 7f246c885b..187af5a235 100644 --- a/python/packages/redis/agent_framework_redis/_history_provider.py +++ b/python/packages/redis/agent_framework_redis/_history_provider.py @@ -118,7 +118,7 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess List of stored Message objects in chronological order. """ key = self._redis_key(session_id) - redis_messages = await self._redis_client.lrange(key, 0, -1) # type: ignore[misc] + redis_messages: list[str] = await self._redis_client.lrange(key, 0, -1) # type: ignore[misc] messages: list[Message] = [] if redis_messages: for serialized in redis_messages: diff --git a/python/packages/redis/pyproject.toml b/python/packages/redis/pyproject.toml index 76b84ad600..c42b050115 100644 --- a/python/packages/redis/pyproject.toml +++ b/python/packages/redis/pyproject.toml @@ -63,6 +63,7 @@ omit = [ [tool.pyright] extends = "../../pyproject.toml" +include = ["agent_framework_redis"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/pyproject.toml b/python/pyproject.toml index b8588b7b9d..10597ed68b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -183,7 +183,7 @@ omit = [ ] [tool.pyright] -include = ["agent_framework*"] +exclude = ["**/tests/**", "**/.venv/**", "packages/devui/frontend/**"] typeCheckingMode = "strict" reportUnnecessaryIsInstance = false reportMissingTypeStubs = false From f076a155bb32066289dde428e17c0cff89e2b104 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 3 Mar 2026 18:17:35 +0100 Subject: [PATCH 02/15] Reduce pyright cost in handoff cloning Simplify cloned_options construction in HandoffAgentExecutor to avoid expensive TypedDict narrowing/inference in _handoff.py, which was causing pyright to spend a long time in orchestrations. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../_handoff.py | 51 +++++++------------ 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index e117b30aa9..a9075fe77c 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -37,7 +37,7 @@ from dataclasses import dataclass from typing import Any -from agent_framework import Agent, ChatOptions, SupportsAgentRun +from agent_framework import Agent, SupportsAgentRun from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware from agent_framework._sessions import AgentSession from agent_framework._tools import FunctionTool, tool @@ -364,43 +364,28 @@ def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]: metadata = options.get("metadata") # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. - cloned_options: ChatOptions[None] = { + cloned_options: dict[str, Any] = { "allow_multiple_tool_calls": False, # Handoff workflows already manage full conversation context explicitly # across executors. Keep provider-side conversation storage disabled to # avoid stale tool-call state (Responses API previous_response chains). "store": False, } - if (frequency_penalty := options.get("frequency_penalty")) is not None: - cloned_options["frequency_penalty"] = frequency_penalty - if (instructions := options.get("instructions")) is not None: - cloned_options["instructions"] = instructions - if logit_bias: - cloned_options["logit_bias"] = dict(logit_bias) - if (max_tokens := options.get("max_tokens")) is not None: - cloned_options["max_tokens"] = max_tokens - if metadata: - cloned_options["metadata"] = dict(metadata) - if (model_id := options.get("model_id")) is not None: - cloned_options["model_id"] = model_id - if (presence_penalty := options.get("presence_penalty")) is not None: - cloned_options["presence_penalty"] = presence_penalty - if (response_format := options.get("response_format")) is not None: - cloned_options["response_format"] = response_format - if (seed := options.get("seed")) is not None: - cloned_options["seed"] = seed - if (stop := options.get("stop")) is not None: - cloned_options["stop"] = stop - if (temperature := options.get("temperature")) is not None: - cloned_options["temperature"] = temperature - if (tool_choice := options.get("tool_choice")) is not None: - cloned_options["tool_choice"] = tool_choice - if all_tools: - cloned_options["tools"] = all_tools - if (top_p := options.get("top_p")) is not None: - cloned_options["top_p"] = top_p - if (user := options.get("user")) is not None: - cloned_options["user"] = user + cloned_options["frequency_penalty"] = options.get("frequency_penalty") + cloned_options["instructions"] = options.get("instructions") + cloned_options["logit_bias"] = dict(logit_bias) if logit_bias else None + cloned_options["max_tokens"] = options.get("max_tokens") + cloned_options["metadata"] = dict(metadata) if metadata else None + cloned_options["model_id"] = options.get("model_id") + cloned_options["presence_penalty"] = options.get("presence_penalty") + cloned_options["response_format"] = options.get("response_format") + cloned_options["seed"] = options.get("seed") + cloned_options["stop"] = options.get("stop") + cloned_options["temperature"] = options.get("temperature") + cloned_options["tool_choice"] = options.get("tool_choice") + cloned_options["tools"] = all_tools if all_tools else None + cloned_options["top_p"] = options.get("top_p") + cloned_options["user"] = options.get("user") return Agent( client=agent.client, @@ -409,7 +394,7 @@ def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]: description=agent.description, context_providers=agent.context_providers, middleware=middleware, - default_options=cloned_options, + default_options=cloned_options, # type: ignore[arg-type] ) def _apply_auto_tools(self, agent: Agent, targets: Sequence[HandoffConfiguration]) -> None: From 3b4b7e58cd3cc68149f88a6f5725e6b1f2f1fd77 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 3 Mar 2026 20:27:41 +0100 Subject: [PATCH 03/15] fix types --- .../openai/_assistants_client.py | 36 ++++++++----------- .../openai/_responses_client.py | 25 +++++++------ .../_history_provider.py | 4 +-- 3 files changed, 28 insertions(+), 37 deletions(-) diff --git a/python/packages/core/agent_framework/openai/_assistants_client.py b/python/packages/core/agent_framework/openai/_assistants_client.py index b90935a33f..b1d5e8795c 100644 --- a/python/packages/core/agent_framework/openai/_assistants_client.py +++ b/python/packages/core/agent_framework/openai/_assistants_client.py @@ -649,17 +649,13 @@ async def _process_stream_events(self, stream: Any, thread_id: str) -> AsyncIter and completed_annotation.file_citation.file_id ): ann["file_id"] = completed_annotation.file_citation.file_id - if ( - completed_annotation.start_index is not None - and completed_annotation.end_index is not None - ): - ann["annotated_regions"] = [ - TextSpanRegion( - type="text_span", - start_index=completed_annotation.start_index, - end_index=completed_annotation.end_index, - ) - ] + ann["annotated_regions"] = [ + TextSpanRegion( + type="text_span", + start_index=completed_annotation.start_index, + end_index=completed_annotation.end_index, + ) + ] text_content.annotations.append(ann) elif isinstance(completed_annotation, FilePathAnnotation): ann = Annotation( @@ -671,17 +667,13 @@ async def _process_stream_events(self, stream: Any, thread_id: str) -> AsyncIter ) if completed_annotation.file_path and completed_annotation.file_path.file_id: ann["file_id"] = completed_annotation.file_path.file_id - if ( - completed_annotation.start_index is not None - and completed_annotation.end_index is not None - ): - ann["annotated_regions"] = [ - TextSpanRegion( - type="text_span", - start_index=completed_annotation.start_index, - end_index=completed_annotation.end_index, - ) - ] + ann["annotated_regions"] = [ + TextSpanRegion( + type="text_span", + start_index=completed_annotation.start_index, + end_index=completed_annotation.end_index, + ) + ] text_content.annotations.append(ann) else: logger.debug("Unparsed annotation type: %s", completed_annotation.type) diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index b4fb1cbe1c..726616adbb 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -460,14 +460,13 @@ def _prepare_tools_for_openai( for tool_item in tools_list: if isinstance(tool_item, FunctionTool) and tool_item.kind == SHELL_TOOL_KIND_VALUE: shell_env = (tool_item.additional_properties or {}).get(OPENAI_SHELL_ENVIRONMENT_KEY) - if isinstance(shell_env, Mapping): - response_tools.append( - FunctionShellTool( - type="shell", - environment=dict(shell_env), - ) + response_tools.append( + FunctionShellTool( + type="shell", + environment=shell_env, # type: ignore[typeddict-item] ) - continue + ) + continue if isinstance(tool_item, FunctionTool): params = tool_item.parameters() params["additionalProperties"] = False @@ -496,7 +495,7 @@ def _get_local_shell_tool_name( if tool_item.kind != SHELL_TOOL_KIND_VALUE: continue shell_env = (tool_item.additional_properties or {}).get(OPENAI_SHELL_ENVIRONMENT_KEY) - if isinstance(shell_env, Mapping) and shell_env.get("type") == "local": + if isinstance(shell_env, Mapping) and shell_env.get("type") == "local": # type: ignore[typeddict-item] return tool_item.name return None @@ -714,7 +713,7 @@ def get_shell_tool( ) if env_config.get("type") == "local": raise ValueError("Local shell requires func. Provide func for local execution.") - return FunctionShellTool(type="shell", environment=env_config) + return FunctionShellTool(type="shell", environment=env_config) # type: ignore[typeddict-item] if isinstance(environment, dict): raise ValueError("When func is provided, environment config is not supported.") @@ -1226,7 +1225,7 @@ def _to_local_shell_output_payload(content: Content) -> str: """Convert function tool output to the local shell JSON payload format.""" payload: dict[str, Any] if isinstance(content.result, Mapping): - payload = dict(content.result) + payload = dict(content.result) # type: ignore[assignment] else: payload = { "stdout": "" if content.result is None else str(content.result), @@ -1242,7 +1241,7 @@ def _to_shell_call_output_payload(content: Content) -> list[dict[str, Any]]: """Convert function tool output to shell_call_output payload format.""" payload: dict[str, Any] if isinstance(content.result, Mapping): - payload = dict(content.result) + payload = dict(content.result) # type: ignore[assignment] else: payload = { "stdout": "" if content.result is None else str(content.result), @@ -1252,8 +1251,8 @@ def _to_shell_call_output_payload(content: Content) -> list[dict[str, Any]]: # Pass through native payload shape when tool already returns shell output entries. direct_output = payload.get("output") - if isinstance(direct_output, list) and all(isinstance(item, Mapping) for item in direct_output): - return [dict(item) for item in direct_output] + if isinstance(direct_output, list) and all(isinstance(item, Mapping) for item in direct_output): # type: ignore[reportUnknownMemberType] + return [dict(item) for item in direct_output] # type: ignore[reportUnknownMemberType] stdout = str(payload.get("stdout", "")) stderr = str(payload.get("stderr", "")) diff --git a/python/packages/redis/agent_framework_redis/_history_provider.py b/python/packages/redis/agent_framework_redis/_history_provider.py index 187af5a235..e1a20b6218 100644 --- a/python/packages/redis/agent_framework_redis/_history_provider.py +++ b/python/packages/redis/agent_framework_redis/_history_provider.py @@ -121,8 +121,8 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess redis_messages: list[str] = await self._redis_client.lrange(key, 0, -1) # type: ignore[misc] messages: list[Message] = [] if redis_messages: - for serialized in redis_messages: - messages.append(Message.from_dict(self._deserialize_json(serialized))) + for serialized in redis_messages: # type: ignore[union-attr] + messages.append(Message.from_dict(self._deserialize_json(serialized))) # type: ignore[union-attr] return messages async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None: From b630e8d2b0c2447ca8d7dde85d4194c7fdd447a0 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 3 Mar 2026 21:02:03 +0100 Subject: [PATCH 04/15] Fix lint and type-check regressions Resolve current Python package check failures across lint, pyright, and mypy after recent code changes, including purview/declarative pyright issues and multiple ruff simplification findings. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../a2a/agent_framework_a2a/_agent.py | 2 +- .../agent_framework_anthropic/_chat_client.py | 6 +-- .../agent_framework_azure_ai/_chat_client.py | 5 +- .../_project_provider.py | 2 +- .../_history_provider.py | 25 +++------- .../samples/cosmos_history_provider.py | 4 +- .../tests/test_cosmos_history_provider.py | 12 +++-- .../agent_framework_azurefunctions/_app.py | 6 +-- .../agent_framework_bedrock/_chat_client.py | 2 +- .../_embedding_client.py | 6 +-- .../core/agent_framework/_middleware.py | 10 ++-- .../core/agent_framework/_sessions.py | 2 +- .../packages/core/agent_framework/_tools.py | 22 +++++---- .../packages/core/agent_framework/_types.py | 48 ++++++++++++------- .../core/agent_framework/observability.py | 23 ++++----- .../openai/_embedding_client.py | 4 +- .../_workflows/_executors_tools.py | 4 +- .../_workflows/_powerfx_functions.py | 11 +++-- .../devui/agent_framework_devui/__init__.py | 2 +- .../agent_framework_devui/_deployment.py | 2 +- .../devui/agent_framework_devui/_discovery.py | 8 +++- .../devui/agent_framework_devui/_executor.py | 4 +- .../devui/agent_framework_devui/_server.py | 10 ++-- .../devui/agent_framework_devui/_session.py | 22 ++++----- .../devui/agent_framework_devui/_utils.py | 1 + .../_foundry_local_client.py | 3 +- .../agent_framework_github_copilot/_agent.py | 2 +- .../lab/gaia/agent_framework_lab_gaia/gaia.py | 4 +- .../agent_framework_lab_tau2/_tau2_utils.py | 3 +- .../tau2/agent_framework_lab_tau2/runner.py | 1 + .../_handoff.py | 12 +++-- .../agent_framework_purview/_client.py | 8 +--- .../agent_framework_purview/_processor.py | 4 +- .../_context_provider.py | 6 +-- 34 files changed, 142 insertions(+), 144 deletions(-) diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index e6b0f49a14..fc6ebc894f 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -19,9 +19,9 @@ FileWithBytes, FileWithUri, Task, + TaskArtifactUpdateEvent, TaskIdParams, TaskQueryParams, - TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent, TextPart, diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index cbe7c51c28..14544c071b 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -803,9 +803,9 @@ def _prepare_tools_for_anthropic(self, options: Mapping[str, Any]) -> dict[str, "allowed_tools": [str(item) for item in allowed_tools] # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] } headers = tool_data.get("headers") - if isinstance(headers, Mapping): - if isinstance(auth := headers.get("authorization"), str): # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - server_def["authorization_token"] = auth + authorization = headers.get("authorization") if isinstance(headers, Mapping) else None # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + if isinstance(authorization, str): + server_def["authorization_token"] = authorization mcp_server_list.append(server_def) else: # Pass through all other tools (dicts, SDK types) unchanged diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index ffad93eaa5..331b645ab6 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -1013,7 +1013,10 @@ async def _process_stream( logger.debug(f"Code Interpreter Input: {code_interpreter.input}") if code_interpreter.outputs is not None: for output in code_interpreter.outputs: - if isinstance(output, RunStepDeltaCodeInterpreterLogOutput) and output.logs: + if ( + isinstance(output, RunStepDeltaCodeInterpreterLogOutput) + and output.logs + ): code_contents.append(Content.from_text(text=output.logs)) if ( isinstance(output, RunStepDeltaCodeInterpreterImageOutput) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py index 8274ab473f..335a7f16ec 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_project_provider.py @@ -223,7 +223,7 @@ async def create_agent( for tool in normalized_tools: if isinstance(tool, MCPTool): mcp_tools.append(tool) - elif isinstance(tool, FunctionTool) or isinstance(tool, MutableMapping): + elif isinstance(tool, (FunctionTool, MutableMapping)): non_mcp_tools.append(tool) # type: ignore[reportUnknownArgumentType] # Connect MCP tools and discover their functions BEFORE creating the agent diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py index 3b27823332..84c1efac52 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py @@ -8,7 +8,7 @@ import time import uuid from collections.abc import Sequence -from typing import Any, ClassVar, TypeGuard, TypedDict, cast +from typing import Any, ClassVar, TypedDict, TypeGuard, cast from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message from agent_framework._sessions import BaseHistoryProvider @@ -25,11 +25,7 @@ def _is_str_key_dict(value: object) -> TypeGuard[dict[str, Any]]: return False candidate_dict = cast(dict[object, Any], value) - for key_obj in candidate_dict: - if not isinstance(key_obj, str): - return False - - return True + return all(isinstance(key_obj, str) for key_obj in candidate_dict) class AzureCosmosHistorySettings(TypedDict, total=False): @@ -136,7 +132,6 @@ def __init__( self._database_client = self._cosmos_client.get_database_client(self.database_name) - async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: """Retrieve stored messages for this session from Azure Cosmos DB.""" await self._ensure_container_proxy() @@ -217,12 +212,8 @@ async def clear(self, session_id: str | None) -> None: async def list_sessions(self) -> list[str]: """List all session IDs stored in this provider's Cosmos container.""" await self._ensure_container_proxy() - query = ( - "SELECT DISTINCT VALUE c.session_id FROM c WHERE c.source_id = @source_id" - ) - parameters: list[dict[str, object]] = [ - {"name": "@source_id", "value": self.source_id} - ] + query = "SELECT DISTINCT VALUE c.session_id FROM c WHERE c.source_id = @source_id" + parameters: list[dict[str, object]] = [{"name": "@source_id", "value": self.source_id}] # without a partition key, it is automatically a cross-partition query items = self._container_proxy.query_items(query=query, parameters=parameters) # type: ignore[union-attr] @@ -261,11 +252,9 @@ async def _ensure_container_proxy(self) -> None: if self._database_client is None: raise RuntimeError("Cosmos database client is not initialized.") - self._container_proxy = ( - await self._database_client.create_container_if_not_exists( - id=self.container_name, - partition_key=PartitionKey(path="/session_id"), - ) + self._container_proxy = await self._database_client.create_container_if_not_exists( + id=self.container_name, + partition_key=PartitionKey(path="/session_id"), ) @staticmethod diff --git a/python/packages/azure-cosmos/samples/cosmos_history_provider.py b/python/packages/azure-cosmos/samples/cosmos_history_provider.py index ea476f9837..ff6138c1e5 100644 --- a/python/packages/azure-cosmos/samples/cosmos_history_provider.py +++ b/python/packages/azure-cosmos/samples/cosmos_history_provider.py @@ -5,10 +5,11 @@ import os from agent_framework.azure import AzureOpenAIResponsesClient -from agent_framework_azure_cosmos import CosmosHistoryProvider from azure.identity.aio import AzureCliCredential from dotenv import load_dotenv +from agent_framework_azure_cosmos import CosmosHistoryProvider + # Load environment variables from .env file. load_dotenv() @@ -31,7 +32,6 @@ """ - async def main() -> None: """Run the Cosmos history provider sample with an Agent.""" project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") diff --git a/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py b/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py index 33d7bf2414..e3ac636aa6 100644 --- a/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py +++ b/python/packages/azure-cosmos/tests/test_cosmos_history_provider.py @@ -9,15 +9,16 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock, patch -import agent_framework_azure_cosmos._history_provider as history_provider_module import pytest from agent_framework import AgentResponse, Message from agent_framework._sessions import AgentSession, SessionContext from agent_framework.exceptions import SettingNotFoundError -from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosResourceNotFoundError +import agent_framework_azure_cosmos._history_provider as history_provider_module +from agent_framework_azure_cosmos._history_provider import CosmosHistoryProvider + skip_if_cosmos_integration_tests_disabled = pytest.mark.skipif( any( os.getenv(name, "") == "" @@ -357,9 +358,10 @@ async def test_async_context_manager_closes_owned_client( async def test_async_context_manager_preserves_original_exception(self, mock_container: MagicMock) -> None: provider = CosmosHistoryProvider(source_id="mem", container_client=mock_container) - with patch.object( - provider, "close", AsyncMock(side_effect=RuntimeError("close failed")) - ), pytest.raises(ValueError, match="inner error"): + with ( + patch.object(provider, "close", AsyncMock(side_effect=RuntimeError("close failed"))), + pytest.raises(ValueError, match="inner error"), + ): async with provider: raise ValueError("inner error") diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 3cb8d688af..401fda3ca7 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -283,11 +283,7 @@ def executor_activity(inputData: str) -> str: shared_state_raw = data.get("shared_state_snapshot", {}) source_executor_ids_raw = data.get("source_executor_ids", [SOURCE_ORCHESTRATOR]) - shared_state_snapshot: dict[str, Any] - if isinstance(shared_state_raw, dict): - shared_state_snapshot = cast(dict[str, Any], shared_state_raw) - else: - shared_state_snapshot = {} + shared_state_snapshot = cast(dict[str, Any], shared_state_raw) if isinstance(shared_state_raw, dict) else {} source_executor_ids: list[str] if isinstance(source_executor_ids_raw, list): diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index c85a7ad836..004c5b254d 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -8,7 +8,7 @@ import sys from collections import deque from collections.abc import AsyncIterable, Awaitable, Mapping, MutableMapping, Sequence -from typing import Any, ClassVar, Generic, Literal, Protocol, TypeGuard, TypedDict, cast +from typing import Any, ClassVar, Generic, Literal, Protocol, TypedDict, TypeGuard, cast from uuid import uuid4 from agent_framework import ( diff --git a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py index 5aac1dcc74..ac46a5c529 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py @@ -30,8 +30,6 @@ from typing_extensions import TypeVar # type: ignore # pragma: no cover - - class BedrockRuntimeMeta(Protocol): endpoint_url: str @@ -247,7 +245,9 @@ async def _generate_embedding_for_text( response_body_raw = response["body"] response_payload = response_body_raw.read() - payload_text = response_payload.decode() if isinstance(response_payload, (bytes, bytearray)) else response_payload + payload_text = ( + response_payload.decode() if isinstance(response_payload, (bytes, bytearray)) else response_payload + ) response_body_raw_map: object = json.loads(payload_text) if not isinstance(response_body_raw_map, dict): raise ValueError("Bedrock embedding response body must be a JSON object") diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index ceece1a410..7f3f3da13d 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -1045,11 +1045,10 @@ async def _execute_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse]: # If result is ChatResponse (shouldn't happen for streaming), raise error raise ValueError("Expected ResponseStream for streaming, got ChatResponse") - stream_result = cast( + return cast( ResponseStream[ChatResponseUpdate, ChatResponse[Any]], cast(Any, ResponseStream).from_awaitable(_execute_stream()), ) - return stream_result # For non-streaming, return the coroutine directly return _execute() # type: ignore[return-value] @@ -1133,9 +1132,7 @@ def run( # Re-categorize self.middleware at runtime to support dynamic changes base_middleware_attr = getattr(self, "middleware", None) base_middleware: Sequence[MiddlewareTypes] = ( - cast(Sequence[MiddlewareTypes], base_middleware_attr) - if isinstance(base_middleware_attr, Sequence) - else [] + cast(Sequence[MiddlewareTypes], base_middleware_attr) if isinstance(base_middleware_attr, Sequence) else [] ) base_middleware_list = categorize_middleware(base_middleware) run_middleware_list = categorize_middleware(middleware) @@ -1182,11 +1179,10 @@ async def _execute_stream() -> ResponseStream[AgentResponseUpdate, AgentResponse # If result is AgentResponse (shouldn't happen for streaming), convert to stream raise ValueError("Expected ResponseStream for streaming, got AgentResponse") - stream_result = cast( + return cast( ResponseStream[AgentResponseUpdate, AgentResponse[Any]], cast(Any, ResponseStream).from_awaitable(_execute_stream()), ) - return stream_result # For non-streaming, return the coroutine directly return _execute() # type: ignore[return-value] diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 26016f68cc..8c3457da26 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -230,7 +230,7 @@ def extend_tools(self, source_id: str, tools: Sequence[Any]) -> None: """ for tool in tools: if hasattr(tool, "additional_properties"): - additional_properties_obj = getattr(tool, "additional_properties") + additional_properties_obj = tool.additional_properties if isinstance(additional_properties_obj, dict): additional_properties = cast(dict[str, Any], additional_properties_obj) additional_properties["context_source"] = source_id diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 91e2ea1c75..4f1040baa1 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -26,8 +26,8 @@ Generic, Literal, TypeAlias, - TypeGuard, TypedDict, + TypeGuard, Union, cast, get_args, @@ -895,9 +895,11 @@ def _build_pydantic_model_from_json_schema( properties = properties_raw if _is_str_key_mapping(properties_raw) else None required_raw = schema.get("required", []) required_obj: object = required_raw - required: list[str] = [ - item for item in cast(list[object], required_obj) if isinstance(item, str) - ] if isinstance(required_obj, list) else [] + required: list[str] = ( + [item for item in cast(list[object], required_obj) if isinstance(item, str)] + if isinstance(required_obj, list) + else [] + ) defs_raw = schema.get("$defs", {}) definitions: Mapping[str, Any] = defs_raw if _is_str_key_mapping(defs_raw) else {} @@ -1010,9 +1012,11 @@ def _resolve_type(prop_details: Mapping[str, Any], parent_name: str = "") -> typ nested_properties = nested_properties_raw if _is_str_key_mapping(nested_properties_raw) else None nested_required_raw = prop_details.get("required", []) nested_required_obj: object = nested_required_raw - nested_required: set[str] = { - item for item in cast(list[object], nested_required_obj) if isinstance(item, str) - } if isinstance(nested_required_obj, list) else set() + nested_required: set[str] = ( + {item for item in cast(list[object], nested_required_obj) if isinstance(item, str)} + if isinstance(nested_required_obj, list) + else set() + ) if nested_properties: # Create the name for the nested model @@ -1637,9 +1641,7 @@ async def _try_execute_function_calls( declaration_only_flag = True break if ( - config.get("terminate_on_unknown_calls", False) - and fcc.type == "function_call" - and fcc.name not in tool_map # type: ignore[attr-defined] + config.get("terminate_on_unknown_calls", False) and fcc.type == "function_call" and fcc.name not in tool_map # type: ignore[attr-defined] ): raise KeyError(f'Error: Requested function "{fcc.name}" not found.') # type: ignore[attr-defined] if approval_needed: diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index e0797f64a2..4c61749cd6 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -3,8 +3,8 @@ from __future__ import annotations import base64 -import json import inspect +import json import logging import re import sys @@ -232,7 +232,7 @@ def _is_str_key_mapping(value: object) -> TypeGuard[Mapping[str, Any]]: if not isinstance(value, Mapping): return False mapping = cast(Mapping[object, object], value) - return all(isinstance(key, str) for key in mapping.keys()) + return all(isinstance(key, str) for key in mapping) def _validate_uri(uri: str, media_type: str | None) -> dict[str, Any]: @@ -1297,16 +1297,12 @@ def from_dict(cls: type[ContentT], data: Mapping[str, Any]) -> ContentT: input_items_obj: object = remaining.get("inputs") if isinstance(input_items_obj, list): input_items: list[Any] = list(cast(Iterable[Any], input_items_obj)) - remaining["inputs"] = [ - cls.from_dict(item) if _is_str_key_mapping(item) else item for item in input_items - ] + remaining["inputs"] = [cls.from_dict(item) if _is_str_key_mapping(item) else item for item in input_items] output_items_obj: object = remaining.get("outputs") if isinstance(output_items_obj, list): output_items: list[Any] = list(cast(Iterable[Any], output_items_obj)) - remaining["outputs"] = [ - cls.from_dict(item) if _is_str_key_mapping(item) else item for item in output_items - ] + remaining["outputs"] = [cls.from_dict(item) if _is_str_key_mapping(item) else item for item in output_items] return cls( type=content_type, @@ -1345,8 +1341,12 @@ def _add_text_content(self, other: Content) -> Content: else: self_raw_repr: object = self.raw_representation other_raw_repr: object = other.raw_representation - self_raw: list[object] = cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] - other_raw: list[object] = cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + self_raw: list[object] = ( + cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] + ) + other_raw: list[object] = ( + cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + ) raw_representation = self_raw + other_raw # Merge annotations @@ -1379,8 +1379,12 @@ def _add_text_reasoning_content(self, other: Content) -> Content: else: self_raw_repr: object = self.raw_representation other_raw_repr: object = other.raw_representation - self_raw: list[object] = cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] - other_raw: list[object] = cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + self_raw: list[object] = ( + cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] + ) + other_raw: list[object] = ( + cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + ) raw_representation = self_raw + other_raw # Merge annotations @@ -1441,8 +1445,12 @@ def _add_function_call_content(self, other: Content) -> Content: else: self_raw_repr: object = self.raw_representation other_raw_repr: object = other.raw_representation - self_raw: list[object] = cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] - other_raw: list[object] = cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + self_raw: list[object] = ( + cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] + ) + other_raw: list[object] = ( + cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + ) raw_representation = self_raw + other_raw return Content( @@ -1484,8 +1492,12 @@ def _add_usage_content(self, other: Content) -> Content: else: self_raw_repr: object = self.raw_representation other_raw_repr: object = other.raw_representation - self_raw: list[object] = cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] - other_raw: list[object] = cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + self_raw: list[object] = ( + cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] + ) + other_raw: list[object] = ( + cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] + ) raw_representation = self_raw + other_raw return Content( @@ -3382,7 +3394,9 @@ def merge_chat_options( base_tools = result.get("tools") if base_tools and value: # Add tools that aren't already present - base_tool_values: list[Any] = list(cast(Iterable[Any], base_tools)) if isinstance(base_tools, list) else [base_tools] + base_tool_values: list[Any] = ( + list(cast(Iterable[Any], base_tools)) if isinstance(base_tools, list) else [base_tools] + ) merged_tools = list(base_tool_values) tool_values: list[Any] = list(cast(Iterable[Any], value)) if isinstance(value, list) else [value] for tool in tool_values: diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 0e7cc2a8ee..6333a4a226 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1253,9 +1253,9 @@ async def _finalize_stream() -> None: # Register a weak reference callback to close the span if stream is garbage collected # without being consumed. This ensures spans don't leak if users don't consume streams. - wrapped_stream: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] = ( - result_stream.with_cleanup_hook(_record_duration).with_cleanup_hook(_finalize_stream) - ) + wrapped_stream: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] = result_stream.with_cleanup_hook( + _record_duration + ).with_cleanup_hook(_finalize_stream) weakref.finalize(wrapped_stream, _close_span) return wrapped_stream @@ -1531,9 +1531,9 @@ async def _finalize_stream() -> None: # Register a weak reference callback to close the span if stream is garbage collected # without being consumed. This ensures spans don't leak if users don't consume streams. - wrapped_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = ( - result_stream.with_cleanup_hook(_record_duration).with_cleanup_hook(_finalize_stream) - ) + wrapped_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = result_stream.with_cleanup_hook( + _record_duration + ).with_cleanup_hook(_finalize_stream) weakref.finalize(wrapped_stream, _close_span) return wrapped_stream @@ -1647,9 +1647,8 @@ def _get_instructions_from_options(options: Any) -> str | list[str] | None: instructions = cast(Mapping[str, Any], options).get("instructions") if isinstance(instructions, str): return instructions - if isinstance(instructions, list): - if all(isinstance(item, str) for item in instructions): # pyright: ignore[reportUnknownVariableType] - return cast("list[str]", instructions) + if isinstance(instructions, list) and all(isinstance(item, str) for item in cast(list[object], instructions)): + return cast("list[str]", instructions) return None return None @@ -1709,11 +1708,7 @@ def _get_span_attributes(**kwargs: Any) -> dict[str, Any]: """Get the span attributes from a kwargs dictionary.""" attributes: dict[str, Any] = {} options = kwargs.get("all_options", kwargs.get("options")) - options_mapping: Mapping[str, Any] | None - if isinstance(options, Mapping): - options_mapping = cast(Mapping[str, Any], options) - else: - options_mapping = None + options_mapping = cast(Mapping[str, Any], options) if isinstance(options, Mapping) else None for source_keys, (otel_key, transform_func, check_options, default_value) in OTEL_ATTR_MAP.items(): # Normalize to tuple of keys diff --git a/python/packages/core/agent_framework/openai/_embedding_client.py b/python/packages/core/agent_framework/openai/_embedding_client.py index e730bf62d3..0efea66ae5 100644 --- a/python/packages/core/agent_framework/openai/_embedding_client.py +++ b/python/packages/core/agent_framework/openai/_embedding_client.py @@ -123,7 +123,9 @@ async def get_embeddings( "total_token_count": response.usage.total_tokens, } - return cast(GeneratedEmbeddings[list[float]], GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)) + return cast( + GeneratedEmbeddings[list[float]], GeneratedEmbeddings(embeddings, options=options, usage=usage_dict) + ) class OpenAIEmbeddingClient( diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py index 6ef171fce5..934717e9ec 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py @@ -15,9 +15,9 @@ import logging import uuid from abc import abstractmethod +from collections.abc import Mapping from dataclasses import dataclass, field from inspect import isawaitable -from collections.abc import Mapping from typing import Any, cast from agent_framework import ( @@ -112,8 +112,6 @@ class ToolApprovalState: # ============================================================================ - - def _empty_messages() -> list[Message]: return [] diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py b/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py index 499e577c96..04374d06a9 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py @@ -44,8 +44,9 @@ def message_text(messages: Any) -> str: content: Any = messages_dict.get("content", "") if isinstance(content, str): return content - if hasattr(content, "text"): - return str(content.text) + text_attr = getattr(content, "text", None) + if text_attr is not None: + return str(text_attr) return str(content) if content else "" if isinstance(messages, list): @@ -65,11 +66,11 @@ def message_text(messages: Any) -> str: else: msg_obj: object = msg if hasattr(msg_obj, "content"): - msg_obj_content: Any = getattr(msg_obj, "content") + msg_obj_content: Any = getattr(msg_obj, "content", None) if isinstance(msg_obj_content, str): texts.append(msg_obj_content) - elif hasattr(msg_obj_content, "text"): - texts.append(str(getattr(msg_obj_content, "text"))) + elif (msg_obj_text := getattr(msg_obj_content, "text", None)) is not None: + texts.append(str(msg_obj_text)) elif msg_obj_content: texts.append(str(msg_obj_content)) return " ".join(texts) diff --git a/python/packages/devui/agent_framework_devui/__init__.py b/python/packages/devui/agent_framework_devui/__init__.py index 4b1130506e..a6dea87b90 100644 --- a/python/packages/devui/agent_framework_devui/__init__.py +++ b/python/packages/devui/agent_framework_devui/__init__.py @@ -265,8 +265,8 @@ def main() -> None: "OpenAIError", "OpenAIResponse", "ResponseStreamEvent", + "get_registered_cleanup_hooks", "main", "register_cleanup", - "get_registered_cleanup_hooks", "serve", ] diff --git a/python/packages/devui/agent_framework_devui/_deployment.py b/python/packages/devui/agent_framework_devui/_deployment.py index 70d785b0bb..34147db1f9 100644 --- a/python/packages/devui/agent_framework_devui/_deployment.py +++ b/python/packages/devui/agent_framework_devui/_deployment.py @@ -7,10 +7,10 @@ import re import secrets import uuid -from typing import cast from collections.abc import AsyncGenerator from datetime import datetime, timezone from pathlib import Path +from typing import cast from urllib.parse import urlparse from .models._discovery_models import Deployment, DeploymentConfig, DeploymentEvent diff --git a/python/packages/devui/agent_framework_devui/_discovery.py b/python/packages/devui/agent_framework_devui/_discovery.py index 5aad165571..37ab53044d 100644 --- a/python/packages/devui/agent_framework_devui/_discovery.py +++ b/python/packages/devui/agent_framework_devui/_discovery.py @@ -885,7 +885,11 @@ async def _extract_tools_from_object(self, obj: Any, obj_type: str) -> list[str] chat_options_tools = chat_options_dict.get("tools") if chat_options_tools is not None: - tool_iterable: list[object] = cast(list[object], chat_options_tools) if isinstance(chat_options_tools, list) else [chat_options_tools] + tool_iterable: list[object] = ( + cast(list[object], chat_options_tools) + if isinstance(chat_options_tools, list) + else [chat_options_tools] + ) for tool_obj in tool_iterable: tool_name = getattr(tool_obj, "__name__", None) if isinstance(tool_name, str): @@ -925,7 +929,7 @@ async def _extract_tools_from_object(self, obj: Any, obj_type: str) -> list[str] tools.append(str(getattr(executor_obj, "id", executor_obj))) elif isinstance(executors, dict): executors_dict = cast(dict[str, Any], executors) - for key_obj in executors_dict.keys(): + for key_obj in executors_dict: tools.append(str(key_obj)) except Exception as e: diff --git a/python/packages/devui/agent_framework_devui/_executor.py b/python/packages/devui/agent_framework_devui/_executor.py index 516ef6d8de..3f732dd80c 100644 --- a/python/packages/devui/agent_framework_devui/_executor.py +++ b/python/packages/devui/agent_framework_devui/_executor.py @@ -677,7 +677,9 @@ def _convert_openai_input_to_chat_message(self, input_items: list[Any], Message: # Use file_data or file_url # Include filename in additional_properties for OpenAI/Azure file handling - additional_props: dict[str, Any] | None = {"filename": filename} if filename else None + additional_props: dict[str, Any] | None = ( + {"filename": filename} if filename else None + ) if isinstance(file_data, str) and file_data: # Assume file_data is base64, create data URI data_uri = f"data:{media_type};base64,{file_data}" diff --git a/python/packages/devui/agent_framework_devui/_server.py b/python/packages/devui/agent_framework_devui/_server.py index f8d862b731..ff26937843 100644 --- a/python/packages/devui/agent_framework_devui/_server.py +++ b/python/packages/devui/agent_framework_devui/_server.py @@ -53,6 +53,7 @@ def _extract_error_details(body: object) -> tuple[str | None, str | None, str | code if isinstance(code, str) else None, ) + # Get package version try: __version__ = importlib.metadata.version("agent-framework-devui") @@ -1120,7 +1121,7 @@ async def delete_conversation_item(conversation_id: str, item_id: str) -> dict[s # Checkpoints are exposed as conversation items with type="checkpoint" # ============================================================================ - _registered_route_handlers = ( + registered_route_handlers = ( health_check, get_meta, discover_entities, @@ -1143,7 +1144,7 @@ async def delete_conversation_item(conversation_id: str, item_id: str) -> dict[s retrieve_conversation_item, delete_conversation_item, ) - _ = _registered_route_handlers + _ = registered_route_handlers async def _stream_execution( self, executor: AgentFrameworkExecutor, request: AgentFrameworkRequest @@ -1165,9 +1166,8 @@ async def _stream_execution( if conversation_id and hasattr(event, "type") and event.type == "response.trace.completed": try: trace_data = event.data if hasattr(event, "data") else None - if trace_data: - if isinstance(conversation_id, str): - executor.conversation_store.add_trace(conversation_id, trace_data) + if trace_data and isinstance(conversation_id, str): + executor.conversation_store.add_trace(conversation_id, trace_data) except Exception as e: logger.debug(f"Failed to store trace event: {e}") diff --git a/python/packages/devui/agent_framework_devui/_session.py b/python/packages/devui/agent_framework_devui/_session.py index 0f53e09e3d..93ac9b31e4 100644 --- a/python/packages/devui/agent_framework_devui/_session.py +++ b/python/packages/devui/agent_framework_devui/_session.py @@ -182,18 +182,16 @@ def get_active_sessions(self) -> list[SessionSummary]: for session_id, session in self.sessions.items(): if session["active"]: - active_sessions.append( - { - "session_id": session_id, - "created_at": session["created_at"].isoformat(), - "request_count": len(session["requests"]), - "last_activity": ( - session["requests"][-1]["timestamp"].isoformat() - if session["requests"] - else session["created_at"].isoformat() - ), - } - ) + active_sessions.append({ + "session_id": session_id, + "created_at": session["created_at"].isoformat(), + "request_count": len(session["requests"]), + "last_activity": ( + session["requests"][-1]["timestamp"].isoformat() + if session["requests"] + else session["created_at"].isoformat() + ), + }) return active_sessions diff --git a/python/packages/devui/agent_framework_devui/_utils.py b/python/packages/devui/agent_framework_devui/_utils.py index bdf76f1ec5..53e4b8416c 100644 --- a/python/packages/devui/agent_framework_devui/_utils.py +++ b/python/packages/devui/agent_framework_devui/_utils.py @@ -21,6 +21,7 @@ def _string_key_dict(value: object) -> dict[str, Any] | None: source: dict[str, Any] = cast(dict[str, Any], value) return {str(k): v for k, v in source.items()} + # ============================================================================ # Agent Metadata Extraction # ============================================================================ diff --git a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py index 4709307299..16451ae85a 100644 --- a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py +++ b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py @@ -260,8 +260,7 @@ class MyOptions(FoundryLocalChatOptions, total=False): f"Model with ID or alias '{model_id_setting}:{device.value}' not found in Foundry Local." if device else ( - f"Model with ID or alias '{model_id_setting}' for your current device " - "not found in Foundry Local." + f"Model with ID or alias '{model_id_setting}' for your current device not found in Foundry Local." ) ) raise ValueError(message) diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 86ecf737f8..1c30af36dc 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -30,9 +30,9 @@ from copilot.types import ( CopilotClientOptions, MCPServerConfig, + MessageOptions, PermissionRequest, PermissionRequestResult, - MessageOptions, ResumeSessionConfig, SessionConfig, SystemMessageConfig, diff --git a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py index ec7a90640d..cba407ded3 100644 --- a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py +++ b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py @@ -287,7 +287,9 @@ def _load_gaia_local(repo_dir: Path, wanted_levels: list[int] | None = None, max if wanted_levels and (lvl not in wanted_levels): continue - tasks.append(Task(task_id=qid, question=q, answer=str(ans), level=lvl, file_name=fname, metadata=record)) + tasks.append( + Task(task_id=qid, question=q, answer=str(ans), level=lvl, file_name=fname, metadata=record) + ) except ImportError: print("Warning: pyarrow not installed. Install with: pip install pyarrow") continue diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py index 9847a460bc..5b1390c3dc 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py @@ -35,8 +35,6 @@ def _to_str(value: object, default: str = "") -> str: return str(value) - - def _is_any_list(value: Any) -> TypeGuard[list[Any]]: return isinstance(value, list) @@ -48,6 +46,7 @@ def _is_any_mapping(value: Any) -> TypeGuard[Mapping[Any, Any]]: def _is_any_sequence(value: Any) -> TypeGuard[list[Any] | tuple[Any, ...] | set[Any]]: return isinstance(value, (list, tuple, set)) + def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool: """Convert a tau2 Tool to a FunctionTool for agent framework compatibility. diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py index 2a29e5b544..8d4aee310f 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/runner.py @@ -47,6 +47,7 @@ def _get_openai_schema(tool: Any) -> dict[str, Any]: return cast(dict[str, Any], schema_dict) raise TypeError(f"Tool {tool} does not expose a dict openai_schema") + # Agent instructions matching tau2's LLMAgent ASSISTANT_AGENT_INSTRUCTION = """ You are a customer service agent that helps the user according to the provided below. diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index a9075fe77c..ddeedfea36 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -51,6 +51,7 @@ from agent_framework._workflows._workflow import Workflow from agent_framework._workflows._workflow_builder import WorkflowBuilder from agent_framework._workflows._workflow_context import WorkflowContext + from ._base_group_chat_orchestrator import TerminationCondition from ._orchestrator_helpers import clean_conversation_for_handoff @@ -251,7 +252,6 @@ def _prepare_agent_with_handoffs( Returns: A cloned ``Agent`` instance with handoff tools added """ - # Clone the agent to avoid mutating the original cloned_agent = self._clone_chat_agent(agent) # Add handoff tools to the cloned agent @@ -444,9 +444,7 @@ def _handoff_tool() -> None: return _handoff_tool @override - async def _run_agent_and_emit( - self, ctx: WorkflowContext[Any, Any] - ) -> None: + async def _run_agent_and_emit(self, ctx: WorkflowContext[Any, Any]) -> None: """Override to support handoff.""" incoming_messages = list(self._cache) cleaned_incoming_messages = clean_conversation_for_handoff(incoming_messages) @@ -481,7 +479,11 @@ async def _run_agent_and_emit( # Handoff workflows are orchestrator-stateful and provider-stateless by design. # If an existing session still has a service conversation id, clear it to avoid # replaying stale unresolved tool calls across resumed turns. - if is_chat_agent(self._agent) and self._agent.default_options.get("store") is False and self._session.service_session_id is not None: + if ( + is_chat_agent(self._agent) + and self._agent.default_options.get("store") is False + and self._session.service_session_id is not None + ): self._session.service_session_id = None # Check termination condition before running the agent diff --git a/python/packages/purview/agent_framework_purview/_client.py b/python/packages/purview/agent_framework_purview/_client.py index 7dedf465ee..e592f34da5 100644 --- a/python/packages/purview/agent_framework_purview/_client.py +++ b/python/packages/purview/agent_framework_purview/_client.py @@ -142,7 +142,6 @@ async def send_content_activities(self, request: ContentActivitiesRequest) -> Co url = f"{self._graph_uri}/users/{request.user_id}/dataSecurityAndGovernance/activities/contentActivities" return await self._post(url, request, ContentActivitiesResponse, token) - @overload async def _post( self, @@ -215,15 +214,12 @@ async def _post( try: # Prefer pydantic-style model_validate if present, else fall back to constructor. model_validate = getattr(response_type, "model_validate", None) - if callable(model_validate): - response_obj = model_validate(data) - else: - response_obj = response_type(**data) # type: ignore[call-arg] + response_obj = model_validate(data) if callable(model_validate) else response_type(**data) # type: ignore[call-arg] # Extract correlation_id from response headers if response object supports it if "client-request-id" in resp.headers and hasattr(response_obj, "correlation_id"): response_correlation_id = resp.headers["client-request-id"] - setattr(response_obj, "correlation_id", response_correlation_id) + response_obj.correlation_id = response_correlation_id # pyright: ignore[reportAttributeAccessIssue] logger.info(f"Purview response from {url} with correlation_id: {response_correlation_id}") typed_response_obj = response_obj if isinstance(response_obj, response_type) else response_type(**data) diff --git a/python/packages/purview/agent_framework_purview/_processor.py b/python/packages/purview/agent_framework_purview/_processor.py index c81785fbad..241de80d61 100644 --- a/python/packages/purview/agent_framework_purview/_processor.py +++ b/python/packages/purview/agent_framework_purview/_processor.py @@ -183,9 +183,7 @@ async def _map_messages( version=self._settings.get("app_version", "Unknown"), application_location=policy_location, ) - integrated_app = IntegratedAppMetadata( - name=app_name, version=self._settings.get("app_version", "Unknown") - ) + integrated_app = IntegratedAppMetadata(name=app_name, version=self._settings.get("app_version", "Unknown")) device_meta = DeviceMetadata( operating_system_specifications=OperatingSystemSpecifications( operating_system_platform="Unknown", operating_system_version="Unknown" diff --git a/python/packages/redis/agent_framework_redis/_context_provider.py b/python/packages/redis/agent_framework_redis/_context_provider.py index 82fad0b8b4..32b6a6cc5d 100644 --- a/python/packages/redis/agent_framework_redis/_context_provider.py +++ b/python/packages/redis/agent_framework_redis/_context_provider.py @@ -384,8 +384,7 @@ async def _redis_search( return_fields=return_fields, stopwords=None, ) - hybrid_results = await self.redis_index.query(query) - return hybrid_results # type: ignore[no-any-return] + return await self.redis_index.query(query) # type: ignore[no-any-return] query = TextQuery( text=q, text_field_name="content", @@ -395,8 +394,7 @@ async def _redis_search( return_fields=return_fields, stopwords=None, ) - text_results = await self.redis_index.query(query) - return text_results # type: ignore[no-any-return] + return await self.redis_index.query(query) # type: ignore[no-any-return] except Exception as exc: # pragma: no cover raise IntegrationInvalidRequestException(f"Redis text search failed: {exc}") from exc From d89fc0439cf6ac8375001d8ad437f3ae01456e64 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 3 Mar 2026 21:13:07 +0100 Subject: [PATCH 05/15] fixed hooks --- .../_serialization.py | 3 +- python/uv.lock | 28 +++++++++---------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py index ad62ba7a06..b353549dcb 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py @@ -20,8 +20,9 @@ import importlib import logging +from collections.abc import Callable from dataclasses import is_dataclass -from typing import Any, Callable, cast +from typing import Any, cast from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value diff --git a/python/uv.lock b/python/uv.lock index 28877c91d2..7233077c30 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -525,7 +525,7 @@ source = { editable = "packages/github_copilot" } dependencies = [ { name = "agent-framework-core", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "github-copilot-sdk", version = "0.1.25", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.11' and sys_platform == 'darwin') or (python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'win32')" }, - { name = "github-copilot-sdk", version = "0.1.29", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'darwin') or (python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, + { name = "github-copilot-sdk", version = "0.1.30", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.11' and sys_platform == 'darwin') or (python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, ] [package.metadata] @@ -1377,19 +1377,19 @@ wheels = [ [[package]] name = "claude-agent-sdk" -version = "0.1.44" +version = "0.1.45" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "mcp", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typing-extensions", marker = "(python_full_version < '3.11' and sys_platform == 'darwin') or (python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'win32')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/30/40/5661e10daf69ee5c864f82a1888cc33c9378b2d7f7d11db3c2360aef3a30/claude_agent_sdk-0.1.44.tar.gz", hash = "sha256:8629436e7af367a1cbc81aa2a58a93aa68b8b2e4e14b0c5be5ac3627bd462c1b", size = 62439, upload-time = "2026-02-26T01:17:28.118Z" } +sdist = { url = "https://files.pythonhosted.org/packages/46/e2/c5d5c4743ece496492a930bb75b878c830a9a9878ae3327b2d292647a8fa/claude_agent_sdk-0.1.45.tar.gz", hash = "sha256:97c1e981431b5af1e08c34731906ab8d4a58fe0774a04df0ea9587dcabc85151", size = 62436, upload-time = "2026-03-03T17:21:08.595Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e9/1a/dcde83a6477bfdf8c5510fd84006cca763296e6bc5576e90cd89b97ec034/claude_agent_sdk-0.1.44-py3-none-macosx_11_0_arm64.whl", hash = "sha256:1dd976ad3efb673aefd5037dc75ee7926fb5033c4b9ab7382897ab647fed74e6", size = 55828889, upload-time = "2026-02-26T01:17:15.474Z" }, - { url = "https://files.pythonhosted.org/packages/4b/33/3b161256956968e18c81e2b2650fed7d2a1144d51042ed6317848643e5d7/claude_agent_sdk-0.1.44-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:d35b38ca40fa28f50fa88705599a298ab30c121c56b53655025eeceb463ac399", size = 70795212, upload-time = "2026-02-26T01:17:18.873Z" }, - { url = "https://files.pythonhosted.org/packages/17/cb/67af9796dad77a94dfe851138f5ffc9e2e0a14407ba55fea07462c1cc8e5/claude_agent_sdk-0.1.44-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:853c15501f71a913a6cc6b40dc0b24b9505166cad164206b8eab229889e670b8", size = 71424685, upload-time = "2026-02-26T01:17:22.345Z" }, - { url = "https://files.pythonhosted.org/packages/46/cd/2d3806c791250a76de2c1be863fc01d420729ad61496253e3d3033464c72/claude_agent_sdk-0.1.44-py3-none-win_amd64.whl", hash = "sha256:597e2fcad372086f93e4f6a380d3088ec4dd9b9efce309c5281b52a256fd5d25", size = 73493771, upload-time = "2026-02-26T01:17:25.837Z" }, + { url = "https://files.pythonhosted.org/packages/20/29/a28b6dfac54dfceddaa47e16c2b9cb61cc2ace4b4a1de064ab6d76debcbd/claude_agent_sdk-0.1.45-py3-none-macosx_11_0_arm64.whl", hash = "sha256:26a5cc60c3a394f5b814f6b2f67650819cbcd38c405bbdc11582b3e097b3a770", size = 57761380, upload-time = "2026-03-03T17:20:55.066Z" }, + { url = "https://files.pythonhosted.org/packages/aa/7c/a803cc6e40de8b13cc822c66fd96c96d88f994983c2622d80cb8b708bb30/claude_agent_sdk-0.1.45-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:decc741b53e0b2c10a64fd84c15acca1102077d9f99941c54905172cd95160c9", size = 73402101, upload-time = "2026-03-03T17:20:58.604Z" }, + { url = "https://files.pythonhosted.org/packages/32/51/bdb9832728189673c60c605854c2153e17dce384a64a6dc88cdbb254ce86/claude_agent_sdk-0.1.45-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:7d48dcf4178c704e4ccbf3f1f4ebf20b3de3f03d0592086c1f3abd16b8ca441e", size = 74091498, upload-time = "2026-03-03T17:21:02.332Z" }, + { url = "https://files.pythonhosted.org/packages/13/37/02e60d7f93aedc8f63f9404cbf2a48bf5d47c27ccb9c0a0f03c803882fa5/claude_agent_sdk-0.1.45-py3-none-win_amd64.whl", hash = "sha256:d1cf34995109c513d8daabcae7208edc260b553b53462a9ac06a7c40e240a288", size = 75784070, upload-time = "2026-03-03T17:21:05.573Z" }, ] [[package]] @@ -2301,7 +2301,7 @@ wheels = [ [[package]] name = "github-copilot-sdk" -version = "0.1.29" +version = "0.1.30" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'darwin'", @@ -2322,12 +2322,12 @@ dependencies = [ { name = "python-dateutil", marker = "(python_full_version >= '3.11' and sys_platform == 'darwin') or (python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/11/8e/2155e40594a60084266d33cefd2333fe3ce44e7189773e6eff9943e25d81/github_copilot_sdk-0.1.29-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:0215045cf6ec2cebfc6dbb0e257e2116d4aa05751f80cc48d5f3c8c658933094", size = 58182462, upload-time = "2026-02-27T22:09:59.687Z" }, - { url = "https://files.pythonhosted.org/packages/55/6a/9fa577564702eb1eb143c16afcdadf7d6305da53fbbd05a0925035808d9e/github_copilot_sdk-0.1.29-py3-none-macosx_11_0_arm64.whl", hash = "sha256:441c917aad8501da5264026b0da5c0e834571256e812617437654ab16bdad77f", size = 54934772, upload-time = "2026-02-27T22:10:02.911Z" }, - { url = "https://files.pythonhosted.org/packages/69/77/0e0fd6f6a0177d93f5f3e5d0e9ed5044fc53c54e58e65bbc6b08eb789350/github_copilot_sdk-0.1.29-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:88230b779dee1695fc44043060006224138c5b5d6724890f7ecdc378ff0d8f73", size = 61071028, upload-time = "2026-02-27T22:10:06.332Z" }, - { url = "https://files.pythonhosted.org/packages/94/f5/9a73bd6e34db4d0ce546b04725cfad1c9fa58426265876b640376381b623/github_copilot_sdk-0.1.29-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:2019bbbaea39d8db54250d11431d89952dd0ad0a16b58159b6b018ea625c78c9", size = 59251702, upload-time = "2026-02-27T22:10:09.466Z" }, - { url = "https://files.pythonhosted.org/packages/ea/32/60713b1ae3ed80b62113f993bd2f4552d2b03753cfea37f90086ac8e6d6e/github_copilot_sdk-0.1.29-py3-none-win_amd64.whl", hash = "sha256:a326fe5ab6ecd7cef5de39d5a5fe18e09e629eb29b401be23a709e83fc578578", size = 53690857, upload-time = "2026-02-27T22:10:12.778Z" }, - { url = "https://files.pythonhosted.org/packages/58/31/d082f4ac13cf3e4ba3a7846b8468521d6d38967de3788a61b6001707fbb5/github_copilot_sdk-0.1.29-py3-none-win_arm64.whl", hash = "sha256:1ace40f23ab8d8c97f8d61d31d01946ade9c83ea7982671864ec5aef0cd7dd01", size = 51699152, upload-time = "2026-02-27T22:10:15.791Z" }, + { url = "https://files.pythonhosted.org/packages/18/37/92b8037c0673999ac1c49e9d079cf6d36283e6ee3453d66b54878da81bc8/github_copilot_sdk-0.1.30-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:47e95246a63beeebf192db6013662c5f39778ccfa6b1b718b79cbec6b6a88bf8", size = 58182964, upload-time = "2026-03-03T17:21:53.564Z" }, + { url = "https://files.pythonhosted.org/packages/08/79/9d0628fa819df73e92ebbd4af949cdd82850cc4bde79b3e78040fcd8ed80/github_copilot_sdk-0.1.30-py3-none-macosx_11_0_arm64.whl", hash = "sha256:601cbe1c5a576906b73cbf8591429451c91148bff5a564e56e1e83ff99b2dc10", size = 54935274, upload-time = "2026-03-03T17:21:57.494Z" }, + { url = "https://files.pythonhosted.org/packages/10/5d/f407e9c9155f912780b4587ab74abf3b94fae91af0463bad317cc8aacdfe/github_copilot_sdk-0.1.30-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:735fb90683bea27a418a0d45df430492db2a395e5ae88d575ac138be49d6cf07", size = 61071530, upload-time = "2026-03-03T17:22:01.601Z" }, + { url = "https://files.pythonhosted.org/packages/b8/9f/5c2ab2baf5f185150058c774da2b5e4c613b4532c48b499ce127419da461/github_copilot_sdk-0.1.30-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:21ade06dfe5ca111663c42fff000ab3ec6595e51b1cf4ab56ff550cdd7a2992f", size = 59252204, upload-time = "2026-03-03T17:22:05.706Z" }, + { url = "https://files.pythonhosted.org/packages/ef/80/4e72ccdc8868250ba8c5d48a1fef5a8244361c2a586820de9b77df0c79ed/github_copilot_sdk-0.1.30-py3-none-win_amd64.whl", hash = "sha256:f1be9e49da2af370a914d4425bfecbc2daecf8e5de0074beaa1e22735bdd5da6", size = 53691358, upload-time = "2026-03-03T17:22:09.474Z" }, + { url = "https://files.pythonhosted.org/packages/53/4f/25ff085d0d5d50d1197fd6ae9a53adc4cc8298940212f5a69f7ced68c33e/github_copilot_sdk-0.1.30-py3-none-win_arm64.whl", hash = "sha256:3e0691eb3030c385f629d63d74ded938e0577fcd98f452259efd5d7fb2283576", size = 51699653, upload-time = "2026-03-03T17:22:13.215Z" }, ] [[package]] From d0d4b4b4d6c4d04ff7901ffc10ee2782de01d7ca Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 4 Mar 2026 12:23:53 +0100 Subject: [PATCH 06/15] Stabilize package tests and test tasks Resolve cross-package non-integration test failures, simplify streaming type flow, harden locale/culture handling, and standardize package test poe tasks to exclude integration tests where applicable. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../a2a/agent_framework_a2a/_agent.py | 4 +- python/packages/a2a/pyproject.toml | 2 +- python/packages/ag-ui/pyproject.toml | 2 +- python/packages/anthropic/pyproject.toml | 2 +- .../packages/azure-ai-search/pyproject.toml | 2 +- python/packages/azure-ai/pyproject.toml | 2 +- python/packages/azure-cosmos/pyproject.toml | 2 +- .../_context.py | 6 +- python/packages/azurefunctions/pyproject.toml | 2 +- python/packages/bedrock/pyproject.toml | 4 +- python/packages/chatkit/pyproject.toml | 2 +- python/packages/claude/pyproject.toml | 2 +- python/packages/copilotstudio/pyproject.toml | 2 +- .../packages/core/agent_framework/_tools.py | 54 +++----- .../packages/core/agent_framework/_types.py | 16 +-- .../_workflows/_runner_context.py | 10 +- .../core/agent_framework/observability.py | 26 ++-- python/packages/core/pyproject.toml | 2 +- .../openai/test_openai_embedding_client.py | 9 +- .../tests/workflow/test_agent_executor.py | 51 ++----- .../core/tests/workflow/test_agent_utils.py | 27 +++- .../packages/core/tests/workflow/test_edge.py | 3 +- .../core/tests/workflow/test_executor.py | 127 +++++------------- .../tests/workflow/test_workflow_agent.py | 36 ++++- .../tests/workflow/test_workflow_kwargs.py | 90 +++++++++++-- .../tests/workflow/test_workflow_states.py | 8 +- .../_workflows/_declarative_base.py | 39 ++++-- python/packages/declarative/pyproject.toml | 2 +- .../tests/test_powerfx_yaml_compatibility.py | 32 +++-- .../devui/agent_framework_devui/_utils.py | 5 +- python/packages/devui/pyproject.toml | 2 +- python/packages/durabletask/pyproject.toml | 4 +- python/packages/foundry_local/pyproject.toml | 2 +- python/packages/github_copilot/pyproject.toml | 2 +- python/packages/lab/pyproject.toml | 8 +- python/packages/mem0/pyproject.toml | 2 +- python/packages/ollama/pyproject.toml | 2 +- .../_handoff.py | 41 ++---- python/packages/orchestrations/pyproject.toml | 2 +- python/packages/purview/pyproject.toml | 2 +- python/packages/redis/pyproject.toml | 2 +- 41 files changed, 326 insertions(+), 314 deletions(-) diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index fc6ebc894f..31fac386b3 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -307,10 +307,12 @@ async def _map_a2a_stream( response_id=str(getattr(item, "message_id", uuid.uuid4())), raw_representation=item, ) - else: + elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], Task): task, _update_event = item for update in self._updates_from_task(task, background=background): yield update + else: + raise NotImplementedError("Only Message and Task responses are supported") # ------------------------------------------------------------------ # Task helpers diff --git a/python/packages/a2a/pyproject.toml b/python/packages/a2a/pyproject.toml index 43e0df726b..b7bfdb9275 100644 --- a/python/packages/a2a/pyproject.toml +++ b/python/packages/a2a/pyproject.toml @@ -87,7 +87,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_a2a" -test = "pytest --cov=agent_framework_a2a --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_a2a --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/ag-ui/pyproject.toml b/python/packages/ag-ui/pyproject.toml index cc5c081c44..044d7d935a 100644 --- a/python/packages/ag-ui/pyproject.toml +++ b/python/packages/ag-ui/pyproject.toml @@ -74,4 +74,4 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_ag_ui" -test = "pytest --cov=agent_framework_ag_ui --cov-report=term-missing:skip-covered -n auto --dist worksteal tests/ag_ui" +test = "pytest -m \"not integration\" --cov=agent_framework_ag_ui --cov-report=term-missing:skip-covered -n auto --dist worksteal tests/ag_ui" diff --git a/python/packages/anthropic/pyproject.toml b/python/packages/anthropic/pyproject.toml index ed31c4800a..51631bdd30 100644 --- a/python/packages/anthropic/pyproject.toml +++ b/python/packages/anthropic/pyproject.toml @@ -87,7 +87,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_anthropic" -test = "pytest --cov=agent_framework_anthropic --cov-report=term-missing:skip-covered -n auto --dist worksteal tests" +test = "pytest -m \"not integration\" --cov=agent_framework_anthropic --cov-report=term-missing:skip-covered -n auto --dist worksteal tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/azure-ai-search/pyproject.toml b/python/packages/azure-ai-search/pyproject.toml index 6af0688f3f..0827c2d816 100644 --- a/python/packages/azure-ai-search/pyproject.toml +++ b/python/packages/azure-ai-search/pyproject.toml @@ -89,7 +89,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_ai_search" -test = "pytest --cov=agent_framework_azure_ai_search --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_azure_ai_search --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/azure-ai/pyproject.toml b/python/packages/azure-ai/pyproject.toml index 9dca3ea0f0..2bd51729c2 100644 --- a/python/packages/azure-ai/pyproject.toml +++ b/python/packages/azure-ai/pyproject.toml @@ -87,7 +87,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_ai" -test = "pytest --cov=agent_framework_azure_ai --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_azure_ai --cov-report=term-missing:skip-covered tests" [tool.poe.tasks.integration-tests] cmd = """ diff --git a/python/packages/azure-cosmos/pyproject.toml b/python/packages/azure-cosmos/pyproject.toml index c05d3cd939..cae3b3168c 100644 --- a/python/packages/azure-cosmos/pyproject.toml +++ b/python/packages/azure-cosmos/pyproject.toml @@ -86,7 +86,7 @@ executor.type = "uv" include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_cosmos" -test = "pytest --cov=agent_framework_azure_cosmos --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_azure_cosmos --cov-report=term-missing:skip-covered tests" integration-tests = "pytest tests/test_cosmos_history_provider.py -m integration" [build-system] diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py index 561e05bee4..a45dcf81fc 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_context.py @@ -44,10 +44,10 @@ def __init__(self) -> None: # region Messaging - async def send_message(self, WorkflowMessage: WorkflowMessage) -> None: + async def send_message(self, message: WorkflowMessage) -> None: """Capture a message sent by an executor.""" - self._messages.setdefault(WorkflowMessage.source_id, []) - self._messages[WorkflowMessage.source_id].append(WorkflowMessage) + self._messages.setdefault(message.source_id, []) + self._messages[message.source_id].append(message) async def drain_messages(self) -> dict[str, list[WorkflowMessage]]: """Drain and return all captured messages.""" diff --git a/python/packages/azurefunctions/pyproject.toml b/python/packages/azurefunctions/pyproject.toml index c55bd86785..0bb2ec9612 100644 --- a/python/packages/azurefunctions/pyproject.toml +++ b/python/packages/azurefunctions/pyproject.toml @@ -93,7 +93,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azurefunctions" -test = "pytest --cov=agent_framework_azurefunctions --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_azurefunctions --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/bedrock/pyproject.toml b/python/packages/bedrock/pyproject.toml index 8fec38093f..b99ecb91ff 100644 --- a/python/packages/bedrock/pyproject.toml +++ b/python/packages/bedrock/pyproject.toml @@ -86,8 +86,8 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_bedrock" -test = "pytest --cov=agent_framework_bedrock --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_bedrock --cov-report=term-missing:skip-covered tests" [build-system] requires = ["hatchling"] -build-backend = "hatchling.build" \ No newline at end of file +build-backend = "hatchling.build" diff --git a/python/packages/chatkit/pyproject.toml b/python/packages/chatkit/pyproject.toml index 91ba8347b8..74d7216da6 100644 --- a/python/packages/chatkit/pyproject.toml +++ b/python/packages/chatkit/pyproject.toml @@ -88,7 +88,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_chatkit" -test = "pytest --cov=agent_framework_chatkit --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_chatkit --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/claude/pyproject.toml b/python/packages/claude/pyproject.toml index 74c9a6f358..f1891586f8 100644 --- a/python/packages/claude/pyproject.toml +++ b/python/packages/claude/pyproject.toml @@ -88,7 +88,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_claude" -test = "pytest --cov=agent_framework_claude --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_claude --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/copilotstudio/pyproject.toml b/python/packages/copilotstudio/pyproject.toml index df6531b623..c37fa71ecf 100644 --- a/python/packages/copilotstudio/pyproject.toml +++ b/python/packages/copilotstudio/pyproject.toml @@ -87,7 +87,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_copilotstudio" -test = "pytest --cov=agent_framework_copilotstudio --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_copilotstudio --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 4f1040baa1..8970b03ae6 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1755,19 +1755,6 @@ def _update_conversation_id( options["conversation_id"] = conversation_id -async def _ensure_response_stream( - stream_like: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] - | Awaitable[ResponseStream[ChatResponseUpdate, ChatResponse[Any]]], -) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - from ._types import ResponseStream - - stream = await stream_like if isinstance(stream_like, Awaitable) else stream_like - if not isinstance(stream, ResponseStream): - raise ValueError("Streaming function invocation requires a ResponseStream result.") - await stream - return cast(ResponseStream[ChatResponseUpdate, ChatResponse[Any]], stream) - - def _extract_tools( options: dict[str, Any] | None, ) -> ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None: @@ -2151,7 +2138,10 @@ def get_response( ResponseStream, ) - super_get_response = super().get_response # type: ignore[misc] + super_get_response_untyped = super().get_response # type: ignore[misc] + + def super_get_response(*args: Any, **kwargs: Any) -> Any: + return super_get_response_untyped(*args, **kwargs) # pyright: ignore[reportUnknownVariableType] # ChatMiddleware adds this kwarg function_middleware_pipeline = FunctionMiddlewarePipeline( @@ -2332,17 +2322,13 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: mutable_options["tool_choice"] = "none" return - stream_like = cast( - ResponseStream[ChatResponseUpdate, ChatResponse[Any]] - | Awaitable[ResponseStream[ChatResponseUpdate, ChatResponse[Any]]], - super_get_response( - messages=prepped_messages, - stream=True, - options=mutable_options, - **filtered_kwargs, - ), + inner_stream: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] = super_get_response( + messages=prepped_messages, + stream=True, + options=mutable_options, + **filtered_kwargs, ) - inner_stream = await _ensure_response_stream(stream_like) + await inner_stream # Collect result hooks from the inner stream to run later stream_result_hooks[:] = _get_result_hooks_from_stream(inner_stream) @@ -2425,21 +2411,17 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS), ) mutable_options["tool_choice"] = "none" - stream_like = cast( - ResponseStream[ChatResponseUpdate, ChatResponse[Any]] - | Awaitable[ResponseStream[ChatResponseUpdate, ChatResponse[Any]]], - super_get_response( - messages=prepped_messages, - stream=True, - options=mutable_options, - **filtered_kwargs, - ), + final_inner_stream: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] = super_get_response( + messages=prepped_messages, + stream=True, + options=mutable_options, + **filtered_kwargs, ) - inner_stream = await _ensure_response_stream(stream_like) - async for update in inner_stream: + await final_inner_stream + async for update in final_inner_stream: yield update # Finalize the inner stream to trigger its hooks - await inner_stream.get_final_response() + await final_inner_stream.get_final_response() def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse[Any]: # Note: stream_result_hooks are already run via inner stream's get_final_response() diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 4c61749cd6..09015f4b43 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -2881,18 +2881,16 @@ def from_awaitable( async def _get_stream(self) -> AsyncIterable[UpdateT]: if self._stream is None: if hasattr(self._stream_source, "__aiter__"): - self._stream = cast(AsyncIterable[UpdateT], self._stream_source) + self._stream = self._stream_source # type: ignore[assignment] else: if not iscoroutine(self._stream_source): - self._stream = cast(AsyncIterable[UpdateT], self._stream_source) + self._stream = self._stream_source # type: ignore[assignment] else: self._stream = await self._stream_source - stream_obj = cast(Any, self._stream) - if isinstance(stream_obj, ResponseStream) and self._wrap_inner: - inner_stream: Any = cast(Any, stream_obj) - self._inner_stream = inner_stream - return cast(AsyncIterable[UpdateT], inner_stream) - return cast(AsyncIterable[UpdateT], cast(Any, self._stream)) + if isinstance(self._stream, ResponseStream) and self._wrap_inner: + self._inner_stream = self._stream # type: ignore[assignment] + return self._inner_stream + return self._stream # type: ignore[return-value] def __aiter__(self) -> ResponseStream[UpdateT, FinalT]: return self @@ -3497,7 +3495,7 @@ def dimensions(self) -> int | None: """ if self._dimensions is not None: return self._dimensions - if isinstance(self.vector, Sized): + if isinstance(self.vector, Sized) and not isinstance(self.vector, str): return len(cast(Sized, self.vector)) return None diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index d52e135e91..e3711ea96f 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -99,11 +99,11 @@ class RunnerContext(Protocol): If checkpoint storage is not configured, checkpoint methods may raise. """ - async def send_message(self, WorkflowMessage: WorkflowMessage) -> None: + async def send_message(self, message: WorkflowMessage) -> None: """Send a WorkflowMessage from the executor to the context. Args: - WorkflowMessage: The WorkflowMessage to be sent. + message: The WorkflowMessage to be sent. """ ... @@ -288,9 +288,9 @@ def __init__(self, checkpoint_storage: CheckpointStorage | None = None): self._streaming: bool = False # region Messaging and Events - async def send_message(self, WorkflowMessage: WorkflowMessage) -> None: - self._messages.setdefault(WorkflowMessage.source_id, []) - self._messages[WorkflowMessage.source_id].append(WorkflowMessage) + async def send_message(self, message: WorkflowMessage) -> None: + self._messages.setdefault(message.source_id, []) + self._messages[message.source_id].append(message) async def drain_messages(self) -> dict[str, list[WorkflowMessage]]: messages = copy(self._messages) diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 6333a4a226..feb26f67f4 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1163,7 +1163,7 @@ def get_response( return super_get_response(messages=messages, stream=stream, options=options, **kwargs) # type: ignore[no-any-return] opts: dict[str, Any] = options or {} # type: ignore[assignment] - provider_name = str(self.otel_provider_name) + provider_name = str(getattr(self, "otel_provider_name", "unknown")) model_id = kwargs.get("model_id") or opts.get("model_id") or getattr(self, "model_id", None) or "unknown" service_url_func = getattr(self, "service_url", None) service_url = str(service_url_func() if callable(service_url_func) else "unknown") @@ -1182,12 +1182,7 @@ def get_response( if isinstance(stream_result, ResponseStream): result_stream: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] = stream_result # pyright: ignore[reportUnknownVariableType] elif isinstance(stream_result, Awaitable): - result_stream = cast( - ResponseStream[ChatResponseUpdate, ChatResponse[Any]], - ResponseStream.from_awaitable( # pyright: ignore[reportUnknownMemberType] - cast(Awaitable[ResponseStream[ChatResponseUpdate, ChatResponse[Any]]], stream_result) - ), - ) + result_stream = ResponseStream.from_awaitable(stream_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] else: raise RuntimeError("Streaming telemetry requires a ResponseStream result.") @@ -1230,8 +1225,8 @@ async def _finalize_stream() -> None: _capture_response( span=span, attributes=response_attributes, - token_usage_histogram=self.token_usage_histogram, - operation_duration_histogram=self.duration_histogram, + token_usage_histogram=getattr(self, "token_usage_histogram", None), + operation_duration_histogram=getattr(self, "duration_histogram", None), duration=duration, ) if ( @@ -1284,8 +1279,8 @@ async def _get_response() -> ChatResponse: _capture_response( span=span, attributes=response_attributes, - token_usage_histogram=self.token_usage_histogram, - operation_duration_histogram=self.duration_histogram, + token_usage_histogram=getattr(self, "token_usage_histogram", None), + operation_duration_histogram=getattr(self, "duration_histogram", None), duration=duration, ) if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages: @@ -1340,7 +1335,7 @@ async def get_embeddings( return await super_get_embeddings(values, options=options) # type: ignore[no-any-return] opts: dict[str, Any] = options or {} # type: ignore[assignment] - provider_name = str(self.otel_provider_name) + provider_name = str(getattr(self, "otel_provider_name", "unknown")) model_id = opts.get("model_id") or getattr(self, "model_id", None) or "unknown" service_url_func = getattr(self, "service_url", None) service_url = str(service_url_func() if callable(service_url_func) else "unknown") @@ -1463,12 +1458,7 @@ def run( if isinstance(run_result, ResponseStream): result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType] elif isinstance(run_result, Awaitable): - result_stream = cast( - ResponseStream[AgentResponseUpdate, AgentResponse[Any]], - ResponseStream.from_awaitable( # pyright: ignore[reportUnknownMemberType] - cast(Awaitable[ResponseStream[AgentResponseUpdate, AgentResponse[Any]]], run_result) - ), - ) + result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] else: raise RuntimeError("Streaming telemetry requires a ResponseStream result.") diff --git a/python/packages/core/pyproject.toml b/python/packages/core/pyproject.toml index c9708a084b..016650ba0a 100644 --- a/python/packages/core/pyproject.toml +++ b/python/packages/core/pyproject.toml @@ -131,7 +131,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework" -test = "pytest --cov=agent_framework --cov-report=term-missing:skip-covered -n auto --dist worksteal tests" +test = "pytest -m \"not integration\" --cov=agent_framework --cov-report=term-missing:skip-covered -n auto --dist worksteal tests" [tool.flit.module] name = "agent_framework" diff --git a/python/packages/core/tests/openai/test_openai_embedding_client.py b/python/packages/core/tests/openai/test_openai_embedding_client.py index c606b67e31..3ddb7538a6 100644 --- a/python/packages/core/tests/openai/test_openai_embedding_client.py +++ b/python/packages/core/tests/openai/test_openai_embedding_client.py @@ -212,7 +212,8 @@ def test_azure_construction_with_existing_client() -> None: assert client.client is mock_client -def test_azure_construction_missing_deployment_name_raises() -> None: +def test_azure_construction_missing_deployment_name_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME", raising=False) with pytest.raises(ValueError, match="deployment name is required"): AzureOpenAIEmbeddingClient( api_key="test-key", @@ -272,6 +273,7 @@ def test_azure_otel_provider_name() -> None: @skip_if_openai_integration_tests_disabled @pytest.mark.flaky +@pytest.mark.integration async def test_integration_openai_get_embeddings() -> None: """End-to-end test of OpenAI embedding generation.""" client = OpenAIEmbeddingClient(model_id="text-embedding-3-small") @@ -289,6 +291,7 @@ async def test_integration_openai_get_embeddings() -> None: @skip_if_openai_integration_tests_disabled @pytest.mark.flaky +@pytest.mark.integration async def test_integration_openai_get_embeddings_multiple() -> None: """Test embedding generation for multiple inputs.""" client = OpenAIEmbeddingClient(model_id="text-embedding-3-small") @@ -302,6 +305,7 @@ async def test_integration_openai_get_embeddings_multiple() -> None: @skip_if_openai_integration_tests_disabled @pytest.mark.flaky +@pytest.mark.integration async def test_integration_openai_get_embeddings_with_dimensions() -> None: """Test embedding generation with custom dimensions.""" client = OpenAIEmbeddingClient(model_id="text-embedding-3-small") @@ -315,6 +319,7 @@ async def test_integration_openai_get_embeddings_with_dimensions() -> None: @skip_if_azure_openai_integration_tests_disabled @pytest.mark.flaky +@pytest.mark.integration async def test_integration_azure_openai_get_embeddings() -> None: """End-to-end test of Azure OpenAI embedding generation.""" client = AzureOpenAIEmbeddingClient() @@ -332,6 +337,7 @@ async def test_integration_azure_openai_get_embeddings() -> None: @skip_if_azure_openai_integration_tests_disabled @pytest.mark.flaky +@pytest.mark.integration async def test_integration_azure_openai_get_embeddings_multiple() -> None: """Test Azure OpenAI embedding generation for multiple inputs.""" client = AzureOpenAIEmbeddingClient() @@ -345,6 +351,7 @@ async def test_integration_azure_openai_get_embeddings_multiple() -> None: @skip_if_azure_openai_integration_tests_disabled @pytest.mark.flaky +@pytest.mark.integration async def test_integration_azure_openai_get_embeddings_with_dimensions() -> None: """Test Azure OpenAI embedding generation with custom dimensions.""" client = AzureOpenAIEmbeddingClient() diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 788e96e61e..599e62d635 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Literal, overload import pytest + from agent_framework import ( AgentExecutor, AgentResponse, @@ -59,30 +60,19 @@ def run( stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> ( - Awaitable[AgentResponse[Any]] - | ResponseStream[AgentResponseUpdate, AgentResponse[Any]] - ): + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: self.call_count += 1 if stream: async def _stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate( - contents=[ - Content.from_text( - text=f"Response #{self.call_count}: {self.name}" - ) - ] + contents=[Content.from_text(text=f"Response #{self.call_count}: {self.name}")] ) return ResponseStream(_stream(), finalizer=AgentResponse.from_updates) async def _run() -> AgentResponse: - return AgentResponse( - messages=[ - Message("assistant", [f"Response #{self.call_count}: {self.name}"]) - ] - ) + return AgentResponse(messages=[Message("assistant", [f"Response #{self.call_count}: {self.name}"])]) return _run() @@ -120,10 +110,7 @@ def run( stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> ( - Awaitable[AgentResponse[Any]] - | ResponseStream[AgentResponseUpdate, AgentResponse[Any]] - ): + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: async def _stream() -> AsyncIterable[AgentResponseUpdate]: @@ -138,9 +125,9 @@ async def _mark_result_hook_called( self.result_hook_called = True return response - return ResponseStream( - _stream(), finalizer=AgentResponse.from_updates - ).with_result_hook(_mark_result_hook_called) + return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook( + _mark_result_hook_called + ) async def _run() -> AgentResponse: return AgentResponse(messages=[Message("assistant", ["hook test"])]) @@ -148,9 +135,7 @@ async def _run() -> AgentResponse: return _run() -async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> ( - None -): +async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None: """AgentExecutor should call get_final_response() so stream result hooks execute.""" agent = _StreamingHookAgent(id="hook_agent", name="HookAgent") executor = AgentExecutor(agent, id="hook_exec") @@ -217,9 +202,7 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: executor_state = executor_states[executor.id] # type: ignore[index] assert "cache" in executor_state, "Checkpoint should store executor cache state" - assert "agent_session" in executor_state, ( - "Checkpoint should store executor session state" - ) + assert "agent_session" in executor_state, "Checkpoint should store executor session state" # Verify session state structure session_state = executor_state["agent_session"] # type: ignore[index] @@ -240,15 +223,11 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: assert restored_agent.call_count == 0 # Build new workflow with the restored executor - wf_resume = SequentialBuilder( - participants=[restored_executor], checkpoint_storage=storage - ).build() + wf_resume = SequentialBuilder(participants=[restored_executor], checkpoint_storage=storage).build() # Resume from checkpoint resumed_output: AgentExecutorResponse | None = None - async for ev in wf_resume.run( - checkpoint_id=restore_checkpoint.checkpoint_id, stream=True - ): + async for ev in wf_resume.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True): if ev.type == "output": resumed_output = ev.data # type: ignore[assignment] if ev.type == "status" and ev.state in ( @@ -391,11 +370,7 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once( assert options is not None assert options["additional_function_arguments"]["custom"] == 1 - warned_keys = { - r.message.split("'")[1] - for r in caplog.records - if "reserved" in r.message.lower() - } + warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()} assert warned_keys == {"session", "stream", "messages"} diff --git a/python/packages/core/tests/workflow/test_agent_utils.py b/python/packages/core/tests/workflow/test_agent_utils.py index 07d1e64c08..633ba1072c 100644 --- a/python/packages/core/tests/workflow/test_agent_utils.py +++ b/python/packages/core/tests/workflow/test_agent_utils.py @@ -16,10 +16,31 @@ def __init__(self, agent_id: str, name: str | None = None) -> None: self.description: str | None = None @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - def run(self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def create_session(self, **kwargs: Any) -> AgentSession: """Creates a new conversation session for the agent.""" diff --git a/python/packages/core/tests/workflow/test_edge.py b/python/packages/core/tests/workflow/test_edge.py index ecaa341726..422d530631 100644 --- a/python/packages/core/tests/workflow/test_edge.py +++ b/python/packages/core/tests/workflow/test_edge.py @@ -4,9 +4,8 @@ from typing import Any from unittest.mock import patch -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - import pytest +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from agent_framework import ( Executor, diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index 77827c0634..77777e198b 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -3,6 +3,8 @@ from dataclasses import dataclass import pytest +from typing_extensions import Never + from agent_framework import ( Executor, Message, @@ -14,7 +16,6 @@ handler, response_handler, ) -from typing_extensions import Never # Module-level types for string forward reference tests @@ -155,11 +156,7 @@ async def handle(self, text: str, ctx: WorkflowContext) -> None: workflow = WorkflowBuilder(start_executor=upper).add_edge(upper, collector).build() events = await workflow.run("hello world") - invoked_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" - ] + invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] assert len(invoked_events) == 2 @@ -193,16 +190,10 @@ async def handle(self, text: str, ctx: WorkflowContext) -> None: sender = MultiSenderExecutor(id="sender") collector = CollectorExecutor(id="collector") - workflow = ( - WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build() - ) + workflow = WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build() events = await workflow.run("hello") - completed_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_completed" - ] + completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] # Sender should have completed with the sent messages sender_completed = next(e for e in completed_events if e.executor_id == "sender") @@ -210,9 +201,7 @@ async def handle(self, text: str, ctx: WorkflowContext) -> None: assert sender_completed.data == ["hello-first", "hello-second"] # Collector should have completed with no sent messages (None) - collector_completed_events = [ - e for e in completed_events if e.executor_id == "collector" - ] + collector_completed_events = [e for e in completed_events if e.executor_id == "collector"] # Collector is called twice (once per message from sender) assert len(collector_completed_events) == 2 for collector_completed in collector_completed_events: @@ -231,11 +220,7 @@ async def handle(self, text: str, ctx: WorkflowContext[Never, str]) -> None: workflow = WorkflowBuilder(start_executor=executor).build() events = await workflow.run("test") - completed_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_completed" - ] + completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] assert len(completed_events) == 1 assert completed_events[0].executor_id == "yielder" @@ -263,9 +248,7 @@ class Response: class ProcessorExecutor(Executor): @handler - async def handle( - self, request: Request, ctx: WorkflowContext[Response] - ) -> None: + async def handle(self, request: Request, ctx: WorkflowContext[Response]) -> None: response = Response(results=[request.query.upper()] * request.limit) await ctx.send_message(response) @@ -277,23 +260,13 @@ async def handle(self, response: Response, ctx: WorkflowContext) -> None: processor = ProcessorExecutor(id="processor") collector = CollectorExecutor(id="collector") - workflow = ( - WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build() - ) + workflow = WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build() input_request = Request(query="hello", limit=3) events = await workflow.run(input_request) - invoked_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" - ] - completed_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_completed" - ] + invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] + completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] # Check processor invoked event has the Request object processor_invoked = next(e for e in invoked_events if e.executor_id == "processor") @@ -302,9 +275,7 @@ async def handle(self, response: Response, ctx: WorkflowContext) -> None: assert processor_invoked.data.limit == 3 # Check processor completed event has the Response object - processor_completed = next( - e for e in completed_events if e.executor_id == "processor" - ) + processor_completed = next(e for e in completed_events if e.executor_id == "processor") assert processor_completed.data is not None assert len(processor_completed.data) == 1 assert isinstance(processor_completed.data[0], Response) @@ -390,9 +361,7 @@ async def handle(self, text: str, ctx: WorkflowContext[int, str]) -> None: # Test executor with union workflow output types class UnionWorkflowOutputExecutor(Executor): @handler - async def handle( - self, text: str, ctx: WorkflowContext[int, str | bool] - ) -> None: + async def handle(self, text: str, ctx: WorkflowContext[int, str | bool]) -> None: pass executor = UnionWorkflowOutputExecutor(id="union_workflow_output") @@ -403,15 +372,11 @@ async def handle( # Test executor with multiple handlers having different workflow output types class MultiHandlerWorkflowExecutor(Executor): @handler - async def handle_string( - self, text: str, ctx: WorkflowContext[int, str] - ) -> None: + async def handle_string(self, text: str, ctx: WorkflowContext[int, str]) -> None: pass @handler - async def handle_number( - self, num: int, ctx: WorkflowContext[bool, float] - ) -> None: + async def handle_number(self, num: int, ctx: WorkflowContext[bool, float]) -> None: pass executor = MultiHandlerWorkflowExecutor(id="multi_workflow") @@ -465,9 +430,7 @@ async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: pass @response_handler - async def handle_response( - self, original_request: str, response: bool, ctx: WorkflowContext[float] - ) -> None: + async def handle_response(self, original_request: str, response: bool, ctx: WorkflowContext[float]) -> None: pass executor = RequestResponseExecutor(id="request_response") @@ -574,9 +537,7 @@ async def test_executor_invoked_event_data_not_mutated_by_handler(): """Test that executor_invoked event (type='executor_invoked').data captures original input, not mutated input.""" @executor(id="Mutator") - async def mutator( - messages: list[Message], ctx: WorkflowContext[list[Message]] - ) -> None: + async def mutator(messages: list[Message], ctx: WorkflowContext[list[Message]]) -> None: # The handler mutates the input list by appending new messages original_len = len(messages) messages.append(Message(role="assistant", text="Added by executor")) @@ -591,11 +552,7 @@ async def mutator( events = await workflow.run(input_messages) # Find the invoked event for the Mutator executor - invoked_events = [ - e - for e in events - if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" - ] + invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] assert len(invoked_events) == 1 mutator_invoked = invoked_events[0] @@ -672,12 +629,8 @@ async def handle(self, message: Any, ctx: WorkflowContext) -> None: assert handler_func._handler_spec["output_types"] == [list] # pyright: ignore[reportFunctionMemberAccess] # Verify can_handle - assert exec_instance.can_handle( - WorkflowMessage(data={"key": "value"}, source_id="mock") - ) - assert not exec_instance.can_handle( - WorkflowMessage(data="string", source_id="mock") - ) + assert exec_instance.can_handle(WorkflowMessage(data={"key": "value"}, source_id="mock")) + assert not exec_instance.can_handle(WorkflowMessage(data="string", source_id="mock")) def test_handler_with_explicit_union_input_type(self): """Test that explicit union input_type is handled correctly.""" @@ -698,9 +651,7 @@ async def handle(self, message: Any, ctx: WorkflowContext) -> None: assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock")) assert exec_instance.can_handle(WorkflowMessage(data=42, source_id="mock")) # Cannot handle float - assert not exec_instance.can_handle( - WorkflowMessage(data=3.14, source_id="mock") - ) + assert not exec_instance.can_handle(WorkflowMessage(data=3.14, source_id="mock")) def test_handler_with_explicit_union_output_type(self): """Test that explicit union output is normalized to a list.""" @@ -776,9 +727,7 @@ async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: class OnlyWorkflowOutputExecutor(Executor): # pyright: ignore[reportUnusedClass] @handler(workflow_output=bool) - async def handle( - self, message: str, ctx: WorkflowContext[int, str] - ) -> None: + async def handle(self, message: str, ctx: WorkflowContext[int, str]) -> None: pass def test_handler_explicit_input_type_allows_no_message_annotation(self): @@ -803,9 +752,7 @@ async def handle_explicit(self, message, ctx: WorkflowContext) -> None: # type: pass @handler - async def handle_introspected( - self, message: float, ctx: WorkflowContext[bool] - ) -> None: + async def handle_introspected(self, message: float, ctx: WorkflowContext[bool]) -> None: pass exec_instance = MixedExecutor(id="mixed") @@ -831,9 +778,7 @@ async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[n # Should resolve the string to the actual type assert ForwardRefMessage in exec_instance._handlers # pyright: ignore[reportPrivateUsage] - assert exec_instance.can_handle( - WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock") - ) + assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock")) def test_handler_with_string_forward_reference_union(self): """Test that string forward references work with union types.""" @@ -846,12 +791,8 @@ async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[n exec_instance = StringUnionExecutor(id="string_union") # Should handle both types - assert exec_instance.can_handle( - WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock") - ) - assert exec_instance.can_handle( - WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock") - ) + assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock")) + assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock")) def test_handler_with_string_forward_reference_output_type(self): """Test that string forward references work for output_type.""" @@ -890,9 +831,7 @@ def test_handler_with_explicit_workflow_output_and_output(self): class PrecedenceExecutor(Executor): @handler(input=int, output=float, workflow_output=str) - async def handle( - self, message: int, ctx: WorkflowContext[int, bool] - ) -> None: + async def handle(self, message: int, ctx: WorkflowContext[int, bool]) -> None: pass exec_instance = PrecedenceExecutor(id="precedence") @@ -958,9 +897,7 @@ class StringUnionWorkflowOutputExecutor(Executor): async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass - exec_instance = StringUnionWorkflowOutputExecutor( - id="string_union_workflow_output" - ) + exec_instance = StringUnionWorkflowOutputExecutor(id="string_union_workflow_output") # Should resolve both types from string union assert ForwardRefTypeA in exec_instance.workflow_output_types @@ -971,14 +908,10 @@ def test_handler_fallback_to_introspection_for_workflow_output_type(self): class IntrospectedWorkflowOutputExecutor(Executor): @handler - async def handle( - self, message: str, ctx: WorkflowContext[int, bool] - ) -> None: + async def handle(self, message: str, ctx: WorkflowContext[int, bool]) -> None: pass - exec_instance = IntrospectedWorkflowOutputExecutor( - id="introspected_workflow_output" - ) + exec_instance = IntrospectedWorkflowOutputExecutor(id="introspected_workflow_output") # Should use introspected types from WorkflowContext[int, bool] assert int in exec_instance.output_types diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index b5a8bb9902..eacf70c6db 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -717,9 +717,23 @@ def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession return AgentSession() @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -813,9 +827,23 @@ def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession return AgentSession() @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 0850c6b060..d315f75f85 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -52,9 +52,23 @@ def __init__(self, name: str = "test_agent") -> None: self.captured_kwargs = [] @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -90,9 +104,23 @@ def __init__(self, name: str = "options_agent") -> None: self.captured_kwargs = [] @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -475,9 +503,23 @@ def __init__(self) -> None: self._asked = False @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -538,9 +580,23 @@ def __init__(self) -> None: self._asked = False @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, @@ -605,9 +661,23 @@ def __init__(self) -> None: self._asked = False @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... @overload - def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( self, diff --git a/python/packages/core/tests/workflow/test_workflow_states.py b/python/packages/core/tests/workflow/test_workflow_states.py index 34c7e8c93f..bf2e277d10 100644 --- a/python/packages/core/tests/workflow/test_workflow_states.py +++ b/python/packages/core/tests/workflow/test_workflow_states.py @@ -38,7 +38,9 @@ async def test_executor_failed_and_workflow_failed_events_streaming(): events.append(ev) # executor_failed event (type='executor_failed') should be emitted before workflow failed event - executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"] + executor_failed_events: list[WorkflowEvent[Any]] = [ + e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed" + ] assert executor_failed_events, "executor_failed event should be emitted when start executor fails" assert executor_failed_events[0].executor_id == "f" assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK @@ -96,7 +98,9 @@ async def test_executor_failed_event_from_second_executor_in_chain(): events.append(ev) # executor_failed event should be emitted for the failing executor - executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"] + executor_failed_events: list[WorkflowEvent[Any]] = [ + e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed" + ] assert executor_failed_events, "executor_failed event should be emitted when second executor fails" assert executor_failed_events[0].executor_id == "failing" assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index 87cdb7dac7..e7af9fde9a 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -25,6 +25,7 @@ from __future__ import annotations +import locale import logging import sys import uuid @@ -103,6 +104,8 @@ class DeclarativeStateData(TypedDict, total=False): # Types that PowerFx can serialize directly # Note: Decimal is included because PowerFx returns Decimal for numeric values _POWERFX_SAFE_TYPES = (str, int, float, bool, type(None), _Decimal) +_POWERFX_EVAL_LOCALE = "en-US" +_POWERFX_NUMERIC_LOCALE_CANDIDATES = ("en_US.UTF-8", "en_US", "C") def _make_powerfx_safe(value: Any) -> Any: @@ -384,23 +387,33 @@ def eval(self, expression: str) -> Any: f"Install dotnet and the powerfx package for full PowerFx support." ) - engine = Engine() symbols = self._to_powerfx_symbols() + # Use setlocale(category) query form so we can restore the exact prior value. + # getlocale() returns a normalized tuple and is not always a lossless + # round-trip for setlocale across platforms/locales. + original_numeric_locale = locale.setlocale(locale.LC_NUMERIC) try: - from System.Globalization import ( - CultureInfo, # pyright: ignore[reportMissingImports, reportUnknownVariableType] - ) + for locale_candidate in _POWERFX_NUMERIC_LOCALE_CANDIDATES: + try: + locale.setlocale(locale.LC_NUMERIC, locale_candidate) + break + except locale.Error: + continue + + engine = Engine() + try: + from System.Globalization import ( # pyright: ignore[reportMissingImports] + CultureInfo, # pyright: ignore[reportUnknownVariableType] + ) + except ImportError: + return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) - original_culture = CultureInfo.CurrentCulture - original_ui_culture = CultureInfo.CurrentUICulture - en_us_culture = CultureInfo("en-US") - CultureInfo.CurrentCulture = en_us_culture - CultureInfo.CurrentUICulture = en_us_culture + original_culture = cast(Any, CultureInfo.CurrentCulture) # pyright: ignore[reportUnknownMemberType] try: - return engine.eval(formula, symbols=symbols) + CultureInfo.CurrentCulture = CultureInfo(_POWERFX_EVAL_LOCALE) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) finally: - CultureInfo.CurrentCulture = original_culture - CultureInfo.CurrentUICulture = original_ui_culture + CultureInfo.CurrentCulture = original_culture # pyright: ignore[reportUnknownMemberType] except ValueError as e: error_msg = str(e) # Handle undefined variable errors gracefully by returning None @@ -409,6 +422,8 @@ def eval(self, expression: str) -> Any: logger.debug(f"PowerFx: undefined variable in expression '{formula}', returning None") return None raise + finally: + locale.setlocale(locale.LC_NUMERIC, original_numeric_locale) def _eval_custom_function(self, formula: str) -> Any | None: """Handle custom functions not supported by the Python PowerFx library. diff --git a/python/packages/declarative/pyproject.toml b/python/packages/declarative/pyproject.toml index d2462353e7..2534339ad7 100644 --- a/python/packages/declarative/pyproject.toml +++ b/python/packages/declarative/pyproject.toml @@ -94,7 +94,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_declarative" -test = "pytest --cov=agent_framework_declarative --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_declarative --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/declarative/tests/test_powerfx_yaml_compatibility.py b/python/packages/declarative/tests/test_powerfx_yaml_compatibility.py index 8ea3c3af57..308982c632 100644 --- a/python/packages/declarative/tests/test_powerfx_yaml_compatibility.py +++ b/python/packages/declarative/tests/test_powerfx_yaml_compatibility.py @@ -16,6 +16,7 @@ - String interpolation: {Variable.Path} """ +import locale from unittest.mock import MagicMock import pytest @@ -494,29 +495,38 @@ async def test_undefined_nested_variable_returns_none(self, mock_state): assert result is None async def test_undefined_variable_returns_none_with_non_english_ui_culture(self, mock_state): - """Test that undefined variables return None even when CurrentUICulture is non-English. + """Test that undefined variables return None even when locale is non-English. - Regression test for #4321: on non-English systems, CurrentUICulture causes + Regression test for #4321: on non-English systems, locale settings can cause PowerFx to emit localized error messages that don't match the English string guards ("isn't recognized", "Name isn't valid"), crashing the workflow. - The fix sets CurrentUICulture to en-US alongside CurrentCulture before eval. + The fix evaluates with locale='en-US' and restores the ambient LC_NUMERIC. """ - from System.Globalization import CultureInfo - state = DeclarativeWorkflowState(mock_state) state.initialize() - # Simulate a non-English UI culture (e.g. Italian) - original_ui_culture = CultureInfo.CurrentUICulture - CultureInfo.CurrentUICulture = CultureInfo("it-IT") + # Simulate a non-English locale (e.g. Italian) + original_numeric_locale = locale.setlocale(locale.LC_NUMERIC) + test_numeric_locale: str | None = None try: + for locale_candidate in ("it_IT.UTF-8", "it_IT", "fr_FR.UTF-8", "fr_FR", "de_DE.UTF-8", "de_DE"): + try: + locale.setlocale(locale.LC_NUMERIC, locale_candidate) + test_numeric_locale = locale.setlocale(locale.LC_NUMERIC) + break + except locale.Error: + continue + + if test_numeric_locale is None: + pytest.skip("No non-English LC_NUMERIC locale available on this system") + # Should return None, not raise ValueError with Italian error text result = state.eval("=Local.StatusConversationId") assert result is None - # Verify the production code restored CurrentUICulture after eval - assert str(CultureInfo.CurrentUICulture) == str(CultureInfo("it-IT")) + # Verify the production code restored LC_NUMERIC after eval + assert locale.setlocale(locale.LC_NUMERIC) == test_numeric_locale finally: - CultureInfo.CurrentUICulture = original_ui_culture + locale.setlocale(locale.LC_NUMERIC, original_numeric_locale) class TestStringInterpolation: diff --git a/python/packages/devui/agent_framework_devui/_utils.py b/python/packages/devui/agent_framework_devui/_utils.py index 53e4b8416c..3eadc35926 100644 --- a/python/packages/devui/agent_framework_devui/_utils.py +++ b/python/packages/devui/agent_framework_devui/_utils.py @@ -123,9 +123,8 @@ def extract_executor_message_types(executor: Any) -> list[Any]: if not message_types and hasattr(executor, "_handlers"): try: handlers = executor._handlers - handlers_dict = _string_key_dict(handlers) - if handlers_dict is not None: - message_types = list(handlers_dict.keys()) + if isinstance(handlers, dict): + message_types = list(handlers.keys()) # type: ignore[arg-type] # pyright: ignore[reportUnknownArgumentType] except Exception as exc: # pragma: no cover - defensive logging path logger.debug(f"Failed to read executor handlers: {exc}") diff --git a/python/packages/devui/pyproject.toml b/python/packages/devui/pyproject.toml index 6f41307dde..a56cf1ab4f 100644 --- a/python/packages/devui/pyproject.toml +++ b/python/packages/devui/pyproject.toml @@ -94,7 +94,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_devui" -test = "pytest --cov=agent_framework_devui --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_devui --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/durabletask/pyproject.toml b/python/packages/durabletask/pyproject.toml index 923bfc9b1d..56493f3126 100644 --- a/python/packages/durabletask/pyproject.toml +++ b/python/packages/durabletask/pyproject.toml @@ -99,8 +99,8 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_durabletask" -test = "pytest --cov=agent_framework_durabletask --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_durabletask --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] -build-backend = "flit_core.buildapi" \ No newline at end of file +build-backend = "flit_core.buildapi" diff --git a/python/packages/foundry_local/pyproject.toml b/python/packages/foundry_local/pyproject.toml index a04a21d82b..97dd99f1ca 100644 --- a/python/packages/foundry_local/pyproject.toml +++ b/python/packages/foundry_local/pyproject.toml @@ -86,7 +86,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_foundry_local" -test = "pytest --cov=agent_framework_foundry_local --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_foundry_local --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/github_copilot/pyproject.toml b/python/packages/github_copilot/pyproject.toml index 940fbf5fa7..47069e34fa 100644 --- a/python/packages/github_copilot/pyproject.toml +++ b/python/packages/github_copilot/pyproject.toml @@ -87,7 +87,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_github_copilot" -test = "pytest --cov=agent_framework_github_copilot --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_github_copilot --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/lab/pyproject.toml b/python/packages/lab/pyproject.toml index d64c0b0593..17650293ac 100644 --- a/python/packages/lab/pyproject.toml +++ b/python/packages/lab/pyproject.toml @@ -152,10 +152,10 @@ mypy-gaia = "mypy --config-file $POE_ROOT/pyproject.toml gaia/agent_framework_la mypy-lightning = "mypy --config-file $POE_ROOT/pyproject.toml lightning/agent_framework_lab_lightning" mypy-tau2 = "mypy --config-file $POE_ROOT/pyproject.toml tau2/agent_framework_lab_tau2" mypy = ["mypy-gaia", "mypy-lightning", "mypy-tau2"] -test = "pytest --cov-report=term-missing:skip-covered --junitxml=test-results.xml" -test-gaia = "pytest gaia/tests --cov=agent_framework_lab_gaia --cov-report=term-missing:skip-covered" -test-lightning = "pytest lightning/tests --cov=agent_framework_lab_lightning --cov-report=term-missing:skip-covered" -test-tau2 = "pytest tau2/tests --cov=agent_framework_lab_tau2 --cov-report=term-missing:skip-covered" +test = "pytest -m \"not integration\" --cov-report=term-missing:skip-covered --junitxml=test-results.xml" +test-gaia = "pytest -m \"not integration\" gaia/tests --cov=agent_framework_lab_gaia --cov-report=term-missing:skip-covered" +test-lightning = "pytest -m \"not integration\" lightning/tests --cov=agent_framework_lab_lightning --cov-report=term-missing:skip-covered" +test-tau2 = "pytest -m \"not integration\" tau2/tests --cov=agent_framework_lab_tau2 --cov-report=term-missing:skip-covered" build = "echo 'Skipping build'" publish = "echo 'Skipping publish'" diff --git a/python/packages/mem0/pyproject.toml b/python/packages/mem0/pyproject.toml index 0d04fe802f..506c4d75b1 100644 --- a/python/packages/mem0/pyproject.toml +++ b/python/packages/mem0/pyproject.toml @@ -87,7 +87,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_mem0" -test = "pytest --cov=agent_framework_mem0 --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_mem0 --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/ollama/pyproject.toml b/python/packages/ollama/pyproject.toml index 57cbcd3b96..dd9ecaf46b 100644 --- a/python/packages/ollama/pyproject.toml +++ b/python/packages/ollama/pyproject.toml @@ -90,7 +90,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_ollama" -test = "pytest --cov=agent_framework_ollama --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_ollama --cov-report=term-missing:skip-covered tests" [tool.uv.build-backend] module-name = "agent_framework_ollama" diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index ddeedfea36..66b8309f9e 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -34,6 +34,7 @@ import logging import sys from collections.abc import Awaitable, Callable, Mapping, Sequence +from copy import deepcopy from dataclasses import dataclass from typing import Any @@ -349,43 +350,21 @@ def _persist_missing_approved_function_results( def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]: """Produce a deep copy of the Agent while preserving runtime configuration.""" options = agent.default_options - middleware = list(agent.middleware or []) # Reconstruct the original tools list by combining regular tools with MCP tools. # Agent.__init__ separates MCP tools during initialization, # so we need to recombine them here to pass the complete tools list to the constructor. # This makes sure MCP tools are preserved when cloning agents for handoff workflows. - tools_from_options = options.get("tools") - all_tools = list(tools_from_options) if tools_from_options else [] + tools_from_options = options.pop("tools", []) if agent.mcp_tools: - all_tools.extend(agent.mcp_tools) - - logit_bias = options.get("logit_bias") - metadata = options.get("metadata") + tools_from_options.extend(agent.mcp_tools) + # this ensures all options (including custom ones) are kept + cloned_options = deepcopy(options) # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. - cloned_options: dict[str, Any] = { - "allow_multiple_tool_calls": False, - # Handoff workflows already manage full conversation context explicitly - # across executors. Keep provider-side conversation storage disabled to - # avoid stale tool-call state (Responses API previous_response chains). - "store": False, - } - cloned_options["frequency_penalty"] = options.get("frequency_penalty") - cloned_options["instructions"] = options.get("instructions") - cloned_options["logit_bias"] = dict(logit_bias) if logit_bias else None - cloned_options["max_tokens"] = options.get("max_tokens") - cloned_options["metadata"] = dict(metadata) if metadata else None - cloned_options["model_id"] = options.get("model_id") - cloned_options["presence_penalty"] = options.get("presence_penalty") - cloned_options["response_format"] = options.get("response_format") - cloned_options["seed"] = options.get("seed") - cloned_options["stop"] = options.get("stop") - cloned_options["temperature"] = options.get("temperature") - cloned_options["tool_choice"] = options.get("tool_choice") - cloned_options["tools"] = all_tools if all_tools else None - cloned_options["top_p"] = options.get("top_p") - cloned_options["user"] = options.get("user") + cloned_options["allow_multiple_tool_calls"] = False + cloned_options["store"] = False + cloned_options["tools"] = tools_from_options return Agent( client=agent.client, @@ -393,8 +372,8 @@ def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]: name=agent.name, description=agent.description, context_providers=agent.context_providers, - middleware=middleware, - default_options=cloned_options, # type: ignore[arg-type] + middleware=agent.agent_middleware, + default_options=cloned_options, # type: ignore[assignment] ) def _apply_auto_tools(self, agent: Agent, targets: Sequence[HandoffConfiguration]) -> None: diff --git a/python/packages/orchestrations/pyproject.toml b/python/packages/orchestrations/pyproject.toml index 52ac5424fb..e15e02f3e3 100644 --- a/python/packages/orchestrations/pyproject.toml +++ b/python/packages/orchestrations/pyproject.toml @@ -85,7 +85,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_orchestrations" -test = "pytest --cov=agent_framework_orchestrations --cov-report=term-missing:skip-covered -n auto --dist worksteal tests" +test = "pytest -m \"not integration\" --cov=agent_framework_orchestrations --cov-report=term-missing:skip-covered -n auto --dist worksteal tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/purview/pyproject.toml b/python/packages/purview/pyproject.toml index cb7819a36e..f30b749435 100644 --- a/python/packages/purview/pyproject.toml +++ b/python/packages/purview/pyproject.toml @@ -86,7 +86,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_purview" -test = "pytest --cov=agent_framework_purview --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_purview --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.9,<4.0"] diff --git a/python/packages/redis/pyproject.toml b/python/packages/redis/pyproject.toml index c42b050115..21aaf47865 100644 --- a/python/packages/redis/pyproject.toml +++ b/python/packages/redis/pyproject.toml @@ -89,7 +89,7 @@ include = "../../shared_tasks.toml" [tool.poe.tasks] mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_redis" -test = "pytest --cov=agent_framework_redis --cov-report=term-missing:skip-covered tests" +test = "pytest -m \"not integration\" --cov=agent_framework_redis --cov-report=term-missing:skip-covered tests" [build-system] requires = ["flit-core >= 3.11,<4.0"] From ba18aef138a7b01817e732e7e093120576a217ec Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 4 Mar 2026 16:54:42 +0100 Subject: [PATCH 07/15] lots of small fixes --- .../agent_framework_anthropic/_chat_client.py | 41 ++++--- .../_context_provider.py | 22 +--- .../agent_framework_azure_ai/_chat_client.py | 101 ++++++++---------- .../agent_framework_azure_ai/_client.py | 72 ++++--------- .../_embedding_client.py | 2 +- .../agent_framework_azure_ai/_shared.py | 36 +++---- .../_history_provider.py | 19 ++-- .../agent_framework_azurefunctions/_app.py | 56 ++++------ .../_serialization.py | 20 ++-- .../_workflow.py | 13 ++- .../agent_framework_bedrock/__init__.py | 4 +- .../agent_framework_bedrock/_chat_client.py | 76 +++---------- .../_embedding_client.py | 84 ++++----------- .../claude/agent_framework_claude/_agent.py | 28 ++--- .../packages/core/agent_framework/_agents.py | 25 ++--- .../packages/core/agent_framework/_clients.py | 2 +- .../packages/core/agent_framework/_tools.py | 34 +++--- .../core/agent_framework/observability.py | 19 ++-- .../openai/_embedding_client.py | 12 +-- .../_embedding_client.py | 23 ++-- 20 files changed, 240 insertions(+), 449 deletions(-) diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index 14544c071b..5cda4991c8 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -4,8 +4,8 @@ import logging import sys -from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence -from typing import Any, ClassVar, Final, Generic, Literal, TypedDict, cast +from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence +from typing import Any, ClassVar, Final, Generic, Literal, TypedDict from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, @@ -788,28 +788,23 @@ def _prepare_tools_for_anthropic(self, options: Mapping[str, Any]) -> dict[str, "description": tool.description, "input_schema": tool.parameters(), }) - elif isinstance(tool, MutableMapping): - tool_data = cast(MutableMapping[str, Any], tool) - if tool_data.get("type") == "mcp": - # MCP servers must be routed to separate mcp_servers parameter - server_def: dict[str, Any] = { - "type": "url", - "name": tool_data.get("server_label", ""), - "url": tool_data.get("server_url", ""), + elif isinstance(tool, Mapping) and tool.get("type") == "mcp": # type: ignore[reportUnknownMemberType] + # MCP servers must be routed to separate mcp_servers parameter + server_def: dict[str, Any] = { + "type": "url", + "name": tool.get("server_label", ""), # type: ignore[reportUnknownMemberType] + "url": tool.get("server_url", ""), # type: ignore[reportUnknownMemberType] + } + allowed_tools = tool.get("allowed_tools") # type: ignore[reportUnknownMemberType] + if isinstance(allowed_tools, Sequence) and not isinstance(allowed_tools, str): + server_def["tool_configuration"] = { + "allowed_tools": [str(item) for item in allowed_tools] # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] } - allowed_tools = tool_data.get("allowed_tools") - if isinstance(allowed_tools, Sequence) and not isinstance(allowed_tools, str): - server_def["tool_configuration"] = { - "allowed_tools": [str(item) for item in allowed_tools] # pyright: ignore[reportUnknownArgumentType,reportUnknownVariableType] - } - headers = tool_data.get("headers") - authorization = headers.get("authorization") if isinstance(headers, Mapping) else None # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if isinstance(authorization, str): - server_def["authorization_token"] = authorization - mcp_server_list.append(server_def) - else: - # Pass through all other tools (dicts, SDK types) unchanged - tool_list.append(tool) + headers = tool.get("headers") # type: ignore[reportUnknownMemberType] + authorization = headers.get("authorization") if isinstance(headers, Mapping) else None # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + if isinstance(authorization, str): + server_def["authorization_token"] = authorization + mcp_server_list.append(server_def) else: # Pass through all other tools (dicts, SDK types) unchanged tool_list.append(tool) diff --git a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py index 3e6f48a572..b2eb41e03f 100644 --- a/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py +++ b/python/packages/azure-ai-search/agent_framework_azure_ai_search/_context_provider.py @@ -11,7 +11,7 @@ import logging import sys from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Annotation, Content, Message, SupportsGetEmbeddings from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext @@ -456,10 +456,10 @@ async def _semantic_search(self, query: str) -> list[Message]: elif self.embedding_function: if isinstance(self.embedding_function, SupportsGetEmbeddings): embeddings = await self.embedding_function.get_embeddings([query]) # type: ignore[reportUnknownVariableType] - query_vector = self._normalize_query_vector(embeddings[0].vector) # type: ignore[reportUnknownVariableType] + query_vector = embeddings[0].vector # type: ignore[reportUnknownVariableType] else: - query_vector = self._normalize_query_vector(await self.embedding_function(query)) - vector_queries = [VectorizedQuery(vector=query_vector, k=vector_k, fields=self.vector_field_name)] + query_vector = await self.embedding_function(query) # type: ignore[reportUnknownVariableType] + vector_queries = [VectorizedQuery(vector=query_vector, k=vector_k, fields=self.vector_field_name)] # type: ignore[reportUnknownArgumentType] search_params: dict[str, Any] = {"search_text": query, "top": self.top_k} if vector_queries: @@ -603,20 +603,6 @@ async def _agentic_search(self, messages: list[Message]) -> list[Message]: return self._parse_messages_from_kb_response(retrieval_result) - @staticmethod - def _normalize_query_vector(vector: object) -> list[float]: - """Normalize query vector values to floats for Azure Search vector query.""" - if not isinstance(vector, list): - raise TypeError("embedding_function must return list[float]") - - vector_values = cast(list[object], vector) - normalized: list[float] = [] - for value in vector_values: - if not isinstance(value, int | float): - raise TypeError("embedding_function must return list[float]") - normalized.append(float(value)) - return normalized - @staticmethod def _prepare_messages_for_kb_search(messages: list[Message]) -> list[KnowledgeBaseMessage]: """Convert framework Messages to KnowledgeBaseMessages for agentic retrieval. diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 331b645ab6..4e9f41e9d6 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -77,9 +77,9 @@ RunStatus, RunStep, RunStepDeltaChunk, - RunStepDeltaCodeInterpreterDetailItemObject, RunStepDeltaCodeInterpreterImageOutput, RunStepDeltaCodeInterpreterLogOutput, + RunStepDeltaToolCall, SubmitToolApprovalAction, SubmitToolOutputsAction, ThreadMessageOptions, @@ -997,43 +997,34 @@ async def _process_stream( role="assistant", ) case RunStepDeltaChunk(): # type: ignore - step_details: Any = event_data.delta.step_details - if step_details is not None and getattr(step_details, "type", None) == "tool_calls": - tool_calls = getattr(step_details, "tool_calls", None) - if isinstance(tool_calls, list): - for tool_call in cast(list[object], tool_calls): - tool_type = getattr(tool_call, "type", None) - code_interpreter = getattr(tool_call, "code_interpreter", None) - if tool_type == "code_interpreter" and isinstance( - code_interpreter, - RunStepDeltaCodeInterpreterDetailItemObject, - ): - code_contents: list[Content] = [] - if code_interpreter.input is not None: - logger.debug(f"Code Interpreter Input: {code_interpreter.input}") - if code_interpreter.outputs is not None: - for output in code_interpreter.outputs: - if ( - isinstance(output, RunStepDeltaCodeInterpreterLogOutput) - and output.logs - ): - code_contents.append(Content.from_text(text=output.logs)) - if ( - isinstance(output, RunStepDeltaCodeInterpreterImageOutput) - and output.image is not None - and output.image.file_id is not None - ): - code_contents.append( - Content.from_hosted_file(file_id=output.image.file_id) - ) - yield ChatResponseUpdate( - role="assistant", - contents=code_contents, - conversation_id=thread_id, - message_id=response_id, - raw_representation=code_interpreter, - response_id=response_id, - ) + step_details = event_data.delta.step_details + if step_details is not None and step_details.type == "tool_calls": + tool_calls = cast(list[RunStepDeltaToolCall], step_details.tool_calls) # type: ignore + for tool_call in tool_calls: + if tool_call.type == "code_interpreter" and tool_call.code_interpreter is not None: # type: ignore[attr-defined, reportUnknownMemberType] + code_contents: list[Content] = [] + if tool_call.code_interpreter.input is not None: # type: ignore[attr-defined, reportUnknownMemberType] + logger.debug(f"Code Interpreter Input: {tool_call.code_interpreter.input}") # type: ignore[attr-defined, reportUnknownMemberType] + if tool_call.code_interpreter.outputs is not None: # type: ignore[attr-defined, reportUnknownMemberType] + for output in tool_call.code_interpreter.outputs: # type: ignore[attr-defined, reportUnknownMemberType] + if isinstance(output, RunStepDeltaCodeInterpreterLogOutput) and output.logs: + code_contents.append(Content.from_text(text=output.logs)) + if ( + isinstance(output, RunStepDeltaCodeInterpreterImageOutput) + and output.image is not None + and output.image.file_id is not None + ): + code_contents.append( + Content.from_hosted_file(file_id=output.image.file_id) + ) + yield ChatResponseUpdate( + role="assistant", + contents=code_contents, + conversation_id=thread_id, + message_id=response_id, + raw_representation=tool_call.code_interpreter, # type: ignore[attr-defined, reportUnknownMemberType] + response_id=response_id, + ) case _: # ThreadMessage or string # possible event_types for ThreadMessage: # AgentStreamEvent.THREAD_MESSAGE_CREATED @@ -1060,7 +1051,7 @@ def _capture_azure_search_tool_calls( ) -> None: """Capture Azure AI Search tool call data from completed steps.""" try: - step_details: Any = getattr(step_data, "step_details", None) + step_details = getattr(step_data, "step_details", None) tool_calls = getattr(step_details, "tool_calls", None) if step_details is not None else None if isinstance(tool_calls, list): for tool_call in cast(list[object], tool_calls): @@ -1224,19 +1215,17 @@ def _prepare_tool_choice_mode( tool_choice = options.get("tool_choice") if tool_choice is None: return None - if tool_choice == "none": - return AgentsToolChoiceOptionMode.NONE - if tool_choice == "auto": - return AgentsToolChoiceOptionMode.AUTO - if isinstance(tool_choice, Mapping): - tool_choice_mapping = cast(Mapping[str, Any], tool_choice) - if tool_choice_mapping.get("mode") == "required": - req_fn = tool_choice_mapping.get("required_function_name") - if req_fn: - return AgentsNamedToolChoice( - type=AgentsNamedToolChoiceType.FUNCTION, - function=FunctionName(name=str(req_fn)), - ) + if tool_choice in {"none", "auto"}: + return AgentsToolChoiceOptionMode(tool_choice) + if ( + isinstance(tool_choice, Mapping) + and tool_choice.get("mode") == "required" # type: ignore[attr-unknown] + and (req_fn := tool_choice.get("required_function_name")) # type: ignore[attr-unknown] + ): + return AgentsNamedToolChoice( + type=AgentsNamedToolChoiceType.FUNCTION, + function=FunctionName(name=req_fn), # type: ignore[call-arg] + ) return None async def _prepare_tool_definitions_and_resources( @@ -1374,11 +1363,9 @@ async def _prepare_tools_for_azure_ai( tool_definitions.extend(tool.definitions) # Handle tool resources (MCP resources handled separately by _prepare_mcp_resources) resources = getattr(tool, "resources", None) - if run_options is not None and isinstance(resources, Mapping) and resources and "mcp" not in resources: - if "tool_resources" not in run_options: - run_options["tool_resources"] = {} - tool_resources = cast(MutableMapping[str, Any], run_options["tool_resources"]) - tool_resources.update(dict(cast(Mapping[str, Any], resources))) + if run_options is not None and resources and isinstance(resources, Mapping) and "mcp" not in resources: + run_options.setdefault("tool_resources", {}) + run_options["tool_resources"].update(tool.resources) else: # Pass through ToolDefinition, dict, and other types unchanged tool_definitions.append(tool) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 4b31c99b61..df0340a8f1 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -2,12 +2,11 @@ from __future__ import annotations -import importlib import json import logging import re import sys -from collections.abc import Awaitable, Callable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence from contextlib import suppress from typing import Any, ClassVar, Generic, Literal, TypedDict, TypeVar, cast @@ -27,8 +26,6 @@ Message, MiddlewareTypes, ResponseStream, - Role, - RoleLiteral, TextSpanRegion, ) from agent_framework._settings import load_settings @@ -73,15 +70,6 @@ logger = logging.getLogger("agent_framework.azure") -AzureMonitorConfigurator = Callable[..., Any] - - -def _normalize_chat_role(role: Role | str | None) -> RoleLiteral | Role | None: - if role in {"system", "user", "assistant", "tool"}: - return cast(RoleLiteral, role) - return None - - class AzureAIProjectAgentOptions(OpenAIResponsesOptions, total=False): """Azure AI Project Agent options.""" @@ -316,18 +304,13 @@ async def configure_azure_monitor( # Import Azure Monitor with proper error handling try: - monitor_module = importlib.import_module("azure.monitor.opentelemetry") + from azure.monitor.opentelemetry import configure_azure_monitor # type: ignore[import] except ImportError as exc: raise ImportError( "azure-monitor-opentelemetry is required for Azure Monitor integration. " "Install it with: pip install azure-monitor-opentelemetry" ) from exc - configure_azure_monitor_attr = getattr(monitor_module, "configure_azure_monitor", None) - if not callable(configure_azure_monitor_attr): - raise ImportError("azure-monitor-opentelemetry does not expose configure_azure_monitor") - configure_azure_monitor: AzureMonitorConfigurator = configure_azure_monitor_attr - from agent_framework.observability import create_metric_views, create_resource, enable_instrumentation # Create resource if not provided in kwargs @@ -461,30 +444,25 @@ def _get_tool_name(self, tool: Any) -> str: return tool.name if isinstance(tool, Mapping): - tool_mapping = cast(Mapping[str, Any], tool) - tool_type = tool_mapping.get("type") + tool_type = tool.get("type") # type: ignore[reportUnknownMemberType] if tool_type == "function": - function_data = tool_mapping.get("function") - if isinstance(function_data, Mapping): - function_mapping = cast(Mapping[str, Any], function_data) - if function_name := function_mapping.get("name"): - return str(function_name) - if tool_name := tool_mapping.get("name"): - return str(tool_name) - if tool_name := tool_mapping.get("name"): - return str(tool_name) - if server_label := tool_mapping.get("server_label"): + function_data = tool.get("function") # type: ignore[reportUnknownMemberType] + if isinstance(function_data, Mapping) and (function_name := function_data.get("name")): # type: ignore[assignment] + return function_name # type: ignore[no-any-return] + if tool_name := tool.get("name"): # type: ignore[reportUnknownMemberType] + return tool_name # type: ignore[no-any-return] + if server_label := tool.get("server_label"): # type: ignore[reportUnknownMemberType] return f"mcp:{server_label}" if tool_type: - return str(tool_type) - return type(cast(Any, tool)).__name__ + return tool_type # type: ignore[no-any-return] + raise ValueError("Dict based tool definitions must include a 'name' property for runtime comparison.") if name_value := getattr(tool, "name", None): - return str(name_value) + return name_value # type: ignore[no-any-return] if server_label_value := getattr(tool, "server_label", None): return f"mcp:{server_label_value}" if tool_type_value := getattr(tool, "type", None): - return str(tool_type_value) + return tool_type_value # type: ignore[no-any-return] return type(tool).__name__ def _get_structured_output_signature(self, chat_options: Mapping[str, Any] | None) -> str | None: @@ -602,15 +580,14 @@ def _transform_input_for_azure_ai(self, input_items: list[dict[str, Any]]) -> li # Add 'annotations' only to output_text content items (assistant messages) # User messages (input_text) do NOT support annotations in Azure AI - if "content" in new_item and isinstance(new_item["content"], list): + if (content := new_item.get("content")) and isinstance(content, list): new_content: list[Any] = [] - for content_item in cast(list[object], new_item["content"]): - if isinstance(content_item, Mapping): - new_content_item: dict[str, Any] = dict(cast(Mapping[str, Any], content_item)) + for content_item in content: # type: ignore[list-item] + if isinstance(content_item, MutableMapping): # Only add annotations to output_text (assistant content) - if new_content_item.get("type") == "output_text" and "annotations" not in new_content_item: - new_content_item["annotations"] = [] - new_content.append(new_content_item) + if content_item.get("type") == "output_text" and "annotations" not in content_item: # type: ignore[reportUnknownMemberType] + content_item["annotations"] = [] + new_content.append(content_item) else: new_content.append(content_item) new_item["content"] = new_content @@ -748,15 +725,10 @@ def _extract_azure_search_urls(self, output_items: Any) -> list[str]: # Streaming "added" events send output as an empty list; skip. continue if output is not None: - urls: Any - if isinstance(output, Mapping): - output_mapping = cast(Mapping[str, Any], output) - urls = output_mapping.get("get_urls") - else: - urls = getattr(output, "get_urls", None) + urls = output.get("get_urls") if isinstance(output, Mapping) else getattr(output, "get_urls", None) # type: ignore if isinstance(urls, list): string_urls: list[str] = [] - for url_item in cast(list[object], urls): + for url_item in urls: # type: ignore[list-item] if isinstance(url_item, str): string_urls.append(url_item) get_urls.extend(string_urls) @@ -914,7 +886,7 @@ def _enrich_update(update: ChatResponseUpdate) -> ChatResponseUpdate: contents=contents_list, conversation_id=update.conversation_id, response_id=update.response_id, - role=_normalize_chat_role(update.role), + role=update.role, # type: ignore[union-attr] model_id=update.model_id, continuation_token=update.continuation_token, additional_properties=update.additional_properties, diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py index 7e6cdfc8b7..a243f77a38 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py @@ -186,7 +186,7 @@ async def get_embeddings( values: Sequence[Content | str], *, options: AzureAIInferenceEmbeddingOptionsT | None = None, - ) -> GeneratedEmbeddings[list[float]]: + ) -> GeneratedEmbeddings[list[float], AzureAIInferenceEmbeddingOptionsT]: """Generate embeddings for text and/or image inputs. Text inputs (``str`` or ``Content`` with ``type="text"``) are sent to the diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py index 4630280bb5..b9f8ed801f 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py @@ -95,18 +95,18 @@ def _extract_project_connection_id(additional_properties: dict[str, Any] | None) return None # Check for direct project_connection_id (programmatic usage) - project_connection_id = additional_properties.get("project_connection_id") - if isinstance(project_connection_id, str): - return project_connection_id + + if (proj_conn_id := additional_properties.get("project_connection_id")) and isinstance(proj_conn_id, str): + return proj_conn_id # type: ignore[no-any-return] # Check for connection.name structure (declarative/YAML usage) - if "connection" in additional_properties: - conn = additional_properties["connection"] - if isinstance(conn, Mapping): - conn_mapping = cast(Mapping[str, Any], conn) - name = conn_mapping.get("name") - if isinstance(name, str): - return name + if ( + (connection := additional_properties.get("connection")) + and isinstance(connection, Mapping) + and (name := connection.get("name")) # type: ignore + and isinstance(name, str) + ): + return name # type: ignore[no-any-return] return None @@ -190,11 +190,9 @@ def to_azure_ai_agent_tools( and tool.resources and "mcp" not in tool.resources ): - if "tool_resources" not in run_options: - run_options["tool_resources"] = {} - tool_resources = cast(MutableMapping[str, Any], run_options["tool_resources"]) + run_options.setdefault("tool_resources", {}) if isinstance(tool.resources, Mapping): - tool_resources.update(dict(cast(Mapping[str, Any], tool.resources))) + run_options["tool_resources"].update(tool.resources) elif isinstance(tool, (dict, MutableMapping)): # Handle dict-based tools - pass through directly tool_dict = tool if isinstance(tool, dict) else dict(tool) @@ -456,13 +454,11 @@ def _prepare_mcp_tool_dict_for_azure_ai(tool_dict: dict[str, Any]) -> MCPTool: mcp["server_description"] = description # Check for project_connection_id - additional_properties = tool_dict.get("additional_properties") - extracted_project_connection_id = ( - _extract_project_connection_id(dict(cast(Mapping[str, Any], additional_properties))) - if isinstance(additional_properties, Mapping) + if project_connection_id := ( + tool_dict.get("project_connection_id") or _extract_project_connection_id(tool_dict.get("additional_properties")) + if isinstance(tool_dict.get("additional_properties"), Mapping) else None - ) - if project_connection_id := tool_dict.get("project_connection_id") or extracted_project_connection_id: + ): mcp["project_connection_id"] = project_connection_id elif headers := tool_dict.get("headers"): mcp["headers"] = headers diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py index 84c1efac52..e26f3e061a 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py @@ -8,7 +8,7 @@ import time import uuid from collections.abc import Sequence -from typing import Any, ClassVar, TypedDict, TypeGuard, cast +from typing import Any, ClassVar, TypedDict from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message from agent_framework._sessions import BaseHistoryProvider @@ -20,14 +20,6 @@ logger = logging.getLogger(__name__) -def _is_str_key_dict(value: object) -> TypeGuard[dict[str, Any]]: - if not isinstance(value, dict): - return False - - candidate_dict = cast(dict[object, Any], value) - return all(isinstance(key_obj, str) for key_obj in candidate_dict) - - class AzureCosmosHistorySettings(TypedDict, total=False): """Settings for CosmosHistoryProvider resolved from args and environment.""" @@ -152,9 +144,12 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess messages: list[Message] = [] async for item in items: - message_payload = item.get("message") - if _is_str_key_dict(message_payload): - messages.append(Message.from_dict(message_payload)) + try: + msg = Message.from_dict(item.get("message")) # type: ignore + except ValueError as e: + logger.warning("Failed to deserialize message from Cosmos DB item: %s", e) + continue + messages.append(msg) return messages diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index 401fda3ca7..01dcc102f4 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -280,17 +280,8 @@ def executor_activity(inputData: str) -> str: data = cast(dict[str, Any], data_obj) message_data = data.get("message") - shared_state_raw = data.get("shared_state_snapshot", {}) - source_executor_ids_raw = data.get("source_executor_ids", [SOURCE_ORCHESTRATOR]) - - shared_state_snapshot = cast(dict[str, Any], shared_state_raw) if isinstance(shared_state_raw, dict) else {} - - source_executor_ids: list[str] - if isinstance(source_executor_ids_raw, list): - source_executor_ids_values = cast(list[object], source_executor_ids_raw) - source_executor_ids = [str(source_executor_id) for source_executor_id in source_executor_ids_values] - else: - source_executor_ids = [SOURCE_ORCHESTRATOR] + shared_state_snapshot = data.get("shared_state_snapshot", {}) + source_executor_ids = cast(list[str], data.get("source_executor_ids", [SOURCE_ORCHESTRATOR])) if not self.workflow: raise RuntimeError("Workflow not initialized in AgentFunctionApp") @@ -478,29 +469,26 @@ async def get_workflow_status( } # Add pending HITL requests info if available - custom_status = status.custom_status - if isinstance(custom_status, dict): - custom_status_typed = cast(dict[str, Any], custom_status) - pending_requests_raw = custom_status_typed.get("pending_requests") - if isinstance(pending_requests_raw, dict): - base_url = self._build_base_url(req.url) - pending_requests: list[dict[str, Any]] = [] - pending_requests_dict = cast(dict[str, Any], pending_requests_raw) - for req_id_raw, req_data_raw in pending_requests_dict.items(): - if not isinstance(req_data_raw, dict): - continue - - req_id = str(req_id_raw) - req_data = cast(dict[str, Any], req_data_raw) - pending_requests.append({ - "requestId": req_id, - "sourceExecutor": req_data.get("source_executor_id"), - "requestData": req_data.get("data"), - "requestType": req_data.get("request_type"), - "responseType": req_data.get("response_type"), - "respondUrl": f"{base_url}/api/workflow/respond/{instance_id}/{req_id}", - }) - response["pendingHumanInputRequests"] = pending_requests + if ( + (custom_status := status.custom_status) + and isinstance(custom_status, dict) + and (pending_requests_dict := custom_status.get("pending_requests")) # type: ignore + and isinstance(pending_requests_dict, dict) + ): + base_url = self._build_base_url(req.url) + pending_requests: list[dict[str, Any]] = [] + for req_id, req_data in pending_requests_dict.items(): # type: ignore + if not isinstance(req_data, dict): + continue + pending_requests.append({ + "requestId": req_id, + "sourceExecutor": req_data.get("source_executor_id"), # type: ignore[reportUnknownMemberType] + "requestData": req_data.get("data"), # type: ignore[reportUnknownMemberType] + "requestType": req_data.get("request_type"), # type: ignore[reportUnknownMemberType] + "responseType": req_data.get("response_type"), # type: ignore[reportUnknownMemberType] + "respondUrl": f"{base_url}/api/workflow/respond/{instance_id}/{req_id}", + }) + response["pendingHumanInputRequests"] = pending_requests return func.HttpResponse( json.dumps(response, default=str), diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py index b353549dcb..f48e55f5d5 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_serialization.py @@ -20,11 +20,12 @@ import importlib import logging -from collections.abc import Callable +from contextlib import suppress from dataclasses import is_dataclass -from typing import Any, cast +from typing import Any from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value +from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -109,11 +110,9 @@ def reconstruct_to_type(value: Any, target_type: type) -> Any: if value is None: return None - try: + with suppress(TypeError): if isinstance(value, target_type): return value - except TypeError: - pass if not isinstance(value, dict): return value @@ -124,19 +123,18 @@ def reconstruct_to_type(value: Any, target_type: type) -> Any: return decoded # Try Pydantic model validation (for unmarked dicts, e.g., external HITL data) - model_validate = getattr(target_type, "model_validate", None) - if callable(model_validate): + if issubclass(target_type, BaseModel): try: - model_validate_fn = cast(Callable[[Any], Any], model_validate) - return model_validate_fn(value) + return target_type.model_validate(value) except Exception: logger.debug("Could not validate Pydantic model %s", target_type) + return value # type: ignore[return-value] # Try dataclass construction (for unmarked dicts, e.g., external HITL data) - if is_dataclass(target_type) and isinstance(target_type, type): + if is_dataclass(target_type) and isinstance(target_type, type): # type: ignore try: return target_type(**value) except Exception: logger.debug("Could not construct dataclass %s", target_type) - return cast(Any, value) + return value # type: ignore[return-value] diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py index 0c24752905..60c04ad66c 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_workflow.py @@ -26,7 +26,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from typing import Any, cast +from typing import Any from agent_framework import ( AgentExecutor, @@ -348,16 +348,16 @@ def _process_agent_response( ExecutorResult containing the processed response """ response_text = agent_response.text if agent_response else None - structured_response = None + structured_response: dict[str, Any] | None = None if agent_response and agent_response.value is not None: model_dump = getattr(agent_response.value, "model_dump", None) if callable(model_dump): dumped = model_dump() if isinstance(dumped, dict): - structured_response = cast(dict[str, Any], dumped) + structured_response = dumped # type: ignore[assignment] elif isinstance(agent_response.value, dict): - structured_response = agent_response.value + structured_response = agent_response.value # type: ignore[assignment] output_message = build_agent_executor_response( executor_id=executor_id, @@ -869,9 +869,8 @@ def _extract_message_content(message: Any) -> str: # Extract text from the last message in the request message_content = message.messages[-1].text or "" elif isinstance(message, dict): - message_dict = cast(dict[str, Any], message) - key_names = list(message_dict.keys()) - logger.warning("Unexpected dict message in _extract_message_content. Keys: %s", key_names) + key_names = list(message.keys()) # type: ignore[union-attr] + logger.warning("Unexpected dict message in _extract_message_content. Keys: %s", key_names) # type: ignore elif isinstance(message, str): message_content = message diff --git a/python/packages/bedrock/agent_framework_bedrock/__init__.py b/python/packages/bedrock/agent_framework_bedrock/__init__.py index 3fbf5c15cf..b2dc511559 100644 --- a/python/packages/bedrock/agent_framework_bedrock/__init__.py +++ b/python/packages/bedrock/agent_framework_bedrock/__init__.py @@ -2,8 +2,8 @@ import importlib.metadata -from ._chat_client import BedrockChatClient, BedrockChatOptions, BedrockGuardrailConfig, BedrockSettings -from ._embedding_client import BedrockEmbeddingClient, BedrockEmbeddingOptions, BedrockEmbeddingSettings +from ._chat_client import BedrockChatClient, BedrockChatOptions, BedrockGuardrailConfig, BedrockSettings # type: ignore +from ._embedding_client import BedrockEmbeddingClient, BedrockEmbeddingOptions, BedrockEmbeddingSettings # type: ignore try: __version__ = importlib.metadata.version(__name__) diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index 004c5b254d..1845c65a99 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. - +# type: ignore +# Because the Bedrock client does not have typing, we are ignoring type issues in this module. from __future__ import annotations import asyncio @@ -8,7 +9,7 @@ import sys from collections import deque from collections.abc import AsyncIterable, Awaitable, Mapping, MutableMapping, Sequence -from typing import Any, ClassVar, Generic, Literal, Protocol, TypedDict, TypeGuard, cast +from typing import Any, ClassVar, Generic, Literal, TypedDict from uuid import uuid4 from agent_framework import ( @@ -214,10 +215,6 @@ class BedrockSettings(TypedDict, total=False): session_token: SecretString | None -class BedrockRuntimeClient(Protocol): - def converse(self, **kwargs: Any) -> Mapping[str, object]: ... - - class BedrockChatClient( ChatMiddlewareLayer[BedrockChatOptionsT], FunctionInvocationLayer[BedrockChatOptionsT], @@ -228,37 +225,6 @@ class BedrockChatClient( """Async chat client for Amazon Bedrock's Converse API with middleware, telemetry, and function invocation.""" OTEL_PROVIDER_NAME: ClassVar[str] = "aws.bedrock" # type: ignore[reportIncompatibleVariableOverride, misc] - _bedrock_client: BedrockRuntimeClient - - @staticmethod - def _is_runtime_client(value: object) -> TypeGuard[BedrockRuntimeClient]: - converse = getattr(value, "converse", None) - return callable(converse) - - @staticmethod - def _get_str(value: object) -> str | None: - return value if isinstance(value, str) else None - - @staticmethod - def _get_dict(value: object) -> dict[str, Any] | None: - if not isinstance(value, dict): - return None - return cast(dict[str, Any], value) - - @staticmethod - def _is_nonstring_sequence(value: object) -> TypeGuard[Sequence[object]]: - return isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)) - - @staticmethod - def _get_content_blocks(value: object) -> list[dict[str, Any]]: - if not BedrockChatClient._is_nonstring_sequence(value): - return [] - blocks: list[dict[str, Any]] = [] - for item in value: - block = BedrockChatClient._get_dict(item) - if block is not None: - blocks.append(block) - return blocks def __init__( self, @@ -326,30 +292,21 @@ class MyOptions(BedrockChatOptions, total=False): region = settings.get("region") or DEFAULT_REGION chat_model_id = settings.get("chat_model_id") - if client is None: + if client: + self._bedrock_client = client + else: session = boto3_session or self._create_session(settings) - client_factory = getattr(session, "client", None) - if not callable(client_factory): - raise TypeError("Boto3 session does not provide a callable client factory.") - created_client: object = client_factory( + self._bedrock_client = session.client( "bedrock-runtime", region_name=region, config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT), ) - if not self._is_runtime_client(created_client): - raise TypeError("Boto3 session did not create a compatible Bedrock runtime client.") - runtime_client = created_client - elif not self._is_runtime_client(client): - raise TypeError("Provided client must expose a callable 'converse' method.") - else: - runtime_client = client super().__init__( middleware=middleware, function_invocation_configuration=function_invocation_configuration, **kwargs, ) - self._bedrock_client = runtime_client self.model_id = chat_model_id self.region = region @@ -370,7 +327,7 @@ def _invoke_converse(self, request: Mapping[str, Any]) -> dict[str, Any]: response = self._bedrock_client.converse(**request) if not isinstance(response, Mapping): raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.") - return dict(response) + return response @override def _inner_get_response( @@ -642,18 +599,17 @@ def _generate_tool_call_id() -> str: return f"tool-call-{uuid4().hex}" def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse: - output = self._get_dict(response.get("output")) or {} - message = self._get_dict(output.get("message")) or {} - content_blocks = self._get_content_blocks(message.get("content")) + """Convert Bedrock Converse API response to ChatResponse.""" + output = response.get("output") or {} + message = output.get("message") or {} + content_blocks = message.get("content") or [] contents = self._parse_message_contents(content_blocks) chat_message = Message(role="assistant", contents=contents, raw_representation=message) - usage_source = self._get_dict(response.get("usage")) or self._get_dict(output.get("usage")) + usage_source = response.get("usage") or output.get("usage") usage_details = self._parse_usage(usage_source) - finish_reason = self._map_finish_reason( - self._get_str(output.get("completionReason")) or self._get_str(response.get("stopReason")) - ) - response_id = self._get_str(response.get("responseId")) or self._get_str(message.get("id")) - model_id = self._get_str(response.get("modelId")) or self._get_str(output.get("modelId")) or self.model_id + finish_reason = self._map_finish_reason(output.get("completionReason") or response.get("stopReason")) + response_id = response.get("responseId") or message.get("id") + model_id = response.get("modelId") or output.get("modelId") or self.model_id return ChatResponse( response_id=response_id, messages=[chat_message], diff --git a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py index ac46a5c529..d07bdee45c 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. - +# type: ignore +# Because the Bedrock client does not have typing, we are ignoring type issues in this module. from __future__ import annotations import asyncio @@ -7,7 +8,7 @@ import logging import sys from collections.abc import Sequence -from typing import Any, ClassVar, Generic, Protocol, TypedDict, cast +from typing import Any, ClassVar, Generic, TypedDict from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, @@ -30,31 +31,6 @@ from typing_extensions import TypeVar # type: ignore # pragma: no cover -class BedrockRuntimeMeta(Protocol): - endpoint_url: str - - -class BedrockResponseBody(Protocol): - def read(self) -> bytes | bytearray | str: ... - - -class BedrockInvokeModelResponse(TypedDict): - body: BedrockResponseBody - - -class BedrockRuntimeClient(Protocol): - meta: BedrockRuntimeMeta - - def invoke_model( - self, - *, - modelId: str, - contentType: str, - accept: str, - body: str, - ) -> BedrockInvokeModelResponse: ... - - logger = logging.getLogger("agent_framework.bedrock") DEFAULT_REGION = "us-east-1" @@ -126,7 +102,7 @@ def __init__( access_key: str | None = None, secret_key: str | None = None, session_token: str | None = None, - client: BaseClient | BedrockRuntimeClient | None = None, + client: BaseClient | None = None, boto3_session: Boto3Session | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, @@ -147,7 +123,9 @@ def __init__( ) resolved_region = settings.get("region") or DEFAULT_REGION - if client is None: + if client: + self._bedrock_client = client + else: if not boto3_session: session_kwargs: dict[str, Any] = {} if region := settings.get("region"): @@ -158,18 +136,13 @@ def __init__( if session_token := settings.get("session_token"): session_kwargs["aws_session_token"] = session_token.get_secret_value() boto3_session = Boto3Session(**session_kwargs) - region_name = cast(str | None, getattr(boto3_session, "region_name", None)) - client_factory = cast(Any, boto3_session.client) # pyright: ignore[reportUnknownMemberType] - client = cast( - BedrockRuntimeClient, - client_factory( - "bedrock-runtime", - region_name=region_name if isinstance(region_name, str) else resolved_region, - config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT), - ), + region_name = boto3_session.region_name + self._bedrock_client = boto3_session.client( + "bedrock-runtime", + region_name=region_name or resolved_region, + config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT), ) - self._bedrock_client: BedrockRuntimeClient = cast(BedrockRuntimeClient, client) self.model_id: str = settings["embedding_model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] self.region = resolved_region super().__init__(**kwargs) @@ -183,7 +156,7 @@ async def get_embeddings( values: Sequence[str], *, options: BedrockEmbeddingOptionsT | None = None, - ) -> GeneratedEmbeddings[list[float]]: + ) -> GeneratedEmbeddings[list[float], BedrockEmbeddingOptionsT]: """Call the Bedrock invoke_model API for embeddings. Uses the Amazon Titan Embeddings model format. Each value is embedded @@ -199,9 +172,8 @@ async def get_embeddings( Raises: ValueError: If model_id is not provided or values is empty. """ - resolved_options = cast(EmbeddingGenerationOptions | None, options) if not values: - return GeneratedEmbeddings([], options=resolved_options) + return GeneratedEmbeddings([], options=options) opts: dict[str, Any] = dict(options) if options else {} model = opts.get("model_id") or self.model_id @@ -221,7 +193,7 @@ async def get_embeddings( if total_input_tokens > 0: usage_dict = {"input_token_count": total_input_tokens} - return GeneratedEmbeddings(embeddings, options=resolved_options, usage=usage_dict) + return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict) async def _generate_embedding_for_text( self, @@ -242,28 +214,10 @@ async def _generate_embedding_for_text( accept="application/json", body=json.dumps(body), ) - - response_body_raw = response["body"] - response_payload = response_body_raw.read() - payload_text = ( - response_payload.decode() if isinstance(response_payload, (bytes, bytearray)) else response_payload - ) - response_body_raw_map: object = json.loads(payload_text) - if not isinstance(response_body_raw_map, dict): - raise ValueError("Bedrock embedding response body must be a JSON object") - response_body = cast(dict[str, Any], response_body_raw_map) - embedding_values = response_body.get("embedding") - if not isinstance(embedding_values, list): - raise ValueError("Bedrock embedding response missing 'embedding' list") - vector: list[float] = [] - for value in cast(list[object], embedding_values): - if isinstance(value, (int, float, str)): - vector.append(float(value)) - continue - raise ValueError("Bedrock embedding response contains non-numeric embedding value") + response_body = json.loads(response["body"].read()) embedding = Embedding( - vector=vector, - dimensions=len(vector), + vector=response_body["embedding"], + dimensions=len(response_body["embedding"]), model_id=model, ) input_tokens = int(response_body.get("inputTextTokenCount", 0)) @@ -317,7 +271,7 @@ def __init__( access_key: str | None = None, secret_key: str | None = None, session_token: str | None = None, - client: BaseClient | BedrockRuntimeClient | None = None, + client: BaseClient | None = None, boto3_session: Boto3Session | None = None, otel_provider_name: str | None = None, env_file_path: str | None = None, diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index e903627637..bc26a3d515 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -319,29 +319,17 @@ def _normalize_tools( if tools is None: return - non_builtin_tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None - if isinstance(tools, str): - self._builtin_tools.append(tools) - return - if isinstance(tools, Sequence) and not isinstance(tools, MutableMapping): - sequence_tools: list[ToolTypes | Callable[..., Any]] = [] - for tool in tools: # pyright: ignore[reportUnknownVariableType] - if isinstance(tool, str): - self._builtin_tools.append(tool) - else: - sequence_tools.append(tool) # pyright: ignore[reportUnknownArgumentType] - non_builtin_tools = sequence_tools - else: - non_builtin_tools = tools - - if not non_builtin_tools: - return - - for tool in normalize_tools(non_builtin_tools): + non_builtin_tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] = [] + if not isinstance(tools, list): + tools = [tools] # type: ignore[assignment, reportUnknownVariableType] + for tool in tools: # type: ignore[reportUnknownVariableType] if isinstance(tool, str): self._builtin_tools.append(tool) else: - self._custom_tools.append(tool) + non_builtin_tools.append(tool) # type: ignore[union-attr, reportUnknownArgumentType] + if not non_builtin_tools: + return + self._custom_tools.extend(normalize_tools(non_builtin_tools)) # type: ignore[reportUnknownVariableType] async def __aenter__(self) -> RawClaudeAgent[OptionsT]: """Start the agent when entering async context.""" diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 878edc1312..6d79c69cbb 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -167,12 +167,12 @@ def _sanitize_agent_name(agent_name: str | None) -> str | None: class _RunContext(TypedDict): session: AgentSession | None session_context: SessionContext - input_messages: list[Message] - session_messages: list[Message] + input_messages: Sequence[Message] + session_messages: Sequence[Message] agent_name: str - chat_options: dict[str, Any] - filtered_kwargs: dict[str, Any] - finalize_kwargs: dict[str, Any] + chat_options: MutableMapping[str, Any] + filtered_kwargs: Mapping[str, Any] + finalize_kwargs: Mapping[str, Any] # region Agent Protocol @@ -863,11 +863,11 @@ async def _run_non_streaming() -> AgentResponse[Any]: kwargs=kwargs, ) response = cast( - ChatResponse[Any] | None, - await self.client.get_response( # type: ignore[call-overload] + ChatResponse[Any], + await self.client.get_response( # type: ignore messages=ctx["session_messages"], stream=False, - options=cast(Any, ctx["chat_options"]), + options=ctx["chat_options"], # type: ignore[reportArgumentType] **ctx["filtered_kwargs"], ), ) @@ -947,7 +947,7 @@ async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]] return self.client.get_response( # type: ignore[call-overload, no-any-return] messages=ctx["session_messages"], stream=True, - options=cast(Any, ctx["chat_options"]), + options=ctx["chat_options"], # type: ignore[reportArgumentType] **ctx["filtered_kwargs"], ) @@ -974,12 +974,9 @@ def _finalizer(updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]: ) return self._finalize_response_updates(updates, response_format=rf) - stream_response = cast( - ResponseStream[ChatResponseUpdate, ChatResponse[Any]], - cast(Any, ResponseStream).from_awaitable(_get_stream()), - ) return ( - stream_response + ResponseStream # type: ignore[reportUnknownMemberType] + .from_awaitable(_get_stream()) .map( transform=partial( map_chat_to_agent_update, diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index bc842bbfbd..5dd049ecd3 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -785,7 +785,7 @@ async def get_embeddings( values: Sequence[EmbeddingInputT], *, options: EmbeddingOptionsT | None = None, - ) -> GeneratedEmbeddings[EmbeddingT]: + ) -> GeneratedEmbeddings[EmbeddingT, EmbeddingOptionsT]: """Generate embeddings for the given values. Args: diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 8970b03ae6..8559d3beea 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -27,7 +27,6 @@ Literal, TypeAlias, TypedDict, - TypeGuard, Union, cast, get_args, @@ -80,13 +79,6 @@ logger = logging.getLogger("agent_framework") -def _is_str_key_mapping(value: object) -> TypeGuard[Mapping[str, Any]]: - if not isinstance(value, Mapping): - return False - keys = cast(Mapping[object, object], value).keys() - return all(isinstance(key, str) for key in keys) - - DEFAULT_MAX_ITERATIONS: Final[int] = 40 DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3 SHELL_TOOL_KIND_VALUE: Final[str] = "shell" @@ -709,7 +701,7 @@ def normalize_tools( if isinstance(tool_item, FunctionTool): normalized.append(tool_item) continue - if _is_str_key_mapping(tool_item): + if isinstance(tool_item, dict): normalized.append(tool_item) continue if isinstance(tool_item, MCPTool): @@ -745,8 +737,8 @@ def _tools_to_dict( # pyright: ignore[reportUnusedFunction] if isinstance(tool_item, SerializationMixin): results.append(tool_item.to_dict()) continue - if _is_str_key_mapping(tool_item): - results.append(dict(tool_item)) + if isinstance(tool_item, dict): + results.append(tool_item) continue logger.warning("Can't parse tool.") return results @@ -825,7 +817,7 @@ def _validate_arguments_against_schema( raise TypeError(f"Missing required argument(s) for '{tool_name}': {', '.join(sorted(missing_fields))}") properties_raw = schema.get("properties") - properties: Mapping[str, Any] = properties_raw if _is_str_key_mapping(properties_raw) else {} + properties: Mapping[str, Any] = properties_raw if isinstance(properties_raw, dict) else {} if schema.get("additionalProperties") is False: unexpected_fields = sorted(field for field in parsed_arguments if field not in properties) @@ -834,7 +826,7 @@ def _validate_arguments_against_schema( for field_name, field_value in parsed_arguments.items(): field_schema_raw = properties.get(field_name) - if not _is_str_key_mapping(field_schema_raw): + if not isinstance(field_schema_raw, dict): continue field_schema = field_schema_raw @@ -892,7 +884,7 @@ def _build_pydantic_model_from_json_schema( The dynamically created Pydantic model class. """ properties_raw = schema.get("properties") - properties = properties_raw if _is_str_key_mapping(properties_raw) else None + properties = properties_raw if isinstance(properties_raw, dict) else None required_raw = schema.get("required", []) required_obj: object = required_raw required: list[str] = ( @@ -901,7 +893,7 @@ def _build_pydantic_model_from_json_schema( else [] ) defs_raw = schema.get("$defs", {}) - definitions: Mapping[str, Any] = defs_raw if _is_str_key_mapping(defs_raw) else {} + definitions: Mapping[str, Any] = defs_raw if isinstance(defs_raw, dict) else {} # Check if 'properties' is missing or not a dictionary if not properties: @@ -942,7 +934,7 @@ def _resolve_type(prop_details: Mapping[str, Any], parent_name: str = "") -> typ # Handle oneOf + discriminator (polymorphic objects) if "oneOf" in prop_details and "discriminator" in prop_details: discriminator_raw = prop_details["discriminator"] - discriminator: Mapping[str, Any] = discriminator_raw if _is_str_key_mapping(discriminator_raw) else {} + discriminator: Mapping[str, Any] = discriminator_raw if isinstance(discriminator_raw, dict) else {} disc_field_raw = discriminator.get("propertyName") disc_field = disc_field_raw if isinstance(disc_field_raw, str) else None @@ -950,7 +942,7 @@ def _resolve_type(prop_details: Mapping[str, Any], parent_name: str = "") -> typ one_of_raw = prop_details["oneOf"] one_of: list[object] = cast(list[object], one_of_raw) if isinstance(one_of_raw, list) else [] for variant_raw in one_of: - if not _is_str_key_mapping(variant_raw): + if not isinstance(variant_raw, dict): continue variant = variant_raw if "$ref" in variant: @@ -999,7 +991,7 @@ def _resolve_type(prop_details: Mapping[str, Any], parent_name: str = "") -> typ case "array": # Handle typed arrays items_schema = prop_details.get("items") - if _is_str_key_mapping(items_schema): + if isinstance(items_schema, dict): # Recursively resolve the item type item_type = _resolve_type(items_schema, f"{parent_name}_item") # Return list[ItemType] instead of bare list @@ -1009,7 +1001,7 @@ def _resolve_type(prop_details: Mapping[str, Any], parent_name: str = "") -> typ case "object": # Handle nested objects by creating a nested Pydantic model nested_properties_raw = prop_details.get("properties") - nested_properties = nested_properties_raw if _is_str_key_mapping(nested_properties_raw) else None + nested_properties = nested_properties_raw if isinstance(nested_properties_raw, dict) else None nested_required_raw = prop_details.get("required", []) nested_required_obj: object = nested_required_raw nested_required: set[str] = ( @@ -1030,7 +1022,7 @@ def _resolve_type(prop_details: Mapping[str, Any], parent_name: str = "") -> typ if isinstance(nested_prop_details_raw, str) else nested_prop_details_raw ) - if not _is_str_key_mapping(nested_prop_details_candidate): + if not isinstance(nested_prop_details_candidate, dict): continue nested_prop_details = nested_prop_details_candidate @@ -1079,7 +1071,7 @@ def _resolve_type(prop_details: Mapping[str, Any], parent_name: str = "") -> typ for prop_name, prop_details_raw in properties.items(): prop_details_candidate = json.loads(prop_details_raw) if isinstance(prop_details_raw, str) else prop_details_raw - if not _is_str_key_mapping(prop_details_candidate): + if not isinstance(prop_details_candidate, dict): continue prop_details = prop_details_candidate diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index feb26f67f4..b1940e6c22 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -199,6 +199,7 @@ class OtelAttr(str, Enum): T_TYPE_INPUT = "input" T_TYPE_OUTPUT = "output" DURATION_UNIT = "s" + # Agent attributes AGENT_NAME = "gen_ai.agent.name" AGENT_DESCRIPTION = "gen_ai.agent.description" @@ -1323,13 +1324,10 @@ async def get_embeddings( values: Sequence[EmbeddingInputT], *, options: EmbeddingOptionsT | None = None, - ) -> GeneratedEmbeddings[EmbeddingT]: + ) -> GeneratedEmbeddings[EmbeddingT, EmbeddingOptionsT]: """Trace embedding generation with OpenTelemetry spans and metrics.""" global OBSERVABILITY_SETTINGS - super_get_embeddings = cast( - "Callable[..., Awaitable[GeneratedEmbeddings[EmbeddingT]]]", - super().get_embeddings, # type: ignore[misc] - ) + super_get_embeddings = super().get_embeddings # type: ignore[misc] if not OBSERVABILITY_SETTINGS.ENABLED: return await super_get_embeddings(values, options=options) # type: ignore[no-any-return] @@ -1349,16 +1347,17 @@ async def get_embeddings( with _get_span(attributes=attributes, span_name_attribute=OtelAttr.REQUEST_MODEL) as span: start_time_stamp = perf_counter() try: - result: GeneratedEmbeddings[EmbeddingT] = await super_get_embeddings(values, options=options) + result: GeneratedEmbeddings[EmbeddingT, EmbeddingOptionsT] = await super_get_embeddings( + values, options=options + ) except Exception as exception: capture_exception(span=span, exception=exception, timestamp=time_ns()) raise duration = perf_counter() - start_time_stamp response_attributes: dict[str, Any] = {**attributes} - usage = cast(Mapping[str, Any], result.usage) if result.usage else None - prompt_tokens = usage.get("prompt_tokens") if usage is not None else None - if prompt_tokens is not None: - response_attributes[OtelAttr.INPUT_TOKENS] = prompt_tokens + usage = result.usage or {} + if (input_tokens := usage.get("input_token_count")) is not None: + response_attributes[OtelAttr.INPUT_TOKENS] = input_tokens _capture_response( span=span, attributes=response_attributes, diff --git a/python/packages/core/agent_framework/openai/_embedding_client.py b/python/packages/core/agent_framework/openai/_embedding_client.py index 0efea66ae5..b940e47c7c 100644 --- a/python/packages/core/agent_framework/openai/_embedding_client.py +++ b/python/packages/core/agent_framework/openai/_embedding_client.py @@ -6,7 +6,7 @@ import struct import sys from collections.abc import Awaitable, Callable, Mapping, Sequence -from typing import Any, Generic, Literal, TypedDict, cast +from typing import Any, Generic, Literal, TypedDict from openai import AsyncOpenAI @@ -67,7 +67,7 @@ async def get_embeddings( values: Sequence[str], *, options: OpenAIEmbeddingOptionsT | None = None, - ) -> GeneratedEmbeddings[list[float]]: + ) -> GeneratedEmbeddings[list[float], OpenAIEmbeddingOptionsT]: """Call the OpenAI embeddings API. Args: @@ -81,9 +81,9 @@ async def get_embeddings( ValueError: If model_id is not provided or values is empty. """ if not values: - return cast(GeneratedEmbeddings[list[float]], GeneratedEmbeddings([], options=options)) + return GeneratedEmbeddings([], options=options) # type: ignore - opts: dict[str, Any] = dict(options) if options else {} + opts: dict[str, Any] = options or {} # type: ignore model = opts.get("model_id") or self.model_id if not model: raise ValueError("model_id is required") @@ -123,9 +123,7 @@ async def get_embeddings( "total_token_count": response.usage.total_tokens, } - return cast( - GeneratedEmbeddings[list[float]], GeneratedEmbeddings(embeddings, options=options, usage=usage_dict) - ) + return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict) class OpenAIEmbeddingClient( diff --git a/python/packages/ollama/agent_framework_ollama/_embedding_client.py b/python/packages/ollama/agent_framework_ollama/_embedding_client.py index f5063159ef..0a922b3276 100644 --- a/python/packages/ollama/agent_framework_ollama/_embedding_client.py +++ b/python/packages/ollama/agent_framework_ollama/_embedding_client.py @@ -5,7 +5,7 @@ import logging import sys from collections.abc import Sequence -from typing import Any, ClassVar, Generic, TypedDict +from typing import Any, ClassVar, Generic, TypedDict, cast from agent_framework import ( BaseEmbeddingClient, @@ -120,8 +120,8 @@ async def get_embeddings( self, values: Sequence[str], *, - options: OllamaEmbeddingOptionsT | None = None, - ) -> GeneratedEmbeddings[list[float]]: + options: OllamaEmbeddingOptionsT | None = None, # type: ignore + ) -> GeneratedEmbeddings[list[float], OllamaEmbeddingOptionsT]: """Call the Ollama embed API. Args: @@ -134,19 +134,10 @@ async def get_embeddings( Raises: ValueError: If model_id is not provided or values is empty. """ - opts: dict[str, Any] = dict(options) if options else {} - if not values: - return GeneratedEmbeddings([], options=None) - - response_options: EmbeddingGenerationOptions | None = None - if options: - response_options = {} - if (model_id := opts.get("model_id")) is not None: - response_options["model_id"] = model_id - if (dimensions := opts.get("dimensions")) is not None: - response_options["dimensions"] = dimensions + return GeneratedEmbeddings([], options=options) + opts: dict[str, Any] = options or {} # type: ignore model = opts.get("model_id") or self.model_id if not model: raise ValueError("model_id is required") @@ -165,7 +156,7 @@ async def get_embeddings( Embedding( vector=list(emb), dimensions=len(emb), - model_id=response.get("model") or model, + model_id=response.get("model") or model, # type: ignore[assignment] ) for emb in response.get("embeddings", []) ] @@ -175,7 +166,7 @@ async def get_embeddings( if prompt_eval_count is not None: usage_dict = {"input_token_count": prompt_eval_count} - return GeneratedEmbeddings(embeddings, options=response_options, usage=usage_dict) + return GeneratedEmbeddings(embeddings, options=cast(OllamaEmbeddingOptionsT, opts), usage=usage_dict) class OllamaEmbeddingClient( From c73c89711a457f29ae1d03d1b0a942444d737fe2 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 4 Mar 2026 17:12:07 +0100 Subject: [PATCH 08/15] Fix current Python test regressions Address current failing unit tests in azure-ai, bedrock, and azure-cosmos while keeping Bedrock parsing logic inline (no new static helper methods). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../agent_framework_azure_ai/_chat_client.py | 19 ++++--- .../agent_framework_azure_ai/_shared.py | 17 ++++--- .../_history_provider.py | 6 ++- .../agent_framework_bedrock/_chat_client.py | 50 +++++++++++++------ 4 files changed, 61 insertions(+), 31 deletions(-) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 4e9f41e9d6..d4146839d5 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -1215,17 +1215,16 @@ def _prepare_tool_choice_mode( tool_choice = options.get("tool_choice") if tool_choice is None: return None - if tool_choice in {"none", "auto"}: + if isinstance(tool_choice, str) and tool_choice in {"none", "auto"}: return AgentsToolChoiceOptionMode(tool_choice) - if ( - isinstance(tool_choice, Mapping) - and tool_choice.get("mode") == "required" # type: ignore[attr-unknown] - and (req_fn := tool_choice.get("required_function_name")) # type: ignore[attr-unknown] - ): - return AgentsNamedToolChoice( - type=AgentsNamedToolChoiceType.FUNCTION, - function=FunctionName(name=req_fn), # type: ignore[call-arg] - ) + if isinstance(tool_choice, dict): + mode: object = tool_choice.get("mode") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + req_fn: object = tool_choice.get("required_function_name") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + if mode == "required" and isinstance(req_fn, str): + return AgentsNamedToolChoice( + type=AgentsNamedToolChoiceType.FUNCTION, + function=FunctionName(name=req_fn), + ) return None async def _prepare_tool_definitions_and_resources( diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py index b9f8ed801f..59289d2746 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py @@ -79,7 +79,7 @@ class AzureAISettings(TypedDict, total=False): model_deployment_name: str | None -def _extract_project_connection_id(additional_properties: dict[str, Any] | None) -> str | None: +def _extract_project_connection_id(additional_properties: Mapping[str, Any] | None) -> str | None: """Extract project_connection_id from tool additional_properties. Checks for both direct 'project_connection_id' key (programmatic usage) @@ -454,11 +454,16 @@ def _prepare_mcp_tool_dict_for_azure_ai(tool_dict: dict[str, Any]) -> MCPTool: mcp["server_description"] = description # Check for project_connection_id - if project_connection_id := ( - tool_dict.get("project_connection_id") or _extract_project_connection_id(tool_dict.get("additional_properties")) - if isinstance(tool_dict.get("additional_properties"), Mapping) - else None - ): + project_connection_id = tool_dict.get("project_connection_id") + if not isinstance(project_connection_id, str): + additional_properties = tool_dict.get("additional_properties") + project_connection_id = ( + _extract_project_connection_id(additional_properties) # pyright: ignore[reportUnknownArgumentType] + if isinstance(additional_properties, Mapping) + else None + ) + + if project_connection_id: mcp["project_connection_id"] = project_connection_id elif headers := tool_dict.get("headers"): mcp["headers"] = headers diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py index e26f3e061a..35c4243c37 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py @@ -144,8 +144,12 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess messages: list[Message] = [] async for item in items: + message_payload = item.get("message") + if not isinstance(message_payload, dict): + logger.warning("Skipping Cosmos DB item with non-mapping message payload.") + continue try: - msg = Message.from_dict(item.get("message")) # type: ignore + msg = Message.from_dict(message_payload) # pyright: ignore[reportUnknownArgumentType] except ValueError as e: logger.warning("Failed to deserialize message from Cosmos DB item: %s", e) continue diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index 1845c65a99..5bc9735846 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -348,7 +348,9 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: contents = list(parsed_response.messages[0].contents if parsed_response.messages else []) if parsed_response.usage_details: contents.append(Content.from_usage(usage_details=parsed_response.usage_details)) # type: ignore[arg-type] - raw_finish_reason = self._get_str(parsed_response.finish_reason) + raw_finish_reason = ( + parsed_response.finish_reason if isinstance(parsed_response.finish_reason, str) else None + ) finish_reason = self._map_finish_reason(raw_finish_reason) yield ChatResponseUpdate( response_id=parsed_response.response_id, @@ -549,7 +551,7 @@ def _convert_tool_result_to_blocks(self, result: Any) -> list[dict[str, Any]]: return self._convert_prepared_tool_result_to_blocks(parsed_result) def _convert_prepared_tool_result_to_blocks(self, value: object) -> list[dict[str, Any]]: - if self._is_nonstring_sequence(value): + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): blocks: list[dict[str, Any]] = [] for item in value: blocks.extend(self._convert_prepared_tool_result_to_blocks(item)) @@ -559,7 +561,7 @@ def _convert_prepared_tool_result_to_blocks(self, value: object) -> list[dict[st def _normalize_tool_result_value(self, value: object) -> dict[str, Any]: if isinstance(value, dict): return {"json": value} - if self._is_nonstring_sequence(value): + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): return {"json": [item for item in value]} if isinstance(value, str): return {"text": value} @@ -640,32 +642,50 @@ def _parse_message_contents(self, content_blocks: Sequence[dict[str, Any]]) -> l if (json_value := block.get("json")) is not None: contents.append(Content.from_text(text=json.dumps(json_value), raw_representation=block)) continue - tool_use = self._get_dict(block.get("toolUse")) + tool_use_value = block.get("toolUse") + tool_use = ( + tool_use_value + if isinstance(tool_use_value, dict) + else dict(tool_use_value) + if isinstance(tool_use_value, Mapping) + else None + ) if tool_use is not None: - tool_name = self._get_str(tool_use.get("name")) + tool_name_value = tool_use.get("name") + tool_name = tool_name_value if isinstance(tool_name_value, str) else None if not tool_name: raise ChatClientInvalidResponseException( "Bedrock response missing required tool name in toolUse block." ) + tool_use_id = tool_use.get("toolUseId") contents.append( Content.from_function_call( - call_id=self._get_str(tool_use.get("toolUseId")) or self._generate_tool_call_id(), + call_id=tool_use_id if isinstance(tool_use_id, str) else self._generate_tool_call_id(), name=tool_name, arguments=tool_use.get("input"), raw_representation=block, ) ) continue - tool_result = self._get_dict(block.get("toolResult")) + tool_result_value = block.get("toolResult") + tool_result = ( + tool_result_value + if isinstance(tool_result_value, dict) + else dict(tool_result_value) + if isinstance(tool_result_value, Mapping) + else None + ) if tool_result is not None: - status = (self._get_str(tool_result.get("status")) or "success").lower() + status_value = tool_result.get("status") + status = (status_value if isinstance(status_value, str) else "success").lower() exception = None if status not in {"success", "ok"}: exception = RuntimeError(f"Bedrock tool result status: {status}") result_value = self._convert_bedrock_tool_result_to_value(tool_result.get("content")) + tool_use_id = tool_result.get("toolUseId") contents.append( Content.from_function_result( - call_id=self._get_str(tool_result.get("toolUseId")) or self._generate_tool_call_id(), + call_id=tool_use_id if isinstance(tool_use_id, str) else self._generate_tool_call_id(), result=result_value, exception=str(exception) if exception else None, # type: ignore[arg-type] raw_representation=block, @@ -691,12 +711,13 @@ def service_url(self) -> str: def _convert_bedrock_tool_result_to_value(self, content: object) -> object: if not content: return None - if self._is_nonstring_sequence(content): + if isinstance(content, Sequence) and not isinstance(content, (str, bytes, bytearray)): values: list[object] = [] for item in content: - item_dict = self._get_dict(item) + item_dict = item if isinstance(item, dict) else dict(item) if isinstance(item, Mapping) else None if item_dict is not None: - if (text_value := self._get_str(item_dict.get("text"))) is not None: + text_value = item_dict.get("text") + if isinstance(text_value, str): values.append(text_value) continue if "json" in item_dict: @@ -704,9 +725,10 @@ def _convert_bedrock_tool_result_to_value(self, content: object) -> object: continue values.append(item) return values[0] if len(values) == 1 else values - content_dict = self._get_dict(content) + content_dict = content if isinstance(content, dict) else dict(content) if isinstance(content, Mapping) else None if content_dict is not None: - if (text_value := self._get_str(content_dict.get("text"))) is not None: + text_value = content_dict.get("text") + if isinstance(text_value, str): return text_value if "json" in content_dict: return content_dict["json"] From a3ebd910c0977a712de33bc9c0fda62f18aba431 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 4 Mar 2026 20:09:06 +0100 Subject: [PATCH 09/15] small fixes --- python/CODING_STANDARD.md | 10 ++++------ .../agent_framework_azure_ai/_chat_client.py | 12 ++++++------ python/pyproject.toml | 1 + 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/python/CODING_STANDARD.md b/python/CODING_STANDARD.md index 3eb84e7ba0..92671b6a19 100644 --- a/python/CODING_STANDARD.md +++ b/python/CODING_STANDARD.md @@ -79,20 +79,18 @@ def process_config(config: MutableMapping[str, Any]) -> None: ... ``` -### Pyright Ignore and Cast Policy +### Typing Ignore and Cast Policy Use typing as a helper first and suppressions as a last resort: - **Prefer explicit typing before suppression**: Start with clearer type annotations, helper types, overloads, - protocols, or refactoring dynamic code into typed helpers. + protocols, or refactoring dynamic code into typed helpers. Prioritize performance over completeness of typing, but make a good-faith effort to reduce uncertainty with typing before ignoring. Prefer to use a cast over a typeguard function since that does add overhead. +- **Avoid redundant casts**: Do not add `cast(...)` if the type already matches; casts should be reserved for + unavoidable narrowing where the runtime contract is known, we will use mypy's check on redundant casts to enforce this. - **Line-level pyright ignores only**: If suppression is still required, use a line-level rule-specific ignore (`# pyright: ignore[reportGeneralTypeIssues]`), never file-level or global suppression for this workflow. - **Private usage boundary**: Accessing private members across `agent_framework*` packages can be acceptable for this codebase, but private member usage for non-Agent Framework dependencies should remain flagged. -- **Avoid redundant casts**: Do not add `cast(...)` if the type already matches; casts should be reserved for - unavoidable narrowing where the runtime contract is known. -- **Uncertainty handoff**: If you are still unsure after best-effort typing, leave a targeted TODO note - (`TODO(): ...`) that explains what reviewer guidance is needed. ## Function Parameter Guidelines diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index d4146839d5..a0c9d9046c 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -704,7 +704,7 @@ async def _create_agent_stream( args["tool_approvals"] = tool_approvals await self.agents_client.runs.submit_tool_outputs_stream(**args) # type: ignore[reportUnknownMemberType] # Pass the handler to the stream to continue processing - stream = handler # type: ignore + stream = handler final_thread_id = thread_run.thread_id else: # Handle thread creation or cancellation @@ -881,7 +881,7 @@ async def _process_stream( azure_search_tool_calls: list[dict[str, Any]] = [] response_stream = await stream.__aenter__() if isinstance(stream, AsyncAgentRunStream) else stream # type: ignore[no-untyped-call] try: - async for event_type, event_data, _ in response_stream: # type: ignore + async for event_type, event_data, _ in response_stream: match event_data: case MessageDeltaChunk(): # only one event_type: AgentStreamEvent.THREAD_MESSAGE_DELTA @@ -1212,15 +1212,15 @@ def _prepare_tool_choice_mode( self, options: Mapping[str, Any] ) -> AgentsToolChoiceOptionMode | AgentsNamedToolChoice | None: """Prepare the tool choice mode for Azure AI Agents API.""" - tool_choice = options.get("tool_choice") + tool_choice = cast(str | dict[str, str] | None, options.get("tool_choice")) if tool_choice is None: return None if isinstance(tool_choice, str) and tool_choice in {"none", "auto"}: return AgentsToolChoiceOptionMode(tool_choice) if isinstance(tool_choice, dict): - mode: object = tool_choice.get("mode") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - req_fn: object = tool_choice.get("required_function_name") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] - if mode == "required" and isinstance(req_fn, str): + mode = tool_choice.get("mode") + req_fn = tool_choice.get("required_function_name") + if mode == "required" and req_fn is not None: return AgentsNamedToolChoice( type=AgentsNamedToolChoiceType.FUNCTION, function=FunctionName(name=req_fn), diff --git a/python/pyproject.toml b/python/pyproject.toml index 10597ed68b..c4219a5483 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -199,6 +199,7 @@ check_untyped_defs = true warn_return_any = true show_error_codes = true warn_unused_ignores = false +warn_redundant_casts = true disallow_incomplete_defs = true disallow_untyped_decorators = true From 28534e80f56071957766a38cbb73195fa5aa5105 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 4 Mar 2026 20:24:18 +0100 Subject: [PATCH 10/15] small fixes --- python/packages/core/agent_framework/_tools.py | 5 +---- .../core/agent_framework/observability.py | 18 +++++------------- python/pyproject.toml | 2 +- 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 8559d3beea..7245d36a9e 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -2130,10 +2130,7 @@ def get_response( ResponseStream, ) - super_get_response_untyped = super().get_response # type: ignore[misc] - - def super_get_response(*args: Any, **kwargs: Any) -> Any: - return super_get_response_untyped(*args, **kwargs) # pyright: ignore[reportUnknownVariableType] + super_get_response = super().get_response # type: ignore[misc] # ChatMiddleware adds this kwarg function_middleware_pipeline = FunctionMiddlewarePipeline( diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index b1940e6c22..73c08ff90e 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1155,10 +1155,7 @@ def get_response( ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: """Trace chat responses with OpenTelemetry spans and metrics.""" global OBSERVABILITY_SETTINGS - super_get_response = cast( - "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", - super().get_response, # type: ignore[misc] - ) + super_get_response = super().get_response # type: ignore[misc] if not OBSERVABILITY_SETTINGS.ENABLED: return super_get_response(messages=messages, stream=stream, options=options, **kwargs) # type: ignore[no-any-return] @@ -1177,15 +1174,10 @@ def get_response( ) if stream: - from ._types import ResponseStream - - stream_result: object = super_get_response(messages=messages, stream=True, options=opts, **kwargs) - if isinstance(stream_result, ResponseStream): - result_stream: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] = stream_result # pyright: ignore[reportUnknownVariableType] - elif isinstance(stream_result, Awaitable): - result_stream = ResponseStream.from_awaitable(stream_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - else: - raise RuntimeError("Streaming telemetry requires a ResponseStream result.") + result_stream = cast( + ResponseStream[ChatResponseUpdate, ChatResponse[Any]], + super_get_response(messages=messages, stream=True, options=opts, **kwargs), + ) # Create span directly without trace.use_span() context attachment. # Streaming spans are closed asynchronously in cleanup hooks, which run diff --git a/python/pyproject.toml b/python/pyproject.toml index c4219a5483..9f4ca3c08c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -187,6 +187,7 @@ exclude = ["**/tests/**", "**/.venv/**", "packages/devui/frontend/**"] typeCheckingMode = "strict" reportUnnecessaryIsInstance = false reportMissingTypeStubs = false +reportUnnecessaryCast = "error" [tool.mypy] plugins = ['pydantic.mypy'] @@ -199,7 +200,6 @@ check_untyped_defs = true warn_return_any = true show_error_codes = true warn_unused_ignores = false -warn_redundant_casts = true disallow_incomplete_defs = true disallow_untyped_decorators = true From 57cd0e3bad3ba80cf4b47b6d32c5d1e6dad4660f Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 5 Mar 2026 10:28:56 +0100 Subject: [PATCH 11/15] removed pydantic from json --- .../packages/core/agent_framework/_tools.py | 302 +----------- python/packages/core/tests/core/test_tools.py | 466 +----------------- .../agent_framework_declarative/_loader.py | 8 +- .../tests/test_declarative_loader.py | 8 +- 4 files changed, 19 insertions(+), 765 deletions(-) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 7245d36a9e..4051cef21b 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -27,7 +27,6 @@ Literal, TypeAlias, TypedDict, - Union, cast, get_args, get_origin, @@ -687,16 +686,13 @@ def normalize_tools( if not tools: return [] - tool_items: list[object] - if isinstance(tools, Sequence) and not isinstance(tools, (str, bytes, bytearray, Mapping)): - sequence_tools = cast(Sequence[object], tools) - tool_items = list(sequence_tools) - else: - tool_items = [tools] + if isinstance(tools, (str, bytes, bytearray, Mapping)) or not isinstance(tools, Sequence): + tools = cast(list[ToolTypes | Callable[..., Any]], [tools]) + from ._mcp import MCPTool normalized: list[ToolTypes] = [] - for tool_item in tool_items: + for tool_item in tools: # check known types, these are also callable, so we need to do that first if isinstance(tool_item, FunctionTool): normalized.append(tool_item) @@ -810,33 +806,28 @@ def _validate_arguments_against_schema( """Run lightweight argument checks for schema-supplied tools.""" parsed_arguments = dict(arguments) - required_raw = schema.get("required", []) - required_fields = [field for field in required_raw if isinstance(field, str)] + required_fields = [field for field in schema.get("required", []) if isinstance(field, str)] missing_fields = [field for field in required_fields if field not in parsed_arguments] if missing_fields: raise TypeError(f"Missing required argument(s) for '{tool_name}': {', '.join(sorted(missing_fields))}") - properties_raw = schema.get("properties") - properties: Mapping[str, Any] = properties_raw if isinstance(properties_raw, dict) else {} - + properties: Mapping[str, Any] = schema.get("properties", {}) if schema.get("additionalProperties") is False: unexpected_fields = sorted(field for field in parsed_arguments if field not in properties) if unexpected_fields: raise TypeError(f"Unexpected argument(s) for '{tool_name}': {', '.join(unexpected_fields)}") for field_name, field_value in parsed_arguments.items(): - field_schema_raw = properties.get(field_name) - if not isinstance(field_schema_raw, dict): + if not isinstance(properties.get(field_name), dict): continue - field_schema = field_schema_raw - enum_values = field_schema.get("enum") + enum_values = properties.get(field_name, {}).get("enum") # type: ignore if isinstance(enum_values, list) and enum_values and field_value not in enum_values: raise TypeError( f"Invalid value for '{field_name}' in '{tool_name}': {field_value!r} is not in {enum_values!r}" ) - schema_type = field_schema.get("type") + schema_type = properties.get(field_name, {}).get("type") # type: ignore if isinstance(schema_type, str): if not _matches_json_schema_type(field_value, schema_type): raise TypeError( @@ -845,10 +836,8 @@ def _validate_arguments_against_schema( ) continue - schema_type_obj: object = schema_type - if isinstance(schema_type_obj, list): - schema_type_list = cast(list[object], schema_type_obj) - allowed_types: list[str] = [item for item in schema_type_list if isinstance(item, str)] + if isinstance(schema_type, list): + allowed_types: list[str] = [item for item in schema_type if isinstance(item, str)] if allowed_types and not any(_matches_json_schema_type(field_value, item) for item in allowed_types): raise TypeError( f"Invalid type for '{field_name}' in '{tool_name}': expected one of " @@ -858,275 +847,6 @@ def _validate_arguments_against_schema( return parsed_arguments -# Map JSON Schema types to Pydantic types -TYPE_MAPPING = { - "string": str, - "integer": int, - "number": float, - "boolean": bool, - "array": list, - "object": dict, - "null": type(None), -} - - -def _build_pydantic_model_from_json_schema( - model_name: str, - schema: Mapping[str, Any], -) -> type[BaseModel]: - """Creates a Pydantic model from JSON Schema with support for $refs, nested objects, and typed arrays. - - Args: - model_name: The name of the model to be created. - schema: The JSON Schema definition (should contain 'properties', 'required', '$defs', etc.). - - Returns: - The dynamically created Pydantic model class. - """ - properties_raw = schema.get("properties") - properties = properties_raw if isinstance(properties_raw, dict) else None - required_raw = schema.get("required", []) - required_obj: object = required_raw - required: list[str] = ( - [item for item in cast(list[object], required_obj) if isinstance(item, str)] - if isinstance(required_obj, list) - else [] - ) - defs_raw = schema.get("$defs", {}) - definitions: Mapping[str, Any] = defs_raw if isinstance(defs_raw, dict) else {} - - # Check if 'properties' is missing or not a dictionary - if not properties: - return create_model(f"{model_name}_input") - - def _resolve_literal_type(prop_details: Mapping[str, Any]) -> type | None: - """Check if property should be a Literal type (const or enum). - - Args: - prop_details: The JSON Schema property details - - Returns: - Literal type if const or enum is present, None otherwise - """ - # const → Literal["value"] - if "const" in prop_details: - return Literal[prop_details["const"]] # type: ignore - - # enum → Literal["a", "b", ...] - enum_raw: object = prop_details.get("enum") - if isinstance(enum_raw, list): - enum_values = cast(list[object], enum_raw) - if enum_values: - return Literal[tuple(enum_values)] # type: ignore - - return None - - def _resolve_type(prop_details: Mapping[str, Any], parent_name: str = "") -> type: - """Resolve JSON Schema type to Python type, handling $ref, nested objects, and typed arrays. - - Args: - prop_details: The JSON Schema property details - parent_name: Name to use for creating nested models (for uniqueness) - - Returns: - Python type annotation (could be int, str, list[str], or a nested Pydantic model) - """ - # Handle oneOf + discriminator (polymorphic objects) - if "oneOf" in prop_details and "discriminator" in prop_details: - discriminator_raw = prop_details["discriminator"] - discriminator: Mapping[str, Any] = discriminator_raw if isinstance(discriminator_raw, dict) else {} - disc_field_raw = discriminator.get("propertyName") - disc_field = disc_field_raw if isinstance(disc_field_raw, str) else None - - variants: list[type] = [] - one_of_raw = prop_details["oneOf"] - one_of: list[object] = cast(list[object], one_of_raw) if isinstance(one_of_raw, list) else [] - for variant_raw in one_of: - if not isinstance(variant_raw, dict): - continue - variant = variant_raw - if "$ref" in variant: - ref_raw = variant["$ref"] - if not isinstance(ref_raw, str): - continue - ref = ref_raw - if ref.startswith("#/$defs/"): - def_name = ref.split("/")[-1] - resolved = definitions.get(def_name) - if resolved: - variant_model = _resolve_type( - resolved, - parent_name=f"{parent_name}_{def_name}", - ) - variants.append(variant_model) - - if variants and disc_field: - return Annotated[ - Union[tuple(variants)], # type: ignore - Field(discriminator=disc_field), - ] - - # Handle $ref by resolving the reference - if "$ref" in prop_details: - ref = prop_details["$ref"] - # Extract the reference path (e.g., "#/$defs/CustomerIdParam" -> "CustomerIdParam") - if ref.startswith("#/$defs/"): - def_name = ref.split("/")[-1] - if def_name in definitions: - # Resolve the reference and use its type - resolved = definitions[def_name] - return _resolve_type(resolved, def_name) - # If we can't resolve the ref, default to dict for safety - return dict - - # Map JSON Schema types to Python types - json_type = prop_details.get("type", "string") - match json_type: - case "integer": - return int - case "number": - return float - case "boolean": - return bool - case "array": - # Handle typed arrays - items_schema = prop_details.get("items") - if isinstance(items_schema, dict): - # Recursively resolve the item type - item_type = _resolve_type(items_schema, f"{parent_name}_item") - # Return list[ItemType] instead of bare list - return list[item_type] # type: ignore - # If no items schema or invalid, return bare list - return list - case "object": - # Handle nested objects by creating a nested Pydantic model - nested_properties_raw = prop_details.get("properties") - nested_properties = nested_properties_raw if isinstance(nested_properties_raw, dict) else None - nested_required_raw = prop_details.get("required", []) - nested_required_obj: object = nested_required_raw - nested_required: set[str] = ( - {item for item in cast(list[object], nested_required_obj) if isinstance(item, str)} - if isinstance(nested_required_obj, list) - else set() - ) - - if nested_properties: - # Create the name for the nested model - nested_model_name = f"{parent_name}_nested" if parent_name else "NestedModel" - - # Recursively build field definitions for the nested model - nested_field_definitions: dict[str, Any] = {} - for nested_prop_name, nested_prop_details_raw in nested_properties.items(): - nested_prop_details_candidate = ( - json.loads(nested_prop_details_raw) - if isinstance(nested_prop_details_raw, str) - else nested_prop_details_raw - ) - if not isinstance(nested_prop_details_candidate, dict): - continue - nested_prop_details = nested_prop_details_candidate - - # Check for Literal types first (const/enum) - literal_type = _resolve_literal_type(nested_prop_details) - if literal_type is not None: - nested_python_type = literal_type - else: - nested_python_type = _resolve_type( - nested_prop_details, - f"{nested_model_name}_{nested_prop_name}", - ) - nested_description = nested_prop_details.get("description", "") - - # Build field kwargs for nested property - nested_field_kwargs: dict[str, Any] = {} - if nested_description: - nested_field_kwargs["description"] = nested_description - - # Create field definition - if nested_prop_name in nested_required: - nested_field_definitions[nested_prop_name] = ( - ( - nested_python_type, - Field(**nested_field_kwargs), - ) - if nested_field_kwargs - else (nested_python_type, ...) - ) - else: - nested_field_kwargs["default"] = nested_prop_details.get("default", None) - nested_field_definitions[nested_prop_name] = ( - nested_python_type, - Field(**nested_field_kwargs), - ) - - # Create and return the nested Pydantic model - return create_model(nested_model_name, **nested_field_definitions) # type: ignore - - # If no properties defined, return bare dict - return dict - case _: - return str # default - - field_definitions: dict[str, Any] = {} - - for prop_name, prop_details_raw in properties.items(): - prop_details_candidate = json.loads(prop_details_raw) if isinstance(prop_details_raw, str) else prop_details_raw - if not isinstance(prop_details_candidate, dict): - continue - prop_details = prop_details_candidate - - # Check for Literal types first (const/enum) - literal_type = _resolve_literal_type(prop_details) - if literal_type is not None: - python_type = literal_type - else: - python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}") - description = prop_details.get("description", "") - - # Build field kwargs (description, etc.) - field_kwargs: dict[str, Any] = {} - if description: - field_kwargs["description"] = description - - # Create field definition for create_model - if prop_name in required: - if field_kwargs: - field_definitions[prop_name] = (python_type, Field(**field_kwargs)) - else: - field_definitions[prop_name] = (python_type, ...) - else: - default_value = prop_details.get("default", None) - field_kwargs["default"] = default_value - if field_kwargs and any(k != "default" for k in field_kwargs): - field_definitions[prop_name] = (python_type, Field(**field_kwargs)) - else: - field_definitions[prop_name] = (python_type, default_value) - - return create_model(f"{model_name}_input", **field_definitions) - - -def _create_model_from_json_schema( # pyright: ignore[reportUnusedFunction] - tool_name: str, schema_json: Mapping[str, Any] -) -> type[BaseModel]: - """Creates a Pydantic model from a given JSON Schema. - - Args: - tool_name: The name of the model to be created. - schema_json: The JSON Schema definition. - - Returns: - The dynamically created Pydantic model class. - """ - # Validate that 'properties' exists and is a dict - if "properties" not in schema_json or not isinstance(schema_json["properties"], dict): - raise ValueError( - f"JSON schema for tool '{tool_name}' must contain a 'properties' key of type dict. " - f"Got: {schema_json.get('properties', None)}" - ) - - return _build_pydantic_model_from_json_schema(tool_name, schema_json) - - @overload def tool( func: Callable[..., Any], diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index 8d74dc181d..f7674edc9b 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -5,7 +5,7 @@ import pytest from opentelemetry import trace from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from agent_framework import ( Content, @@ -13,7 +13,6 @@ tool, ) from agent_framework._tools import ( - _build_pydantic_model_from_json_schema, _parse_annotation, _parse_inputs, ) @@ -1001,467 +1000,4 @@ def test_parse_annotation_with_annotated_and_literal(): assert get_args(literal_type) == ("A", "B", "C") -def test_build_pydantic_model_from_json_schema_array_of_objects_issue(): - """Test for Tools with complex input schema (array of objects). - - This test verifies that JSON schemas with array properties containing nested objects - are properly parsed, ensuring that the nested object schema is preserved - and not reduced to a bare dict. - - Example from issue: - ``` - const SalesOrderItemSchema = z.object({ - customerMaterialNumber: z.string().optional(), - quantity: z.number(), - unitOfMeasure: z.string() - }); - - const CreateSalesOrderInputSchema = z.object({ - contract: z.string(), - items: z.array(SalesOrderItemSchema) - }); - ``` - - The issue was that agents only saw: - ``` - {"contract": "str", "items": "list[dict]"} - ``` - - Instead of the proper nested schema with all fields. - """ - # Schema matching the issue description - schema = { - "type": "object", - "properties": { - "contract": {"type": "string", "description": "Reference contract number"}, - "items": { - "type": "array", - "description": "Sales order line items", - "items": { - "type": "object", - "properties": { - "customerMaterialNumber": { - "type": "string", - "description": "Customer's material number", - }, - "quantity": {"type": "number", "description": "Order quantity"}, - "unitOfMeasure": { - "type": "string", - "description": "Unit of measure (e.g., 'ST', 'KG', 'TO')", - }, - }, - "required": ["quantity", "unitOfMeasure"], - }, - }, - }, - "required": ["contract", "items"], - } - - model = _build_pydantic_model_from_json_schema("create_sales_order", schema) - - # Test valid data - valid_data = { - "contract": "CONTRACT-123", - "items": [ - { - "customerMaterialNumber": "MAT-001", - "quantity": 10, - "unitOfMeasure": "ST", - }, - {"quantity": 5.5, "unitOfMeasure": "KG"}, - ], - } - - instance = model(**valid_data) - - # Verify the data was parsed correctly - assert instance.contract == "CONTRACT-123" - assert len(instance.items) == 2 - - # Verify first item - assert instance.items[0].customerMaterialNumber == "MAT-001" - assert instance.items[0].quantity == 10 - assert instance.items[0].unitOfMeasure == "ST" - - # Verify second item (optional field not provided) - assert instance.items[1].quantity == 5.5 - assert instance.items[1].unitOfMeasure == "KG" - - # Verify that items are proper BaseModel instances, not bare dicts - assert isinstance(instance.items[0], BaseModel) - assert isinstance(instance.items[1], BaseModel) - - # Verify that the nested object has the expected fields - assert hasattr(instance.items[0], "customerMaterialNumber") - assert hasattr(instance.items[0], "quantity") - assert hasattr(instance.items[0], "unitOfMeasure") - - # CRITICAL: Validate using the same methods that actual chat clients use - # This is what would actually be sent to the LLM - - # Create a FunctionTool wrapper to access the client-facing APIs - def dummy_func(**kwargs): - return kwargs - - test_func = FunctionTool( - func=dummy_func, - name="create_sales_order", - description="Create a sales order", - input_model=model, - ) - - # Test 1: Anthropic client uses tool.parameters() directly - anthropic_schema = test_func.parameters() - - # Verify contract property - assert "contract" in anthropic_schema["properties"] - assert anthropic_schema["properties"]["contract"]["type"] == "string" - - # Verify items array property exists - assert "items" in anthropic_schema["properties"] - items_prop = anthropic_schema["properties"]["items"] - assert items_prop["type"] == "array" - - # THE KEY TEST for Anthropic: array items must have proper object schema - assert "items" in items_prop, "Array should have 'items' schema definition" - array_items_schema = items_prop["items"] - - # Resolve schema if using $ref - if "$ref" in array_items_schema: - ref_path = array_items_schema["$ref"] - assert ref_path.startswith("#/$defs/") or ref_path.startswith("#/definitions/") - ref_name = ref_path.split("/")[-1] - defs = anthropic_schema.get("$defs", anthropic_schema.get("definitions", {})) - assert ref_name in defs, f"Referenced schema '{ref_name}' should exist" - item_schema = defs[ref_name] - else: - item_schema = array_items_schema - - # Verify the nested object has all properties defined - assert "properties" in item_schema, "Array items should have properties (not bare dict)" - item_properties = item_schema["properties"] - - # All three fields must be present in schema sent to LLM - assert "customerMaterialNumber" in item_properties, "customerMaterialNumber missing from LLM schema" - assert "quantity" in item_properties, "quantity missing from LLM schema" - assert "unitOfMeasure" in item_properties, "unitOfMeasure missing from LLM schema" - - # Verify types are correct - assert item_properties["customerMaterialNumber"]["type"] == "string" - assert item_properties["quantity"]["type"] in ["number", "integer"] - assert item_properties["unitOfMeasure"]["type"] == "string" - - # Test 2: OpenAI client uses tool.to_json_schema_spec() - openai_spec = test_func.to_json_schema_spec() - - assert openai_spec["type"] == "function" - assert "function" in openai_spec - openai_schema = openai_spec["function"]["parameters"] - - # Verify the same structure is present in OpenAI format - assert "items" in openai_schema["properties"] - openai_items_prop = openai_schema["properties"]["items"] - assert openai_items_prop["type"] == "array" - assert "items" in openai_items_prop - - openai_array_items = openai_items_prop["items"] - if "$ref" in openai_array_items: - ref_path = openai_array_items["$ref"] - ref_name = ref_path.split("/")[-1] - defs = openai_schema.get("$defs", openai_schema.get("definitions", {})) - openai_item_schema = defs[ref_name] - else: - openai_item_schema = openai_array_items - - assert "properties" in openai_item_schema - openai_props = openai_item_schema["properties"] - assert "customerMaterialNumber" in openai_props - assert "quantity" in openai_props - assert "unitOfMeasure" in openai_props - - # Test validation - missing required quantity - with pytest.raises(ValidationError): - model( - contract="CONTRACT-456", - items=[ - { - "customerMaterialNumber": "MAT-002", - "unitOfMeasure": "TO", - # Missing required 'quantity' - } - ], - ) - - # Test validation - missing required unitOfMeasure - with pytest.raises(ValidationError): - model( - contract="CONTRACT-789", - items=[ - { - "quantity": 20 - # Missing required 'unitOfMeasure' - } - ], - ) - - -def test_one_of_discriminator_polymorphism(): - """Test that oneOf with discriminator creates proper polymorphic union types. - - Tests that oneOf + discriminator patterns are properly converted to Pydantic discriminated unions. - """ - schema = { - "$defs": { - "CreateProject": { - "description": "Action: Create an Azure DevOps project.", - "properties": { - "name": { - "const": "create_project", - "default": "create_project", - "type": "string", - }, - "params": {"$ref": "#/$defs/CreateProjectParams"}, - }, - "required": ["params"], - "type": "object", - }, - "CreateProjectParams": { - "description": "Parameters for the create_project action.", - "properties": { - "orgUrl": {"minLength": 1, "type": "string"}, - "projectName": {"minLength": 1, "type": "string"}, - "description": {"default": "", "type": "string"}, - "template": {"default": "Agile", "type": "string"}, - "sourceControl": { - "default": "Git", - "enum": ["Git", "Tfvc"], - "type": "string", - }, - "visibility": {"default": "private", "type": "string"}, - }, - "required": ["orgUrl", "projectName"], - "type": "object", - }, - "DeployRequest": { - "description": "Request to deploy Azure DevOps resources.", - "properties": { - "projectName": {"minLength": 1, "type": "string"}, - "organization": {"minLength": 1, "type": "string"}, - "actions": { - "items": { - "discriminator": { - "mapping": { - "create_project": "#/$defs/CreateProject", - "hello_world": "#/$defs/HelloWorld", - }, - "propertyName": "name", - }, - "oneOf": [ - {"$ref": "#/$defs/HelloWorld"}, - {"$ref": "#/$defs/CreateProject"}, - ], - }, - "type": "array", - }, - }, - "required": ["projectName", "organization"], - "type": "object", - }, - "HelloWorld": { - "description": "Action: Prints a greeting message.", - "properties": { - "name": { - "const": "hello_world", - "default": "hello_world", - "type": "string", - }, - "params": {"$ref": "#/$defs/HelloWorldParams"}, - }, - "required": ["params"], - "type": "object", - }, - "HelloWorldParams": { - "description": "Parameters for the hello_world action.", - "properties": { - "name": { - "description": "Name to greet", - "minLength": 1, - "type": "string", - } - }, - "required": ["name"], - "type": "object", - }, - }, - "properties": {"params": {"$ref": "#/$defs/DeployRequest"}}, - "required": ["params"], - "type": "object", - } - - # Build the model - model = _build_pydantic_model_from_json_schema("deploy_tool", schema) - - # Verify the model structure - assert model is not None - assert issubclass(model, BaseModel) - - # Test with HelloWorld action - hello_world_data = { - "params": { - "projectName": "MyProject", - "organization": "MyOrg", - "actions": [ - { - "name": "hello_world", - "params": {"name": "Alice"}, - } - ], - } - } - - instance = model(**hello_world_data) - assert instance.params.projectName == "MyProject" - assert instance.params.organization == "MyOrg" - assert len(instance.params.actions) == 1 - assert instance.params.actions[0].name == "hello_world" - assert instance.params.actions[0].params.name == "Alice" - - # Test with CreateProject action - create_project_data = { - "params": { - "projectName": "MyProject", - "organization": "MyOrg", - "actions": [ - { - "name": "create_project", - "params": { - "orgUrl": "https://dev.azure.com/myorg", - "projectName": "NewProject", - "sourceControl": "Git", - }, - } - ], - } - } - - instance2 = model(**create_project_data) - assert instance2.params.actions[0].name == "create_project" - assert instance2.params.actions[0].params.projectName == "NewProject" - assert instance2.params.actions[0].params.sourceControl == "Git" - - # Test with mixed actions - mixed_data = { - "params": { - "projectName": "MyProject", - "organization": "MyOrg", - "actions": [ - {"name": "hello_world", "params": {"name": "Bob"}}, - { - "name": "create_project", - "params": { - "orgUrl": "https://dev.azure.com/myorg", - "projectName": "AnotherProject", - }, - }, - ], - } - } - - instance3 = model(**mixed_data) - assert len(instance3.params.actions) == 2 - assert instance3.params.actions[0].name == "hello_world" - assert instance3.params.actions[1].name == "create_project" - - -def test_const_creates_literal(): - """Test that const in JSON Schema creates Literal type.""" - schema = { - "properties": { - "action": { - "const": "create", - "type": "string", - "description": "Action type", - }, - "value": {"type": "integer"}, - }, - "required": ["action", "value"], - } - - model = _build_pydantic_model_from_json_schema("test_const", schema) - - # Verify valid const value works - instance = model(action="create", value=42) - assert instance.action == "create" - assert instance.value == 42 - - # Verify incorrect const value fails - with pytest.raises(ValidationError): - model(action="delete", value=42) - - -def test_enum_creates_literal(): - """Test that enum in JSON Schema creates Literal type.""" - schema = { - "properties": { - "status": { - "enum": ["pending", "approved", "rejected"], - "type": "string", - "description": "Status", - }, - "priority": {"enum": [1, 2, 3], "type": "integer"}, - }, - "required": ["status"], - } - - model = _build_pydantic_model_from_json_schema("test_enum", schema) - - # Verify valid enum values work - instance = model(status="approved", priority=2) - assert instance.status == "approved" - assert instance.priority == 2 - - # Verify invalid enum value fails - with pytest.raises(ValidationError): - model(status="unknown") - - with pytest.raises(ValidationError): - model(status="pending", priority=5) - - -def test_nested_object_with_const_and_enum(): - """Test that const and enum work in nested objects.""" - schema = { - "properties": { - "config": { - "type": "object", - "properties": { - "type": { - "const": "production", - "default": "production", - "type": "string", - }, - "level": {"enum": ["low", "medium", "high"], "type": "string"}, - }, - "required": ["level"], - } - }, - "required": ["config"], - } - - model = _build_pydantic_model_from_json_schema("test_nested", schema) - - # Valid data - instance = model(config={"type": "production", "level": "high"}) - assert instance.config.type == "production" - assert instance.config.level == "high" - - # Invalid const in nested object - with pytest.raises(ValidationError): - model(config={"type": "development", "level": "low"}) - - # Invalid enum in nested object - with pytest.raises(ValidationError): - model(config={"type": "production", "level": "critical"}) - - # endregion diff --git a/python/packages/declarative/agent_framework_declarative/_loader.py b/python/packages/declarative/agent_framework_declarative/_loader.py index ebbc61280d..1ff403e87a 100644 --- a/python/packages/declarative/agent_framework_declarative/_loader.py +++ b/python/packages/declarative/agent_framework_declarative/_loader.py @@ -15,7 +15,6 @@ from agent_framework import ( FunctionTool as AFFunctionTool, ) -from agent_framework._tools import _create_model_from_json_schema # type: ignore from agent_framework.exceptions import AgentException from dotenv import load_dotenv @@ -445,7 +444,7 @@ def create_agent_from_dict(self, agent_def: dict[str, Any]) -> Agent: if tools := self._parse_tools(prompt_agent.tools): chat_options["tools"] = tools if output_schema := prompt_agent.outputSchema: - chat_options["response_format"] = _create_model_from_json_schema("agent", output_schema.to_json_schema()) + chat_options["response_format"] = output_schema.to_json_schema() # Step 3: Create the agent instance return Agent( client=client, @@ -563,7 +562,7 @@ async def create_agent_from_dict_async(self, agent_def: dict[str, Any]) -> Agent if tools := self._parse_tools(prompt_agent.tools): chat_options["tools"] = tools if output_schema := prompt_agent.outputSchema: - chat_options["response_format"] = _create_model_from_json_schema("agent", output_schema.to_json_schema()) + chat_options["response_format"] = output_schema.to_json_schema() return Agent( client=client, name=prompt_agent.name, @@ -611,8 +610,7 @@ async def _create_agent_with_provider(self, prompt_agent: PromptAgent, mapping: # Parse response format into default_options default_options: dict[str, Any] | None = None if prompt_agent.outputSchema: - response_format = _create_model_from_json_schema("agent", prompt_agent.outputSchema.to_json_schema()) - default_options = {"response_format": response_format} + default_options = {"response_format": prompt_agent.outputSchema.to_json_schema()} # Create the agent using the provider # The provider's create_agent returns a Agent directly diff --git a/python/packages/declarative/tests/test_declarative_loader.py b/python/packages/declarative/tests/test_declarative_loader.py index aee0d762d9..2ca87bfa65 100644 --- a/python/packages/declarative/tests/test_declarative_loader.py +++ b/python/packages/declarative/tests/test_declarative_loader.py @@ -560,8 +560,6 @@ def test_create_agent_from_dict_output_schema_in_default_options(self): """Test that outputSchema is passed as response_format in Agent.default_options.""" from unittest.mock import MagicMock - from pydantic import BaseModel - from agent_framework_declarative import AgentFactory agent_def = { @@ -580,8 +578,10 @@ def test_create_agent_from_dict_output_schema_in_default_options(self): agent = factory.create_agent_from_dict(agent_def) assert "response_format" in agent.default_options - assert isinstance(agent.default_options["response_format"], type) - assert issubclass(agent.default_options["response_format"], BaseModel) + response_format = agent.default_options["response_format"] + assert isinstance(response_format, dict) + assert response_format["type"] == "object" + assert response_format["properties"]["answer"]["type"] == "string" def test_create_agent_from_dict_chat_options_in_default_options(self): """Test that chat options (temperature, top_p) are in Agent.default_options.""" From 84918301c68dca584ac170078386f75a045f87bd Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 5 Mar 2026 15:46:49 +0100 Subject: [PATCH 12/15] final updates --- python/CODING_STANDARD.md | 10 +- .../claude/agent_framework_claude/_agent.py | 61 +--- .../packages/core/agent_framework/__init__.py | 6 +- .../packages/core/agent_framework/_types.py | 340 ++++++------------ .../core/agent_framework/observability.py | 10 +- .../packages/core/tests/core/test_skills.py | 8 +- python/packages/core/tests/core/test_types.py | 20 +- .../agent_framework_declarative/_loader.py | 2 +- .../_workflows/_executors_agents.py | 4 +- .../_workflows/_executors_basic.py | 32 +- .../_workflows/_executors_tools.py | 29 +- .../_workflows/_powerfx_functions.py | 8 +- .../_workflows/_state.py | 6 +- .../devui/agent_framework_devui/__init__.py | 7 +- .../agent_framework_devui/_conversations.py | 29 +- .../devui/agent_framework_devui/_discovery.py | 8 +- .../devui/agent_framework_devui/_mapper.py | 8 +- .../_openai/_executor.py | 10 +- .../devui/agent_framework_devui/_utils.py | 5 +- .../models/_discovery_models.py | 11 +- .../agent_framework_ollama/_chat_client.py | 4 +- .../_embedding_client.py | 4 +- .../_handoff.py | 8 +- 23 files changed, 227 insertions(+), 403 deletions(-) diff --git a/python/CODING_STANDARD.md b/python/CODING_STANDARD.md index 92671b6a19..ccb8e058e3 100644 --- a/python/CODING_STANDARD.md +++ b/python/CODING_STANDARD.md @@ -27,6 +27,12 @@ Public modules must include a module-level docstring, including `__init__.py` fi ## Type Annotations +We use typing as a helper, it is not a goal in and of itself, so be pragmatic about where and when to strictly type, versus when to use a targetted cast or ignore. +In general, the public interfaces of our classes, are important to get right, internally it is okay to have loosely typed code, as long as tests cover the code itself. +This includes making a conscious choice when to program defensively, you can always do `getattr(item, 'attribute')` but that might end up causing you issues down the road +because the type of `item` in this case, should have that attribute and if it doesn't it points to a larger issue, so if the type is expected to have that attribute, you should +use `item.attribute` to ensure it fails at that point, rather then somewhere downstream where a value is expected but none was found. + ### Future Annotations > **Note:** This convention is being adopted. See [#3578](https://github.com/microsoft/agent-framework/issues/3578) for progress. @@ -87,8 +93,10 @@ Use typing as a helper first and suppressions as a last resort: protocols, or refactoring dynamic code into typed helpers. Prioritize performance over completeness of typing, but make a good-faith effort to reduce uncertainty with typing before ignoring. Prefer to use a cast over a typeguard function since that does add overhead. - **Avoid redundant casts**: Do not add `cast(...)` if the type already matches; casts should be reserved for unavoidable narrowing where the runtime contract is known, we will use mypy's check on redundant casts to enforce this. +- **Avoid multiple assignments**: Avoid assigning multiple variables just to get typing to pass, that has performance impact while typing should not have that. - **Line-level pyright ignores only**: If suppression is still required, use a line-level rule-specific ignore - (`# pyright: ignore[reportGeneralTypeIssues]`), never file-level or global suppression for this workflow. + (`# pyright: ignore[reportGeneralTypeIssues]`), file-level is allowed if there is a compelling reason for it, that should be documented right beneath the ignore. + Never change the global suppression flags for mypy and pyright unless the dev team okays it. - **Private usage boundary**: Accessing private members across `agent_framework*` packages can be acceptable for this codebase, but private member usage for non-Agent Framework dependencies should remain flagged. diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index bc26a3d515..127e3647ee 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -225,11 +225,7 @@ def __init__( description: str | None = None, context_providers: Sequence[BaseContextProvider] | None = None, middleware: Sequence[AgentMiddlewareTypes] | None = None, - tools: ToolTypes - | Callable[..., Any] - | str - | Sequence[ToolTypes | Callable[..., Any] | str] - | None = None, + tools: ToolTypes | Callable[..., Any] | str | Sequence[ToolTypes | Callable[..., Any] | str] | None = None, default_options: OptionsT | MutableMapping[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, @@ -305,11 +301,7 @@ def __init__( def _normalize_tools( self, - tools: ToolTypes - | Callable[..., Any] - | str - | Sequence[ToolTypes | Callable[..., Any] | str] - | None, + tools: ToolTypes | Callable[..., Any] | str | Sequence[ToolTypes | Callable[..., Any] | str] | None, ) -> None: """Separate built-in tools (strings) from custom tools. @@ -374,9 +366,7 @@ async def _ensure_session(self, session_id: str | None = None) -> None: session_id: The session ID to use, or None for a new session. """ needs_new_client = ( - not self._started - or self._client is None - or (session_id and session_id != self._current_session_id) + not self._started or self._client is None or (session_id and session_id != self._current_session_id) ) if needs_new_client: @@ -399,9 +389,7 @@ async def _ensure_session(self, session_id: str | None = None) -> None: self._client = None raise AgentException(f"Failed to start Claude SDK client: {ex}") from ex - def _prepare_client_options( - self, resume_session_id: str | None = None - ) -> SDKOptions: + def _prepare_client_options(self, resume_session_id: str | None = None) -> SDKOptions: """Prepare SDK options for client initialization. Args: @@ -441,9 +429,7 @@ def _prepare_client_options( # Prepare custom tools (FunctionTool instances) custom_tools_server, custom_tool_names = ( - self._prepare_tools(self._custom_tools) - if self._custom_tools - else (None, []) + self._prepare_tools(self._custom_tools) if self._custom_tools else (None, []) ) # MCP servers - merge user-provided servers with custom tools server @@ -490,13 +476,9 @@ def _prepare_tools( if not sdk_tools: return None, [] - return create_sdk_mcp_server( - name=TOOLS_MCP_SERVER_NAME, tools=sdk_tools - ), tool_names + return create_sdk_mcp_server(name=TOOLS_MCP_SERVER_NAME, tools=sdk_tools), tool_names - def _function_tool_to_sdk_mcp_tool( - self, func_tool: FunctionTool - ) -> SdkMcpTool[Any]: + def _function_tool_to_sdk_mcp_tool(self, func_tool: FunctionTool) -> SdkMcpTool[Any]: """Convert a FunctionTool to an SDK MCP tool. Args: @@ -519,9 +501,7 @@ async def handler(args: dict[str, Any]) -> dict[str, Any]: return {"content": [{"type": "text", "text": f"Error: {e}"}]} # Get JSON schema from pydantic model - schema: dict[str, Any] = ( - func_tool.input_model.model_json_schema() if func_tool.input_model else {} - ) + schema: dict[str, Any] = func_tool.input_model.model_json_schema() if func_tool.input_model else {} input_schema: dict[str, Any] = { "type": "object", "properties": schema.get("properties", {}), @@ -582,9 +562,7 @@ def default_options(self) -> dict[str, Any]: opts["instructions"] = system_prompt return opts - def _finalize_response( - self, updates: Sequence[AgentResponseUpdate] - ) -> AgentResponse[Any]: + def _finalize_response(self, updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]: """Build AgentResponse and propagate structured_output as value. Args: @@ -623,10 +601,7 @@ def run( stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> ( - Awaitable[AgentResponse[Any]] - | ResponseStream[AgentResponseUpdate, AgentResponse[Any]] - ): + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Run the agent with the given messages. Args: @@ -692,11 +667,7 @@ async def _get_stream( if text: yield AgentResponseUpdate( role="assistant", - contents=[ - Content.from_text( - text=text, raw_representation=message - ) - ], + contents=[Content.from_text(text=text, raw_representation=message)], raw_representation=message, ) elif delta_type == "thinking_delta": @@ -704,11 +675,7 @@ async def _get_stream( if thinking: yield AgentResponseUpdate( role="assistant", - contents=[ - Content.from_text_reasoning( - text=thinking, raw_representation=message - ) - ], + contents=[Content.from_text_reasoning(text=thinking, raw_representation=message)], raw_representation=message, ) elif isinstance(message, AssistantMessage): @@ -725,9 +692,7 @@ async def _get_stream( "server_error": "Claude API server error", "unknown": "Unknown error from Claude API", } - error_msg = error_messages.get( - message.error, f"Claude API error: {message.error}" - ) + error_msg = error_messages.get(message.error, f"Claude API error: {message.error}") # Extract any error details from content blocks if message.content: for block in message.content: diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 1cbcc7a8cb..ef03652898 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -205,9 +205,6 @@ "AgentResponseUpdate", "AgentRunInputs", "AgentSession", - "Skill", - "SkillResource", - "SkillsProvider", "Annotation", "BaseAgent", "BaseChatClient", @@ -272,6 +269,9 @@ "SecretString", "SessionContext", "SingleEdgeGroup", + "Skill", + "SkillResource", + "SkillsProvider", "SubWorkflowRequestMessage", "SubWorkflowResponseMessage", "SupportsAgentRun", diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 09015f4b43..7ae9dbaa3d 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -3,7 +3,6 @@ from __future__ import annotations import base64 -import inspect import json import logging import re @@ -22,9 +21,11 @@ ) from copy import deepcopy from datetime import datetime -from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, NewType, TypeGuard, cast, overload +from inspect import isawaitable +from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, NewType, cast, overload from pydantic import BaseModel +from typing_extensions import TypedDict from ._serialization import SerializationMixin from ._tools import ToolTypes @@ -35,10 +36,6 @@ from typing import TypeVar # pragma: no cover else: from typing_extensions import TypeVar # pragma: no cover -if sys.version_info >= (3, 11): - from typing import TypedDict # type: ignore # pragma: no cover -else: - from typing_extensions import TypedDict # type: ignore # pragma: no cover logger = logging.getLogger("agent_framework") @@ -221,20 +218,6 @@ def _get_data_bytes(content: Content) -> bytes | None: # pyright: ignore[report KNOWN_URI_SCHEMAS: Final[set[str]] = {"http", "https", "ftp", "ftps", "file", "s3", "gs", "azure", "blob"} -def _is_legacy_value_mapping(value: object) -> TypeGuard[Mapping[str, str]]: - if not isinstance(value, Mapping): - return False - mapping = cast(Mapping[object, object], value) - return isinstance(mapping.get("value"), str) - - -def _is_str_key_mapping(value: object) -> TypeGuard[Mapping[str, Any]]: - if not isinstance(value, Mapping): - return False - mapping = cast(Mapping[object, object], value) - return all(isinstance(key, str) for key in mapping) - - def _validate_uri(uri: str, media_type: str | None) -> dict[str, Any]: """Validate URI format and return validation result. @@ -286,18 +269,9 @@ def _serialize_value(value: Any, exclude_none: bool) -> Any: if isinstance(value, Content): return value.to_dict(exclude_none=exclude_none) if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): - serialized_items: list[Any] = [] - for item in cast(Iterable[Any], value): - item_any: Any = item - serialized_items.append(_serialize_value(item_any, exclude_none)) - return serialized_items + return [_serialize_value(item, exclude_none) for item in cast(Iterable[Any], value)] if isinstance(value, Mapping): - serialized_mapping: dict[Any, Any] = {} - for key, map_value in cast(Mapping[object, object], value).items(): - key_any: Any = key - map_value_any: Any = map_value - serialized_mapping[key_any] = _serialize_value(map_value_any, exclude_none) - return serialized_mapping + return {k: _serialize_value(v, exclude_none) for k, v in value.items()} # type: ignore[reportUnknownVariableType] if hasattr(value, "to_dict"): return value.to_dict() # type: ignore[call-arg] return value @@ -401,7 +375,7 @@ class Annotation(TypedDict, total=False): # endregion -class UsageDetails(TypedDict, total=False): +class UsageDetails(TypedDict, total=False, extra_items=int): # type: ignore[call-arg] """A dictionary representing usage details. This is a non-closed dictionary, so any specific provider fields can be added as needed. @@ -422,6 +396,9 @@ class UsageDetails(TypedDict, total=False): def add_usage_details(usage1: UsageDetails | None, usage2: UsageDetails | None) -> UsageDetails: """Add two UsageDetails dictionaries by summing all numeric values. + If any of the two usage details contains a key with a non-int value, it will be skipped, + even if the other contains a int-value on that key. + Args: usage1: First usage details dictionary. usage2: Second usage details dictionary. @@ -445,22 +422,15 @@ def add_usage_details(usage1: UsageDetails | None, usage2: UsageDetails | None) return usage1 result = UsageDetails() - # Combine all keys from both dictionaries all_keys = set(usage1.keys()) | set(usage2.keys()) - for key in all_keys: - val1 = usage1.get(key) - val2 = usage2.get(key) - - # Sum if both present, otherwise use the non-None value - if val1 is not None and val2 is not None: - result[key] = val1 + val2 # type: ignore[literal-required, operator] - elif val1 is not None: - result[key] = val1 # type: ignore[literal-required] - elif val2 is not None: - result[key] = val2 # type: ignore[literal-required] - + if not isinstance((val1 := usage1.get(key, 0)), (int | None)) or not isinstance( + (val2 := usage2.get(key, 0)), (int | None) + ): + logger.warning("Non `int` value found in usage details, skipping.") + continue + result[key] = (val1 or 0) + (val2 or 0) # type: ignore[literal-required] return result @@ -490,7 +460,7 @@ def __init__( error_code: str | None = None, error_details: str | None = None, # Usage content fields - usage_details: dict[str, Any] | UsageDetails | None = None, + usage_details: UsageDetails | None = None, # Function call/result fields call_id: str | None = None, name: str | None = None, @@ -1289,20 +1259,14 @@ def from_dict(cls: type[ContentT], data: Mapping[str, Any]) -> ContentT: return cls.from_data(remaining["data"], remaining["media_type"]) # Handle nested Content objects (e.g., function_call in function_approval_request) - function_call_raw = remaining.get("function_call") - if _is_str_key_mapping(function_call_raw): - remaining["function_call"] = cls.from_dict(function_call_raw) + if (function_call := remaining.get("function_call")) and isinstance(function_call, dict): + remaining["function_call"] = cls.from_dict(function_call) # type: ignore[reportUnknownArgumentType] # Handle list of Content objects (e.g., inputs in code_interpreter_tool_call) - input_items_obj: object = remaining.get("inputs") - if isinstance(input_items_obj, list): - input_items: list[Any] = list(cast(Iterable[Any], input_items_obj)) - remaining["inputs"] = [cls.from_dict(item) if _is_str_key_mapping(item) else item for item in input_items] - - output_items_obj: object = remaining.get("outputs") - if isinstance(output_items_obj, list): - output_items: list[Any] = list(cast(Iterable[Any], output_items_obj)) - remaining["outputs"] = [cls.from_dict(item) if _is_str_key_mapping(item) else item for item in output_items] + if (input_items := remaining.get("inputs")) and isinstance(input_items, list): + remaining["inputs"] = [cls.from_dict(item) if isinstance(item, dict) else item for item in input_items] # type: ignore[reportUnknownVariableType] + if (output_items := remaining.get("outputs")) and isinstance(output_items, list): + remaining["outputs"] = [cls.from_dict(item) if isinstance(item, dict) else item for item in output_items] # type: ignore[reportUnknownVariableType] return cls( type=content_type, @@ -1332,70 +1296,16 @@ def __add__(self, other: Content) -> Content: def _add_text_content(self, other: Content) -> Content: """Add two TextContent instances.""" - # Merge raw representations - raw_representation: Any - if self.raw_representation is None: - raw_representation = other.raw_representation - elif other.raw_representation is None: - raw_representation = self.raw_representation - else: - self_raw_repr: object = self.raw_representation - other_raw_repr: object = other.raw_representation - self_raw: list[object] = ( - cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] - ) - other_raw: list[object] = ( - cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] - ) - raw_representation = self_raw + other_raw - - # Merge annotations - annotations: Sequence[Annotation] | None - if self.annotations is None: - annotations = other.annotations - elif other.annotations is None: - annotations = self.annotations - else: - annotations = [*self.annotations, *other.annotations] - return Content( "text", text=self.text + other.text, # type: ignore[attr-defined, operator] - annotations=annotations, - additional_properties={ - **(other.additional_properties or {}), - **(self.additional_properties or {}), - }, - raw_representation=raw_representation, + annotations=_combine_annotations(self.annotations, other.annotations), + additional_properties=_combine_additional_props(self.additional_properties, other.additional_properties), + raw_representation=_combine_raw_representations(self.raw_representation, other.raw_representation), ) def _add_text_reasoning_content(self, other: Content) -> Content: """Add two TextReasoningContent instances.""" - # Merge raw representations - if self.raw_representation is None: - raw_representation = other.raw_representation - elif other.raw_representation is None: - raw_representation = self.raw_representation - else: - self_raw_repr: object = self.raw_representation - other_raw_repr: object = other.raw_representation - self_raw: list[object] = ( - cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] - ) - other_raw: list[object] = ( - cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] - ) - raw_representation = self_raw + other_raw - - # Merge annotations - annotations: Sequence[Annotation] | None - if self.annotations is None: - annotations = other.annotations - elif other.annotations is None: - annotations = self.annotations - else: - annotations = [*self.annotations, *other.annotations] - # Concatenate text, handling None values self_text = self.text or "" # type: ignore[attr-defined] other_text = other.text or "" # type: ignore[attr-defined] @@ -1408,12 +1318,9 @@ def _add_text_reasoning_content(self, other: Content) -> Content: "text_reasoning", text=combined_text, protected_data=protected_data, - annotations=annotations, - additional_properties={ - **(other.additional_properties or {}), - **(self.additional_properties or {}), - }, - raw_representation=raw_representation, + annotations=_combine_annotations(self.annotations, other.annotations), + additional_properties=_combine_additional_props(self.additional_properties, other.additional_properties), + raw_representation=_combine_raw_representations(self.raw_representation, other.raw_representation), ) def _add_function_call_content(self, other: Content) -> Content: @@ -1437,77 +1344,23 @@ def _add_function_call_content(self, other: Content) -> Content: else: raise TypeError("Incompatible argument types") - # Merge raw representations - if self.raw_representation is None: - raw_representation: Any = other.raw_representation - elif other.raw_representation is None: - raw_representation = self.raw_representation - else: - self_raw_repr: object = self.raw_representation - other_raw_repr: object = other.raw_representation - self_raw: list[object] = ( - cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] - ) - other_raw: list[object] = ( - cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] - ) - raw_representation = self_raw + other_raw - return Content( "function_call", call_id=self_call_id, name=getattr(self, "name", getattr(other, "name", None)), arguments=arguments, exception=getattr(self, "exception", None) or getattr(other, "exception", None), - additional_properties={ - **(self.additional_properties or {}), - **(other.additional_properties or {}), - }, - raw_representation=raw_representation, + additional_properties=_combine_additional_props(self.additional_properties, other.additional_properties), + raw_representation=_combine_raw_representations(self.raw_representation, other.raw_representation), ) def _add_usage_content(self, other: Content) -> Content: """Add two UsageContent instances by combining their usage details.""" - self_details = getattr(self, "usage_details", {}) - other_details = getattr(other, "usage_details", {}) - - # Combine token counts - combined_details: dict[str, Any] = {} - for key in set(list(self_details.keys()) + list(other_details.keys())): - self_val = self_details.get(key) - other_val = other_details.get(key) - if isinstance(self_val, int) and isinstance(other_val, int): - combined_details[key] = self_val + other_val - elif self_val is not None: - combined_details[key] = self_val - elif other_val is not None: - combined_details[key] = other_val - - # Merge raw representations - raw_representation: Any - if self.raw_representation is None: - raw_representation = other.raw_representation - elif other.raw_representation is None: - raw_representation = self.raw_representation - else: - self_raw_repr: object = self.raw_representation - other_raw_repr: object = other.raw_representation - self_raw: list[object] = ( - cast(list[object], self_raw_repr) if isinstance(self_raw_repr, list) else [self_raw_repr] - ) - other_raw: list[object] = ( - cast(list[object], other_raw_repr) if isinstance(other_raw_repr, list) else [other_raw_repr] - ) - raw_representation = self_raw + other_raw - return Content( "usage", - usage_details=combined_details, - additional_properties={ - **(self.additional_properties or {}), - **(other.additional_properties or {}), - }, - raw_representation=raw_representation, + usage_details=add_usage_details(self.usage_details, other.usage_details), + additional_properties=_combine_additional_props(self.additional_properties, other.additional_properties), + raw_representation=_combine_raw_representations(self.raw_representation, other.raw_representation), ) def has_top_level_media_type(self, top_level_media_type: Literal["application", "audio", "image", "text"]) -> bool: @@ -1584,6 +1437,42 @@ def parse_arguments(self) -> dict[str, Any | None] | None: return self.arguments # type: ignore[return-value] +def _combine_additional_props( + self_additional_properties: dict[str, Any], other_additional_properties: dict[str, Any] +) -> dict[str, Any]: + """Combine additional properties for addition operations.""" + return { + **other_additional_properties, + **self_additional_properties, + } + + +def _combine_raw_representations( + self_repr: Any, + other_repr: Any, +) -> Any: + """Combine raw representations for addition operations.""" + if self_repr is None: + return other_repr + if other_repr is None: + return self_repr + self_list = self_repr if isinstance(self_repr, list) else [self_repr] # type: ignore[reportUnknownVariableType] + other_list = other_repr if isinstance(other_repr, list) else [other_repr] # type: ignore[reportUnknownVariableType] + return self_list + other_list # type: ignore[reportUnknownVariableType] + + +def _combine_annotations( + self_annotations: Sequence[Annotation] | None, + other_annotations: Sequence[Annotation] | None, +) -> Sequence[Annotation] | None: + """Combine annotations for addition operations.""" + if self_annotations is None: + return other_annotations + if other_annotations is None: + return self_annotations + return [*self_annotations, *other_annotations] + + # endregion @@ -1719,10 +1608,6 @@ def __init__( Additional properties are used within Agent Framework, they are not sent to services. raw_representation: Optional raw representation of the chat message. """ - # Handle role conversion from legacy dict format - if _is_legacy_value_mapping(role): - role = role["value"] - # Handle contents conversion parsed_contents = [] if contents is None else _parse_content_list(contents) @@ -2080,9 +1965,6 @@ def __init__( self.conversation_id = conversation_id self.model_id = model_id self.created_at = created_at - # Handle legacy dict format for finish_reason - if _is_legacy_value_mapping(finish_reason): - finish_reason = cast(FinishReasonLiteral | FinishReason, finish_reason["value"]) self.finish_reason = finish_reason self.usage_details = usage_details self._value: ResponseModelT | None = value @@ -2674,10 +2556,6 @@ def __init__( processed_contents.append(c) self.contents = processed_contents - # Handle legacy dict format for role - if _is_legacy_value_mapping(role): - role = role["value"] - self.role: str | None = role self.author_name = author_name self.agent_id = agent_id @@ -2726,12 +2604,6 @@ def map_chat_to_agent_update(update: ChatResponseUpdate, agent_name: str | None) OuterFinalT = TypeVar("OuterFinalT") -async def _await_if_needed(value: _T | Awaitable[_T]) -> _T: - if inspect.isawaitable(value): - return await cast(Awaitable[_T], value) - return value - - class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]): """Async stream wrapper that supports iteration and deferred finalization.""" @@ -2777,7 +2649,7 @@ def __init__( self._inner_stream: ResponseStream[Any, Any] | None = None self._inner_stream_source: ResponseStream[Any, Any] | Awaitable[ResponseStream[Any, Any]] | None = None self._wrap_inner: bool = False - self._map_update: Callable[[Any], Any | Awaitable[Any]] | None = None + self._map_update: Callable[[Any], UpdateT | Awaitable[UpdateT]] | None = None def map( self, @@ -2909,22 +2781,16 @@ async def __anext__(self) -> UpdateT: await self._run_cleanup_hooks() raise if self._map_update is not None: - mapped = self._map_update(update) - if isinstance(mapped, Awaitable): - mapped_any: Any = cast(Any, await mapped) - update = cast(UpdateT, mapped_any) - else: - mapped_any = mapped - update = cast(UpdateT, mapped_any) + update = self._map_update(update) # type: ignore[assignment] + if isawaitable(update): + update = await update self._updates.append(update) for hook in self._transform_hooks: hooked = hook(update) - if isinstance(hooked, Awaitable): - hooked_any: Any = cast(Any, await hooked) - update = cast(UpdateT, hooked_any) - elif hooked is not None: - hooked_any = cast(Any, hooked) - update = cast(UpdateT, hooked_any) + if isawaitable(hooked): + hooked = await hooked + if hooked is not None: + update = hooked return update def __await__(self) -> Any: @@ -2971,14 +2837,18 @@ async def get_final_response(self) -> FinalT: inner_result: Any if inner_stream._finalizer is not None: inner_finalizer = inner_stream._finalizer - inner_result = await _await_if_needed(inner_finalizer(inner_stream._updates)) + inner_result = inner_finalizer(inner_stream._updates) + if isawaitable(inner_result): + inner_result = await inner_result else: inner_result = list(inner_stream._updates) # Run inner stream's result hooks inner_hooks = cast(list[Callable[[Any], Any | Awaitable[Any] | None]], inner_stream._result_hooks) for hook in inner_hooks: - hooked_result = await _await_if_needed(hook(inner_result)) + hooked_result = hook(inner_result) + if isawaitable(hooked_result): + hooked_result = await hooked_result if hooked_result is not None: inner_result = hooked_result inner_stream._final_result = inner_result @@ -2988,7 +2858,9 @@ async def get_final_response(self) -> FinalT: # If outer has no finalizer, use inner's result (preserves from_awaitable behavior) outer_result: Any if self._finalizer is not None: - outer_result = await _await_if_needed(self._finalizer(self._updates)) + outer_result = self._finalizer(self._updates) + if isawaitable(outer_result): + outer_result = await outer_result else: # No outer finalizer - use inner's finalized result outer_result = inner_result @@ -2996,12 +2868,14 @@ async def get_final_response(self) -> FinalT: # Apply outer's result_hooks outer_hooks = cast(list[Callable[[Any], Any | Awaitable[Any] | None]], self._result_hooks) for hook in outer_hooks: - outer_hooked_result = await _await_if_needed(hook(outer_result)) - if outer_hooked_result is not None: - outer_result = outer_hooked_result - self._final_result = cast(FinalT, outer_result) + outer_hook_result = hook(outer_result) + if isawaitable(outer_hook_result): + outer_hook_result = await outer_hook_result + if outer_hook_result is not None: + outer_result = outer_hook_result + self._final_result = outer_result self._finalized = True - return cast(FinalT, self._final_result) + return self._final_result # type: ignore[return-value] if not self._finalized: if not self._consumed: @@ -3011,18 +2885,22 @@ async def get_final_response(self) -> FinalT: # Use finalizer if configured, otherwise return collected updates result: Any if self._finalizer is not None: - result = await _await_if_needed(self._finalizer(self._updates)) + result = self._finalizer(self._updates) + if isawaitable(result): + result = await result else: result = list(self._updates) final_hooks = cast(list[Callable[[Any], Any | Awaitable[Any] | None]], self._result_hooks) for hook in final_hooks: - final_hook_result = await _await_if_needed(hook(result)) + final_hook_result = hook(result) + if isawaitable(final_hook_result): + final_hook_result = await final_hook_result if final_hook_result is not None: result = final_hook_result - self._final_result = cast(FinalT, result) + self._final_result = result self._finalized = True - return cast(FinalT, self._final_result) + return self._final_result # type: ignore[return-value] def with_transform_hook( self, @@ -3056,7 +2934,7 @@ async def _run_cleanup_hooks(self) -> None: self._cleanup_run = True for hook in self._cleanup_hooks: result = hook() - if isinstance(result, Awaitable): + if isawaitable(result): await result @property @@ -3367,11 +3245,9 @@ def merge_chat_options( # Copy base values (shallow copy for simple values, dict copy for dicts) for key, value in base.items(): if isinstance(value, dict): - dict_value = cast(Mapping[Any, Any], value) - result[key] = dict(dict_value) + result[key] = dict(value) # type: ignore[reportUnknownArgumentType] elif isinstance(value, list): - list_value: list[Any] = list(cast(Iterable[Any], value)) - result[key] = list(list_value) + result[key] = list(value) # type: ignore[reportUnknownArgumentType] else: result[key] = value @@ -3392,12 +3268,8 @@ def merge_chat_options( base_tools = result.get("tools") if base_tools and value: # Add tools that aren't already present - base_tool_values: list[Any] = ( - list(cast(Iterable[Any], base_tools)) if isinstance(base_tools, list) else [base_tools] - ) - merged_tools = list(base_tool_values) - tool_values: list[Any] = list(cast(Iterable[Any], value)) if isinstance(value, list) else [value] - for tool in tool_values: + merged_tools = list(base_tools) + for tool in value if isinstance(value, Iterable) else [value]: # type: ignore[reportUnknownVariableType] if tool not in merged_tools: merged_tools.append(tool) result["tools"] = merged_tools diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 73c08ff90e..4fe84d5f78 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1060,8 +1060,8 @@ def configure_otel_providers( OBSERVABILITY_SETTINGS.vs_code_extension_port = updated_settings.vs_code_extension_port OBSERVABILITY_SETTINGS.env_file_path = updated_settings.env_file_path OBSERVABILITY_SETTINGS.env_file_encoding = updated_settings.env_file_encoding - OBSERVABILITY_SETTINGS._resource = updated_settings._resource # pyright: ignore[reportPrivateUsage] - OBSERVABILITY_SETTINGS._executed_setup = False # pyright: ignore[reportPrivateUsage] + OBSERVABILITY_SETTINGS._resource = updated_settings._resource # type: ignore[reportPrivateUsage] + OBSERVABILITY_SETTINGS._executed_setup = False # type: ignore[reportPrivateUsage] else: # Update the observability settings with the provided values OBSERVABILITY_SETTINGS.enable_instrumentation = True @@ -1154,6 +1154,8 @@ def get_response( **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: """Trace chat responses with OpenTelemetry spans and metrics.""" + from ._types import ChatResponse, ChatResponseUpdate, ResponseStream + global OBSERVABILITY_SETTINGS super_get_response = super().get_response # type: ignore[misc] @@ -1628,8 +1630,8 @@ def _get_instructions_from_options(options: Any) -> str | list[str] | None: instructions = cast(Mapping[str, Any], options).get("instructions") if isinstance(instructions, str): return instructions - if isinstance(instructions, list) and all(isinstance(item, str) for item in cast(list[object], instructions)): - return cast("list[str]", instructions) + if isinstance(instructions, list) and all(isinstance(item, str) for item in instructions): # type: ignore[reportUnknownVariableType] + return instructions # type: ignore[reportUnknownVariableType] return None return None diff --git a/python/packages/core/tests/core/test_skills.py b/python/packages/core/tests/core/test_skills.py index c572f4727b..e64691e655 100644 --- a/python/packages/core/tests/core/test_skills.py +++ b/python/packages/core/tests/core/test_skills.py @@ -10,7 +10,7 @@ import pytest -from agent_framework import Skill, SkillResource, SkillsProvider, SessionContext +from agent_framework import SessionContext, Skill, SkillResource, SkillsProvider from agent_framework._skills import ( DEFAULT_RESOURCE_EXTENSIONS, _create_instructions, @@ -1348,9 +1348,7 @@ class TestReadAndParseSkillFile: def test_valid_file(self, tmp_path: Path) -> None: skill_dir = tmp_path / "my-skill" skill_dir.mkdir() - (skill_dir / "SKILL.md").write_text( - "---\nname: my-skill\ndescription: A skill.\n---\nBody.", encoding="utf-8" - ) + (skill_dir / "SKILL.md").write_text("---\nname: my-skill\ndescription: A skill.\n---\nBody.", encoding="utf-8") result = _read_and_parse_skill_file(str(skill_dir)) assert result is not None name, desc, content = result @@ -1393,7 +1391,7 @@ def test_with_description(self) -> None: def test_xml_escapes_name(self) -> None: r = SkillResource(name='ref"special', content="data") elem = _create_resource_element(r) - assert '"' in elem + assert """ in elem def test_xml_escapes_description(self) -> None: r = SkillResource(name="ref", description='Uses & "quotes"', content="data") diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index bcf3a6891b..0d314c1aa5 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -550,7 +550,6 @@ def test_usage_details(): assert usage["input_token_count"] == 5 assert usage["output_token_count"] == 10 assert usage["total_token_count"] == 15 - assert usage.get("additional_counts", {}) == {} def test_usage_details_addition(): @@ -581,8 +580,8 @@ def test_usage_details_addition(): def test_usage_details_fail(): # TypedDict doesn't validate types at runtime, so this test no longer applies # Creating UsageDetails with wrong types won't raise ValueError - usage = UsageDetails(input_token_count=5, output_token_count=10, total_token_count=15, wrong_type="42.923") # type: ignore[typeddict-item] - assert usage["wrong_type"] == "42.923" # type: ignore[typeddict-item] + usage = UsageDetails(input_token_count=5, output_token_count=10, total_token_count=15, wrong_type="42.923") + assert usage["wrong_type"] == "42.923" def test_usage_details_additional_counts(): @@ -601,6 +600,15 @@ def test_usage_details_add_with_none_and_type_errors(): # TypedDict doesn't support + operator, use add_usage_details +def test_usage_details_add_skips_non_int(): + u1 = UsageDetails(input_token_count=10, other="test") + u2 = UsageDetails(input_token_count=10, another="test") + u3 = add_usage_details(u1, u2) + assert len(u3.keys()) == 1 + assert "input_token_count" in u3 + assert u3["input_token_count"] == 20 + + # region UserInputRequest and Response @@ -1705,7 +1713,7 @@ def test_chat_response_complex_serialization(): {"role": "user", "contents": [{"type": "text", "text": "Hello"}]}, {"role": "assistant", "contents": [{"type": "text", "text": "Hi there"}]}, ], - "finish_reason": {"value": "stop"}, + "finish_reason": "stop", "usage_details": { "type": "usage_details", "input_token_count": 5, @@ -1831,7 +1839,7 @@ def test_agent_run_response_update_all_content_types(): }, {"type": "text_reasoning", "text": "reasoning"}, ], - "role": {"value": "assistant"}, # Test role as dict + "role": "assistant", # Test role as dict } update = AgentResponseUpdate.from_dict(update_data) @@ -2394,7 +2402,7 @@ def test_content_add_usage_content_non_integer_values(): result = usage1 + usage2 # Non-integer "model" should take first non-None value - assert result.usage_details["model"] == "gpt-4" + assert "model" not in result.usage_details # Integer "count" should be summed assert result.usage_details["count"] == 30 diff --git a/python/packages/declarative/agent_framework_declarative/_loader.py b/python/packages/declarative/agent_framework_declarative/_loader.py index 1ff403e87a..625189a2f4 100644 --- a/python/packages/declarative/agent_framework_declarative/_loader.py +++ b/python/packages/declarative/agent_framework_declarative/_loader.py @@ -33,7 +33,7 @@ RemoteConnection, Tool, WebSearchTool, - _safe_mode_context, # pyright: ignore[reportPrivateUsage] + _safe_mode_context, # type: ignore[reportPrivateUsage] agent_schema_dispatch, ) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py index 20345bb750..02cc6dab11 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py @@ -188,9 +188,9 @@ def _validate_conversation_history(messages: list[Message], agent_name: str) -> tool_result_ids: set[str] = set() for i, msg in enumerate(messages): - if not hasattr(msg, "contents"): + if not (contents := getattr(msg, "contents", None)): continue - for content in msg.contents: + for content in contents: if content.type == "function_call" and content.call_id: tool_call_ids.add(content.call_id) logger.debug( diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py index 1e9b4a8bc9..677fd1aac8 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_basic.py @@ -30,8 +30,7 @@ def _get_variable_path(action_def: dict[str, Any], key: str = "variable") -> str if isinstance(variable, str): return variable if isinstance(variable, Mapping): - variable_map = cast(Mapping[str, Any], variable) - path = variable_map.get("path") + path = variable.get("path") # type: ignore[reportUnknownVariableType] return path if isinstance(path, str) else None fallback_path = action_def.get("path") @@ -155,19 +154,19 @@ async def handle_action( """Handle the SetMultipleVariables action.""" state = await self._ensure_state_initialized(ctx, trigger) - assignments_obj = self._action_def.get("assignments", []) - assignments = cast(list[Any], assignments_obj) if isinstance(assignments_obj, list) else [] # type: ignore[redundant-cast] - for assignment_obj in assignments: - if not isinstance(assignment_obj, Mapping): + assignments = cast( + list[Mapping[str, Any]], + self._action_def.get("assignments") if isinstance(self._action_def.get("assignments"), list) else [], + ) + for assignment in assignments: + if not isinstance(assignment, Mapping): continue - assignment = cast(Mapping[str, Any], assignment_obj) variable = assignment.get("variable") path: str | None if isinstance(variable, str): path = variable elif isinstance(variable, Mapping): - variable_map = cast(Mapping[str, Any], variable) - path_value = variable_map.get("path") + path_value = variable.get("path") # type: ignore[reportUnknownMemberType] path = path_value if isinstance(path_value, str) else None else: fallback_path = assignment.get("path") @@ -262,8 +261,7 @@ async def handle_action( # Activity can be a string directly or a dict with a "text" field if isinstance(activity, Mapping): - activity_map = cast(Mapping[str, Any], activity) - text: Any = activity_map.get("text", "") + text: Any = activity.get("text", "") # type: ignore[reportUnknownMemberType] else: text = activity @@ -276,7 +274,7 @@ async def handle_action( # Yield the text as workflow output if text: - await ctx.yield_output(str(text)) + await ctx.yield_output(str(text)) # type: ignore[reportUnknownArgumentType] await ctx.send_message(ActionComplete()) @@ -357,7 +355,7 @@ async def handle_action( if current_table_value is None: current_table = [] elif isinstance(current_table_value, list): - current_table = list(cast(list[Any], current_table_value)) # type: ignore[redundant-cast] + current_table = list(current_table_value) # type: ignore[reportUnknownArgumentType] else: current_table = [current_table_value] @@ -437,7 +435,7 @@ async def handle_action( if current_table_value is None: current_table = [] elif isinstance(current_table_value, list): - current_table = list(cast(list[Any], current_table_value)) # type: ignore[redundant-cast] + current_table = list(current_table_value) # type: ignore[reportUnknownArgumentType] else: current_table = [current_table_value] @@ -476,8 +474,7 @@ async def handle_action( elif operation == "addorupdate": evaluated_item = state.eval_if_expression(item) if key_field and isinstance(evaluated_item, dict): - evaluated_item_dict = cast(dict[str, Any], evaluated_item) - key_value = evaluated_item_dict.get(key_field) + key_value = evaluated_item.get(key_field) # type: ignore[reportUnknownArgumentType] # Find existing item with same key found_idx = -1 for i, r in enumerate(current_table): @@ -502,8 +499,7 @@ async def handle_action( if 0 <= idx < len(current_table): current_table[idx] = evaluated_item elif key_field and isinstance(evaluated_item, dict): - evaluated_item_dict = cast(dict[str, Any], evaluated_item) - key_value = evaluated_item_dict.get(key_field) + key_value = evaluated_item.get(key_field) # type: ignore[reportUnknownArgumentType] for i, r in enumerate(current_table): if isinstance(r, dict) and cast(dict[str, Any], r).get(key_field) == key_value: current_table[i] = evaluated_item diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py index 934717e9ec..85aa4f6a5a 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_tools.py @@ -19,6 +19,7 @@ from dataclasses import dataclass, field from inspect import isawaitable from typing import Any, cast +from collections.abc import Callable from agent_framework import ( Content, @@ -112,10 +113,6 @@ class ToolApprovalState: # ============================================================================ -def _empty_messages() -> list[Message]: - return [] - - @dataclass class ToolInvocationResult: """Result from a tool invocation. @@ -132,7 +129,7 @@ class ToolInvocationResult: success: bool result: Any = None error: str | None = None - messages: list[Message] = field(default_factory=_empty_messages) + messages: list[Message] = field(default_factory=cast(Callable[..., list[Message]], list)) rejected: bool = False rejection_reason: str | None = None @@ -272,20 +269,19 @@ def _get_output_config(self) -> tuple[str | None, str | None, bool]: Returns: Tuple of (messages_var, result_var, auto_send) """ - output_config_obj = self._action_def.get("output", {}) + output_config: dict[str, str | bool] = self._action_def.get("output", {}) - if not isinstance(output_config_obj, Mapping): + if not isinstance(output_config, Mapping): return None, None, True - output_config = cast(Mapping[str, Any], output_config_obj) - messages_var_obj = output_config.get("messages") - result_var_obj = output_config.get("result") + messages_var = output_config.get("messages") + result_var = output_config.get("result") auto_send = bool(output_config.get("autoSend", True)) - - messages_var = messages_var_obj if isinstance(messages_var_obj, str) else None - result_var = result_var_obj if isinstance(result_var_obj, str) else None - - return (messages_var, result_var, auto_send) + return ( + str(messages_var) if messages_var else None, + str(result_var) if result_var else None, + auto_send, + ) def _store_result( self, @@ -499,8 +495,7 @@ async def handle_action( type(arguments_def).__name__, ) elif isinstance(arguments_def, dict): - arguments_map = cast(dict[str, Any], arguments_def) - for key, value in arguments_map.items(): + for key, value in arguments_def.items(): # type: ignore[reportUnknownVariableType] arguments[key] = state.eval_if_expression(value) # Check if approval is required diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py b/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py index 04374d06a9..f61120a469 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_powerfx_functions.py @@ -195,12 +195,8 @@ def is_blank(value: Any) -> bool: return True if isinstance(value, str) and not value.strip(): return True - if isinstance(value, list): - value_list = cast(list[Any], value) # type: ignore[redundant-cast] - return len(value_list) == 0 - if isinstance(value, dict): - value_dict = cast(dict[Any, Any], value) # type: ignore[redundant-cast] - return len(value_dict) == 0 + if isinstance(value, (list, dict)): + return len(value) == 0 # type: ignore[reportUnknownArgumentType] return False diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_state.py b/python/packages/declarative/agent_framework_declarative/_workflows/_state.py index fb0abc1086..76530f50dd 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_state.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_state.py @@ -615,11 +615,9 @@ def eval_if_expression(self, value: Any) -> Any: if isinstance(value, str): return self.eval(value) if isinstance(value, dict): - value_dict = cast(dict[Any, Any], value) # type: ignore[redundant-cast] - return {str(k): self.eval_if_expression(v) for k, v in value_dict.items()} + return {str(k): self.eval_if_expression(v) for k, v in value.items()} # type: ignore[reportUnknownVariableType] if isinstance(value, list): - value_list = cast(list[Any], value) # type: ignore[redundant-cast] - return [self.eval_if_expression(item) for item in value_list] + return [self.eval_if_expression(item) for item in value] # type: ignore[reportUnknownVariableType] return value def reset_local(self) -> None: diff --git a/python/packages/devui/agent_framework_devui/__init__.py b/python/packages/devui/agent_framework_devui/__init__.py index a6dea87b90..6af274743a 100644 --- a/python/packages/devui/agent_framework_devui/__init__.py +++ b/python/packages/devui/agent_framework_devui/__init__.py @@ -73,7 +73,7 @@ def register_cleanup(entity: Any, *hooks: Callable[[], Any]) -> None: ) -def get_registered_cleanup_hooks(entity: Any) -> list[Callable[[], Any]]: +def _get_registered_cleanup_hooks(entity: Any) -> list[Callable[[], Any]]: # type: ignore[reportUnusedFunction] """Get cleanup hooks registered for an entity (internal use). Args: @@ -86,10 +86,6 @@ def get_registered_cleanup_hooks(entity: Any) -> list[Callable[[], Any]]: return _cleanup_registry.get(entity_id, []) -# Backward-compatible private alias -_get_registered_cleanup_hooks = get_registered_cleanup_hooks - - def serve( entities: list[Any] | None = None, entities_dir: str | None = None, @@ -265,7 +261,6 @@ def main() -> None: "OpenAIError", "OpenAIResponse", "ResponseStreamEvent", - "get_registered_cleanup_hooks", "main", "register_cleanup", "serve", diff --git a/python/packages/devui/agent_framework_devui/_conversations.py b/python/packages/devui/agent_framework_devui/_conversations.py index ba2fa21586..8130835002 100644 --- a/python/packages/devui/agent_framework_devui/_conversations.py +++ b/python/packages/devui/agent_framework_devui/_conversations.py @@ -11,12 +11,14 @@ import time import uuid from abc import ABC, abstractmethod +from collections.abc import MutableSequence from typing import Any, Literal, cast from agent_framework import AgentSession, Message from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage, WorkflowCheckpoint from openai.types.conversations import Conversation, ConversationDeletedResource from openai.types.conversations.conversation_item import ConversationItem +from openai.types.conversations.message import Content as OpenAIContent from openai.types.conversations.message import Message as OpenAIMessage from openai.types.conversations.text_content import TextContent from openai.types.responses import ( @@ -304,9 +306,11 @@ async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> for item in items: # Simple conversion - assume text content for now role = item.get("role", "user") - content_obj = item.get("content", []) - content = cast(list[dict[str, Any]], content_obj) if isinstance(content_obj, list) else [] - first_content = content[0] if content and isinstance(content[0], dict) else {} + content = item.get("content", []) + first_content = cast( + dict[str, Any], + content[0] if content and isinstance(content, list) and isinstance(content[0], dict) else {}, + ) text_obj = first_content.get("text", "") text = text_obj if isinstance(text_obj, str) else str(text_obj) @@ -321,26 +325,19 @@ async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> for msg in chat_messages: item_id = f"item_{uuid.uuid4().hex}" - # Extract role - handle both string and enum - msg_role_obj: object = getattr(msg, "role", "user") - role_str = str(getattr(msg_role_obj, "value", msg_role_obj)) - role = cast(MessageRole, role_str) # Safe: Agent Framework roles match OpenAI roles - # Convert Message contents to OpenAI TextContent format - message_content: list[TextContent] = [] - for content_item in cast(list[Any], msg.contents): - if getattr(content_item, "type", None) == "text": + message_content: MutableSequence[OpenAIContent] = [] + for content_item in msg.contents: + if content_item.type == "text": # Extract text from TextContent object - text_value_obj = getattr(content_item, "text", "") - text_value = text_value_obj if isinstance(text_value_obj, str) else str(text_value_obj) - message_content.append(TextContent(type="text", text=text_value)) + message_content.append(TextContent(type="text", text=content_item.text or "")) # Create Message object (concrete type from ConversationItem union) message = OpenAIMessage( id=item_id, type="message", # Required discriminator for union - role=role, - content=cast(Any, message_content), + role=cast(MessageRole, msg.role), # Safe: Agent Framework roles match OpenAI roles, + content=message_content, status="completed", # Required field ) conv_items.append(message) diff --git a/python/packages/devui/agent_framework_devui/_discovery.py b/python/packages/devui/agent_framework_devui/_discovery.py index 37ab53044d..372e870c15 100644 --- a/python/packages/devui/agent_framework_devui/_discovery.py +++ b/python/packages/devui/agent_framework_devui/_discovery.py @@ -141,9 +141,9 @@ async def load_entity(self, entity_id: str, checkpoint_manager: Any = None) -> A self._loaded_objects[entity_id] = entity_obj # Check module-level registry for cleanup hooks - from . import get_registered_cleanup_hooks + from . import _get_registered_cleanup_hooks # type: ignore[reportPrivateUsage] - registered_hooks = get_registered_cleanup_hooks(entity_obj) + registered_hooks = _get_registered_cleanup_hooks(entity_obj) if registered_hooks: if entity_id not in self._cleanup_hooks: self._cleanup_hooks[entity_id] = [] @@ -299,9 +299,9 @@ def register_entity(self, entity_id: str, entity_info: EntityInfo, entity_object self._loaded_objects[entity_id] = entity_object # Check module-level registry for cleanup hooks - from . import get_registered_cleanup_hooks + from . import _get_registered_cleanup_hooks # type: ignore[reportPrivateUsage] - registered_hooks = get_registered_cleanup_hooks(entity_object) + registered_hooks = _get_registered_cleanup_hooks(entity_object) if registered_hooks: if entity_id not in self._cleanup_hooks: self._cleanup_hooks[entity_id] = [] diff --git a/python/packages/devui/agent_framework_devui/_mapper.py b/python/packages/devui/agent_framework_devui/_mapper.py index 59728dd03c..9e79b308c5 100644 --- a/python/packages/devui/agent_framework_devui/_mapper.py +++ b/python/packages/devui/agent_framework_devui/_mapper.py @@ -62,14 +62,10 @@ def _to_str_dict(value: Any) -> dict[str, Any] | None: - """Convert arbitrary dict-like payload to a string-keyed dictionary.""" + """Cast arbitrary dict-like payload to a string-keyed dictionary.""" if not isinstance(value, dict): return None - source: dict[str, Any] = cast(dict[str, Any], value) - result: dict[str, Any] = {} - for key_obj, val_obj in source.items(): - result[str(key_obj)] = val_obj - return result + return cast(dict[str, Any], value) def _stringify_name(value: Any) -> str: diff --git a/python/packages/devui/agent_framework_devui/_openai/_executor.py b/python/packages/devui/agent_framework_devui/_openai/_executor.py index c62e7acc98..ac0e641e60 100644 --- a/python/packages/devui/agent_framework_devui/_openai/_executor.py +++ b/python/packages/devui/agent_framework_devui/_openai/_executor.py @@ -11,7 +11,7 @@ import logging import os from collections.abc import AsyncGenerator -from typing import Any, cast +from typing import Any from openai import APIStatusError, AsyncOpenAI, AsyncStream, AuthenticationError, PermissionDeniedError, RateLimitError from openai.types.responses import Response, ResponseStreamEvent @@ -22,17 +22,15 @@ logger = logging.getLogger(__name__) -def _extract_error_details(body: object) -> tuple[str | None, str | None, str | None]: +def _extract_error_details(body: Any) -> tuple[str | None, str | None, str | None]: """Extract typed OpenAI error fields from error body payload.""" if not isinstance(body, dict): return None, None, None - body_dict = cast(dict[str, object], body) - error_obj = body_dict.get("error") - if not isinstance(error_obj, dict): + error_dict: dict[str, Any] = body.get("error") # type: ignore[assignment, reportUnknownVariableType] + if not isinstance(error_dict, dict): return None, None, None - error_dict = cast(dict[str, object], error_obj) message = error_dict.get("message") error_type = error_dict.get("type") code = error_dict.get("code") diff --git a/python/packages/devui/agent_framework_devui/_utils.py b/python/packages/devui/agent_framework_devui/_utils.py index 3eadc35926..889a690c87 100644 --- a/python/packages/devui/agent_framework_devui/_utils.py +++ b/python/packages/devui/agent_framework_devui/_utils.py @@ -15,11 +15,10 @@ def _string_key_dict(value: object) -> dict[str, Any] | None: + """Cast value to a dict.""" if not isinstance(value, dict): return None - - source: dict[str, Any] = cast(dict[str, Any], value) - return {str(k): v for k, v in source.items()} + return cast(dict[str, Any], value) # ============================================================================ diff --git a/python/packages/devui/agent_framework_devui/models/_discovery_models.py b/python/packages/devui/agent_framework_devui/models/_discovery_models.py index e3fcccb5f9..47e6d1bdcc 100644 --- a/python/packages/devui/agent_framework_devui/models/_discovery_models.py +++ b/python/packages/devui/agent_framework_devui/models/_discovery_models.py @@ -2,16 +2,15 @@ """Discovery API models for entity information.""" +from __future__ import annotations + import re -from typing import Any +from typing import Any, cast +from collections.abc import Callable from pydantic import BaseModel, Field, field_validator -def _default_entities() -> list["EntityInfo"]: - return [] - - class EnvVarRequirement(BaseModel): """Environment variable requirement for an entity.""" @@ -61,7 +60,7 @@ class EntityInfo(BaseModel): class DiscoveryResponse(BaseModel): """Response model for entity discovery.""" - entities: list[EntityInfo] = Field(default_factory=_default_entities) + entities: list[EntityInfo] = Field(default_factory=cast(Callable[..., list[EntityInfo]], list)) # ============================================================================ diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index 796c17107e..e31c1971da 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -329,11 +329,11 @@ def __init__( env_file_path=env_file_path, ) - self.model_id = ollama_settings["model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] + self.model_id = ollama_settings["model_id"] # type: ignore[assignment, reportTypedDictNotRequiredAccess] # we can just pass in None for the host, the default is set by the Ollama package. self.client = client or AsyncClient(host=ollama_settings.get("host")) # Save Host URL for serialization with to_dict() - self.host = str(self.client._client.base_url) # pyright: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType] + self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType] super().__init__( middleware=middleware, diff --git a/python/packages/ollama/agent_framework_ollama/_embedding_client.py b/python/packages/ollama/agent_framework_ollama/_embedding_client.py index 0a922b3276..5cd35fc9f3 100644 --- a/python/packages/ollama/agent_framework_ollama/_embedding_client.py +++ b/python/packages/ollama/agent_framework_ollama/_embedding_client.py @@ -107,9 +107,9 @@ def __init__( env_file_encoding=env_file_encoding, ) - self.model_id = ollama_settings["embedding_model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] + self.model_id = ollama_settings["embedding_model_id"] # type: ignore[assignment,reportTypedDictNotRequiredAccess] self.client = client or AsyncClient(host=ollama_settings.get("host")) - self.host = str(self.client._client.base_url) # pyright: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType] + self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType] super().__init__(**kwargs) def service_url(self) -> str: diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index 66b8309f9e..4352a8af47 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -356,15 +356,17 @@ def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]: # so we need to recombine them here to pass the complete tools list to the constructor. # This makes sure MCP tools are preserved when cloning agents for handoff workflows. tools_from_options = options.pop("tools", []) - if agent.mcp_tools: - tools_from_options.extend(agent.mcp_tools) + new_tools = [*tools_from_options, *(agent.mcp_tools if agent.mcp_tools else [])] # this ensures all options (including custom ones) are kept cloned_options = deepcopy(options) # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. cloned_options["allow_multiple_tool_calls"] = False cloned_options["store"] = False - cloned_options["tools"] = tools_from_options + cloned_options["tools"] = new_tools + + # restore the original tools, in case they are shared between agents + options["tools"] = tools_from_options return Agent( client=agent.client, From 688063734d56e70b8810374c6b09cc0308bad1bd Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 5 Mar 2026 15:57:11 +0100 Subject: [PATCH 13/15] fix core --- .../packages/core/agent_framework/_agents.py | 2 +- .../packages/core/agent_framework/_tools.py | 38 +++++++++++-------- .../core/agent_framework/observability.py | 20 ++++++---- python/packages/core/pyproject.toml | 2 +- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 6d79c69cbb..3aaf9f1419 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -1083,7 +1083,7 @@ async def _prepare_run_context( # Merge runtime kwargs into additional_function_arguments so they're available # in function middleware context and tool invocation. - existing_additional_args = opts.pop("additional_function_arguments", None) or {} + existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {} additional_function_arguments = {**kwargs, **existing_additional_args} # Include session so as_tool() wrappers with propagate_session=True can access it. if active_session is not None: diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 4051cef21b..3f11189fdc 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -692,21 +692,21 @@ def normalize_tools( from ._mcp import MCPTool normalized: list[ToolTypes] = [] - for tool_item in tools: + for tool_item in tools: # type: ignore[reportUnknownVariableType] # check known types, these are also callable, so we need to do that first if isinstance(tool_item, FunctionTool): normalized.append(tool_item) continue if isinstance(tool_item, dict): - normalized.append(tool_item) + normalized.append(tool_item) # type: ignore[reportUnknownArgumentType] continue if isinstance(tool_item, MCPTool): normalized.append(tool_item) continue - if callable(tool_item): + if callable(tool_item): # type: ignore[reportUnknownArgumentType] normalized.append(tool(tool_item)) continue - normalized.append(tool_item) + normalized.append(tool_item) # type: ignore[reportUnknownArgumentType] return normalized @@ -734,7 +734,7 @@ def _tools_to_dict( # pyright: ignore[reportUnusedFunction] results.append(tool_item.to_dict()) continue if isinstance(tool_item, dict): - results.append(tool_item) + results.append(tool_item) # type: ignore[reportUnknownArgumentType] continue logger.warning("Can't parse tool.") return results @@ -837,7 +837,7 @@ def _validate_arguments_against_schema( continue if isinstance(schema_type, list): - allowed_types: list[str] = [item for item in schema_type if isinstance(item, str)] + allowed_types: list[str] = [item for item in schema_type if isinstance(item, str)] # type: ignore[reportUnknownVariableType] if allowed_types and not any(_matches_json_schema_type(field_value, item) for item in allowed_types): raise TypeError( f"Invalid type for '{field_name}' in '{tool_name}': expected one of " @@ -2031,11 +2031,14 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: mutable_options["tool_choice"] = "none" return - inner_stream: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] = super_get_response( - messages=prepped_messages, - stream=True, - options=mutable_options, - **filtered_kwargs, + inner_stream = cast( + ResponseStream[ChatResponseUpdate, ChatResponse[Any]], + super_get_response( + messages=prepped_messages, + stream=True, + options=mutable_options, + **filtered_kwargs, + ), ) await inner_stream # Collect result hooks from the inner stream to run later @@ -2120,11 +2123,14 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS), ) mutable_options["tool_choice"] = "none" - final_inner_stream: ResponseStream[ChatResponseUpdate, ChatResponse[Any]] = super_get_response( - messages=prepped_messages, - stream=True, - options=mutable_options, - **filtered_kwargs, + final_inner_stream = cast( + ResponseStream[ChatResponseUpdate, ChatResponse[Any]], + super_get_response( + messages=prepped_messages, + stream=True, + options=mutable_options, + **filtered_kwargs, + ), ) await final_inner_stream async for update in final_inner_stream: diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 4fe84d5f78..c622e0c603 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1154,7 +1154,7 @@ def get_response( **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: """Trace chat responses with OpenTelemetry spans and metrics.""" - from ._types import ChatResponse, ChatResponseUpdate, ResponseStream + from ._types import ChatResponse, ChatResponseUpdate, ResponseStream # type: ignore[reportUnusedImport] global OBSERVABILITY_SETTINGS super_get_response = super().get_response # type: ignore[misc] @@ -1260,11 +1260,14 @@ async def _get_response() -> ChatResponse: ) start_time_stamp = perf_counter() try: - response: ChatResponse[Any] = await super_get_response( - messages=messages, - stream=False, - options=opts, - **kwargs, + response = cast( + ChatResponse[Any], + await super_get_response( + messages=messages, + stream=False, + options=opts, + **kwargs, + ), ) except Exception as exception: capture_exception(span=span, exception=exception, timestamp=time_ns()) @@ -1341,8 +1344,9 @@ async def get_embeddings( with _get_span(attributes=attributes, span_name_attribute=OtelAttr.REQUEST_MODEL) as span: start_time_stamp = perf_counter() try: - result: GeneratedEmbeddings[EmbeddingT, EmbeddingOptionsT] = await super_get_embeddings( - values, options=options + result = cast( + GeneratedEmbeddings[EmbeddingT, EmbeddingOptionsT], + await super_get_embeddings(values, options=options), ) except Exception as exception: capture_exception(span=span, exception=exception, timestamp=time_ns()) diff --git a/python/packages/core/pyproject.toml b/python/packages/core/pyproject.toml index 016650ba0a..9d002453df 100644 --- a/python/packages/core/pyproject.toml +++ b/python/packages/core/pyproject.toml @@ -104,7 +104,7 @@ extend = "../../pyproject.toml" [tool.pyright] extends = "../../pyproject.toml" -include = ["tests/workflow"] +include = ["agent_framework", "tests/workflow"] [tool.mypy] plugins = ['pydantic.mypy'] From a6bb84cdb76db003759e5d201326b7fb2bc80757 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 5 Mar 2026 16:00:21 +0100 Subject: [PATCH 14/15] fix tests --- python/packages/core/agent_framework/observability.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index c622e0c603..1e9b7cea6a 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1323,6 +1323,8 @@ async def get_embeddings( options: EmbeddingOptionsT | None = None, ) -> GeneratedEmbeddings[EmbeddingT, EmbeddingOptionsT]: """Trace embedding generation with OpenTelemetry spans and metrics.""" + from ._types import EmbeddingOptionsT, EmbeddingT, GeneratedEmbeddings # type: ignore[reportUnusedImport] + global OBSERVABILITY_SETTINGS super_get_embeddings = super().get_embeddings # type: ignore[misc] From 7ed92f41efffbfc4bca548b8356cdea8cbba53ac Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Thu, 5 Mar 2026 16:05:28 +0100 Subject: [PATCH 15/15] fix obser --- python/packages/core/agent_framework/observability.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 1e9b7cea6a..a595582b33 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1323,7 +1323,7 @@ async def get_embeddings( options: EmbeddingOptionsT | None = None, ) -> GeneratedEmbeddings[EmbeddingT, EmbeddingOptionsT]: """Trace embedding generation with OpenTelemetry spans and metrics.""" - from ._types import EmbeddingOptionsT, EmbeddingT, GeneratedEmbeddings # type: ignore[reportUnusedImport] + from ._types import GeneratedEmbeddings # type: ignore[reportUnusedImport] global OBSERVABILITY_SETTINGS super_get_embeddings = super().get_embeddings # type: ignore[misc]