diff --git a/.gitignore b/.gitignore index 6ac0bf3f..36485683 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,7 @@ venv.bak/ .dmypy.json dmypy.json -.copilot-instructions.md +private/ # other .DS_STORE @@ -30,6 +30,7 @@ ref/ py.typed CLAUDE.md +.copilot-instructions.md .env.claude/ .claude/ diff --git a/examples/a2a-test/src/main.py b/examples/a2a-test/src/main.py index 7981c6e9..12df51c8 100644 --- a/examples/a2a-test/src/main.py +++ b/examples/a2a-test/src/main.py @@ -180,7 +180,11 @@ async def location_handler(params: LocationParams) -> str: parts=[Part(root=TextPart(kind="text", text="Please provide a location"))], ) else: - return result.response.content if result.response.content else "No weather information available." + return ( + result.response.content + if result.response and result.response.content + else "No weather information available." + ) # A2A Server Message Event Handler @@ -206,7 +210,7 @@ async def handle_a2a_message(message: A2AMessageEvent) -> None: await respond(result) -async def handler(message: str) -> ModelMessage: +async def handler(message: str) -> ModelMessage | None: # Now we can send the message to the prompt and it will decide if # the a2a agent should be used or not and also manages contacting the agent result = await prompt.send(message) @@ -219,7 +223,7 @@ async def handle_message(ctx: ActivityContext[MessageActivity]): await ctx.reply(TypingActivityInput()) result = await handler(ctx.activity.text) - if result.content: + if result and result.content: await ctx.send(result.content) diff --git a/examples/ai-test/src/handlers/function_calling.py b/examples/ai-test/src/handlers/function_calling.py index 71087dd8..d1e60e29 100644 --- a/examples/ai-test/src/handlers/function_calling.py +++ b/examples/ai-test/src/handlers/function_calling.py @@ -68,7 +68,7 @@ async def handle_pokemon_search(model: AIModel, ctx: ActivityContext[MessageActi input=ctx.activity.text, instructions="You are a helpful assistant that can look up Pokemon for the user." ) - if chat_result.response.content: + if chat_result.response and chat_result.response.content: message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() await ctx.send(message) else: @@ -129,7 +129,7 @@ async def handle_multiple_functions(model: AIModel, ctx: ActivityContext[Message ), ) - if chat_result.response.content: + if chat_result.response and chat_result.response.content: message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() await ctx.send(message) else: diff --git a/examples/ai-test/src/handlers/memory_management.py b/examples/ai-test/src/handlers/memory_management.py index f713780e..92bf826f 100644 --- a/examples/ai-test/src/handlers/memory_management.py +++ b/examples/ai-test/src/handlers/memory_management.py @@ -39,7 +39,7 @@ async def handle_stateful_conversation(model: AIModel, ctx: ActivityContext[Mess input=ctx.activity.text, instructions="You are a helpful assistant that remembers our previous conversation." ) - if chat_result.response.content: + if chat_result.response and chat_result.response.content: message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() await ctx.send(message) else: diff --git a/examples/ai-test/src/main.py b/examples/ai-test/src/main.py index a156c1ee..bdfc01c2 100644 --- a/examples/ai-test/src/main.py +++ b/examples/ai-test/src/main.py @@ -68,7 +68,7 @@ async def handle_simple_chat(ctx: ActivityContext[MessageActivity]): input=ctx.activity.text, instructions="You are a friendly assistant who talks like a pirate" ) - if chat_result.response.content: + if chat_result.response and chat_result.response.content: message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() await ctx.send(message) @@ -107,7 +107,7 @@ async def handle_streaming(ctx: ActivityContext[MessageActivity]): if hasattr(ctx.activity.conversation, "is_group") and ctx.activity.conversation.is_group: # Group chat - send final response - if chat_result.response.content: + if chat_result.response and chat_result.response.content: message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() await ctx.send(message) else: @@ -167,7 +167,7 @@ async def handle_feedback_demo(ctx: ActivityContext[MessageActivity]): input="Tell me a short joke", instructions="You are a comedian. Keep responses brief and funny." ) - if chat_result.response.content: + if chat_result.response and chat_result.response.content: # Create message with feedback enabled and initialize storage message = MessageActivityInput(text=chat_result.response.content).add_ai_generated().add_feedback() sent_message = await ctx.send(message) diff --git a/examples/mcp-client/src/main.py b/examples/mcp-client/src/main.py index 2a91a0ac..8c4b007a 100644 --- a/examples/mcp-client/src/main.py +++ b/examples/mcp-client/src/main.py @@ -88,7 +88,7 @@ async def handle_agent_chat(ctx: ActivityContext[MessageActivity]): # Use ChatPrompt with MCP tools (stateful conversation) result = await responses_prompt.send(query) - if result.response.content: + if result.response and result.response.content: message = MessageActivityInput(text=result.response.content).add_ai_generated() await ctx.send(message) @@ -111,7 +111,7 @@ async def handle_prompt_chat(ctx: ActivityContext[MessageActivity]): ), ) - if result.response.content: + if result.response and result.response.content: message = MessageActivityInput(text=result.response.content).add_ai_generated() await ctx.send(message) @@ -157,7 +157,7 @@ async def handle_fallback_message(ctx: ActivityContext[MessageActivity]): # Use ChatPrompt with MCP tools for general conversation result = await responses_prompt.send(ctx.activity.text) - if result.response.content: + if result.response and result.response.content: message = MessageActivityInput(text=result.response.content).add_ai_generated() await ctx.send(message) diff --git a/packages/ai/src/microsoft/teams/ai/__init__.py b/packages/ai/src/microsoft/teams/ai/__init__.py index b256d155..533859d4 100644 --- a/packages/ai/src/microsoft/teams/ai/__init__.py +++ b/packages/ai/src/microsoft/teams/ai/__init__.py @@ -3,11 +3,21 @@ Licensed under the MIT License. """ +from . import plugins, utils from .ai_model import AIModel from .chat_prompt import ChatPrompt, ChatSendResult -from .function import Function, FunctionCall, FunctionHandler, FunctionHandlers, FunctionHandlerWithNoParams +from .function import ( + DeferredResult, + Function, + FunctionCall, + FunctionHandler, + FunctionHandlers, + FunctionHandlerWithNoParams, +) from .memory import ListMemory, Memory -from .message import FunctionMessage, Message, ModelMessage, SystemMessage, UserMessage +from .message import DeferredMessage, FunctionMessage, Message, ModelMessage, SystemMessage, UserMessage +from .plugin import AIPluginProtocol, BaseAIPlugin +from .utils import * # noqa: F401, F403 __all__ = [ "ChatSendResult", @@ -17,12 +27,18 @@ "ModelMessage", "SystemMessage", "FunctionMessage", + "DeferredMessage", "Function", "FunctionCall", + "DeferredResult", "Memory", "ListMemory", "AIModel", + "AIPluginProtocol", + "BaseAIPlugin", "FunctionHandler", "FunctionHandlerWithNoParams", "FunctionHandlers", ] +__all__.extend(utils.__all__) +__all__.extend(plugins.__all__) diff --git a/packages/ai/src/microsoft/teams/ai/agent.py b/packages/ai/src/microsoft/teams/ai/agent.py new file mode 100644 index 00000000..23959fbc --- /dev/null +++ b/packages/ai/src/microsoft/teams/ai/agent.py @@ -0,0 +1,70 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +from typing import Any, Awaitable, Callable + +from microsoft.teams.ai.plugin import AIPluginProtocol + +from .ai_model import AIModel +from .chat_prompt import ChatPrompt, ChatSendResult +from .function import Function +from .memory import ListMemory, Memory +from .message import Message, SystemMessage + + +class Agent(ChatPrompt): + """ + A stateful implementation of ChatPrompt with persistent memory. + + Agent extends ChatPrompt by providing default memory management, + making it easier to maintain conversation context across multiple + interactions without manually passing memory each time. + """ + + def __init__( + self, + model: AIModel, + *, + memory: Memory | None = None, + functions: list[Function[Any]] | None = None, + plugins: list[AIPluginProtocol] | None = None, + ): + """ + Initialize Agent with model and persistent memory. + + Args: + model: AI model implementation for text generation + memory: Memory for conversation persistence. Defaults to InMemory ListMemory + functions: Optional list of functions the model can call + plugins: Optional list of plugins for extending functionality + """ + super().__init__(model, functions=functions, plugins=plugins) + self.memory = memory or ListMemory() + + async def send( + self, + input: str | Message | None, + *, + instructions: str | SystemMessage | None = None, + memory: Memory | None = None, + on_chunk: Callable[[str], Awaitable[None]] | Callable[[str], None] | None = None, + ) -> ChatSendResult: + """ + Send a message using the agent's persistent memory. + + Args: + input: Message to send (string will be converted to UserMessage) + instructions: Optional system message to guide model behavior + memory: Optional memory override. Defaults to agent's persistent memory + on_chunk: Optional callback for streaming response chunks + + Returns: + ChatSendResult containing the final model response + + Note: + If no memory is provided, uses the agent's default memory, + making conversation state persistent across calls. + """ + return await super().send(input, memory=memory or self.memory, instructions=instructions, on_chunk=on_chunk) diff --git a/packages/ai/src/microsoft/teams/ai/ai_model.py b/packages/ai/src/microsoft/teams/ai/ai_model.py index 4e935671..cec695eb 100644 --- a/packages/ai/src/microsoft/teams/ai/ai_model.py +++ b/packages/ai/src/microsoft/teams/ai/ai_model.py @@ -9,7 +9,7 @@ from .function import Function from .memory import Memory -from .message import Message, ModelMessage, SystemMessage +from .message import DeferredMessage, Message, ModelMessage, SystemMessage class AIModel(Protocol): @@ -23,13 +23,13 @@ class AIModel(Protocol): async def generate_text( self, - input: Message, + input: Message | None, *, system: SystemMessage | None = None, memory: Memory | None = None, functions: dict[str, Function[BaseModel]] | None = None, on_chunk: Callable[[str], Awaitable[None]] | None = None, - ) -> ModelMessage: + ) -> ModelMessage | list[DeferredMessage]: """ Generate a text response from the AI model. diff --git a/packages/ai/src/microsoft/teams/ai/chat_prompt.py b/packages/ai/src/microsoft/teams/ai/chat_prompt.py index eed8313d..a8037cd7 100644 --- a/packages/ai/src/microsoft/teams/ai/chat_prompt.py +++ b/packages/ai/src/microsoft/teams/ai/chat_prompt.py @@ -6,14 +6,16 @@ import inspect from dataclasses import dataclass from inspect import isawaitable +from logging import Logger from typing import Any, Awaitable, Callable, Dict, Optional, Self, TypeVar, Union, cast, overload +from microsoft.teams.common.logging import ConsoleLogger from pydantic import BaseModel from .ai_model import AIModel from .function import Function, FunctionHandler, FunctionHandlers, FunctionHandlerWithNoParams from .memory import Memory -from .message import Message, ModelMessage, SystemMessage, UserMessage +from .message import DeferredMessage, FunctionMessage, Message, ModelMessage, SystemMessage, UserMessage from .plugin import AIPluginProtocol T = TypeVar("T", bound=BaseModel) @@ -28,7 +30,8 @@ class ChatSendResult: calls and plugin processing have been completed. """ - response: ModelMessage # Final model response after processing + response: ModelMessage | None # Final model response after processing + is_deferred: bool = False class ChatPrompt: @@ -43,9 +46,11 @@ def __init__( self, model: AIModel, *, - memory: Memory | None = None, functions: list[Function[Any]] | None = None, plugins: list[AIPluginProtocol] | None = None, + memory: Memory | None = None, + logger: Logger | None = None, + instructions: str | SystemMessage | None = None, ): """ Initialize ChatPrompt with model and optional functions/plugins. @@ -55,11 +60,17 @@ def __init__( memory: Optional default memory for conversation persistence functions: Optional list of functions the model can call plugins: Optional list of plugins for extending functionality + memory: Optional memory for conversation context and deferred state + logger: Optional logger for debugging and monitoring + instructions: Optional default system instructions for the model """ self.model = model self.memory = memory self.functions: dict[str, Function[Any]] = {func.name: func for func in functions} if functions else {} self.plugins: list[AIPluginProtocol] = plugins or [] + self.memory = memory + self.logger = logger or ConsoleLogger().create_logger("@teams/ai/chat_prompt") + self.instructions = instructions @overload def with_function(self, function: Function[T]) -> Self: ... @@ -137,9 +148,136 @@ def with_plugin(self, plugin: AIPluginProtocol) -> Self: self.plugins.append(plugin) return self + async def requires_resuming(self) -> bool: + """ + Check if there are any deferred functions that need resuming. + + Returns: + True if there are DeferredMessage objects in memory that need resuming + """ + if not self.memory: + return False + + messages = await self.memory.get_all() + return any(isinstance(msg, DeferredMessage) for msg in messages) + + async def resolve_deferred(self, activity: Any) -> list[str]: + """ + Resolve deferred functions with the provided activity input. + + Only attempts to resolve deferred functions whose resumers can handle + the provided activity type (determined by can_handle method). + + Args: + activity: Activity data to use for resolving deferred functions + + Returns: + List of resolution results from successfully resolved functions + """ + if not self.memory: + return [] + + messages = await self.memory.get_all() + deferred_messages = [msg for msg in messages if isinstance(msg, DeferredMessage)] + + if not deferred_messages: + return [] + + results: list[str] = [] + updated_messages = messages.copy() # Work with a copy + + for i, msg in enumerate(updated_messages): + if not isinstance(msg, DeferredMessage): + continue + + # Try plugins first, then fall back to built-in resumer + result = await self._try_resolve_with_plugins(msg, activity) + if result is None: + result = await self._try_resolve_with_builtin_resumer(msg, activity) + + if result is not None: + updated_messages[i] = FunctionMessage(content=result, function_id=msg.function_id) + results.append(result) + + # Update memory with resolved messages + if results: # Only update if we actually resolved something + await self.memory.set_all(updated_messages) + + return results + + async def _try_resolve_with_plugins(self, msg: DeferredMessage, activity: Any) -> str | None: + """ + Try to resolve a deferred message using plugins. + + Args: + msg: The deferred message to resolve + activity: Activity data for resolution + + Returns: + Result string if a plugin handled it, None otherwise + """ + for plugin in self.plugins: + result = await plugin.on_resume(msg.function_name, activity, msg.deferred_result.state) + if result is not None: + return result + return None + + async def _try_resolve_with_builtin_resumer(self, msg: DeferredMessage, activity: Any) -> str | None: + """ + Try to resolve a deferred message using the built-in resumer. + + Args: + msg: The deferred message to resolve + activity: Activity data for resolution + + Returns: + Result string if resolved successfully, None if skipped, raises on error + """ + resumer_name = msg.function_name + associated_func = self.functions.get(resumer_name) + + if not associated_func or associated_func.resumer is None: + raise ValueError(f"Expected a resumer for {resumer_name} but chat prompt was not set up with one") + + # Check if the resumer can handle this type of activity + if not associated_func.resumer.can_handle(activity): + return None # Skip this deferred function + + try: + # Call the resumer with the activity and saved state + result = associated_func.resumer(activity, msg.deferred_result.state) + if isawaitable(result): + result = await result + return result + + except Exception as e: + # Return error message instead of raising + return f"Error resolving {resumer_name}: {str(e)}" + + async def resume(self, activity: Any) -> ChatSendResult: + """ + Resume deferred functions with the provided activity input. + + If all deferred functions are resolved, automatically continues with + normal chat processing using the activity text as input. + + Args: + activity: Activity data to use for resolving deferred functions + + Returns: + ChatSendResult - either indicating still deferred or containing the chat response + """ + await self.resolve_deferred(activity) + + # If there are still deferred functions pending, return early + if await self.requires_resuming(): + return ChatSendResult(response=None, is_deferred=True) + + return await self.send(input=None) + async def send( self, - input: str | Message, + input: str | Message | None, *, memory: Memory | None = None, on_chunk: Callable[[str], Awaitable[None]] | Callable[[str], None] | None = None, @@ -162,11 +300,18 @@ async def send( if isinstance(input, str): input = UserMessage(content=input) + # Use constructor instructions as default if none provided + if instructions is None: + instructions = self.instructions + # Convert string instructions to SystemMessage if isinstance(instructions, str): instructions = SystemMessage(content=instructions) - current_input = await self._run_before_send_hooks(input) + if input is not None: + current_input = await self._run_before_send_hooks(input) + else: + current_input = None current_system_message = await self._run_build_instructions_hooks(instructions) wrapped_functions = await self._build_wrapped_functions() @@ -184,6 +329,8 @@ async def on_chunk_fn(chunk: str): functions=wrapped_functions, on_chunk=on_chunk_fn if on_chunk else None, ) + if isinstance(response, list): + return ChatSendResult(response=None, is_deferred=True) current_response = await self._run_after_send_hooks(response) @@ -287,7 +434,9 @@ async def _build_wrapped_functions(self) -> dict[str, Function[BaseModel]] | Non name=func.name, description=func.description, parameter_schema=func.parameter_schema, - handler=self._wrap_function_handler(func.handler, func.name), + handler=self._wrap_function_handler(cast(FunctionHandler[BaseModel], func.handler), func.name) + if func.resumer is None + else func.handler, ) return wrapped_functions diff --git a/packages/ai/src/microsoft/teams/ai/function.py b/packages/ai/src/microsoft/teams/ai/function.py index beafcc1b..ada67971 100644 --- a/packages/ai/src/microsoft/teams/ai/function.py +++ b/packages/ai/src/microsoft/teams/ai/function.py @@ -3,12 +3,13 @@ Licensed under the MIT License. """ -from dataclasses import dataclass, field -from typing import Any, Awaitable, Dict, Generic, Protocol, TypeVar, Union +from dataclasses import dataclass +from typing import Any, Awaitable, Dict, Generic, Literal, Protocol, TypeVar, Union from pydantic import BaseModel Params = TypeVar("Params", bound=BaseModel, contravariant=True) +ResumableData = TypeVar("ResumableData") """ Type variable for function parameter schemas. @@ -38,6 +39,61 @@ def __call__(self, params: Params) -> Union[str, Awaitable[str]]: ... +class DeferredFunctionResumer(Generic[Params, ResumableData]): + """ + The resumable function returns the actual string + """ + + def can_handle(self, activity: Any) -> bool: + """ + Check if this resumer can handle the given activity input. + + Args: + activity: The activity data to check + + Returns: + True if this resumer can process the activity, False otherwise + """ + ... + + def __call__(self, params: Params, resumableData: ResumableData) -> Awaitable[str]: ... + + +@dataclass +class DeferredResult: + """ + Represents a deferred result that can be resumed later on + """ + + state: dict[str, Any] + type: Literal["deferred"] = "deferred" + + +@dataclass +class FunctionCall: + """ + Represents a function call request from an AI model. + + Contains the function name, unique call ID, and parsed arguments + that will be passed to the function handler. + """ + + id: str # Unique identifier for this function call + name: str # Name of the function to call + arguments: dict[str, Any] # Parsed arguments for the function + + +class DeferredFunctionHandler(Protocol[Params]): + """ + The Deferred Function handler defers the job and returns the name + of the resumable function + Returns the name of the resumable function, and the parameters to save + state + """ + + def __call__(self, params: Params) -> Awaitable[DeferredResult]: ... + + class FunctionHandlerWithNoParams(Protocol): """ Protocol for function handlers that can be called by AI models. @@ -81,19 +137,8 @@ class Function(Generic[Params]): name: str # Unique identifier for the function description: str # Human-readable description of what the function does - parameter_schema: Union[type[Params], Dict[str, Any], None] # Pydantic model class, JSON schema dict, or None - handler: Union[FunctionHandler[Params], FunctionHandlerWithNoParams] # Function implementation (sync or async) - - -@dataclass -class FunctionCall: - """ - Represents a function call request from an AI model. - - Contains the function name, unique call ID, and parsed arguments - that will be passed to the function handler if any. - """ - - id: str # Unique identifier for this function call - name: str # Name of the function to call - arguments: dict[str, Any] = field(default_factory=dict[str, Any]) # Parsed arguments for the function + parameter_schema: Union[type[Params], Dict[str, Any], None] # Pydantic model class or JSON schema dict + handler: ( + FunctionHandler[Params] | FunctionHandlerWithNoParams | DeferredFunctionHandler[Params] + ) # Function implementation (sync or async) + resumer: DeferredFunctionResumer[Params, Any] | None = None # Optional resumer for deferred functions diff --git a/packages/ai/src/microsoft/teams/ai/message.py b/packages/ai/src/microsoft/teams/ai/message.py index 6981d09b..c1e43ce9 100644 --- a/packages/ai/src/microsoft/teams/ai/message.py +++ b/packages/ai/src/microsoft/teams/ai/message.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Literal, Union -from .function import FunctionCall +from .function import DeferredResult, FunctionCall @dataclass @@ -64,7 +64,19 @@ class FunctionMessage: role: Literal["function"] = "function" # Message type identifier -Message = Union[UserMessage, ModelMessage, SystemMessage, FunctionMessage] +@dataclass +class DeferredMessage: + """ + Represents a function call that is deferred + """ + + deferred_result: DeferredResult + function_name: str + function_id: str + content: None = None + + +Message = Union[UserMessage, ModelMessage, SystemMessage, FunctionMessage, DeferredMessage] """ Union type representing any message in a conversation. diff --git a/packages/ai/src/microsoft/teams/ai/plugin.py b/packages/ai/src/microsoft/teams/ai/plugin.py index 82194610..b324090b 100644 --- a/packages/ai/src/microsoft/teams/ai/plugin.py +++ b/packages/ai/src/microsoft/teams/ai/plugin.py @@ -4,7 +4,7 @@ """ from abc import abstractmethod -from typing import Optional, Protocol, TypeVar, runtime_checkable +from typing import Any, Optional, Protocol, TypeVar, runtime_checkable from pydantic import BaseModel @@ -111,6 +111,20 @@ async def on_build_instructions(self, instructions: SystemMessage | None) -> Sys """ ... + async def on_resume(self, function_name: str, activity: Any, state: dict[str, Any]) -> str | None: + """ + Called when ChatPrompt is attempting to resume a deferred function. + + Args: + function_name: Name of the function that was deferred + activity: The activity data to use for resolving + state: The state that was saved when function was deferred + + Returns: + Result string if this plugin handled the resuming, None otherwise + """ + ... + class BaseAIPlugin: """ @@ -165,3 +179,7 @@ async def on_build_functions(self, functions: list[Function[BaseModel]]) -> list async def on_build_instructions(self, instructions: SystemMessage | None) -> SystemMessage | None: """Modify the system message before sending to model.""" return instructions + + async def on_resume(self, function_name: str, activity: Any, state: dict[str, Any]) -> str | None: + """Called when ChatPrompt is attempting to resume a deferred function.""" + return None diff --git a/packages/ai/src/microsoft/teams/ai/utils/__init__.py b/packages/ai/src/microsoft/teams/ai/utils/__init__.py new file mode 100644 index 00000000..7d06c498 --- /dev/null +++ b/packages/ai/src/microsoft/teams/ai/utils/__init__.py @@ -0,0 +1,8 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +from .function_utils import execute_function, get_function_schema, parse_function_arguments + +__all__ = ["get_function_schema", "parse_function_arguments", "execute_function"] diff --git a/packages/openai/src/microsoft/teams/openai/function_utils.py b/packages/ai/src/microsoft/teams/ai/utils/function_utils.py similarity index 69% rename from packages/openai/src/microsoft/teams/openai/function_utils.py rename to packages/ai/src/microsoft/teams/ai/utils/function_utils.py index 012647a4..d60b83e5 100644 --- a/packages/openai/src/microsoft/teams/openai/function_utils.py +++ b/packages/ai/src/microsoft/teams/ai/utils/function_utils.py @@ -3,11 +3,19 @@ Licensed under the MIT License. """ -from typing import Any, Dict, Optional +import inspect +from typing import Any, Dict, Optional, cast -from microsoft.teams.ai import Function from pydantic import BaseModel, ConfigDict, create_model +from ..function import ( + DeferredFunctionHandler, + DeferredResult, + Function, + FunctionHandler, + FunctionHandlerWithNoParams, +) + def get_function_schema(func: Function[Any]) -> Dict[str, Any]: """ @@ -59,3 +67,20 @@ def parse_function_arguments(func: Function[Any], arguments: Dict[str, Any]) -> else: # For Pydantic model schemas, parse normally return func.parameter_schema(**arguments) + + +async def execute_function(function: Function[Any], arguments: Dict[str, Any]) -> str | DeferredResult: + parsed_args = parse_function_arguments(function, arguments) + if parsed_args: + # Handle both sync and async function handlers + handler = cast(FunctionHandler[BaseModel] | DeferredFunctionHandler[BaseModel], function.handler) + result = handler(parsed_args) + else: + handler = cast(FunctionHandlerWithNoParams, function.handler) + result = handler() + + if inspect.isawaitable(result): + fn_res = await result + else: + fn_res = result + return fn_res diff --git a/packages/ai/tests/test_chat_prompt.py b/packages/ai/tests/test_chat_prompt.py index 0f39dab7..94c496c5 100644 --- a/packages/ai/tests/test_chat_prompt.py +++ b/packages/ai/tests/test_chat_prompt.py @@ -139,7 +139,7 @@ async def test_string_input_conversion(self, mock_model: MockAIModel) -> None: result = await prompt.send("Hello world") assert isinstance(result, ChatSendResult) - assert result.response.content == "GENERATED - Hello world" + assert result.response and result.response.content == "GENERATED - Hello world" @pytest.mark.asyncio async def test_memory_updates(self) -> None: @@ -182,6 +182,7 @@ async def test_function_handler_execution(self, mock_function_handler: Mock) -> result = await prompt.send("Call the function") # Verify the function call is in the response + assert isinstance(result.response, ModelMessage) assert result.response.function_calls is not None assert len(result.response.function_calls) == 1 assert result.response.function_calls[0].name == "test_function" @@ -223,10 +224,12 @@ async def test_full_conversation_flow(self, test_function: Function[MockFunction # First exchange result1 = await prompt.send("Hello", memory=memory) + assert isinstance(result1.response, ModelMessage) assert result1.response.content == "GENERATED - Hello" # Second exchange result2 = await prompt.send("How are you?", memory=memory) + assert isinstance(result2.response, ModelMessage) assert result2.response.content == "GENERATED - How are you?" # Verify memory contains complete conversation history @@ -284,16 +287,19 @@ async def test_different_message_types(self, mock_model: MockAIModel) -> None: # String input result1 = await prompt.send("String input") + assert isinstance(result1.response, ModelMessage) assert result1.response.content == "GENERATED - String input" # UserMessage input user_msg = UserMessage(content="User message") result2 = await prompt.send(user_msg) + assert isinstance(result2.response, ModelMessage) assert result2.response.content == "GENERATED - User message" # ModelMessage input (for function calling scenarios) model_msg = ModelMessage(content="Model message", function_calls=None) result3 = await prompt.send(model_msg) + assert isinstance(result3.response, ModelMessage) assert result3.response.content == "GENERATED - Model message" @pytest.mark.asyncio @@ -334,6 +340,7 @@ def handler_no_params() -> str: # Verify both work in send result = await prompt.send("Test message") + assert isinstance(result.response, ModelMessage) assert result.response.content == "GENERATED - Test message" @@ -431,6 +438,7 @@ async def test_on_before_send_hook(self, mock_model: MockAIModel) -> None: assert mock_model.last_input is not None assert mock_model.last_input.content == "MODIFIED: Original message" # Verify the response reflects the modified input + assert isinstance(result.response, ModelMessage) assert result.response.content == "GENERATED - MODIFIED: Original message" @pytest.mark.asyncio @@ -443,6 +451,7 @@ async def test_on_after_send_hook(self, mock_model: MockAIModel) -> None: result = await prompt.send("Test message") assert plugin.after_send_called + assert isinstance(result.response, ModelMessage) assert result.response.content == "RESPONSE_MODIFIED: GENERATED - Test message" @pytest.mark.asyncio @@ -493,6 +502,7 @@ async def test_function_call_hooks(self, mock_function_handler: Mock) -> None: # Verify after hook was called and modified result assert len(plugin.after_function_called) == 1 assert plugin.after_function_called[0][0] == "test_function" + assert isinstance(result.response, ModelMessage) assert result.response.content is not None assert "FUNCTION_MODIFIED: Function executed successfully" in result.response.content @@ -536,6 +546,7 @@ async def test_multiple_plugins_execution_order(self, mock_model: MockAIModel) - assert plugin2.after_send_called # Input should be modified by both plugins in order + assert isinstance(result.response, ModelMessage) assert result.response.content == "SECOND_RESP: FIRST_RESP: GENERATED - SECOND: FIRST: Original" @pytest.mark.asyncio @@ -557,6 +568,7 @@ async def on_after_send(self, response: ModelMessage) -> ModelMessage | None: result = await prompt.send("Test message") # Should be unchanged since plugin returned None + assert isinstance(result.response, ModelMessage) assert result.response.content == "GENERATED - Test message" @pytest.mark.asyncio @@ -568,6 +580,8 @@ async def test_empty_plugin_list_maintains_compatibility(self, mock_model: MockA result_with = await prompt_with_plugins.send("Test message") result_without = await prompt_without_plugins.send("Test message") + assert isinstance(result_with.response, ModelMessage) + assert isinstance(result_without.response, ModelMessage) assert result_with.response.content == result_without.response.content @pytest.mark.asyncio @@ -592,6 +606,7 @@ async def test_plugin_with_async_function_handler(self, mock_function_handler: M # Verify function was called and result was modified by plugin assert len(plugin.before_function_called) == 1 assert len(plugin.after_function_called) == 1 + assert isinstance(result.response, ModelMessage) assert result.response.content is not None assert "ASYNC_MODIFIED: Function executed successfully" in result.response.content @@ -621,6 +636,7 @@ async def test_base_plugin_default_implementations(self, mock_model: MockAIModel # Should work without any issues using default implementations result = await prompt.send("Test with base plugin") + assert isinstance(result.response, ModelMessage) assert result.response.content == "GENERATED - Test with base plugin" # Test with functions too @@ -631,6 +647,7 @@ def handler(params: MockFunctionParams) -> str: prompt_with_func = ChatPrompt(mock_model, functions=[test_function], plugins=[base_plugin]) result2 = await prompt_with_func.send("Test with function") + assert isinstance(result2.response, ModelMessage) assert result2.response.content == "GENERATED - Test with function" @pytest.mark.asyncio @@ -673,6 +690,7 @@ async def test_comprehensive_plugin_behavior_verification(self, mock_function_ha assert "test_function" in mock_model.last_functions # Verify final response includes all modifications + assert isinstance(result.response, ModelMessage) assert result.response.content is not None assert "RESP_MOD:" in result.response.content assert "FUNC_MOD: Function executed successfully" in result.response.content diff --git a/packages/openai/tests/test_function_utils.py b/packages/ai/tests/test_function_utils.py similarity index 98% rename from packages/openai/tests/test_function_utils.py rename to packages/ai/tests/test_function_utils.py index dc202983..faf7005c 100644 --- a/packages/openai/tests/test_function_utils.py +++ b/packages/ai/tests/test_function_utils.py @@ -8,8 +8,7 @@ from typing import Optional import pytest -from microsoft.teams.ai import Function -from microsoft.teams.openai.function_utils import get_function_schema, parse_function_arguments +from microsoft.teams.ai import Function, get_function_schema, parse_function_arguments from pydantic import BaseModel, ValidationError diff --git a/packages/common/tests/test_client.py b/packages/common/tests/test_client.py index f24e253a..36505075 100644 --- a/packages/common/tests/test_client.py +++ b/packages/common/tests/test_client.py @@ -3,6 +3,8 @@ Licensed under the MIT License. """ +# pyright: basic + import httpx import pytest from microsoft.teams.common.http import Client, ClientOptions, Interceptor diff --git a/packages/common/tests/test_event_emitter.py b/packages/common/tests/test_event_emitter.py index ad6a67c2..6a1ce9b5 100644 --- a/packages/common/tests/test_event_emitter.py +++ b/packages/common/tests/test_event_emitter.py @@ -3,6 +3,8 @@ Licensed under the MIT License. """ +# pyright: basic + import asyncio from unittest.mock import Mock diff --git a/packages/common/tests/test_logging_filter.py b/packages/common/tests/test_logging_filter.py index 0f99409e..6b812101 100644 --- a/packages/common/tests/test_logging_filter.py +++ b/packages/common/tests/test_logging_filter.py @@ -3,6 +3,8 @@ Licensed under the MIT License. """ +# pyright: basic + import logging from unittest.mock import MagicMock diff --git a/packages/common/tests/test_logging_formatter.py b/packages/common/tests/test_logging_formatter.py index b30923f0..bd9b5737 100644 --- a/packages/common/tests/test_logging_formatter.py +++ b/packages/common/tests/test_logging_formatter.py @@ -3,6 +3,8 @@ Licensed under the MIT License. """ +# pyright: basic + import logging from typing import Collection, Union diff --git a/packages/openai/src/microsoft/teams/openai/completions_model.py b/packages/openai/src/microsoft/teams/openai/completions_model.py index ad531496..70ec2e67 100644 --- a/packages/openai/src/microsoft/teams/openai/completions_model.py +++ b/packages/openai/src/microsoft/teams/openai/completions_model.py @@ -3,7 +3,6 @@ Licensed under the MIT License. """ -import inspect import json from dataclasses import dataclass from typing import Any, Awaitable, Callable, TypedDict, cast @@ -19,8 +18,11 @@ ModelMessage, SystemMessage, UserMessage, + get_function_schema, ) -from microsoft.teams.ai.function import FunctionHandler, FunctionHandlerWithNoParams +from microsoft.teams.ai.function import DeferredResult +from microsoft.teams.ai.message import DeferredMessage +from microsoft.teams.ai.utils.function_utils import execute_function from microsoft.teams.openai.common import OpenAIBaseModel from pydantic import BaseModel @@ -39,8 +41,6 @@ ChatCompletionUserMessageParam, ) -from .function_utils import get_function_schema, parse_function_arguments - class _ToolCallData(TypedDict): """ @@ -68,13 +68,13 @@ class OpenAICompletionsAIModel(OpenAIBaseModel, AIModel): async def generate_text( self, - input: Message, + input: Message | None, *, system: SystemMessage | None = None, memory: Memory | None = None, functions: dict[str, Function[BaseModel]] | None = None, on_chunk: Callable[[str], Awaitable[None]] | None = None, - ) -> ModelMessage: + ) -> ModelMessage | list[DeferredMessage]: """ Generate text using OpenAI Chat Completions API. @@ -97,28 +97,36 @@ async def generate_text( if memory is None: memory = ListMemory() - # Execute any pending function calls first - function_results = await self._execute_functions(input, functions) - # Get conversation history from memory (make a copy to avoid modifying memory's internal state) messages = list(await memory.get_all()) + + # Execute any pending function calls first + function_results = await self._execute_functions(input, messages, functions) self.logger.debug(f"Retrieved {len(messages)} messages from memory, {len(function_results)} function results") # Push current input to memory - await memory.push(input) + if input is not None: + await memory.push(input) # Push function results to memory and add to messages + deferred_messages: list[DeferredMessage] = [] if function_results: # Add the original ModelMessage with function_calls to messages first - messages.append(input) + if input is not None: + messages.append(input) for result in function_results: await memory.push(result) messages.append(result) + if isinstance(result, DeferredMessage): + deferred_messages.append(result) # Don't add input again at the end - Order matters here! input_to_send = None else: input_to_send = input + if len(deferred_messages) > 0: + return deferred_messages + # Convert messages to OpenAI format openai_messages = self._convert_messages(input_to_send, system, messages) self.logger.debug(f"Converted to {len(openai_messages)} OpenAI messages") @@ -153,35 +161,37 @@ async def generate_text( return model_response async def _execute_functions( - self, input: Message, functions: dict[str, Function[BaseModel]] | None - ) -> list[FunctionMessage]: + self, input: Message | None, memory_messages: list[Message], functions: dict[str, Function[BaseModel]] | None + ) -> list[FunctionMessage | DeferredMessage]: """Execute any pending function calls in the input message.""" - function_results: list[FunctionMessage] = [] + function_results: list[FunctionMessage | DeferredMessage] = [] if isinstance(input, ModelMessage) and input.function_calls: # Execute any pending function calls self.logger.debug(f"Executing {len(input.function_calls)} function calls") for call in input.function_calls: + existing_function_result = next( + ( + message + for message in memory_messages + if isinstance(message, FunctionMessage) and message.function_id == call.id + ), + None, + ) + if existing_function_result is None: + self.logger.debug(f"{call.name} already called. Skipping exeuction") if functions and call.name in functions: function = functions[call.name] try: # Parse arguments using utility function - parsed_args = parse_function_arguments(function, call.arguments) - if parsed_args: - # Handle both sync and async function handlers - handler = cast(FunctionHandler[BaseModel], function.handler) - result = handler(parsed_args) + fn_res = await execute_function(function, call.arguments) + if isinstance(fn_res, DeferredResult): + function_results.append( + DeferredMessage(deferred_result=fn_res, function_name=call.name, function_id=call.id) + ) else: - handler = cast(FunctionHandlerWithNoParams, function.handler) - result = handler() - - if inspect.isawaitable(result): - fn_res = await result - else: - fn_res = result - - # Create function result message - function_results.append(FunctionMessage(content=fn_res, function_id=call.id)) + # Create function result message + function_results.append(FunctionMessage(content=fn_res, function_id=call.id)) except Exception as e: self.logger.error(e) # Handle function execution errors @@ -264,37 +274,43 @@ def _convert_messages( return openai_messages def _convert_message_to_openai_format(self, message: Message) -> ChatCompletionMessageParam: - if isinstance( - message, - UserMessage, - ): - return ChatCompletionUserMessageParam(role=message.role, content=message.content) - if isinstance(message, SystemMessage): - return ChatCompletionSystemMessageParam(role=message.role, content=message.content) - - elif isinstance(message, FunctionMessage): - return ChatCompletionToolMessageParam( - role="tool", - content=message.content or [], - tool_call_id=message.function_id, - ) - elif isinstance(message, ModelMessage): # pyright: ignore [reportUnnecessaryIsInstance] - if message.function_calls: - tool_calls = [ - ChatCompletionMessageFunctionToolCallParam( - id=call.id, - function={"name": call.name, "arguments": json.dumps(call.arguments)}, - type="function", + match message: + case UserMessage(): + return ChatCompletionUserMessageParam(role=message.role, content=message.content) + case SystemMessage(): + return ChatCompletionSystemMessageParam(role=message.role, content=message.content) + case FunctionMessage(): + return ChatCompletionToolMessageParam( + role="tool", + content=message.content or [], + tool_call_id=message.function_id, + ) + case ModelMessage(): + if message.function_calls: + tool_calls = [ + ChatCompletionMessageFunctionToolCallParam( + id=call.id, + function={"name": call.name, "arguments": json.dumps(call.arguments)}, + type="function", + ) + for call in message.function_calls + ] + else: + # we need to do this cast because Completions expects tool_calls to be >= 1, + # but the type is not Optional + tool_calls = cast(list[ChatCompletionMessageFunctionToolCallParam], None) + return ChatCompletionAssistantMessageParam( + role="assistant", content=message.content, tool_calls=tool_calls + ) + case DeferredMessage(): + raise ValueError( + ( + "A deferred_message should not be sent to OpenAI. It needs to be resolved " + "and converted to a FunctionMessage." ) - for call in message.function_calls - ] - else: - # we need to do this cast because Completions expects tool_calls to be >= 1, - # but the type is not Optional - tool_calls = cast(list[ChatCompletionMessageFunctionToolCallParam], None) - return ChatCompletionAssistantMessageParam(role="assistant", content=message.content, tool_calls=tool_calls) - else: - raise Exception(f"Message {message.role} not supported") + ) + case _: + raise Exception(f"Message {message.role} not supported") def _convert_functions(self, functions: dict[str, Function[BaseModel]]) -> list[ChatCompletionToolUnionParam]: function_values = functions.values() diff --git a/packages/openai/src/microsoft/teams/openai/responses_chat_model.py b/packages/openai/src/microsoft/teams/openai/responses_chat_model.py index 5ffd8b4d..c5f64f86 100644 --- a/packages/openai/src/microsoft/teams/openai/responses_chat_model.py +++ b/packages/openai/src/microsoft/teams/openai/responses_chat_model.py @@ -10,6 +10,7 @@ from microsoft.teams.ai import ( AIModel, + DeferredMessage, Function, FunctionCall, FunctionHandler, @@ -21,6 +22,8 @@ ModelMessage, SystemMessage, UserMessage, + get_function_schema, + parse_function_arguments, ) from pydantic import BaseModel @@ -40,7 +43,6 @@ ) from .common import OpenAIBaseModel -from .function_utils import get_function_schema, parse_function_arguments @dataclass @@ -57,13 +59,13 @@ class OpenAIResponsesAIModel(OpenAIBaseModel, AIModel): async def generate_text( self, - input: Message, + input: Message | None, *, system: SystemMessage | None = None, memory: Memory | None = None, functions: dict[str, Function[BaseModel]] | None = None, on_chunk: Callable[[str], Awaitable[None]] | None = None, - ) -> ModelMessage: + ) -> ModelMessage | list[DeferredMessage]: """ Generate text using OpenAI Responses API. @@ -95,13 +97,13 @@ async def generate_text( async def _send_stateful( self, - input: Message, + input: Message | None, system: SystemMessage | None, memory: Memory, functions: dict[str, Function[BaseModel]] | None, on_chunk: Callable[[str], Awaitable[None]] | None, function_results: list[FunctionMessage], - ) -> ModelMessage: + ) -> ModelMessage | list[DeferredMessage]: """Handle stateful conversation using OpenAI Responses API state management.""" # Get response IDs from memory - OpenAI manages conversation state messages = list(await memory.get_all()) @@ -163,21 +165,22 @@ async def _send_stateful( async def _send_stateless( self, - input: Message, + input: Message | None, system: SystemMessage | None, memory: Memory, functions: dict[str, Function[BaseModel]] | None, on_chunk: Callable[[str], Awaitable[None]] | None, function_results: list[FunctionMessage], - ) -> ModelMessage: + ) -> ModelMessage | list[DeferredMessage]: """Handle stateless conversation using standard OpenAI API pattern.""" # Get conversation history from memory (make a copy to avoid modifying memory's internal state) messages = list(await memory.get_all()) self.logger.debug(f"Retrieved {len(messages)} messages from memory") - # Push current input to memory - await memory.push(input) - messages.append(input) + if input: + # Push current input to memory + await memory.push(input) + messages.append(input) # Push function results to memory and add to messages if function_results: @@ -229,7 +232,7 @@ async def _send_stateless( return model_response async def _execute_functions( - self, input: Message, functions: dict[str, Function[BaseModel]] | None + self, input: Message | None, functions: dict[str, Function[BaseModel]] | None ) -> list[FunctionMessage]: """Execute any pending function calls in the input message.""" function_results: list[FunctionMessage] = [] diff --git a/packages/openai/tests/test_openai_completions_model.py b/packages/openai/tests/test_openai_completions_model.py index 2da098a2..26063486 100644 --- a/packages/openai/tests/test_openai_completions_model.py +++ b/packages/openai/tests/test_openai_completions_model.py @@ -76,6 +76,7 @@ async def test_generate_text_basic_message( result = await model.generate_text(input_msg) # Assertions + assert isinstance(result, ModelMessage) assert result.content == "Hello, world!" assert result.function_calls is None diff --git a/tests/defferred_ai/README.md b/tests/defferred_ai/README.md new file mode 100644 index 00000000..da3b70c7 --- /dev/null +++ b/tests/defferred_ai/README.md @@ -0,0 +1,82 @@ +# Deferred AI Test + +Test application demonstrating approval workflow using `ApprovalPlugin`. + +## What This Demonstrates + +This test shows how to use the `ApprovalPlugin` to wrap functions that require human approval before execution. + +### How It Works + +1. **User asks to buy stocks**: "Buy 10 shares of MSFT" +2. **AI calls the function**: The AI model calls `buy_stock(stock="MSFT", quantity=10)` +3. **Plugin intercepts**: ApprovalPlugin wraps the function and defers execution +4. **Approval requested**: User sees approval request with function details +5. **User responds**: "yes" or "no" +6. **Plugin resumes**: + - If approved → executes original function and returns result + - If denied → returns cancellation message + +## Usage + +```bash +# Start the app +python src/main.py + +# In chat, ask to buy stocks +> Buy 10 shares of MSFT + +# You'll see approval request +> Approval Required +> Function: buy_stock +> Parameters: {'stock': 'MSFT', 'quantity': 10} +> +> Please respond with: +> - 'yes' or 'approve' to confirm +> - 'no' or 'deny' to cancel + +# Respond with approval +> yes + +# Stock purchase executes +> ✅ Successfully purchased 10 shares of MSFT. Order executed at market price. +``` + +## Code Overview + +```python +# Create your function +stock_function = Function( + name="buy_stock", + description="purchase stocks by specifying ticker symbol and quantity", + parameter_schema=BuyStockParams, + handler=lambda params: f"✅ Successfully purchased {params.quantity} shares of {params.stock}", +) + +# Wrap it with approval +approval_plugin = ApprovalPlugin( + sender=ctx, + fn_names=["buy_stock"] # Functions that need approval +) + +# Add to ChatPrompt +chat_prompt = ChatPrompt( + model=ai_model, + functions=[stock_function], + memory=memory, +).with_plugin(approval_plugin) + +# Use normally - approval happens automatically +if await chat_prompt.requires_resuming(): + result = await chat_prompt.resume(ctx.activity) +else: + result = await chat_prompt.send(ctx.activity.text) +``` + +## Key Benefits + +- ✅ **Clean code**: Just specify which functions need approval +- ✅ **No function modification**: Original functions stay unchanged +- ✅ **Automatic deferral**: Plugin handles all the deferred execution logic +- ✅ **Reusable**: Same plugin works across different ChatPrompts +- ✅ **Natural UX**: AI calls functions normally, approval is transparent diff --git a/tests/defferred_ai/pyproject.toml b/tests/defferred_ai/pyproject.toml new file mode 100644 index 00000000..d262dc5a --- /dev/null +++ b/tests/defferred_ai/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "defferred_ai" +version = "0.1.0" +description = "testing deferred tools" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "dotenv>=0.9.9", + "microsoft-teams-apps", +] + +[tool.uv.sources] +microsoft-teams-apps = { workspace = true } diff --git a/tests/defferred_ai/src/approval_for_function.py b/tests/defferred_ai/src/approval_for_function.py new file mode 100644 index 00000000..edfefe89 --- /dev/null +++ b/tests/defferred_ai/src/approval_for_function.py @@ -0,0 +1,170 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +import logging +from typing import Any, Protocol + +from microsoft.teams.ai import BaseAIPlugin, DeferredResult, Function, execute_function +from microsoft.teams.api import MessageActivityInput +from microsoft.teams.common.logging.console import ConsoleLogger +from pydantic import BaseModel + + +class MessageSender(Protocol): + """Protocol for anything that can send messages.""" + + async def send(self, message: str | MessageActivityInput) -> Any: + """Send a message.""" + ... + + +class ApprovalPlugin(BaseAIPlugin): + """ + Plugin that wraps specified functions with approval workflow. + + This plugin intercepts function calls, requests approval from the user, + and executes the original function only after approval is granted. + """ + + def __init__(self, sender: MessageSender, functions: list[Function[Any]], *, logger: logging.Logger | None = None): + """ + Initialize the approval plugin. + + Args: + sender: Message sender for sending approval requests + fn_names: List of function names to wrap with approval workflow + """ + super().__init__("approval") + self.sender = sender + self.logger: logging.Logger = logger or ConsoleLogger().create_logger("ApprovalPlugin") + self._original_functions: dict[str, Function[BaseModel]] = {f.name: f for f in functions} + + async def on_resume(self, function_name: str, activity: Any, state: dict[str, Any]) -> str | None: + """ + Handle approval responses when resuming deferred functions. + + Args: + function_name: Name of the function that was deferred + activity: Activity data to use for resolving + state: The state that was saved when function was deferred + + Returns: + Result string if this plugin handled the approval, None otherwise + """ + # Only handle functions we're wrapping + if function_name not in self._original_functions: + return None + + # Check if this activity has text (duck typing for MessageActivity) + if not hasattr(activity, "text") or not isinstance(activity.text, str): + return None + + text = activity.text.lower().strip() + approval_keywords = ["yes", "no", "approve", "deny", "reject", "confirm", "cancel"] + if not any(keyword in text for keyword in approval_keywords): + return None # Not an approval response yet + + # Handle approval/denial + if any(word in text for word in ["yes", "approve", "confirm"]): + return await self._execute_wrapped_function(function_name, state) + else: + return f"Denied: Execution of {function_name} was cancelled by user." + + async def on_build_functions(self, functions: list[Function[BaseModel]]) -> list[Function[BaseModel]] | None: + """ + Wrap specified functions with approval workflow. + + Args: + functions: Current list of available functions + + Returns: + Updated function list with wrapped functions + """ + # Wrap each specified function + wrapped_functions: list[Function[BaseModel]] = [] + for func in functions: + if func.name in self._original_functions: + if func.resumer is not None: + self.logger.warning( + f"{func.name} seems to be a resumable function. ApprovalPlugin only works" + "for functions that are not resumable themselves." + ) + continue + wrapped_func = self._create_wrapped_function(func) + wrapped_functions.append(wrapped_func) + else: + wrapped_functions.append(func) + + return wrapped_functions + + def _create_wrapped_function(self, original_func: Function[BaseModel]) -> Function[BaseModel]: + """ + Create a wrapped version of a function that requires approval. + + Args: + original_func: The original function to wrap + + Returns: + Wrapped function that defers for approval before execution + """ + # Store original function for later execution + + self.logger.debug(f"Wrapping {original_func.name} with ApprovalPlugin Function") + + async def wrapped_handler(params: BaseModel) -> DeferredResult: + """Handler that requests approval before executing original function.""" + # Send approval request + await self.sender.send( + f"Approval Required\n\n" + f"Function: {original_func.name}\n" + f"Parameters: {params.model_dump()}\n\n" + "Please respond with:\n" + "- 'yes' or 'approve' to confirm\n" + "- 'no' or 'deny' to cancel" + ) + + # Save params for later execution + return DeferredResult( + state={ + "params": params.model_dump(), + "original_function_name": original_func.name, + }, + ) + + return Function( + name=original_func.name, + description=original_func.description, + parameter_schema=original_func.parameter_schema, + handler=wrapped_handler, + resumer=None, # Plugin handles resuming via on_resume hook + ) + + async def _execute_wrapped_function(self, function_name: str, state: dict[str, Any]) -> str: + """ + Execute the original wrapped function after approval. + + Args: + function_name: Name of the function to execute + state: State containing saved parameters + + Returns: + Result from executing the original function + """ + original_func = self._original_functions.get(function_name) + if not original_func: + raise ValueError(f"Could not re-run original function {function_name} because it no longer exists") + try: + # Recreate params from saved state + saved_params = state.get("params", {}) + self.logger.info(f"Running original function {function_name} after approval") + result = await execute_function(original_func, saved_params) + if isinstance(result, DeferredResult): + raise ValueError( + "Functions that use ApprovalPlugin cannot be deferrable!" + f"And {original_func.name} just returned a DeferredResult" + ) + return result + except Exception as e: + return f"Approved but Failed\nError executing {function_name}: {str(e)}" diff --git a/tests/defferred_ai/src/main.py b/tests/defferred_ai/src/main.py new file mode 100644 index 00000000..91074035 --- /dev/null +++ b/tests/defferred_ai/src/main.py @@ -0,0 +1,100 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +import asyncio +from os import getenv + +from approval_for_function import ApprovalPlugin +from dotenv import find_dotenv, load_dotenv +from microsoft.teams.ai import ChatPrompt, Function, ListMemory +from microsoft.teams.api import MessageActivity, MessageActivityInput +from microsoft.teams.apps import ActivityContext, App +from microsoft.teams.devtools import DevToolsPlugin +from microsoft.teams.openai import OpenAICompletionsAIModel +from pydantic import BaseModel + +load_dotenv(find_dotenv(usecwd=True)) + + +app = App(plugins=[DevToolsPlugin()]) + + +def get_required_env(key: str) -> str: + value = getenv(key) + if not value: + raise ValueError(f"Required environment variable {key} is not set") + return value + + +# Get OpenAI model (like in ai-test) +AZURE_OPENAI_MODEL = get_required_env("AZURE_OPENAI_MODEL") +ai_model = OpenAICompletionsAIModel(model=AZURE_OPENAI_MODEL) + + +class BuyStockParams(BaseModel): + stock: str + quantity: int + + +def create_buy_stock_function() -> Function[BuyStockParams]: + """Create a buy stock function.""" + + def handler(params: BuyStockParams) -> str: + print("Actually running the buy stock fn") + return f"✅ Successfully purchased {params.quantity} shares of {params.stock}. Order executed at market price." + + return Function( + name="buy_stock", + description="purchase stocks by specifying ticker symbol and quantity", + parameter_schema=BuyStockParams, + handler=handler, + ) + + +# Global memory instance +memory = ListMemory() + + +@app.on_message +async def handle_stock_trading(ctx: ActivityContext[MessageActivity]) -> None: + """Handle stock trading with approval using ApprovalPlugin.""" + print(f"[STOCK TRADING] Message received: {ctx.activity.text}") + + try: + # Create stock function (will be wrapped by plugin) + stock_function = create_buy_stock_function() + + # Create approval plugin with fn_names to wrap + approval_plugin = ApprovalPlugin(sender=ctx, functions=[stock_function]) + + chat_prompt = ChatPrompt( + instructions=( + "You are a helpful assistant. Use the available stock trading tool when users want to buy stocks." + ), + model=ai_model, + functions=[stock_function], # Plugin will wrap this function + memory=memory, + ).with_plugin(approval_plugin) + + # Handle deferred functions or normal chat + if await chat_prompt.requires_resuming(): + chat_result = await chat_prompt.resume(ctx.activity) + else: + chat_result = await chat_prompt.send(input=ctx.activity.text) + + if chat_result.response and chat_result.response.content: + message = MessageActivityInput(text=chat_result.response.content).add_ai_generated() + await ctx.send(message) + elif chat_result.is_deferred: + # Approval message already sent by the plugin + pass + + except Exception as e: + print(f"[STOCK TRADING] Error: {str(e)}") + await ctx.send(f"❌ Error: {str(e)}") + + +if __name__ == "__main__": + asyncio.run(app.start()) diff --git a/uv.lock b/uv.lock index c8f4267e..c8e662aa 100644 --- a/uv.lock +++ b/uv.lock @@ -11,6 +11,7 @@ members = [ "a2a", "ai-test", "cards", + "defferred-ai", "dialogs", "echo", "graph", @@ -618,6 +619,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f0/8b/2c95f0645c6f40211896375e6fa51f504b8ccb29c21f6ae661fe87ab044e/cyclopts-3.24.0-py3-none-any.whl", hash = "sha256:809d04cde9108617106091140c3964ee6fceb33cecdd537f7ffa360bde13ed71", size = 86154, upload-time = "2025-09-08T15:40:56.41Z" }, ] +[[package]] +name = "defferred-ai" +version = "0.1.0" +source = { virtual = "tests/defferred_ai" } +dependencies = [ + { name = "dotenv" }, + { name = "microsoft-teams-apps" }, +] + +[package.metadata] +requires-dist = [ + { name = "dotenv", specifier = ">=0.9.9" }, + { name = "microsoft-teams-apps", editable = "packages/apps" }, +] + [[package]] name = "dependency-injector" version = "4.48.2"