Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions gemini_live_boilerplate/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gemini_live_handler import GeminiClient
from utils import function_tool
from tools import schedule_meet_tool, cancel_meet_tool, get_current_time # pylint: disable=no-name-in-module
from tool_context import ToolContext # pylint: disable=no-name-in-module

RECORDINGS_DIR = "recordings"

Expand All @@ -36,10 +37,12 @@ class WebSocketHandler:

def __init__(self, ws):
self.websocket = ws
self.connection_id = f"conn_{int(time.time() * 1000)}"
self.gemini = None
self.tool_context = ToolContext(session_id=self.connection_id)

self.audio_queue = asyncio.Queue()
self.response_queue = asyncio.Queue()
self.connection_id = f"conn_{int(time.time() * 1000)}"
self._gemini_session = None
self._gemini_client_initialized = False
self.session_usage_metadata = None
Expand Down Expand Up @@ -252,7 +255,8 @@ async def _process_gemini_stream(self):
function_response = await self.gemini.call_function(
fc_id=func_call.id,
fc_name=tool_call_name,
fc_args=tool_call_args
fc_args=tool_call_args,
tool_ctx=self.tool_context
)
await self.response_queue.put({"event": "tool_response","data": {"name": tool_call_name, "args": function_response.response}})
function_responses.append(function_response)
Expand Down
37 changes: 18 additions & 19 deletions gemini_live_boilerplate/server/gemini_live_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import base64
import asyncio
import logging
from typing import Callable, List
from typing import Callable, List, Optional
from google import genai
from google.genai import types as genai_types

# TODO: Add more specific prompt
from prompt import SYSTEM_PROMPT
from tool_context import ToolContext # pylint: disable=no-name-in-module
from utils import accepts_tool_context

# TODO: Use same logger everywhere, don't import new one
logging.basicConfig(
Expand Down Expand Up @@ -63,35 +65,32 @@ def _setup_config(self):
)

# NOTE: Can add support for injected tool arg like langchain, for later
async def call_function(self, fc_id: str, fc_name: str, fc_args=None):
async def call_function(self,
fc_id: str,
fc_name: str,
fc_args=None,
tool_ctx: ToolContext = None):
"""Calls the functions that were defined and returns the function response"""
func_args = fc_args if fc_args else {}
func_args = fc_args.copy() if fc_args else {}

try:
for tool in self.tools:
if getattr(tool, "name", None) == fc_name:
if asyncio.iscoroutinefunction(tool):
function_result = await tool(**func_args)
else:
function_result = tool(**func_args)
return genai_types.FunctionResponse(
id=fc_id,
name=fc_name,
response={
"result": function_result
}
)
if callable(tool) and tool.__name__ == fc_name:
tool_name = getattr(tool, "name", None) or getattr(tool, "__name__", None)
if tool_name == fc_name and callable(tool):
# Check if func accepts tool context
param, accepts = accepts_tool_context(tool)
if accepts and tool_ctx:
func_args[param] = tool_ctx

if asyncio.iscoroutinefunction(tool):
function_result = await tool(**func_args)
else:
function_result = tool(**func_args)

return genai_types.FunctionResponse(
id=fc_id,
name=fc_name,
response={
"result": function_result
}
response={"result": function_result}
)
logger.error(f"Function with name '{fc_name}' is not defined.")
return genai_types.FunctionResponse(
Expand Down
86 changes: 86 additions & 0 deletions gemini_live_boilerplate/server/tool_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Manages state for agent and tools throughout the session"""
from typing import Optional, Dict, Any

class ToolContext:
"""Manages state and tool context for tools within a session
This class provides a simple and safe interface for tools to read, write,
and delete data related to the current session, enabling stateful
multi-turn interactions.
"""
def __init__(self,
session_id: str,
initial_state: Optional[Dict[str, Any]] = None):
"""Initialize the context
Args:
session_id: A unique identifier for the current session.
initial_state: An optional dictionary to pre-populate the state.
"""
self.session_id = session_id
if initial_state:
self._state = initial_state.copy()
else:
self._state = {}

def dump_state(self) -> Dict[str, Any]:
"""
Returns a copy of the entire state dictionary.
Returns:
A shallow copy of the internal state dictionary, preventing direct mutation.
"""
return self._state.copy()

def get(self,
key: str,
default: Optional[Any] = None) -> Any:
"""
Retrieves the value of a key from the state.

Args:
key: The key of the value to retrieve.
default: The value to return if the key is not found.

Returns:
The value associated with the key, or the default value if not found.
"""

return self._state.get(key, default)

def update(self, **kwargs) -> Dict[str, Any]:
"""
Updates the state with one or more key-value pairs.
Overwrites keys if they are already present.

Args:
**kwargs: Arbitrary keyword arguments to add or update in the state.

Returns:
The instance of the class to allow for method chaining
(e.g., context.update(...).update(...))

Usage:
updated_data = ctx.update(key1=value1, key2=value2...)
or
data_to_update = {"key1": "value1", "key2": "value2"}
updated_data = ctx.update(**data_to_update)
"""
self._state.update(kwargs)
return self._state.copy()

def delete(self, key: str) -> bool:
"""
Deletes a key-value pair from the state if it exists.

Args:
key: The key to delete from the state.

Returns:
True if the key was found and deleted, False otherwise.
"""
if key in self._state:
del self._state[key]
return True
return False

def clear(self) -> None:
"""Clears the entire state, resetting it to an empty dictionary."""
self._state = {}
47 changes: 34 additions & 13 deletions gemini_live_boilerplate/server/tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Common module for function tools"""
# pylint: disable=line-too-long
# pylint: disable=line-too-long, too-many-positional-arguments
from enum import Enum
from typing import Optional, Annotated
import uuid
import datetime
from utils import function_tool # pylint: disable=no-name-in-module
from tool_context import ToolContext # pylint: disable=no-name-in-module

class MeetingRoom(Enum):
"""Enum class for meeting room"""
Expand All @@ -13,30 +15,49 @@ class MeetingRoom(Enum):

@function_tool
def schedule_meet_tool(
tool_ctx: ToolContext,
attendees: Annotated[list[str], "List of the people attending the meeting"],
topic: Annotated[str, "The subject or the topic of the meeting"],
date: Annotated[str, "The date of the meeting (e.g., 25/06/2025)"],
meeting_room: Annotated[MeetingRoom, "The name of the meeting room."] = MeetingRoom.VIRTUAL,
time_slot: Annotated[Optional[str], "Time of the meeting (e.g., '14:00'-'15:00'). Immediate schedule if value not provided."] = "Now"):
time_slot: Annotated[Optional[str], "Time of the meeting (e.g., '14:00'-'15:00'). Immediate schedule if value not provided."] = "Now",
):
"""Schedules meeting for a given list of attendees at a given time and date"""
meet_id = str(uuid.uuid4())[:5]

