diff --git a/lib/req_llm/stream_server.ex b/lib/req_llm/stream_server.ex index f0f9d1895..703041e02 100644 --- a/lib/req_llm/stream_server.ex +++ b/lib/req_llm/stream_server.ex @@ -92,7 +92,8 @@ defmodule ReqLLM.StreamServer do object_acc: [], fixture_saved?: false, raw_iodata: [], - raw_bytes: 0 + raw_bytes: 0, + terminated?: false ] @doc """ @@ -601,9 +602,11 @@ defmodule ReqLLM.StreamServer do }) # Check if any events signaled completion + terminated? = Enum.any?(events, &termination_event?/1) + new_state = - if Enum.any?(events, &termination_event?/1) do - finalize_stream_with_fixture(new_state) + if terminated? do + finalize_stream_with_fixture(%{new_state | terminated?: true}) else new_state end @@ -689,6 +692,11 @@ defmodule ReqLLM.StreamServer do end defp finalize_stream(state) do + # Flush any remaining SSE buffer content before finalizing. + # The last SSE event may be buffered if the terminating blank line + # arrived in a separate HTTP chunk or was missing entirely. + state = flush_sse_buffer(state) + {flush_chunks, new_provider_state} = if function_exported?(state.provider_mod, :flush_stream_state, 2) do state.provider_mod.flush_stream_state(state.model, state.provider_state) @@ -730,6 +738,35 @@ defmodule ReqLLM.StreamServer do %{state | status: :done, metadata: metadata} end + defp flush_sse_buffer(%{sse_buffer: buffer} = state) when byte_size(buffer) > 0 do + # Force-parse the buffer by appending a terminating blank line. + # This handles the case where the server closed the connection + # without a trailing \n\n after the last SSE event. + {events, _remaining} = parse_protocol_events("\n\n", state) + terminated? = Enum.any?(events, &termination_event?/1) + + if events != [] do + {stream_chunks, new_provider_state} = + events + |> Enum.map(&SSE.process_sse_event/1) + |> Enum.reduce({[], state.provider_state}, fn event, {chunks_acc, prov_state} -> + {new_chunks, updated_prov_state} = + decode_provider_event(event, state.provider_mod, state.model, prov_state) + + {chunks_acc ++ new_chunks, updated_prov_state} + end) + + state + |> Map.put(:provider_state, new_provider_state) + |> Map.put(:terminated?, state.terminated? or terminated?) + |> then(&enqueue_chunks(stream_chunks, &1)) + else + state + end + end + + defp flush_sse_buffer(state), do: state + defp finalize_stream_with_fixture(state) do Debug.dbug( fn -> @@ -798,10 +835,16 @@ defmodule ReqLLM.StreamServer do end defp extract_final_metadata(state) do - # Return accumulated metadata with HTTP status and headers - state.metadata - |> Map.put(:status, state.http_status) - |> Map.put(:headers, state.headers) + meta = + state.metadata + |> Map.put(:status, state.http_status) + |> Map.put(:headers, state.headers) + + if state.terminated? do + Map.put_new(meta, :finish_reason, :stop) + else + Map.put_new(meta, :finish_reason, :incomplete) + end end defp reply_to_waiting_callers(state) do diff --git a/test/req_llm/stream_server/streaming_test.exs b/test/req_llm/stream_server/streaming_test.exs index 7c66278b6..934e1564d 100644 --- a/test/req_llm/stream_server/streaming_test.exs +++ b/test/req_llm/stream_server/streaming_test.exs @@ -5,6 +5,8 @@ defmodule ReqLLM.StreamServer.StreamingTest do Covers: - Backpressure handling - SSE edge cases (large events, incomplete events, multi-line events) + - SSE buffer flushing on stream finalization + - Default finish_reason metadata - Timeout handling Uses mocked HTTP tasks and the shared MockProvider for isolated testing. @@ -119,6 +121,103 @@ defmodule ReqLLM.StreamServer.StreamingTest do end end + describe "SSE buffer flushing on finalize" do + test "flushes buffered event missing trailing blank line on :done" do + server = start_server() + + sse_without_terminator = ~s(data: {"choices": [{"delta": {"content": "buffered"}}]}\n) + StreamServer.http_event(server, {:data, sse_without_terminator}) + StreamServer.http_event(server, :done) + + assert {:ok, chunk} = StreamServer.next(server, 100) + assert chunk.type == :content + assert chunk.text == "buffered" + assert :halt = StreamServer.next(server, 100) + end + + test "flushes buffered event split across chunks without trailing blank line" do + server = start_server() + + StreamServer.http_event(server, {:data, "data: {\"cho"}) + + StreamServer.http_event( + server, + {:data, "ices\": [{\"delta\": {\"content\": \"split\"}}]}\n"} + ) + + StreamServer.http_event(server, :done) + + assert {:ok, chunk} = StreamServer.next(server, 100) + assert chunk.type == :content + assert chunk.text == "split" + assert :halt = StreamServer.next(server, 100) + end + + test "noop when sse_buffer is empty at finalize" do + server = start_server() + + sse_data = ~s(data: {"choices": [{"delta": {"content": "complete"}}]}\n\n) + StreamServer.http_event(server, {:data, sse_data}) + StreamServer.http_event(server, :done) + + assert {:ok, chunk} = StreamServer.next(server, 100) + assert chunk.text == "complete" + assert :halt = StreamServer.next(server, 100) + end + end + + describe "finish_reason metadata" do + test "defaults to :stop when provider sends termination event without finish_reason" do + server = start_server() + + sse_data = ~s(data: {"choices": [{"delta": {"content": "hi"}}]}\n\n) + done_event = "data: [DONE]\n\n" + + StreamServer.http_event(server, {:data, sse_data}) + StreamServer.http_event(server, {:data, done_event}) + StreamServer.http_event(server, :done) + + assert {:ok, metadata} = StreamServer.await_metadata(server, 500) + assert metadata.finish_reason == :stop + end + + test "defaults to :stop when buffered done event is missing trailing blank line" do + server = start_server() + + StreamServer.http_event(server, {:data, "data: [DONE]\n"}) + StreamServer.http_event(server, :done) + + assert {:ok, metadata} = StreamServer.await_metadata(server, 500) + assert metadata.finish_reason == :stop + end + + test "sets finish_reason to :incomplete when stream ends without termination event" do + server = start_server() + + sse_data = ~s(data: {"choices": [{"delta": {"content": "hi"}}]}\n\n) + StreamServer.http_event(server, {:data, sse_data}) + StreamServer.http_event(server, :done) + + assert {:ok, metadata} = StreamServer.await_metadata(server, 500) + assert metadata.finish_reason == :incomplete + end + + test "preserves provider-supplied finish_reason" do + server = start_server() + + sse_data = ~s(data: {"choices": [{"delta": {"content": "hi"}}]}\n\n) + finish_json = Jason.encode!(%{"choices" => [%{"finish_reason" => "tool_use"}]}) + finish_event = "data: #{finish_json}\n\n" + + StreamServer.http_event(server, {:data, sse_data}) + StreamServer.http_event(server, {:data, finish_event}) + StreamServer.http_event(server, :done) + + assert {:ok, metadata} = StreamServer.await_metadata(server, 500) + assert metadata.finish_reason == "tool_use" + end + end + describe "timeout handling" do test "next/2 respects timeout parameter" do server = start_server() diff --git a/test/support/stream_server_helpers.ex b/test/support/stream_server_helpers.ex index 189e58fa0..8ded84bac 100644 --- a/test/support/stream_server_helpers.ex +++ b/test/support/stream_server_helpers.ex @@ -41,6 +41,14 @@ defmodule ReqLLM.Test.StreamServerHelpers do [StreamChunk.meta(%{usage: usage})] end + def decode_stream_event( + %{data: %{"choices" => [%{"finish_reason" => reason}]}}, + _model + ) + when is_binary(reason) do + [StreamChunk.meta(%{finish_reason: reason})] + end + def decode_stream_event(_event, _model), do: [] def prepare_request(_op, _model, _data, _opts), do: {:error, :not_implemented}