Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gemini_live_boilerplate/client/pcm-player-processor.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class PCMPlayerProcessor extends AudioWorkletProcessor {
super();

// Init buffer
this.bufferSize = 24000 * 1; // 24kHz x 1 seconds(reduced from 180s, to not buffer on client side, client should playback what's coming)
this.bufferSize = 24000 * 180; // 24kHz x 1 seconds(reduced from 180s, to not buffer on client side, client should playback what's coming)
this.buffer = new Float32Array(this.bufferSize);
this.writeIndex = 0;
this.readIndex = 0;
Expand Down
38 changes: 29 additions & 9 deletions gemini_live_boilerplate/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from quart import Quart, websocket

from gemini_live_handler import GeminiClient
from tools import schedule_meet_tool, cancel_meet_tool # pylint: disable=no-name-in-module
from utils import function_tool
from tools import schedule_meet_tool, cancel_meet_tool, get_current_time # pylint: disable=no-name-in-module

RECORDINGS_DIR = "recordings"

Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(self, ws):
self._gemini_client_initialized = False
self.session_usage_metadata = None
self.session_started_event = asyncio.Event()
self._end_session_event = asyncio.Event()
self._tasks = []

self.session_start_time = None
Expand Down Expand Up @@ -71,6 +73,22 @@ def _finalize_and_store_model_utterance(self):
self.current_model_utterance_chunks = []
self.current_model_utterance_start_ms = None

# NOTE: Change the name as needed and add extra functions
# such as static goodbye read out if needed.
@function_tool
async def end_call(self):
"""Handles ending the session by the bot"""
logger.info(f"[{self.connection_id}] end_call tool invoked by LLM.")

await self.response_queue.put({
"event": "end_session",
"data": {"reason": "Bot ended the call"}
})

self._end_session_event.set()

return {"status": "ok", "message": "Session will end now."}

async def _receive_from_client(self):
"""Handles receiving messages from the client"""
logger.info(f"[{self.connection_id}] Receiving task started.")
Expand Down Expand Up @@ -107,6 +125,7 @@ async def _receive_from_client(self):
elif message_event == "end_session":
# Client wants to end the session
logger.info(f"[{self.connection_id}] End session event received.")
self._end_session_event.set()
# Exit this task gracefully. The main handler will clean up.
return
else:
Expand Down Expand Up @@ -230,7 +249,7 @@ async def _process_gemini_stream(self):
tool_call_args = func_call.args
logger.info(f"Tool call with {tool_call_name} {tool_call_args}")
await self.response_queue.put({"event": "tool_call","data": {"name": tool_call_name, "args": tool_call_args}})
function_response = self.gemini.call_function(
function_response = await self.gemini.call_function(
fc_id=func_call.id,
fc_name=tool_call_name,
fc_args=tool_call_args
Expand Down Expand Up @@ -319,7 +338,7 @@ async def handle_websocket_connection(self):
# Initialize Gemini client class
self.gemini = GeminiClient(
api_key=gemini_api_key,
tools=[schedule_meet_tool, cancel_meet_tool]
tools=[schedule_meet_tool, cancel_meet_tool, get_current_time, self.end_call]
)
self._gemini_client_initialized = True

Expand All @@ -332,15 +351,16 @@ async def handle_websocket_connection(self):
sender = asyncio.create_task(self._send_to_client(), name=f"sender[{self.connection_id}]")

self._tasks = [receiver, processor, sender]
end_event_task = asyncio.create_task(self._end_session_event.wait())

# Wait for any one of the task to complete
# done, pending = await asyncio.wait(
# self._tasks,
# return_when=asyncio.FIRST_COMPLETED,
# )
# Since google-genai library throws an error if we wait for first complete
# we wait for all tasks to complete themselves
await asyncio.gather(*self._tasks)
# await asyncio.gather(*self._tasks)
# Since we are using wait, if end event happens, we will see a harmless warning
await asyncio.wait(
[*self._tasks, end_event_task],
return_when=asyncio.FIRST_COMPLETED
)

except ConnectionClosedOK:
# Clean shutdown. Gemini library throws this exception
Expand Down
68 changes: 45 additions & 23 deletions gemini_live_boilerplate/server/gemini_live_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Gemini live client handler"""
import base64
import asyncio
import logging
from typing import Callable, List
from google import genai
Expand Down Expand Up @@ -61,33 +62,54 @@ def _setup_config(self):
output_audio_transcription=genai_types.AudioTranscriptionConfig(),
)

# TODO: Add support for async functions, identify if they are coroutines and proceed
# NOTE: Can add support for injected tool arg like langchain, for later
def call_function(self, fc_id: str, fc_name: str, fc_args=None):
async def call_function(self, fc_id: str, fc_name: str, fc_args=None):
"""Calls the functions that were defined and returns the function response"""
func_args = fc_args if fc_args else {}

for tool in self.tools:
if getattr(tool, "name", None) == fc_name:
function_result = tool(**func_args)
return genai_types.FunctionResponse(
id=fc_id,
name=fc_name,
response={
"result": function_result
}
)
if callable(tool) and tool.__name__ == fc_name:
function_result = tool(**func_args)
return genai_types.FunctionResponse(
id=fc_id,
name=fc_name,
response={
"result": function_result
}
)
logger.error(f"Function with name '{fc_name}' is not defined.")
raise ValueError(f"Function with '{fc_name}' is not defined.")
try:
for tool in self.tools:
if getattr(tool, "name", None) == fc_name:
if asyncio.iscoroutinefunction(tool):
function_result = await tool(**func_args)
else:
function_result = tool(**func_args)
return genai_types.FunctionResponse(
id=fc_id,
name=fc_name,
response={
"result": function_result
}
)
if callable(tool) and tool.__name__ == fc_name:
if asyncio.iscoroutinefunction(tool):
function_result = await tool(**func_args)
else:
function_result = tool(**func_args)
return genai_types.FunctionResponse(
id=fc_id,
name=fc_name,
response={
"result": function_result
}
)
logger.error(f"Function with name '{fc_name}' is not defined.")
return genai_types.FunctionResponse(
id=fc_id,
name=fc_name,
response={
"result": f"Function with `{fc_name}` is not defined."
}
)
except Exception as e:
logger.error(f"Error occured invoking '{fc_name}' with {fc_args}. Error: {str(e)}.")
return genai_types.FunctionResponse(
id=fc_id,
name=fc_name,
response={
"result": "Error occured invoking function. Unable to execute the function"
}
)

def convert_audio_for_client(self, audio_data: bytes) -> str:
"""Converts audio data to base64 for client"""
Expand Down
2 changes: 1 addition & 1 deletion gemini_live_boilerplate/server/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
python-dotenv==1.1.1
quart==0.20.0
quart-cors==0.8.0
google-genai==1.23.0
google-genai==1.39.1
pydub==0.25.1
20 changes: 20 additions & 0 deletions gemini_live_boilerplate/server/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=line-too-long
from enum import Enum
from typing import Optional, Annotated
import datetime
from utils import function_tool # pylint: disable=no-name-in-module

class MeetingRoom(Enum):
Expand Down Expand Up @@ -36,3 +37,22 @@ def cancel_meet_tool(meet_id: Annotated[str, "The id of the meeting to cancel in
return response_message

return "An error occurred while cancelling the meeting. Please make sure meeting ID is valid."

@function_tool
async def get_current_time(country: Annotated[str, "Name of the country"]) -> dict:
"""Returns the current time in a specified country."""
if country.lower() == "india":
tz = datetime.timezone(datetime.timedelta(hours=5, minutes=30))
else:
return {
"status": "error",
"error_message": (
f"Sorry, I don't have timezone information for {country}."
),
}

now = datetime.datetime.now(tz)
report = (
f'The current time in {country} is {now.strftime("%Y-%m-%d %H:%M:%S %Z%z")}'
)
return {"status": "success", "result": report}