response_message = f"Meeting with Topic: '{topic}' is successfully scheduled. \n\n"
response_message += f"**Meeting Details**:\nAttendees: {attendees}.\nMeeting room: {meeting_room.value}\nDate: {date}\nTime slot: {time_slot}"
if tool_ctx:
meeting_details = {
"meet_id": meet_id,
"attendees": attendees,
"topic": topic,
"date": date,
"meeting_room": meeting_room,
"time_slot": time_slot
}
# Update the state with these new values
# There are a lot of ways to customize this
tool_ctx.update(**meeting_details)
print(f"Current State: {tool_ctx.dump_state()}")

response_message = f"Meeting with Topic: '{topic}' is successfully scheduled with ID {meet_id}. \n\n"
response_message += f"**Meeting Details**:\nAttendees: {attendees}.\nMeeting room: {meeting_room}\nDate: {date}\nTime slot: {time_slot}"

return response_message

@function_tool
def cancel_meet_tool(meet_id: Annotated[str, "The id of the meeting to cancel in lower case"]):
def cancel_meet_tool(meet_id: Annotated[str, "The id of the meeting to cancel in lower case"],
tool_ctx: ToolContext):
"""Cancels the meeting with the given ID"""

if meet_id.startswith("a"):
response_message = f"Successfully cancelled meeting with ID: {meet_id}"
return response_message
if meet_id.startswith("b"):
response_message = "The meeting is currently in progress, unable to cancel this meeting."
return response_message

return "An error occurred while cancelling the meeting. Please make sure meeting ID is valid."
try:
if tool_ctx.get("meet_id") == meet_id:
# We can add more logic here, but keeping it simple for now
print(f"Current State: {tool_ctx.dump_state()}")
return f"Successfully cancelled meeting with ID: {meet_id}"
return f"Meeting with ID: {meet_id} does not exist."
except Exception as e:
print(f"Error occured when canceling a meeting. {str(e)}")
return "An error occurred while cancelling the meeting. Please make sure meeting ID is valid."

@function_tool
async def get_current_time(country: Annotated[str, "Name of the country"]) -> dict:
Expand Down
17 changes: 17 additions & 0 deletions gemini_live_boilerplate/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from enum import Enum
from typing import get_type_hints, Optional, Union, get_args, get_origin, Annotated

from tool_context import ToolContext # pylint: disable=no-name-in-module

class OpenAPITypes(Enum):
"""The basic data types defined by OpenAPI 3.0"""

Expand Down Expand Up @@ -141,6 +143,13 @@ def _process_parameters(self):
actual_type, description_annotation = get_args(py_type)
description = description_annotation

# Check if the type is `ToolContext` directly or wrapped in Optional (Union)
is_direct_context = actual_type is ToolContext
is_optional_context = get_origin(actual_type) is Union and ToolContext in get_args(actual_type) # pylint: disable=line-too-long

if is_direct_context or is_optional_context:
continue

# Create the base schema from the actual type
param_schema = self._create_schema_from_type(actual_type)

Expand Down Expand Up @@ -188,3 +197,11 @@ def function_tool(func):
"""
func.tool_metadata = FunctionTool(func).to_declaration().oas_format
return func

def accepts_tool_context(func):
"""Return (param_name, True) if a function has a ToolContext parameter, else (None, False)."""
type_hints = get_type_hints(func)
for param, annotation in type_hints.items():
if annotation is ToolContext:
return param, True
return None, False