Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions sdk/voice/speechmatics/voice/_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ async def put_bytes(self, data: bytes) -> None:
data: The data frame to add to the buffer.
"""

# If the right length and buffer zero
# If data is exactly one frame and there's no buffered remainder,
# put the frame directly into the buffer.
if len(data) // self._sample_width == self._frame_size and len(self._buffer) == 0:
return await self.put_frame(data)

Expand All @@ -109,19 +110,23 @@ async def put_bytes(self, data: bytes) -> None:
await self.put_frame(frame)

async def put_frame(self, data: bytes) -> None:
"""Add data to the buffer.
"""Add data frame to the buffer.

New data added to the end of the buffer. The oldest data is removed
to maintain the total number of seconds in the buffer.
New data frame is added to the end of the buffer. The oldest data is removed
to maintain the total number of seconds in the buffer.`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the "`" for? :)


Args:
data: The data frame to add to the buffer.
"""
# Verify number of bytes matches frame size
if len(data) != self._frame_bytes:
raise ValueError(f"Invalid frame size: {len(data)} bytes, expected {self._frame_bytes} bytes")

# Add data to the buffer
async with self._lock:
self._frames.append(data)
self._total_frames += 1
# Trim to rolling window, keep last _max_frames frames
if len(self._frames) > self._max_frames:
self._frames = self._frames[-self._max_frames :]

Expand Down
4 changes: 2 additions & 2 deletions sdk/voice/speechmatics/voice/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ class AgentServerMessageType(str, Enum):
StartOfTurn: Start of turn has been detected.
EndOfTurnPrediction: End of turn prediction timing.
EndOfTurn: End of turn has been detected.
SmartTurn: Smart turn metadata.
SmartTurnResult: Smart turn metadata.
SpeakersResult: Speakers result has been detected.
Metrics: Metrics for the STT engine.
SessionMetrics: Metrics for the STT engine.
SpeakerMetrics: Metrics relating to speakers.

Examples:
Expand Down
98 changes: 66 additions & 32 deletions sdk/voice/speechmatics/voice/_smart_turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class SmartTurnDetector:
Further information at https://github.com/pipecat-ai/smart-turn
"""

WINDOW_SECONDS = 8
DEFAULT_SAMPLE_RATE = 16000

def __init__(self, auto_init: bool = True, threshold: float = 0.8):
"""Create the new SmartTurnDetector.

Expand Down Expand Up @@ -125,7 +128,7 @@ def setup(self) -> None:
self.session = self.build_session(SMART_TURN_MODEL_LOCAL_PATH)

# Load the feature extractor
self.feature_extractor = WhisperFeatureExtractor(chunk_length=8)
self.feature_extractor = WhisperFeatureExtractor(chunk_length=self.WINDOW_SECONDS)

# Set initialized
self._is_initialized = True
Expand Down Expand Up @@ -156,83 +159,113 @@ def build_session(self, onnx_path: str) -> ort.InferenceSession:
# Return the new session
return ort.InferenceSession(onnx_path, sess_options=so)

async def predict(
self, audio_array: bytes, language: str, sample_rate: int = 16000, sample_width: int = 2
) -> SmartTurnPredictionResult:
"""Predict whether an audio segment is complete (turn ended) or incomplete.
def _prepare_audio(self, audio_array: bytes, sample_rate: int, sample_width: int) -> np.ndarray:
"""Prepare the audio for inference.

Args:
audio_array: Numpy array containing audio samples at 16kHz. The function
will convert the audio into float32 and truncate to 8 seconds (keeping the end)
or pad to 8 seconds.
language: Language of the audio.
sample_rate: Sample rate of the audio.
sample_width: Sample width of the audio.

Returns:
Prediction result containing completion status and probability.
Numpy array containing audio samples at 16kHz.
"""

# Check if initialized
if not self._is_initialized:
return SmartTurnPredictionResult(error="SmartTurnDetector is not initialized")

# Check a valid language
if not self.valid_language(language):
logger.warning(f"Invalid language: {language}. Results may be unreliable.")

# Record start time
start_time = datetime.datetime.now()

# Convert into numpy array
dtype = np.int16 if sample_width == 2 else np.int8
int16_array: np.ndarray = np.frombuffer(audio_array, dtype=dtype).astype(np.int16)

# Truncate to last 8 seconds if needed (keep the tail/end of audio)
max_samples = 8 * sample_rate
# Truncate to last WINDOW_SECONDS seconds if needed (keep the tail/end of audio)
max_samples = self.WINDOW_SECONDS * sample_rate
if len(int16_array) > max_samples:
int16_array = int16_array[-max_samples:]

# Convert int16 to float32 in range [-1, 1] (same as reference implementation)
float32_array: np.ndarray = int16_array.astype(np.float32) / 32768.0

# Process audio using Whisper's feature extractor
return float32_array

def _get_input_features(self, audio_data: np.ndarray, sample_rate: int) -> np.ndarray:
"""
Get the input features for the audio data using Whisper's feature extractor.

