Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions src/agents/voice/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ async def _add_error(self, error: Exception):
def _transform_audio_buffer(
self, buffer: list[bytes], output_dtype: npt.DTypeLike
) -> npt.NDArray[np.int16 | np.float32]:
np_array = np.frombuffer(b"".join(buffer), dtype=np.int16)
combined_buffer = b"".join(buffer)
if len(combined_buffer) % 2 != 0:
# np.int16 needs 2-byte alignment; pad odd-length chunks safely.
combined_buffer += b"\x00"

np_array = np.frombuffer(combined_buffer, dtype=np.int16)

if output_dtype == np.int16:
return np_array
Expand Down Expand Up @@ -118,6 +123,7 @@ async def _stream_audio(
first_byte_received = False
buffer: list[bytes] = []
full_audio_data: list[bytes] = []
pending_byte = b""

async for chunk in self.tts_model.run(text, self.tts_settings):
if not first_byte_received:
Expand All @@ -128,15 +134,33 @@ async def _stream_audio(
buffer.append(chunk)
full_audio_data.append(chunk)
if len(buffer) >= self._buffer_size:
audio_np = self._transform_audio_buffer(buffer, self.tts_settings.dtype)
if self.tts_settings.transform_data:
audio_np = self.tts_settings.transform_data(audio_np)
await local_queue.put(
VoiceStreamEventAudio(data=audio_np)
) # Use local queue
combined = pending_byte + b"".join(buffer)
if len(combined) % 2 != 0:
pending_byte = combined[-1:]
combined = combined[:-1]
else:
pending_byte = b""

if combined:
audio_np = self._transform_audio_buffer(
[combined], self.tts_settings.dtype
)
if self.tts_settings.transform_data:
audio_np = self.tts_settings.transform_data(audio_np)
await local_queue.put(
VoiceStreamEventAudio(data=audio_np)
) # Use local queue
buffer = []
if buffer:
audio_np = self._transform_audio_buffer(buffer, self.tts_settings.dtype)
combined = pending_byte + b"".join(buffer)
else:
combined = pending_byte

if combined:
# Final flush: pad the remaining half sample if needed.
if len(combined) % 2 != 0:
combined += b"\x00"
audio_np = self._transform_audio_buffer([combined], self.tts_settings.dtype)
if self.tts_settings.transform_data:
audio_np = self.tts_settings.transform_data(audio_np)
await local_queue.put(VoiceStreamEventAudio(data=audio_np)) # Use local queue
Expand Down
68 changes: 67 additions & 1 deletion tests/voice/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,84 @@
from __future__ import annotations

import asyncio

import numpy as np
import numpy.typing as npt
import pytest

try:
from agents.voice import AudioInput, TTSModelSettings, VoicePipeline, VoicePipelineConfig
from agents.voice import (
AudioInput,
StreamedAudioResult,
TTSModelSettings,
VoicePipeline,
VoicePipelineConfig,
VoiceStreamEvent,
VoiceStreamEventAudio,
VoiceStreamEventLifecycle,
)

from .fake_models import FakeStreamedAudioInput, FakeSTT, FakeTTS, FakeWorkflow
from .helpers import extract_events
except ImportError:
pass


def test_streamed_audio_result_odd_length_buffer_int16() -> None:
result = StreamedAudioResult(
FakeTTS(),
TTSModelSettings(dtype=np.int16),
VoicePipelineConfig(),
)

transformed = result._transform_audio_buffer([b"\x01"], np.int16)

assert transformed.dtype == np.int16
assert transformed.tolist() == [1]


def test_streamed_audio_result_odd_length_buffer_float32() -> None:
result = StreamedAudioResult(
FakeTTS(),
TTSModelSettings(dtype=np.float32),
VoicePipelineConfig(),
)

transformed = result._transform_audio_buffer([b"\x01"], np.float32)

assert transformed.dtype == np.float32
assert transformed.shape == (1, 1)
assert transformed[0, 0] == pytest.approx(1 / 32767.0)


@pytest.mark.asyncio
async def test_streamed_audio_result_preserves_cross_chunk_sample_boundaries() -> None:
class SplitSampleTTS(FakeTTS):
async def run(self, text: str, settings: TTSModelSettings):
del text, settings
yield b"\x01"
yield b"\x00"

result = StreamedAudioResult(
SplitSampleTTS(),
TTSModelSettings(buffer_size=1, dtype=np.int16),
VoicePipelineConfig(),
)
local_queue: asyncio.Queue[VoiceStreamEvent | None] = asyncio.Queue()

await result._stream_audio("hello", local_queue, finish_turn=True)

audio_chunks: list[bytes] = []
while True:
event = await local_queue.get()
if isinstance(event, VoiceStreamEventAudio) and event.data is not None:
audio_chunks.append(event.data.tobytes())
if isinstance(event, VoiceStreamEventLifecycle) and event.event == "turn_ended":
break

assert audio_chunks == [np.array([1], dtype=np.int16).tobytes()]


@pytest.mark.asyncio
async def test_voicepipeline_run_single_turn() -> None:
# Single turn. Should produce a single audio output, which is the TTS output for "out_1".
Expand Down