diff --git a/src/agents/voice/pipeline.py b/src/agents/voice/pipeline.py index cd46806964..ac641471ff 100644 --- a/src/agents/voice/pipeline.py +++ b/src/agents/voice/pipeline.py @@ -84,24 +84,20 @@ async def _process_audio_input(self, audio_input: AudioInput) -> str: ) async def _run_single_turn(self, audio_input: AudioInput) -> StreamedAudioResult: - # Since this is single turn, we can use the TraceCtxManager to manage starting/ending the - # trace - with TraceCtxManager( - workflow_name=self.config.workflow_name or "Voice Agent", - trace_id=None, # Automatically generated - group_id=self.config.group_id, - metadata=self.config.trace_metadata, - tracing=self.config.tracing, - disabled=self.config.tracing_disabled, - ): - input_text = await self._process_audio_input(audio_input) - - output = StreamedAudioResult( - self._get_tts_model(), self.config.tts_settings, self.config - ) - - async def stream_events(): + output = StreamedAudioResult(self._get_tts_model(), self.config.tts_settings, self.config) + + async def stream_events(): + # Keep the trace scope active for the entire async processing lifecycle. + with TraceCtxManager( + workflow_name=self.config.workflow_name or "Voice Agent", + trace_id=None, # Automatically generated + group_id=self.config.group_id, + metadata=self.config.trace_metadata, + tracing=self.config.tracing, + disabled=self.config.tracing_disabled, + ): try: + input_text = await self._process_audio_input(audio_input) async for text_event in self.workflow.run(input_text): await output._add_text(text_event) await output._turn_done() @@ -111,37 +107,37 @@ async def stream_events(): await output._add_error(e) raise e - output._set_task(asyncio.create_task(stream_events())) - return output + output._set_task(asyncio.create_task(stream_events())) + return output async def _run_multi_turn(self, audio_input: StreamedAudioInput) -> StreamedAudioResult: - with TraceCtxManager( - workflow_name=self.config.workflow_name or "Voice Agent", - trace_id=None, - group_id=self.config.group_id, - metadata=self.config.trace_metadata, - tracing=self.config.tracing, - disabled=self.config.tracing_disabled, - ): - output = StreamedAudioResult( - self._get_tts_model(), self.config.tts_settings, self.config - ) - - try: - async for intro_text in self.workflow.on_start(): - await output._add_text(intro_text) - except Exception as e: - logger.warning(f"on_start() failed: {e}") - - transcription_session = await self._get_stt_model().create_session( - audio_input, - self.config.stt_settings, - self.config.trace_include_sensitive_data, - self.config.trace_include_sensitive_audio_data, - ) - - async def process_turns(): + output = StreamedAudioResult(self._get_tts_model(), self.config.tts_settings, self.config) + + async def process_turns(): + # Keep the trace scope active for the full streamed session. + with TraceCtxManager( + workflow_name=self.config.workflow_name or "Voice Agent", + trace_id=None, + group_id=self.config.group_id, + metadata=self.config.trace_metadata, + tracing=self.config.tracing, + disabled=self.config.tracing_disabled, + ): + transcription_session = None try: + try: + async for intro_text in self.workflow.on_start(): + await output._add_text(intro_text) + except Exception as e: + logger.warning(f"on_start() failed: {e}") + + transcription_session = await self._get_stt_model().create_session( + audio_input, + self.config.stt_settings, + self.config.trace_include_sensitive_data, + self.config.trace_include_sensitive_audio_data, + ) + async for input_text in transcription_session.transcribe_turns(): result = self.workflow.run(input_text) async for text_event in result: @@ -152,8 +148,9 @@ async def process_turns(): await output._add_error(e) raise e finally: - await transcription_session.close() + if transcription_session is not None: + await transcription_session.close() await output._done() - output._set_task(asyncio.create_task(process_turns())) - return output + output._set_task(asyncio.create_task(process_turns())) + return output diff --git a/src/agents/voice/result.py b/src/agents/voice/result.py index fea79902ea..aaa2ba3bd6 100644 --- a/src/agents/voice/result.py +++ b/src/agents/voice/result.py @@ -265,6 +265,7 @@ def _check_errors(self): async def stream(self) -> AsyncIterator[VoiceStreamEvent]: """Stream the events and audio data as they're generated.""" + saw_session_end = False while True: try: event = await self._queue.get() @@ -278,8 +279,18 @@ async def stream(self) -> AsyncIterator[VoiceStreamEvent]: break yield event if event.type == "voice_stream_event_lifecycle" and event.event == "session_ended": + saw_session_end = True break + # On the normal completion path, let the producer task finish gracefully so any active + # trace context can emit `trace_end` before we run cleanup. + if ( + saw_session_end + and self.text_generation_task is not None + and not self.text_generation_task.done() + ): + await asyncio.shield(self.text_generation_task) + self._check_errors() self._cleanup_tasks() diff --git a/tests/voice/test_pipeline.py b/tests/voice/test_pipeline.py index 5190446879..76142c0868 100644 --- a/tests/voice/test_pipeline.py +++ b/tests/voice/test_pipeline.py @@ -1,9 +1,13 @@ from __future__ import annotations +import asyncio + import numpy as np import numpy.typing as npt import pytest +from tests.testing_processor import fetch_events + try: from agents.voice import AudioInput, TTSModelSettings, VoicePipeline, VoicePipelineConfig @@ -177,3 +181,71 @@ def _transform_data( "session_ended", ] await fake_tts.verify_audio("out_1", audio_chunks[0], dtype=np.int16) + + +class _BlockingWorkflow(FakeWorkflow): + def __init__(self, gate: asyncio.Event): + super().__init__() + self._gate = gate + + async def run(self, _: str): + await self._gate.wait() + yield "out_1" + + +class _OnStartYieldThenFailWorkflow(FakeWorkflow): + async def on_start(self): + yield "intro" + raise RuntimeError("boom") + + +@pytest.mark.asyncio +async def test_voicepipeline_trace_not_finished_before_single_turn_completes() -> None: + fake_stt = FakeSTT(["first"]) + fake_tts = FakeTTS() + gate = asyncio.Event() + workflow = _BlockingWorkflow(gate) + config = VoicePipelineConfig(tts_settings=TTSModelSettings(buffer_size=1)) + pipeline = VoicePipeline( + workflow=workflow, stt_model=fake_stt, tts_model=fake_tts, config=config + ) + + audio_input = AudioInput(buffer=np.zeros(2, dtype=np.int16)) + result = await pipeline.run(audio_input) + await asyncio.sleep(0) + + events_before_unblock = fetch_events() + assert "trace_start" in events_before_unblock + assert "trace_end" not in events_before_unblock + + gate.set() + await extract_events(result) + assert fetch_events()[-1] == "trace_end" + + +@pytest.mark.asyncio +async def test_voicepipeline_trace_finishes_after_multi_turn_processing() -> None: + fake_stt = FakeSTT(["first", "second"]) + workflow = FakeWorkflow([["out_1"], ["out_2"]]) + fake_tts = FakeTTS() + pipeline = VoicePipeline(workflow=workflow, stt_model=fake_stt, tts_model=fake_tts) + + streamed_audio_input = await FakeStreamedAudioInput.get(count=2) + result = await pipeline.run(streamed_audio_input) + await extract_events(result) + assert fetch_events()[-1] == "trace_end" + + +@pytest.mark.asyncio +async def test_voicepipeline_multi_turn_on_start_exception_does_not_abort() -> None: + fake_stt = FakeSTT(["first"]) + workflow = _OnStartYieldThenFailWorkflow([["out_1"]]) + fake_tts = FakeTTS() + pipeline = VoicePipeline(workflow=workflow, stt_model=fake_stt, tts_model=fake_tts) + + streamed_audio_input = await FakeStreamedAudioInput.get(count=1) + result = await pipeline.run(streamed_audio_input) + events, _ = await extract_events(result) + + assert events[-1] == "session_ended" + assert "error" not in events