Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions python/packages/azure-ai/agent_framework_azure_ai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@
from azure.ai.projects.aio import AIProjectClient
from azure.ai.projects.models import (
ApproximateLocation,
CodeInterpreterContainerAuto,
CodeInterpreterTool,
CodeInterpreterToolAuto,
FoundryFeaturesOptInKeys,
ImageGenTool,
MCPTool,
PromptAgentDefinition,
PromptAgentDefinitionText,
PromptAgentDefinitionTextOptions,
RaiConfig,
Reasoning,
WebSearchPreviewTool,
Expand Down Expand Up @@ -78,6 +79,9 @@ class AzureAIProjectAgentOptions(OpenAIResponsesOptions, total=False):
reasoning: Reasoning # type: ignore[misc]
"""Configuration for enabling reasoning capabilities (requires azure.ai.projects.models.Reasoning)."""

foundry_features: FoundryFeaturesOptInKeys | str
"""Optional Foundry preview feature opt-in for agent version creation."""


AzureAIClientOptionsT = TypeVar(
"AzureAIClientOptionsT",
Expand Down Expand Up @@ -392,7 +396,7 @@ async def _get_agent_reference_or_create(
# response_format is accessed from chat_options or additional_properties
# since the base class excludes it from run_options
if chat_options and (response_format := chat_options.get("response_format")):
args["text"] = PromptAgentDefinitionText(format=create_text_format_config(response_format))
args["text"] = PromptAgentDefinitionTextOptions(format=create_text_format_config(response_format))

# Combine instructions from messages and options
# instructions is accessed from chat_options since the base class excludes it from run_options
Expand All @@ -404,11 +408,15 @@ async def _get_agent_reference_or_create(
if combined_instructions:
args["instructions"] = "".join(combined_instructions)

created_agent = await self.project_client.agents.create_version(
agent_name=self.agent_name,
definition=PromptAgentDefinition(**args),
description=self.agent_description,
)
create_version_kwargs: dict[str, Any] = {
"agent_name": self.agent_name,
"definition": PromptAgentDefinition(**args),
"description": self.agent_description,
}
if foundry_features := run_options.get("foundry_features"):
create_version_kwargs["foundry_features"] = foundry_features

created_agent = await self.project_client.agents.create_version(**create_version_kwargs)

self.agent_version = created_agent.version
self.warn_runtime_tools_and_structure_changed = True
Expand Down Expand Up @@ -500,6 +508,7 @@ def _remove_agent_level_run_options(
"temperature": ("temperature",),
"top_p": ("top_p",),
"reasoning": ("reasoning",),
"foundry_features": ("foundry_features",),
}

for run_keys in agent_level_option_to_run_keys.values():
Expand All @@ -526,9 +535,9 @@ async def _prepare_options(
run_options["input"] = self._transform_input_for_azure_ai(cast(list[dict[str, Any]], run_options["input"]))

if not self._is_application_endpoint:
# Application-scoped response APIs do not support "agent" property.
# Application-scoped response APIs do not support "agent_reference" property.
agent_reference = await self._get_agent_reference_or_create(run_options, instructions, options)
run_options["extra_body"] = {"agent": agent_reference}
run_options["extra_body"] = {"agent_reference": agent_reference}

# Remove only keys that map to this client's declared options TypedDict.
self._remove_agent_level_run_options(run_options, options)
Expand Down Expand Up @@ -922,7 +931,7 @@ def get_code_interpreter_tool( # type: ignore[override]
if file_ids is None and isinstance(container, dict):
file_ids = container.get("file_ids")
resolved = resolve_file_ids(file_ids)
tool_container = CodeInterpreterToolAuto(file_ids=resolved)
tool_container = CodeInterpreterContainerAuto(file_ids=resolved)
return CodeInterpreterTool(container=tool_container, **kwargs)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from agent_framework._settings import load_settings
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
from azure.ai.projects.aio import AIProjectClient
from azure.ai.projects.models import ItemParam, ResponsesAssistantMessageItemParam, ResponsesUserMessageItemParam

from ._shared import AzureAISettings

Expand Down Expand Up @@ -149,7 +148,7 @@ async def before_run(
# On first run, retrieve static memories (user profile memories)
if not state.get("initialized"):
try:
static_search_result = await self.project_client.memory_stores.search_memories(
static_search_result = await self.project_client.beta.memory_stores.search_memories(
name=self.memory_store_name,
scope=self.scope or context.session_id, # type: ignore[arg-type]
)
Expand All @@ -169,15 +168,15 @@ async def before_run(
if not has_input:
return

# Convert input messages to ItemParam format for search
# Convert input messages to memory search item format
items = [
ItemParam({"type": "text", "text": msg.text})
{"type": "text", "text": msg.text}
for msg in context.input_messages
if msg and msg.text and msg.text.strip()
]

try:
search_result = await self.project_client.memory_stores.search_memories(
search_result = await self.project_client.beta.memory_stores.search_memories(
name=self.memory_store_name,
scope=self.scope or context.session_id, # type: ignore[arg-type]
items=items,
Expand Down Expand Up @@ -224,24 +223,24 @@ async def after_run(
if context.response and context.response.messages:
messages_to_store.extend(context.response.messages)

# Filter and convert messages to ItemParam format
items: list[ResponsesUserMessageItemParam | ResponsesAssistantMessageItemParam] = []
# Filter and convert messages to memory update item format
items: list[dict[str, str]] = []
for message in messages_to_store:
if message.role in {"user", "assistant", "system"} and message.text and message.text.strip():
if message.role == "user":
items.append(ResponsesUserMessageItemParam(content=message.text))
items.append({"role": "user", "type": "message", "content": message.text})
elif message.role == "assistant":
items.append(ResponsesAssistantMessageItemParam(content=message.text))
items.append({"role": "assistant", "type": "message", "content": message.text})

if not items:
return

try:
# Fire and forget - don't wait for the update to complete
update_poller = await self.project_client.memory_stores.begin_update_memories(
update_poller = await self.project_client.beta.memory_stores.begin_update_memories(
name=self.memory_store_name,
scope=self.scope or context.session_id, # type: ignore[arg-type]
items=items, # type: ignore[arg-type]
items=items,
previous_update_id=state.get("previous_update_id"),
update_delay=self.update_delay,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
import sys
from collections.abc import Callable, MutableMapping, Sequence
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Any, Generic

from agent_framework import (
Expand All @@ -21,10 +21,9 @@
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
from azure.ai.projects.aio import AIProjectClient
from azure.ai.projects.models import (
AgentReference,
AgentVersionDetails,
PromptAgentDefinition,
PromptAgentDefinitionText,
PromptAgentDefinitionTextOptions,
)
from azure.ai.projects.models import (
FunctionTool as AzureFunctionTool,
Expand Down Expand Up @@ -200,13 +199,14 @@ async def create_agent(
response_format = opts.get("response_format")
rai_config = opts.get("rai_config")
reasoning = opts.get("reasoning")
foundry_features = opts.get("foundry_features")

args: dict[str, Any] = {"model": resolved_model}

if instructions:
args["instructions"] = instructions
if response_format and isinstance(response_format, (type, dict)):
args["text"] = PromptAgentDefinitionText(
args["text"] = PromptAgentDefinitionTextOptions(
format=create_text_format_config(response_format) # type: ignore[arg-type]
)
if rai_config:
Expand Down Expand Up @@ -241,11 +241,15 @@ async def create_agent(
if all_tools_for_azure:
args["tools"] = to_azure_ai_tools(all_tools_for_azure)

created_agent = await self._project_client.agents.create_version(
agent_name=name,
definition=PromptAgentDefinition(**args),
description=description,
)
create_version_kwargs: dict[str, Any] = {
"agent_name": name,
"definition": PromptAgentDefinition(**args),
"description": description,
}
if foundry_features:
create_version_kwargs["foundry_features"] = foundry_features

created_agent = await self._project_client.agents.create_version(**create_version_kwargs)

return self._to_chat_agent_from_details(
created_agent,
Expand All @@ -259,7 +263,7 @@ async def get_agent(
self,
*,
name: str | None = None,
reference: AgentReference | None = None,
reference: Mapping[str, str | None] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
default_options: OptionsCoT | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
Expand All @@ -272,7 +276,7 @@ async def get_agent(

Args:
name: The name of the agent to retrieve (fetches latest version).
reference: Reference containing the agent's name and optionally a specific version.
reference: Mapping containing the agent's ``name`` and optionally a specific ``version``.
tools: Tools to make available to the agent. Required if the agent has function tools.
default_options: A TypedDict containing default chat options for the agent.
These options are applied to every run unless overridden.
Expand All @@ -287,12 +291,15 @@ async def get_agent(
"""
existing_agent: AgentVersionDetails

if reference and reference.version:
reference_name = str(reference.get("name")) if reference and reference.get("name") else None
reference_version = str(reference.get("version")) if reference and reference.get("version") else None

if reference_name and reference_version:
# Fetch specific version
existing_agent = await self._project_client.agents.get_version(
agent_name=reference.name, agent_version=reference.version
agent_name=reference_name, agent_version=reference_version
)
elif agent_name := (reference.name if reference else name):
elif agent_name := (reference_name if reference_name else name):
# Fetch latest version
details = await self._project_client.agents.get(agent_name=agent_name)
existing_agent = details.versions.latest
Expand Down
20 changes: 10 additions & 10 deletions python/packages/azure-ai/agent_framework_azure_ai/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from azure.ai.projects.models import (
CodeInterpreterTool,
MCPTool,
ResponseTextFormatConfigurationJsonObject,
ResponseTextFormatConfigurationJsonSchema,
ResponseTextFormatConfigurationText,
TextResponseFormatConfigurationResponseFormatJsonObject,
TextResponseFormatConfigurationResponseFormatText,
TextResponseFormatJsonSchema,
Tool,
WebSearchPreviewTool,
)
Expand Down Expand Up @@ -463,17 +463,17 @@ def _prepare_mcp_tool_dict_for_azure_ai(tool_dict: dict[str, Any]) -> MCPTool:
def create_text_format_config(
response_format: type[BaseModel] | Mapping[str, Any],
) -> (
ResponseTextFormatConfigurationJsonSchema
| ResponseTextFormatConfigurationJsonObject
| ResponseTextFormatConfigurationText
TextResponseFormatJsonSchema
| TextResponseFormatConfigurationResponseFormatJsonObject
| TextResponseFormatConfigurationResponseFormatText
):
"""Convert response_format into Azure text format configuration."""
if isinstance(response_format, type) and issubclass(response_format, BaseModel):
schema = response_format.model_json_schema()
# Ensure additionalProperties is explicitly false to satisfy Azure validation
if isinstance(schema, dict):
schema.setdefault("additionalProperties", False)
return ResponseTextFormatConfigurationJsonSchema(
return TextResponseFormatJsonSchema(
name=response_format.__name__,
schema=schema,
strict=True,
Expand All @@ -494,11 +494,11 @@ def create_text_format_config(
config_kwargs["strict"] = format_config["strict"]
if "description" in format_config:
config_kwargs["description"] = format_config["description"]
return ResponseTextFormatConfigurationJsonSchema(**config_kwargs)
return TextResponseFormatJsonSchema(**config_kwargs)
if format_type == "json_object":
return ResponseTextFormatConfigurationJsonObject()
return TextResponseFormatConfigurationResponseFormatJsonObject()
if format_type == "text":
return ResponseTextFormatConfigurationText()
return TextResponseFormatConfigurationResponseFormatText()

raise IntegrationInvalidRequestException("response_format must be a Pydantic model or mapping.")

Expand Down
20 changes: 10 additions & 10 deletions python/packages/azure-ai/tests/test_azure_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
from azure.ai.projects.aio import AIProjectClient
from azure.ai.projects.models import (
ApproximateLocation,
CodeInterpreterContainerAuto,
CodeInterpreterTool,
CodeInterpreterToolAuto,
FileSearchTool,
ImageGenTool,
MCPTool,
ResponseTextFormatConfigurationJsonSchema,
TextResponseFormatJsonSchema,
WebSearchPreviewTool,
)
from azure.core.exceptions import ResourceNotFoundError
Expand Down Expand Up @@ -427,7 +427,7 @@ async def test_prepare_options_basic(mock_project_client: MagicMock) -> None:
run_options = await client._prepare_options(messages, {})

assert "extra_body" in run_options
assert run_options["extra_body"]["agent"]["name"] == "test-agent"
assert run_options["extra_body"]["agent_reference"]["name"] == "test-agent"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -465,7 +465,7 @@ async def test_prepare_options_with_application_endpoint(

if expects_agent:
assert "extra_body" in run_options
assert run_options["extra_body"]["agent"]["name"] == "test-agent"
assert run_options["extra_body"]["agent_reference"]["name"] == "test-agent"
else:
assert "extra_body" not in run_options

Expand Down Expand Up @@ -507,7 +507,7 @@ async def test_prepare_options_with_application_project_client(

if expects_agent:
assert "extra_body" in run_options
assert run_options["extra_body"]["agent"]["name"] == "test-agent"
assert run_options["extra_body"]["agent_reference"]["name"] == "test-agent"
else:
assert "extra_body" not in run_options

Expand Down Expand Up @@ -979,10 +979,10 @@ async def test_agent_creation_with_response_format(
assert hasattr(created_definition, "text")
assert created_definition.text is not None

# Check that the format is a ResponseTextFormatConfigurationJsonSchema
# Check that the format is a TextResponseFormatJsonSchema
assert hasattr(created_definition.text, "format")
format_config = created_definition.text.format
assert isinstance(format_config, ResponseTextFormatConfigurationJsonSchema)
assert isinstance(format_config, TextResponseFormatJsonSchema)

# Check the schema name matches the model class name
assert format_config.name == "ResponseFormatModel"
Expand Down Expand Up @@ -1040,7 +1040,7 @@ async def test_agent_creation_with_mapping_response_format(
assert hasattr(created_definition, "text")
assert created_definition.text is not None
format_config = created_definition.text.format
assert isinstance(format_config, ResponseTextFormatConfigurationJsonSchema)
assert isinstance(format_config, TextResponseFormatJsonSchema)
assert format_config.name == runtime_schema["title"]
assert format_config.schema == runtime_schema
assert format_config.strict is True
Expand Down Expand Up @@ -1110,7 +1110,7 @@ async def test_prepare_options_excludes_response_format(
assert "text_format" not in run_options
# But extra_body should contain agent reference
assert "extra_body" in run_options
assert run_options["extra_body"]["agent"]["name"] == "test-agent"
assert run_options["extra_body"]["agent_reference"]["name"] == "test-agent"


async def test_prepare_options_keeps_values_for_unsupported_option_keys(
Expand Down Expand Up @@ -1254,7 +1254,7 @@ def test_from_azure_ai_tools_mcp() -> None:

def test_from_azure_ai_tools_code_interpreter() -> None:
"""Test from_azure_ai_tools with Code Interpreter tool."""
ci_tool = CodeInterpreterTool(container=CodeInterpreterToolAuto(file_ids=["file-1"]))
ci_tool = CodeInterpreterTool(container=CodeInterpreterContainerAuto(file_ids=["file-1"]))
parsed_tools = from_azure_ai_tools([ci_tool])
assert len(parsed_tools) == 1
assert parsed_tools[0]["type"] == "code_interpreter"
Expand Down
Loading