From d932c05183d9a04a10e925c25950bfd19f0ab55e Mon Sep 17 00:00:00 2001 From: Amy Wu Date: Fri, 30 Jan 2026 10:15:19 -0800 Subject: [PATCH] fix: add error handling for live and live music APIs fixes #668 PiperOrigin-RevId: 863291046 --- google/genai/errors.py | 25 ++++++++++---- google/genai/live.py | 24 +++++++++++-- google/genai/live_music.py | 15 +++++++- google/genai/tests/errors/test_api_error.py | 38 +++++++++++++++++++++ 4 files changed, 93 insertions(+), 9 deletions(-) diff --git a/google/genai/errors.py b/google/genai/errors.py index 64979ab0e..63d9334b9 100644 --- a/google/genai/errors.py +++ b/google/genai/errors.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union import httpx import json +import websockets from . import _common @@ -69,14 +70,26 @@ def _rebuild(state: dict[str, Any]) -> 'APIError': return obj def _get_status(self, response_json: Any) -> Any: - return response_json.get( - 'status', response_json.get('error', {}).get('status', None) - ) + try: + status = response_json.get( + 'status', response_json.get('error', {}).get('status', None) + ) + return status + except AttributeError: + # If response_json is not a dict, return close code to handle the case + # when encountering a websocket error. + return None def _get_message(self, response_json: Any) -> Any: - return response_json.get( - 'message', response_json.get('error', {}).get('message', None) - ) + try: + message = response_json.get( + 'message', response_json.get('error', {}).get('message', None) + ) + return message + except AttributeError: + # If response_json is not a dict, return it as None. + # This is to handle the case when encountering a websocket error. + return None def _get_code(self, response_json: Any) -> Any: return response_json.get( diff --git a/google/genai/live.py b/google/genai/live.py index 966fa641e..93953a02f 100644 --- a/google/genai/live.py +++ b/google/genai/live.py @@ -26,7 +26,7 @@ import google.auth import pydantic -from websockets import ConnectionClosed +import websockets from . import _api_module from . import _common @@ -41,6 +41,7 @@ from .live_music import AsyncLiveMusic from .models import _Content_to_mldev +ConnectionClosed = websockets.ConnectionClosed try: from websockets.asyncio.client import ClientConnection @@ -534,6 +535,14 @@ async def _receive(self) -> types.LiveServerMessage: raw_response = await self._ws.recv(decode=False) except TypeError: raw_response = await self._ws.recv() # type: ignore[assignment] + except ConnectionClosed as e: + if e.rcvd: + code = e.rcvd.code + reason = e.rcvd.reason + else: + code = 1006 + reason = websockets.frames.CLOSE_CODE_EXPLANATIONS.get(code, 'Abnormal closure.') + errors.APIError.raise_error(code, reason, None) if raw_response: try: response = json.loads(raw_response) @@ -545,8 +554,11 @@ async def _receive(self) -> types.LiveServerMessage: if self._api_client.vertexai: response_dict = live_converters._LiveServerMessage_from_vertex(response) else: - response_dict = response + response_dict = live_converters._LiveServerMessage_from_mldev(response) + if not response_dict and response: + # Error handling. + errors.APIError.raise_error(response.get('code'), response, None) return types.LiveServerMessage._from_response( response=response_dict, kwargs=parameter_model.model_dump() ) @@ -1093,6 +1105,14 @@ async def connect( raw_response = await ws.recv(decode=False) except TypeError: raw_response = await ws.recv() # type: ignore[assignment] + except ConnectionClosed as e: + if e.rcvd: + code = e.rcvd.code + reason = e.rcvd.reason + else: + code = 1006 + reason = 'Abnormal closure.' + errors.APIError.raise_error(code, reason, None) if raw_response: try: response = json.loads(raw_response) diff --git a/google/genai/live_music.py b/google/genai/live_music.py index 2f739d5b6..8730e08f9 100644 --- a/google/genai/live_music.py +++ b/google/genai/live_music.py @@ -19,15 +19,18 @@ import json import logging from typing import AsyncIterator +import websockets from . import _api_module from . import _common from . import _live_converters as live_converters from . import _transformers as t +from . import errors from . import types from ._api_client import BaseApiClient from ._common import set_value_by_path as setv +ConnectionClosed = websockets.ConnectionClosed try: from websockets.asyncio.client import ClientConnection @@ -122,6 +125,14 @@ async def _receive(self) -> types.LiveMusicServerMessage: raw_response = await self._ws.recv(decode=False) except TypeError: raw_response = await self._ws.recv() # type: ignore[assignment] + except ConnectionClosed as e: + if e.rcvd: + code = e.rcvd.code + reason = e.rcvd.reason + else: + code = 1006 + reason = websockets.frames.CLOSE_CODE_EXPLANATIONS.get(code, 'Abnormal closure.') + errors.APIError.raise_error(code, reason, None) if raw_response: try: response = json.loads(raw_response) @@ -134,7 +145,9 @@ async def _receive(self) -> types.LiveMusicServerMessage: raise NotImplementedError('Live music generation is not supported in Vertex AI.') else: response_dict = response - + if not response_dict and response: + # Error handling. + errors.APIError.raise_error(response.get('code'), response, None) return types.LiveMusicServerMessage._from_response( response=response_dict, kwargs=parameter_model.model_dump() ) diff --git a/google/genai/tests/errors/test_api_error.py b/google/genai/tests/errors/test_api_error.py index dfaf61b40..bc77b374e 100644 --- a/google/genai/tests/errors/test_api_error.py +++ b/google/genai/tests/errors/test_api_error.py @@ -23,6 +23,7 @@ import httpx import pytest +import websockets from ... import errors @@ -257,6 +258,43 @@ def test_constructor_message_not_present(): } +def test_constructor_with_websocket_connection_closed_error(): + actual_error = errors.APIError( + 1007, + 'At most one response modality can be specified in the setup request.' + ' To enable simultaneous transcription and audio output,', + None, + ) + assert actual_error.code == 1007 + assert ( + actual_error.details + == 'At most one response modality can be specified in the setup request.' + ' To enable simultaneous transcription and audio output,', + ) + assert actual_error.status == None + assert actual_error.message == None + + +def test_raise_for_websocket_connection_closed_error(): + try: + errors.APIError.raise_error( + 1007, + 'At most one response modality can be specified in the setup request.' + ' To enable simultaneous transcription and audio output,', + None, + ) + except errors.APIError as actual_error: + assert actual_error.code == 1007 + assert ( + actual_error.details + == 'At most one response modality can be specified in the setup' + ' request.' + ' To enable simultaneous transcription and audio output,' + ) + assert actual_error.status == None + assert actual_error.message == None + + def test_raise_for_response_code_exist_json_decoder_error(): class FakeResponse(httpx.Response):