Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 40 additions & 21 deletions sdk/voice/speechmatics/voice/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def __init__(
self._session_speakers: dict[str, SessionSpeaker] = {}
self._is_speaking: bool = False
self._current_speaker: Optional[str] = None
self._last_valid_partial_word_count: int = 0
self._dz_enabled: bool = self._config.enable_diarization
self._dz_config = self._config.speaker_config
self._last_speak_start_time: Optional[float] = None
Expand Down Expand Up @@ -454,7 +455,7 @@ def _prepare_config(
)

# Punctuation overrides
if config.punctuation_overrides:
if config.punctuation_overrides is not None:
transcription_config.punctuation_overrides = config.punctuation_overrides

# Configure the audio
Expand Down Expand Up @@ -1122,8 +1123,7 @@ async def _add_speech_fragments(self, message: dict[str, Any], is_final: bool =
self._last_fragment_end_time = max(self._last_fragment_end_time, fragment.end_time)

# Evaluate for VAD (only done on partials)
if not is_final:
await self._vad_evaluation(fragments)
await self._vad_evaluation(fragments, is_final=is_final)

# Fragments to retain
retained_fragments = [
Expand Down Expand Up @@ -1698,52 +1698,71 @@ async def _await_forced_eou(self, timeout: float = 1.0) -> None:
# VAD (VOICE ACTIVITY DETECTION) / SPEAKER DETECTION
# ============================================================================

async def _vad_evaluation(self, fragments: list[SpeechFragment]) -> None:
async def _vad_evaluation(self, fragments: list[SpeechFragment], is_final: bool) -> None:
"""Emit a VAD event.

This will emit `SPEAKER_STARTED` and `SPEAKER_ENDED` events to the client and is
based on valid transcription for active speakers. Ignored or speakers not in
focus will not be considered an active participant.

This should only run on partial / non-final words.

Args:
fragments: The list of fragments to use for evaluation.
is_final: Whether the fragments are final.
"""

# Find the valid list of partial words
# Filter fragments for valid speakers, if required
if self._dz_enabled and self._dz_config.focus_speakers:
new_partials = [
frag
for frag in fragments
if frag.speaker in self._dz_config.focus_speakers and frag.type_ == "word" and not frag.is_final
]
else:
new_partials = [frag for frag in fragments if frag.type_ == "word" and not frag.is_final]
fragments = [f for f in fragments if f.speaker in self._dz_config.focus_speakers]

# Find partial and final words
words = [f for f in fragments if f.type_ == "word"]

# Check if we have any new words
has_words = len(words) > 0

# Handle finals
if is_final:
"""Check for finals without partials.

When a forced end of utterance is used, the transcription may skip partials
and go straight to finals. In this case, we need to check if we had any partials
last time and if not, we need to assume we have a new speaker.
"""

# Check if transcript went straight to finals (typical with forced end of utterance)
if not self._is_speaking and has_words and self._last_valid_partial_word_count == 0:
# Track the current speaker
self._current_speaker = words[0].speaker
self._is_speaking = True

# Emit speaker started event
await self._handle_speaker_started(self._current_speaker, words[0].start_time)

# No further processing needed
return

# Check if we have new partials
has_valid_partial = len(new_partials) > 0
# Track partial count
self._last_valid_partial_word_count = len(words)

# Current states
current_is_speaking = self._is_speaking
current_speaker = self._current_speaker

# Establish the speaker from latest partials
latest_speaker = new_partials[-1].speaker if has_valid_partial else current_speaker
latest_speaker = words[-1].speaker if has_words else current_speaker

# Determine if the speaker has changed (and we have a speaker)
speaker_changed = latest_speaker != current_speaker and current_speaker is not None

# Start / end times (earliest and latest)
speaker_start_time = new_partials[0].start_time if has_valid_partial else None
speaker_start_time = words[0].start_time if has_words else None
speaker_end_time = self._last_fragment_end_time

# If diarization is enabled, indicate speaker switching
if self._dz_enabled and latest_speaker is not None:
"""When enabled, we send a speech events if the speaker has changed.

This
will emit a SPEAKER_ENDED for the previous speaker and a SPEAKER_STARTED
This will emit a SPEAKER_ENDED for the previous speaker and a SPEAKER_STARTED
for the new speaker.

For any client that wishes to show _which_ speaker is speaking, this will
Expand Down Expand Up @@ -1774,7 +1793,7 @@ async def _vad_evaluation(self, fragments: list[SpeechFragment]) -> None:
self._current_speaker = latest_speaker

# No further processing if we have no new fragments and we are not speaking
if has_valid_partial == current_is_speaking:
if has_words == current_is_speaking:
return

# Update speaking state
Expand Down
Binary file added tests/voice/assets/audio_07a_16kHz.wav
Binary file not shown.
Binary file added tests/voice/assets/audio_07b_16kHz.wav
Binary file not shown.
168 changes: 168 additions & 0 deletions tests/voice/test_17_eou_feou.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import datetime
import json
import os

import pytest
from _utils import get_client
from _utils import send_audio_file
from pydantic import Field

from speechmatics.voice import AdditionalVocabEntry
from speechmatics.voice import AgentServerMessageType
from speechmatics.voice._models import BaseModel
from speechmatics.voice._presets import VoiceAgentConfigPreset

# Skip for CI testing
pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping smart turn tests in CI")


# Constants
API_KEY = os.getenv("SPEECHMATICS_API_KEY")
SHOW_LOG = os.getenv("SPEECHMATICS_SHOW_LOG", "0").lower() in ["1", "true"]


class TranscriptionSpeaker(BaseModel):
text: str
speaker_id: int = "S1"


class TranscriptionTest(BaseModel):
id: str
path: str
sample_rate: int
language: str
segments: list[TranscriptionSpeaker]
additional_vocab: list[AdditionalVocabEntry] = Field(default_factory=list)


class TranscriptionTests(BaseModel):
samples: list[TranscriptionTest]


SAMPLES: TranscriptionTests = TranscriptionTests.from_dict(
{
"samples": [
{
"id": "07",
"path": "./assets/audio_07b_16kHz.wav",
"sample_rate": 16000,
"language": "en",
"segments": [
{"text": "Hello."},
{"text": "So tomorrow."},
{"text": "Wednesday."},
{"text": "Of course. That's fine."},
{"text": "Because."},
{"text": "In front."},
{"text": "Do you think so?"},
{"text": "Brilliant."},
{"text": "Banana."},
{"text": "When?"},
{"text": "Today."},
{"text": "This morning."},
{"text": "Goodbye."},
],
},
]
}
)


@pytest.mark.asyncio
@pytest.mark.parametrize("sample", SAMPLES.samples, ids=lambda s: f"{s.id}:{s.path}")
async def test_prediction(sample: TranscriptionTest):
"""Test transcription and prediction"""

# API key
api_key = os.getenv("SPEECHMATICS_API_KEY")
if not api_key:
pytest.skip("Valid API key required for test")

# Start time
start_time = datetime.datetime.now()

# Results
eot_count: int = 0
segment_transcribed: list[str] = []

# Client
client = await get_client(
api_key=api_key,
connect=False,
config=VoiceAgentConfigPreset.ADAPTIVE(),
)

# SOT detected
def sot_detected(message):
nonlocal eot_count
eot_count += 1
print("✅ START_OF_TURN: {turn_id}".format(**message))

# Finalized segment
def add_segments(message):
segments = message["segments"]
for s in segments:
segment_transcribed.append(s["text"])
print('🚀 ADD_SEGMENT: {speaker_id} @ "{text}"'.format(**s))

# EOT detected
def eot_detected(message):
nonlocal eot_count
eot_count += 1
print("🏁 END_OF_TURN: {turn_id}\n".format(**message))

# Callback for each message
def log_message(message):
ts = (datetime.datetime.now() - start_time).total_seconds()
log = json.dumps({"ts": round(ts, 3), "payload": message})
if SHOW_LOG:
print(log)

# # Add listeners
# for message_type in AgentServerMessageType:
# if message_type not in [AgentServerMessageType.AUDIO_ADDED]:
# client.on(message_type, log_message)

# Custom listeners
client.on(AgentServerMessageType.START_OF_TURN, sot_detected)
client.on(AgentServerMessageType.END_OF_TURN, eot_detected)
client.on(AgentServerMessageType.ADD_SEGMENT, add_segments)

# HEADER
if SHOW_LOG:
print()
print()
print("---")

# Connect
try:
await client.connect()
except Exception:
pytest.skip("Failed to connect to server")

# Check we are connected
assert client._is_connected

# Individual payloads
await send_audio_file(client, sample.path)

# FOOTER
if SHOW_LOG:
print("---")
print()
print()

# Close session
await client.disconnect()
assert not client._is_connected

# Debug count
print(f"EOT count: {eot_count}")
print(f"Segment transcribed: {len(segment_transcribed)}")

# Check the length of the results
assert len(segment_transcribed) == len(sample.segments)

# Validate (if we have expected results)
for idx, result in enumerate(segment_transcribed):
assert result.lower() == sample.segments[idx].text.lower()