Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions src/openai/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,24 @@ def __stream__(self) -> Iterator[_T]:
if sse.data.startswith("[DONE]"):
break

# Handle explicit "error" event type from Assistants API
if sse.event == "error":
data = sse.json()
message = None
if is_mapping(data):
message = data.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data,
)

# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
Expand Down Expand Up @@ -163,24 +163,24 @@ async def __stream__(self) -> AsyncIterator[_T]:
if sse.data.startswith("[DONE]"):
break

# Handle explicit "error" event type from Assistants API
if sse.event == "error":
data = sse.json()
message = None
if is_mapping(data):
message = data.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data,
)

# we have to special case the Assistants `thread.` events since we won't have an "event" key in the data
if sse.event and sse.event.startswith("thread."):
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
else:
data = sse.json()
Expand Down
80 changes: 80 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from openai import OpenAI, AsyncOpenAI
from openai._exceptions import APIError
from openai._streaming import Stream, AsyncStream, ServerSentEvent


Expand Down Expand Up @@ -216,6 +217,85 @@ def body() -> Iterator[bytes]:
assert sse.json() == {"content": "известни"}


@pytest.mark.asyncio
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
async def test_error_event_raises_api_error(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
"""
Test that an SSE event with event type 'error' raises an APIError.

This is a regression test for issue #2796 where the error event check was
unreachable dead code inside the 'thread.' event handling block.
"""

def body() -> Iterator[bytes]:
yield b"event: error\n"
yield b'data: {"message": "Test error message", "type": "server_error", "code": "internal_error"}\n'
yield b"\n"

# Create a mock request to attach to the response
mock_request = httpx.Request("POST", "https://api.openai.com/v1/test")

if sync:
response = httpx.Response(200, content=body(), request=mock_request)
stream = Stream(
cast_to=object,
client=client,
response=response,
)
with pytest.raises(APIError) as exc_info:
for _ in stream:
pass
assert exc_info.value.message == "Test error message"
else:
response = httpx.Response(200, content=to_aiter(body()), request=mock_request)
stream = AsyncStream(
cast_to=object,
client=async_client,
response=response,
)
with pytest.raises(APIError) as exc_info:
async for _ in stream:
pass
assert exc_info.value.message == "Test error message"


@pytest.mark.asyncio
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
async def test_error_event_with_missing_message(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
"""Test that error events without a message field use a default message."""

def body() -> Iterator[bytes]:
yield b"event: error\n"
yield b'data: {"type": "server_error"}\n'
yield b"\n"

# Create a mock request to attach to the response
mock_request = httpx.Request("POST", "https://api.openai.com/v1/test")

if sync:
response = httpx.Response(200, content=body(), request=mock_request)
stream = Stream(
cast_to=object,
client=client,
response=response,
)
with pytest.raises(APIError) as exc_info:
for _ in stream:
pass
assert exc_info.value.message == "An error occurred during streaming"
else:
response = httpx.Response(200, content=to_aiter(body()), request=mock_request)
stream = AsyncStream(
cast_to=object,
client=async_client,
response=response,
)
with pytest.raises(APIError) as exc_info:
async for _ in stream:
pass
assert exc_info.value.message == "An error occurred during streaming"


async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]:
for chunk in iter:
yield chunk
Expand Down