diff --git a/gemini_live_boilerplate/server/app.py b/gemini_live_boilerplate/server/app.py index 333f18d..ea29bfd 100644 --- a/gemini_live_boilerplate/server/app.py +++ b/gemini_live_boilerplate/server/app.py @@ -14,6 +14,7 @@ from gemini_live_handler import GeminiClient from utils import function_tool from tools import schedule_meet_tool, cancel_meet_tool, get_current_time # pylint: disable=no-name-in-module +from tool_context import ToolContext # pylint: disable=no-name-in-module RECORDINGS_DIR = "recordings" @@ -36,10 +37,12 @@ class WebSocketHandler: def __init__(self, ws): self.websocket = ws + self.connection_id = f"conn_{int(time.time() * 1000)}" self.gemini = None + self.tool_context = ToolContext(session_id=self.connection_id) + self.audio_queue = asyncio.Queue() self.response_queue = asyncio.Queue() - self.connection_id = f"conn_{int(time.time() * 1000)}" self._gemini_session = None self._gemini_client_initialized = False self.session_usage_metadata = None @@ -252,7 +255,8 @@ async def _process_gemini_stream(self): function_response = await self.gemini.call_function( fc_id=func_call.id, fc_name=tool_call_name, - fc_args=tool_call_args + fc_args=tool_call_args, + tool_ctx=self.tool_context ) await self.response_queue.put({"event": "tool_response","data": {"name": tool_call_name, "args": function_response.response}}) function_responses.append(function_response) diff --git a/gemini_live_boilerplate/server/gemini_live_handler.py b/gemini_live_boilerplate/server/gemini_live_handler.py index d1d1874..fa23f46 100644 --- a/gemini_live_boilerplate/server/gemini_live_handler.py +++ b/gemini_live_boilerplate/server/gemini_live_handler.py @@ -2,12 +2,14 @@ import base64 import asyncio import logging -from typing import Callable, List +from typing import Callable, List, Optional from google import genai from google.genai import types as genai_types # TODO: Add more specific prompt from prompt import SYSTEM_PROMPT +from tool_context import ToolContext # pylint: disable=no-name-in-module +from utils import accepts_tool_context # TODO: Use same logger everywhere, don't import new one logging.basicConfig( @@ -63,35 +65,32 @@ def _setup_config(self): ) # NOTE: Can add support for injected tool arg like langchain, for later - async def call_function(self, fc_id: str, fc_name: str, fc_args=None): + async def call_function(self, + fc_id: str, + fc_name: str, + fc_args=None, + tool_ctx: ToolContext = None): """Calls the functions that were defined and returns the function response""" - func_args = fc_args if fc_args else {} + func_args = fc_args.copy() if fc_args else {} try: for tool in self.tools: - if getattr(tool, "name", None) == fc_name: - if asyncio.iscoroutinefunction(tool): - function_result = await tool(**func_args) - else: - function_result = tool(**func_args) - return genai_types.FunctionResponse( - id=fc_id, - name=fc_name, - response={ - "result": function_result - } - ) - if callable(tool) and tool.__name__ == fc_name: + tool_name = getattr(tool, "name", None) or getattr(tool, "__name__", None) + if tool_name == fc_name and callable(tool): + # Check if func accepts tool context + param, accepts = accepts_tool_context(tool) + if accepts and tool_ctx: + func_args[param] = tool_ctx + if asyncio.iscoroutinefunction(tool): function_result = await tool(**func_args) else: function_result = tool(**func_args) + return genai_types.FunctionResponse( id=fc_id, name=fc_name, - response={ - "result": function_result - } + response={"result": function_result} ) logger.error(f"Function with name '{fc_name}' is not defined.") return genai_types.FunctionResponse( diff --git a/gemini_live_boilerplate/server/tool_context.py b/gemini_live_boilerplate/server/tool_context.py new file mode 100644 index 0000000..8b34b1b --- /dev/null +++ b/gemini_live_boilerplate/server/tool_context.py @@ -0,0 +1,86 @@ +"""Manages state for agent and tools throughout the session""" +from typing import Optional, Dict, Any + +class ToolContext: + """Manages state and tool context for tools within a session + This class provides a simple and safe interface for tools to read, write, + and delete data related to the current session, enabling stateful + multi-turn interactions. + """ + def __init__(self, + session_id: str, + initial_state: Optional[Dict[str, Any]] = None): + """Initialize the context + Args: + session_id: A unique identifier for the current session. + initial_state: An optional dictionary to pre-populate the state. + """ + self.session_id = session_id + if initial_state: + self._state = initial_state.copy() + else: + self._state = {} + + def dump_state(self) -> Dict[str, Any]: + """ + Returns a copy of the entire state dictionary. + Returns: + A shallow copy of the internal state dictionary, preventing direct mutation. + """ + return self._state.copy() + + def get(self, + key: str, + default: Optional[Any] = None) -> Any: + """ + Retrieves the value of a key from the state. + + Args: + key: The key of the value to retrieve. + default: The value to return if the key is not found. + + Returns: + The value associated with the key, or the default value if not found. + """ + + return self._state.get(key, default) + + def update(self, **kwargs) -> Dict[str, Any]: + """ + Updates the state with one or more key-value pairs. + Overwrites keys if they are already present. + + Args: + **kwargs: Arbitrary keyword arguments to add or update in the state. + + Returns: + The instance of the class to allow for method chaining + (e.g., context.update(...).update(...)) + + Usage: + updated_data = ctx.update(key1=value1, key2=value2...) + or + data_to_update = {"key1": "value1", "key2": "value2"} + updated_data = ctx.update(**data_to_update) + """ + self._state.update(kwargs) + return self._state.copy() + + def delete(self, key: str) -> bool: + """ + Deletes a key-value pair from the state if it exists. + + Args: + key: The key to delete from the state. + + Returns: + True if the key was found and deleted, False otherwise. + """ + if key in self._state: + del self._state[key] + return True + return False + + def clear(self) -> None: + """Clears the entire state, resetting it to an empty dictionary.""" + self._state = {} diff --git a/gemini_live_boilerplate/server/tools.py b/gemini_live_boilerplate/server/tools.py index ddd4224..302e88f 100644 --- a/gemini_live_boilerplate/server/tools.py +++ b/gemini_live_boilerplate/server/tools.py @@ -1,9 +1,11 @@ """Common module for function tools""" -# pylint: disable=line-too-long +# pylint: disable=line-too-long, too-many-positional-arguments from enum import Enum from typing import Optional, Annotated +import uuid import datetime from utils import function_tool # pylint: disable=no-name-in-module +from tool_context import ToolContext # pylint: disable=no-name-in-module class MeetingRoom(Enum): """Enum class for meeting room""" @@ -13,30 +15,49 @@ class MeetingRoom(Enum): @function_tool def schedule_meet_tool( + tool_ctx: ToolContext, attendees: Annotated[list[str], "List of the people attending the meeting"], topic: Annotated[str, "The subject or the topic of the meeting"], date: Annotated[str, "The date of the meeting (e.g., 25/06/2025)"], meeting_room: Annotated[MeetingRoom, "The name of the meeting room."] = MeetingRoom.VIRTUAL, - time_slot: Annotated[Optional[str], "Time of the meeting (e.g., '14:00'-'15:00'). Immediate schedule if value not provided."] = "Now"): + time_slot: Annotated[Optional[str], "Time of the meeting (e.g., '14:00'-'15:00'). Immediate schedule if value not provided."] = "Now", + ): """Schedules meeting for a given list of attendees at a given time and date""" + meet_id = str(uuid.uuid4())[:5] - response_message = f"Meeting with Topic: '{topic}' is successfully scheduled. \n\n" - response_message += f"**Meeting Details**:\nAttendees: {attendees}.\nMeeting room: {meeting_room.value}\nDate: {date}\nTime slot: {time_slot}" + if tool_ctx: + meeting_details = { + "meet_id": meet_id, + "attendees": attendees, + "topic": topic, + "date": date, + "meeting_room": meeting_room, + "time_slot": time_slot + } + # Update the state with these new values + # There are a lot of ways to customize this + tool_ctx.update(**meeting_details) + print(f"Current State: {tool_ctx.dump_state()}") + + response_message = f"Meeting with Topic: '{topic}' is successfully scheduled with ID {meet_id}. \n\n" + response_message += f"**Meeting Details**:\nAttendees: {attendees}.\nMeeting room: {meeting_room}\nDate: {date}\nTime slot: {time_slot}" return response_message @function_tool -def cancel_meet_tool(meet_id: Annotated[str, "The id of the meeting to cancel in lower case"]): +def cancel_meet_tool(meet_id: Annotated[str, "The id of the meeting to cancel in lower case"], + tool_ctx: ToolContext): """Cancels the meeting with the given ID""" - if meet_id.startswith("a"): - response_message = f"Successfully cancelled meeting with ID: {meet_id}" - return response_message - if meet_id.startswith("b"): - response_message = "The meeting is currently in progress, unable to cancel this meeting." - return response_message - - return "An error occurred while cancelling the meeting. Please make sure meeting ID is valid." + try: + if tool_ctx.get("meet_id") == meet_id: + # We can add more logic here, but keeping it simple for now + print(f"Current State: {tool_ctx.dump_state()}") + return f"Successfully cancelled meeting with ID: {meet_id}" + return f"Meeting with ID: {meet_id} does not exist." + except Exception as e: + print(f"Error occured when canceling a meeting. {str(e)}") + return "An error occurred while cancelling the meeting. Please make sure meeting ID is valid." @function_tool async def get_current_time(country: Annotated[str, "Name of the country"]) -> dict: diff --git a/gemini_live_boilerplate/server/utils.py b/gemini_live_boilerplate/server/utils.py index 576bd2e..0fb7144 100644 --- a/gemini_live_boilerplate/server/utils.py +++ b/gemini_live_boilerplate/server/utils.py @@ -4,6 +4,8 @@ from enum import Enum from typing import get_type_hints, Optional, Union, get_args, get_origin, Annotated +from tool_context import ToolContext # pylint: disable=no-name-in-module + class OpenAPITypes(Enum): """The basic data types defined by OpenAPI 3.0""" @@ -141,6 +143,13 @@ def _process_parameters(self): actual_type, description_annotation = get_args(py_type) description = description_annotation + # Check if the type is `ToolContext` directly or wrapped in Optional (Union) + is_direct_context = actual_type is ToolContext + is_optional_context = get_origin(actual_type) is Union and ToolContext in get_args(actual_type) # pylint: disable=line-too-long + + if is_direct_context or is_optional_context: + continue + # Create the base schema from the actual type param_schema = self._create_schema_from_type(actual_type) @@ -188,3 +197,11 @@ def function_tool(func): """ func.tool_metadata = FunctionTool(func).to_declaration().oas_format return func + +def accepts_tool_context(func): + """Return (param_name, True) if a function has a ToolContext parameter, else (None, False).""" + type_hints = get_type_hints(func) + for param, annotation in type_hints.items(): + if annotation is ToolContext: + return param, True + return None, False