Args:
audio_data: Numpy array containing audio samples.
sample_rate: Sample rate of the audio.
"""

inputs = self.feature_extractor(
float32_array,
audio_data,
sampling_rate=sample_rate,
return_tensors="np",
padding="max_length",
max_length=max_samples,
max_length= self.WINDOW_SECONDS * sample_rate,
truncation=True,
do_normalize=True,
)

# Extract features and ensure correct shape for ONNX
# Ensure dimensions are correct shape for ONNX
input_features = inputs.input_features.squeeze(0).astype(np.float32)
input_features = np.expand_dims(input_features, axis=0)

# Run ONNX inference
outputs = self.session.run(None, {"input_features": input_features})
return input_features

async def predict(
self, audio_array: bytes, language: str, sample_rate: int = DEFAULT_SAMPLE_RATE, sample_width: int = 2
) -> SmartTurnPredictionResult:
"""Predict whether an audio segment is complete (turn ended) or incomplete.

Args:
audio_array: Numpy array containing audio samples at 16kHz. The function
will convert the audio into float32 and truncate to 8 seconds (keeping the end)
or pad to 8 seconds.
language: Language of the audio.
sample_rate: Sample rate of the audio.
sample_width: Sample width of the audio.

Returns:
Prediction result containing completion status and probability.
"""

# Extract probability (ONNX model returns sigmoid probabilities)
# Check if initialized
if not self._is_initialized:
return SmartTurnPredictionResult(error="SmartTurnDetector is not initialized")

# Check a valid language
if not self.valid_language(language):
logger.warning(f"Invalid language: {language}. Results may be unreliable.")

# Record start time
start_time = datetime.datetime.now()

# Convert the audio into required format
prepared_audio = self._prepare_audio(audio_array, sample_rate, sample_width)

# Feature extraction
input_features = self._get_input_features(prepared_audio, sample_rate)

# Model inference
outputs = self.session.run(None, {"input_features": input_features})
probability = outputs[0][0].item()

# Make prediction (True for Complete, False for Incomplete)
prediction = probability >= self._threshold

# Record end time
# Result Formatting
end_time = datetime.datetime.now()
duration = float((end_time - start_time).total_seconds())

# Return the result
return SmartTurnPredictionResult(
prediction=prediction,
probability=round(probability, 3),
processing_time=round(float((end_time - start_time).total_seconds()), 3),
processing_time=round(duration, 3),
)

@staticmethod
def truncate_audio_to_last_n_seconds(
audio_array: np.ndarray, n_seconds: float = 8.0, sample_rate: int = 16000
audio_array: np.ndarray, n_seconds: float = 8.0, sample_rate: int = DEFAULT_SAMPLE_RATE
) -> np.ndarray:
"""Truncate audio to last n seconds or pad with zeros to meet n seconds.

Expand Down Expand Up @@ -300,7 +333,8 @@ def model_exists() -> bool:

@staticmethod
def valid_language(language: str) -> bool:
"""Check if the language is valid.
"""Check if the language is valid against list of supported languages
for the Pipecat model.

Args:
language: Language code to validate.
Expand Down
4 changes: 2 additions & 2 deletions sdk/voice/speechmatics/voice/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def segment_list_from_fragments(
speaker_groups.append([])
speaker_groups[-1].append(frag)

# Create SpeakerFragments objects
# Create SpeakerSegment objects
segments: list[SpeakerSegment] = []
for group in speaker_groups:
# Skip if the group is empty
Expand Down Expand Up @@ -143,7 +143,7 @@ def segment_list_from_fragments(
FragmentUtils.update_segment_text(session=session, segment=segment)
segments.append(segment)

# Return the grouped SpeakerFragments objects
# Return the grouped SpeakerSegment objects
return segments

@staticmethod
Expand Down
Loading