Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions google/genai/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import httpx
import json
import websockets
from . import _common


Expand Down Expand Up @@ -69,14 +70,26 @@ def _rebuild(state: dict[str, Any]) -> 'APIError':
return obj

def _get_status(self, response_json: Any) -> Any:
return response_json.get(
'status', response_json.get('error', {}).get('status', None)
)
try:
status = response_json.get(
'status', response_json.get('error', {}).get('status', None)
)
return status
except AttributeError:
# If response_json is not a dict, return close code to handle the case
# when encountering a websocket error.
return None

def _get_message(self, response_json: Any) -> Any:
return response_json.get(
'message', response_json.get('error', {}).get('message', None)
)
try:
message = response_json.get(
'message', response_json.get('error', {}).get('message', None)
)
return message
except AttributeError:
# If response_json is not a dict, return it as None.
# This is to handle the case when encountering a websocket error.
return None

def _get_code(self, response_json: Any) -> Any:
return response_json.get(
Expand Down
24 changes: 22 additions & 2 deletions google/genai/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import google.auth
import pydantic
from websockets import ConnectionClosed
import websockets

from . import _api_module
from . import _common
Expand All @@ -41,6 +41,7 @@
from .live_music import AsyncLiveMusic
from .models import _Content_to_mldev

ConnectionClosed = websockets.ConnectionClosed

try:
from websockets.asyncio.client import ClientConnection
Expand Down Expand Up @@ -534,6 +535,14 @@ async def _receive(self) -> types.LiveServerMessage:
raw_response = await self._ws.recv(decode=False)
except TypeError:
raw_response = await self._ws.recv() # type: ignore[assignment]
except ConnectionClosed as e:
if e.rcvd:
code = e.rcvd.code
reason = e.rcvd.reason
else:
code = 1006
reason = websockets.frames.CLOSE_CODE_EXPLANATIONS.get(code, 'Abnormal closure.')
errors.APIError.raise_error(code, reason, None)
if raw_response:
try:
response = json.loads(raw_response)
Expand All @@ -545,8 +554,11 @@ async def _receive(self) -> types.LiveServerMessage:
if self._api_client.vertexai:
response_dict = live_converters._LiveServerMessage_from_vertex(response)
else:
response_dict = response
response_dict = live_converters._LiveServerMessage_from_mldev(response)

if not response_dict and response:
# Error handling.
errors.APIError.raise_error(response.get('code'), response, None)
return types.LiveServerMessage._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
Expand Down Expand Up @@ -1093,6 +1105,14 @@ async def connect(
raw_response = await ws.recv(decode=False)
except TypeError:
raw_response = await ws.recv() # type: ignore[assignment]
except ConnectionClosed as e:
if e.rcvd:
code = e.rcvd.code
reason = e.rcvd.reason
else:
code = 1006
reason = 'Abnormal closure.'
errors.APIError.raise_error(code, reason, None)
if raw_response:
try:
response = json.loads(raw_response)
Expand Down
15 changes: 14 additions & 1 deletion google/genai/live_music.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@
import json
import logging
from typing import AsyncIterator
import websockets

from . import _api_module
from . import _common
from . import _live_converters as live_converters
from . import _transformers as t
from . import errors
from . import types
from ._api_client import BaseApiClient
from ._common import set_value_by_path as setv

ConnectionClosed = websockets.ConnectionClosed

try:
from websockets.asyncio.client import ClientConnection
Expand Down Expand Up @@ -122,6 +125,14 @@ async def _receive(self) -> types.LiveMusicServerMessage:
raw_response = await self._ws.recv(decode=False)
except TypeError:
raw_response = await self._ws.recv() # type: ignore[assignment]
except ConnectionClosed as e:
if e.rcvd:
code = e.rcvd.code
reason = e.rcvd.reason
else:
code = 1006
reason = websockets.frames.CLOSE_CODE_EXPLANATIONS.get(code, 'Abnormal closure.')
errors.APIError.raise_error(code, reason, None)
if raw_response:
try:
response = json.loads(raw_response)
Expand All @@ -134,7 +145,9 @@ async def _receive(self) -> types.LiveMusicServerMessage:
raise NotImplementedError('Live music generation is not supported in Vertex AI.')
else:
response_dict = response

if not response_dict and response:
# Error handling.
errors.APIError.raise_error(response.get('code'), response, None)
return types.LiveMusicServerMessage._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
Expand Down
38 changes: 38 additions & 0 deletions google/genai/tests/errors/test_api_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import httpx
import pytest
import websockets

from ... import errors

Expand Down Expand Up @@ -257,6 +258,43 @@ def test_constructor_message_not_present():
}


def test_constructor_with_websocket_connection_closed_error():
actual_error = errors.APIError(
1007,
'At most one response modality can be specified in the setup request.'
' To enable simultaneous transcription and audio output,',
None,
)
assert actual_error.code == 1007
assert (
actual_error.details
== 'At most one response modality can be specified in the setup request.'
' To enable simultaneous transcription and audio output,',
)
assert actual_error.status == None
assert actual_error.message == None


def test_raise_for_websocket_connection_closed_error():
try:
errors.APIError.raise_error(
1007,
'At most one response modality can be specified in the setup request.'
' To enable simultaneous transcription and audio output,',
None,
)
except errors.APIError as actual_error:
assert actual_error.code == 1007
assert (
actual_error.details
== 'At most one response modality can be specified in the setup'
' request.'
' To enable simultaneous transcription and audio output,'
)
assert actual_error.status == None
assert actual_error.message == None


def test_raise_for_response_code_exist_json_decoder_error():
class FakeResponse(httpx.Response):

Expand Down