From 80a8cc77a0daf809b2e7209791d42f9863c1a55e Mon Sep 17 00:00:00 2001 From: Ruthvik-1411 Date: Sun, 28 Sep 2025 23:04:42 +0530 Subject: [PATCH 1/2] feat: add state/tool context management to live api --- gemini_live_boilerplate/server/app.py | 8 +- .../server/gemini_live_handler.py | 37 ++++---- .../server/tool_context.py | 85 +++++++++++++++++++ gemini_live_boilerplate/server/tools.py | 46 +++++++--- gemini_live_boilerplate/server/utils.py | 17 ++++ 5 files changed, 160 insertions(+), 33 deletions(-) create mode 100644 gemini_live_boilerplate/server/tool_context.py diff --git a/gemini_live_boilerplate/server/app.py b/gemini_live_boilerplate/server/app.py index 333f18d..ea29bfd 100644 --- a/gemini_live_boilerplate/server/app.py +++ b/gemini_live_boilerplate/server/app.py @@ -14,6 +14,7 @@ from gemini_live_handler import GeminiClient from utils import function_tool from tools import schedule_meet_tool, cancel_meet_tool, get_current_time # pylint: disable=no-name-in-module +from tool_context import ToolContext # pylint: disable=no-name-in-module RECORDINGS_DIR = "recordings" @@ -36,10 +37,12 @@ class WebSocketHandler: def __init__(self, ws): self.websocket = ws + self.connection_id = f"conn_{int(time.time() * 1000)}" self.gemini = None + self.tool_context = ToolContext(session_id=self.connection_id) + self.audio_queue = asyncio.Queue() self.response_queue = asyncio.Queue() - self.connection_id = f"conn_{int(time.time() * 1000)}" self._gemini_session = None self._gemini_client_initialized = False self.session_usage_metadata = None @@ -252,7 +255,8 @@ async def _process_gemini_stream(self): function_response = await self.gemini.call_function( fc_id=func_call.id, fc_name=tool_call_name, - fc_args=tool_call_args + fc_args=tool_call_args, + tool_ctx=self.tool_context ) await self.response_queue.put({"event": "tool_response","data": {"name": tool_call_name, "args": function_response.response}}) function_responses.append(function_response) diff --git a/gemini_live_boilerplate/server/gemini_live_handler.py b/gemini_live_boilerplate/server/gemini_live_handler.py index d1d1874..f32c08d 100644 --- a/gemini_live_boilerplate/server/gemini_live_handler.py +++ b/gemini_live_boilerplate/server/gemini_live_handler.py @@ -2,12 +2,14 @@ import base64 import asyncio import logging -from typing import Callable, List +from typing import Callable, List, Optional from google import genai from google.genai import types as genai_types # TODO: Add more specific prompt from prompt import SYSTEM_PROMPT +from tool_context import ToolContext # pylint: disable=no-name-in-module +from utils import accepts_tool_context # TODO: Use same logger everywhere, don't import new one logging.basicConfig( @@ -63,35 +65,32 @@ def _setup_config(self): ) # NOTE: Can add support for injected tool arg like langchain, for later - async def call_function(self, fc_id: str, fc_name: str, fc_args=None): + async def call_function(self, + fc_id: str, + fc_name: str, + fc_args=None, + tool_ctx: ToolContext = None): """Calls the functions that were defined and returns the function response""" - func_args = fc_args if fc_args else {} + func_args = fc_args.copy() if fc_args else {} try: for tool in self.tools: - if getattr(tool, "name", None) == fc_name: - if asyncio.iscoroutinefunction(tool): - function_result = await tool(**func_args) - else: - function_result = tool(**func_args) - return genai_types.FunctionResponse( - id=fc_id, - name=fc_name, - response={ - "result": function_result - } - ) - if callable(tool) and tool.__name__ == fc_name: + tool_name = getattr(tool, "name", None) or getattr(tool, "__name__", None) + if tool_name == fc_name and callable(tool): + # Check if func accepts tool context + param, accepts = accepts_tool_context(tool) + if accepts and tool_ctx: + func_args[param] = tool_ctx + if asyncio.iscoroutinefunction(tool): function_result = await tool(**func_args) else: function_result = tool(**func_args) + return genai_types.FunctionResponse( id=fc_id, name=fc_name, - response={ - "result": function_result - } + response={"result": function_result} ) logger.error(f"Function with name '{fc_name}' is not defined.") return genai_types.FunctionResponse( diff --git a/gemini_live_boilerplate/server/tool_context.py b/gemini_live_boilerplate/server/tool_context.py new file mode 100644 index 0000000..6d20746 --- /dev/null +++ b/gemini_live_boilerplate/server/tool_context.py @@ -0,0 +1,85 @@ +"""Manages state for agent and tools throughout the session""" +from typing import Optional, Dict, Any + +class ToolContext: + """Manages state and tool context for tools within a session + This class provides a simple and safe interface for tools to read, write, + and delete data related to the current session, enabling stateful + multi-turn interactions. + """ + def __init__(self, + session_id: str, + initial_state: Optional[Dict[str, Any]] = None): + """Initialize the context + Args: + session_id: A unique identifier for the current session. + initial_state: An optional dictionary to pre-populate the state. + """ + self.session_id = session_id + if initial_state: + self._state = initial_state.copy() + else: + self._state = {} + + def dump_state(self) -> Dict[str, Any]: + """ + Returns a copy of the entire state dictionary. + Returns: + A shallow copy of the internal state dictionary, preventing direct mutation. + """ + return self._state.copy() + + def get(self, + key: str, + default: Optional[Any] = None) -> Any: + """ + Retrieves the value of a key from the state. + + Args: + key: The key of the value to retrieve. + default: The value to return if the key is not found. + + Returns: + The value associated with the key, or the default value if not found. + """ + + return self._state.get(key, default) + + def update(self, **kwargs) -> Dict[str, Any]: + """ + Updates the state with one or more key-value pairs. + Overwrites keys if they are already present. + + Args: + **kwargs: Arbitrary keyword arguments to add or update in the state. + + Returns: + The instance of the class to allow for method chaining (e.g., context.update(...).update(...)) + + Usage: + updated_data = ctx.update(key1=value1, key2=value2...) + or + data_to_update = {"key1": "value1", "key2": "value2"} + updated_data = ctx.update(**data_to_update) + """ + self._state.update(kwargs) + return self._state.copy() + + def delete(self, key: str) -> bool: + """ + Deletes a key-value pair from the state if it exists. + + Args: + key: The key to delete from the state. + + Returns: + True if the key was found and deleted, False otherwise. + """ + if key in self._state: + del self._state[key] + return True + return False + + def clear(self) -> None: + """Clears the entire state, resetting it to an empty dictionary.""" + self._state = {} diff --git a/gemini_live_boilerplate/server/tools.py b/gemini_live_boilerplate/server/tools.py index ddd4224..7e1e2fc 100644 --- a/gemini_live_boilerplate/server/tools.py +++ b/gemini_live_boilerplate/server/tools.py @@ -2,8 +2,10 @@ # pylint: disable=line-too-long from enum import Enum from typing import Optional, Annotated +import uuid import datetime from utils import function_tool # pylint: disable=no-name-in-module +from tool_context import ToolContext # pylint: disable=no-name-in-module class MeetingRoom(Enum): """Enum class for meeting room""" @@ -13,30 +15,50 @@ class MeetingRoom(Enum): @function_tool def schedule_meet_tool( + tool_ctx: ToolContext, attendees: Annotated[list[str], "List of the people attending the meeting"], topic: Annotated[str, "The subject or the topic of the meeting"], date: Annotated[str, "The date of the meeting (e.g., 25/06/2025)"], meeting_room: Annotated[MeetingRoom, "The name of the meeting room."] = MeetingRoom.VIRTUAL, - time_slot: Annotated[Optional[str], "Time of the meeting (e.g., '14:00'-'15:00'). Immediate schedule if value not provided."] = "Now"): + time_slot: Annotated[Optional[str], "Time of the meeting (e.g., '14:00'-'15:00'). Immediate schedule if value not provided."] = "Now", + ): """Schedules meeting for a given list of attendees at a given time and date""" + meet_id = str(uuid.uuid4())[:5] - response_message = f"Meeting with Topic: '{topic}' is successfully scheduled. \n\n" - response_message += f"**Meeting Details**:\nAttendees: {attendees}.\nMeeting room: {meeting_room.value}\nDate: {date}\nTime slot: {time_slot}" + if tool_ctx: + meeting_details = { + "meet_id": meet_id, + "attendees": attendees, + "topic": topic, + "date": date, + "meeting_room": meeting_room, + "time_slot": time_slot + } + # Update the state with these new values + # There are a lot of ways to customize this + tool_ctx.update(**meeting_details) + print(f"Current State: {tool_ctx.dump_state()}") + + response_message = f"Meeting with Topic: '{topic}' is successfully scheduled with ID {meet_id}. \n\n" + response_message += f"**Meeting Details**:\nAttendees: {attendees}.\nMeeting room: {meeting_room}\nDate: {date}\nTime slot: {time_slot}" return response_message @function_tool -def cancel_meet_tool(meet_id: Annotated[str, "The id of the meeting to cancel in lower case"]): +def cancel_meet_tool(meet_id: Annotated[str, "The id of the meeting to cancel in lower case"], + tool_ctx: ToolContext): """Cancels the meeting with the given ID""" - if meet_id.startswith("a"): - response_message = f"Successfully cancelled meeting with ID: {meet_id}" - return response_message - if meet_id.startswith("b"): - response_message = "The meeting is currently in progress, unable to cancel this meeting." - return response_message - - return "An error occurred while cancelling the meeting. Please make sure meeting ID is valid." + try: + if tool_ctx.get("meet_id") == meet_id: + # We can add more logic here, but keeping it simple for now + print(f"Current State: {tool_ctx.dump_state()}") + return f"Successfully cancelled meeting with ID: {meet_id}" + else: + return f"Meeting with ID: {meet_id} does not exist." + except Exception as e: + print(f"Error occured when canceling a meeting. {str(e)}") + return "An error occurred while cancelling the meeting. Please make sure meeting ID is valid." @function_tool async def get_current_time(country: Annotated[str, "Name of the country"]) -> dict: diff --git a/gemini_live_boilerplate/server/utils.py b/gemini_live_boilerplate/server/utils.py index 576bd2e..fab29b2 100644 --- a/gemini_live_boilerplate/server/utils.py +++ b/gemini_live_boilerplate/server/utils.py @@ -4,6 +4,8 @@ from enum import Enum from typing import get_type_hints, Optional, Union, get_args, get_origin, Annotated +from tool_context import ToolContext # pylint: disable=no-name-in-module + class OpenAPITypes(Enum): """The basic data types defined by OpenAPI 3.0""" @@ -140,6 +142,13 @@ def _process_parameters(self): # Unpack the annotated type and its description actual_type, description_annotation = get_args(py_type) description = description_annotation + + # Check if the type is `ToolContext` directly or wrapped in Optional (Union) + is_direct_context = actual_type is ToolContext + is_optional_context = get_origin(actual_type) is Union and ToolContext in get_args(actual_type) + + if is_direct_context or is_optional_context: + continue # Create the base schema from the actual type param_schema = self._create_schema_from_type(actual_type) @@ -188,3 +197,11 @@ def function_tool(func): """ func.tool_metadata = FunctionTool(func).to_declaration().oas_format return func + +def accepts_tool_context(func): + """Return (param_name, True) if a function has a ToolContext parameter, else (None, False).""" + type_hints = get_type_hints(func) + for param, annotation in type_hints.items(): + if annotation is ToolContext: + return param, True + return None, False From 45a9f6a88dc77e304983a80c2495a638364b12d6 Mon Sep 17 00:00:00 2001 From: Ruthvik-1411 Date: Sun, 28 Sep 2025 23:16:06 +0530 Subject: [PATCH 2/2] chore: lint fixes --- .../server/gemini_live_handler.py | 2 +- gemini_live_boilerplate/server/tool_context.py | 13 +++++++------ gemini_live_boilerplate/server/tools.py | 5 ++--- gemini_live_boilerplate/server/utils.py | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/gemini_live_boilerplate/server/gemini_live_handler.py b/gemini_live_boilerplate/server/gemini_live_handler.py index f32c08d..fa23f46 100644 --- a/gemini_live_boilerplate/server/gemini_live_handler.py +++ b/gemini_live_boilerplate/server/gemini_live_handler.py @@ -81,7 +81,7 @@ async def call_function(self, param, accepts = accepts_tool_context(tool) if accepts and tool_ctx: func_args[param] = tool_ctx - + if asyncio.iscoroutinefunction(tool): function_result = await tool(**func_args) else: diff --git a/gemini_live_boilerplate/server/tool_context.py b/gemini_live_boilerplate/server/tool_context.py index 6d20746..8b34b1b 100644 --- a/gemini_live_boilerplate/server/tool_context.py +++ b/gemini_live_boilerplate/server/tool_context.py @@ -20,7 +20,7 @@ def __init__(self, self._state = initial_state.copy() else: self._state = {} - + def dump_state(self) -> Dict[str, Any]: """ Returns a copy of the entire state dictionary. @@ -44,7 +44,7 @@ def get(self, """ return self._state.get(key, default) - + def update(self, **kwargs) -> Dict[str, Any]: """ Updates the state with one or more key-value pairs. @@ -54,7 +54,8 @@ def update(self, **kwargs) -> Dict[str, Any]: **kwargs: Arbitrary keyword arguments to add or update in the state. Returns: - The instance of the class to allow for method chaining (e.g., context.update(...).update(...)) + The instance of the class to allow for method chaining + (e.g., context.update(...).update(...)) Usage: updated_data = ctx.update(key1=value1, key2=value2...) @@ -64,14 +65,14 @@ def update(self, **kwargs) -> Dict[str, Any]: """ self._state.update(kwargs) return self._state.copy() - + def delete(self, key: str) -> bool: """ Deletes a key-value pair from the state if it exists. Args: key: The key to delete from the state. - + Returns: True if the key was found and deleted, False otherwise. """ @@ -79,7 +80,7 @@ def delete(self, key: str) -> bool: del self._state[key] return True return False - + def clear(self) -> None: """Clears the entire state, resetting it to an empty dictionary.""" self._state = {} diff --git a/gemini_live_boilerplate/server/tools.py b/gemini_live_boilerplate/server/tools.py index 7e1e2fc..302e88f 100644 --- a/gemini_live_boilerplate/server/tools.py +++ b/gemini_live_boilerplate/server/tools.py @@ -1,5 +1,5 @@ """Common module for function tools""" -# pylint: disable=line-too-long +# pylint: disable=line-too-long, too-many-positional-arguments from enum import Enum from typing import Optional, Annotated import uuid @@ -54,8 +54,7 @@ def cancel_meet_tool(meet_id: Annotated[str, "The id of the meeting to cancel in # We can add more logic here, but keeping it simple for now print(f"Current State: {tool_ctx.dump_state()}") return f"Successfully cancelled meeting with ID: {meet_id}" - else: - return f"Meeting with ID: {meet_id} does not exist." + return f"Meeting with ID: {meet_id} does not exist." except Exception as e: print(f"Error occured when canceling a meeting. {str(e)}") return "An error occurred while cancelling the meeting. Please make sure meeting ID is valid." diff --git a/gemini_live_boilerplate/server/utils.py b/gemini_live_boilerplate/server/utils.py index fab29b2..0fb7144 100644 --- a/gemini_live_boilerplate/server/utils.py +++ b/gemini_live_boilerplate/server/utils.py @@ -142,10 +142,10 @@ def _process_parameters(self): # Unpack the annotated type and its description actual_type, description_annotation = get_args(py_type) description = description_annotation - + # Check if the type is `ToolContext` directly or wrapped in Optional (Union) is_direct_context = actual_type is ToolContext - is_optional_context = get_origin(actual_type) is Union and ToolContext in get_args(actual_type) + is_optional_context = get_origin(actual_type) is Union and ToolContext in get_args(actual_type) # pylint: disable=line-too-long if is_direct_context or is_optional_context: continue