diff --git a/sdk/voice/speechmatics/voice/_client.py b/sdk/voice/speechmatics/voice/_client.py index 2e1ee43..059022d 100644 --- a/sdk/voice/speechmatics/voice/_client.py +++ b/sdk/voice/speechmatics/voice/_client.py @@ -335,6 +335,7 @@ def __init__( self._session_speakers: dict[str, SessionSpeaker] = {} self._is_speaking: bool = False self._current_speaker: Optional[str] = None + self._last_valid_partial_word_count: int = 0 self._dz_enabled: bool = self._config.enable_diarization self._dz_config = self._config.speaker_config self._last_speak_start_time: Optional[float] = None @@ -454,7 +455,7 @@ def _prepare_config( ) # Punctuation overrides - if config.punctuation_overrides: + if config.punctuation_overrides is not None: transcription_config.punctuation_overrides = config.punctuation_overrides # Configure the audio @@ -1122,8 +1123,7 @@ async def _add_speech_fragments(self, message: dict[str, Any], is_final: bool = self._last_fragment_end_time = max(self._last_fragment_end_time, fragment.end_time) # Evaluate for VAD (only done on partials) - if not is_final: - await self._vad_evaluation(fragments) + await self._vad_evaluation(fragments, is_final=is_final) # Fragments to retain retained_fragments = [ @@ -1698,52 +1698,71 @@ async def _await_forced_eou(self, timeout: float = 1.0) -> None: # VAD (VOICE ACTIVITY DETECTION) / SPEAKER DETECTION # ============================================================================ - async def _vad_evaluation(self, fragments: list[SpeechFragment]) -> None: + async def _vad_evaluation(self, fragments: list[SpeechFragment], is_final: bool) -> None: """Emit a VAD event. This will emit `SPEAKER_STARTED` and `SPEAKER_ENDED` events to the client and is based on valid transcription for active speakers. Ignored or speakers not in focus will not be considered an active participant. - This should only run on partial / non-final words. - Args: fragments: The list of fragments to use for evaluation. + is_final: Whether the fragments are final. """ - # Find the valid list of partial words + # Filter fragments for valid speakers, if required if self._dz_enabled and self._dz_config.focus_speakers: - new_partials = [ - frag - for frag in fragments - if frag.speaker in self._dz_config.focus_speakers and frag.type_ == "word" and not frag.is_final - ] - else: - new_partials = [frag for frag in fragments if frag.type_ == "word" and not frag.is_final] + fragments = [f for f in fragments if f.speaker in self._dz_config.focus_speakers] + + # Find partial and final words + words = [f for f in fragments if f.type_ == "word"] + + # Check if we have any new words + has_words = len(words) > 0 + + # Handle finals + if is_final: + """Check for finals without partials. + + When a forced end of utterance is used, the transcription may skip partials + and go straight to finals. In this case, we need to check if we had any partials + last time and if not, we need to assume we have a new speaker. + """ + + # Check if transcript went straight to finals (typical with forced end of utterance) + if not self._is_speaking and has_words and self._last_valid_partial_word_count == 0: + # Track the current speaker + self._current_speaker = words[0].speaker + self._is_speaking = True + + # Emit speaker started event + await self._handle_speaker_started(self._current_speaker, words[0].start_time) + + # No further processing needed + return - # Check if we have new partials - has_valid_partial = len(new_partials) > 0 + # Track partial count + self._last_valid_partial_word_count = len(words) # Current states current_is_speaking = self._is_speaking current_speaker = self._current_speaker # Establish the speaker from latest partials - latest_speaker = new_partials[-1].speaker if has_valid_partial else current_speaker + latest_speaker = words[-1].speaker if has_words else current_speaker # Determine if the speaker has changed (and we have a speaker) speaker_changed = latest_speaker != current_speaker and current_speaker is not None # Start / end times (earliest and latest) - speaker_start_time = new_partials[0].start_time if has_valid_partial else None + speaker_start_time = words[0].start_time if has_words else None speaker_end_time = self._last_fragment_end_time # If diarization is enabled, indicate speaker switching if self._dz_enabled and latest_speaker is not None: """When enabled, we send a speech events if the speaker has changed. - This - will emit a SPEAKER_ENDED for the previous speaker and a SPEAKER_STARTED + This will emit a SPEAKER_ENDED for the previous speaker and a SPEAKER_STARTED for the new speaker. For any client that wishes to show _which_ speaker is speaking, this will @@ -1774,7 +1793,7 @@ async def _vad_evaluation(self, fragments: list[SpeechFragment]) -> None: self._current_speaker = latest_speaker # No further processing if we have no new fragments and we are not speaking - if has_valid_partial == current_is_speaking: + if has_words == current_is_speaking: return # Update speaking state diff --git a/tests/voice/assets/audio_07a_16kHz.wav b/tests/voice/assets/audio_07a_16kHz.wav new file mode 100644 index 0000000..0f03adc Binary files /dev/null and b/tests/voice/assets/audio_07a_16kHz.wav differ diff --git a/tests/voice/assets/audio_07b_16kHz.wav b/tests/voice/assets/audio_07b_16kHz.wav new file mode 100644 index 0000000..94e1c67 Binary files /dev/null and b/tests/voice/assets/audio_07b_16kHz.wav differ diff --git a/tests/voice/test_17_eou_feou.py b/tests/voice/test_17_eou_feou.py new file mode 100644 index 0000000..f0fe137 --- /dev/null +++ b/tests/voice/test_17_eou_feou.py @@ -0,0 +1,168 @@ +import datetime +import json +import os + +import pytest +from _utils import get_client +from _utils import send_audio_file +from pydantic import Field + +from speechmatics.voice import AdditionalVocabEntry +from speechmatics.voice import AgentServerMessageType +from speechmatics.voice._models import BaseModel +from speechmatics.voice._presets import VoiceAgentConfigPreset + +# Skip for CI testing +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping smart turn tests in CI") + + +# Constants +API_KEY = os.getenv("SPEECHMATICS_API_KEY") +SHOW_LOG = os.getenv("SPEECHMATICS_SHOW_LOG", "0").lower() in ["1", "true"] + + +class TranscriptionSpeaker(BaseModel): + text: str + speaker_id: int = "S1" + + +class TranscriptionTest(BaseModel): + id: str + path: str + sample_rate: int + language: str + segments: list[TranscriptionSpeaker] + additional_vocab: list[AdditionalVocabEntry] = Field(default_factory=list) + + +class TranscriptionTests(BaseModel): + samples: list[TranscriptionTest] + + +SAMPLES: TranscriptionTests = TranscriptionTests.from_dict( + { + "samples": [ + { + "id": "07", + "path": "./assets/audio_07b_16kHz.wav", + "sample_rate": 16000, + "language": "en", + "segments": [ + {"text": "Hello."}, + {"text": "So tomorrow."}, + {"text": "Wednesday."}, + {"text": "Of course. That's fine."}, + {"text": "Because."}, + {"text": "In front."}, + {"text": "Do you think so?"}, + {"text": "Brilliant."}, + {"text": "Banana."}, + {"text": "When?"}, + {"text": "Today."}, + {"text": "This morning."}, + {"text": "Goodbye."}, + ], + }, + ] + } +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sample", SAMPLES.samples, ids=lambda s: f"{s.id}:{s.path}") +async def test_prediction(sample: TranscriptionTest): + """Test transcription and prediction""" + + # API key + api_key = os.getenv("SPEECHMATICS_API_KEY") + if not api_key: + pytest.skip("Valid API key required for test") + + # Start time + start_time = datetime.datetime.now() + + # Results + eot_count: int = 0 + segment_transcribed: list[str] = [] + + # Client + client = await get_client( + api_key=api_key, + connect=False, + config=VoiceAgentConfigPreset.ADAPTIVE(), + ) + + # SOT detected + def sot_detected(message): + nonlocal eot_count + eot_count += 1 + print("✅ START_OF_TURN: {turn_id}".format(**message)) + + # Finalized segment + def add_segments(message): + segments = message["segments"] + for s in segments: + segment_transcribed.append(s["text"]) + print('🚀 ADD_SEGMENT: {speaker_id} @ "{text}"'.format(**s)) + + # EOT detected + def eot_detected(message): + nonlocal eot_count + eot_count += 1 + print("🏁 END_OF_TURN: {turn_id}\n".format(**message)) + + # Callback for each message + def log_message(message): + ts = (datetime.datetime.now() - start_time).total_seconds() + log = json.dumps({"ts": round(ts, 3), "payload": message}) + if SHOW_LOG: + print(log) + + # # Add listeners + # for message_type in AgentServerMessageType: + # if message_type not in [AgentServerMessageType.AUDIO_ADDED]: + # client.on(message_type, log_message) + + # Custom listeners + client.on(AgentServerMessageType.START_OF_TURN, sot_detected) + client.on(AgentServerMessageType.END_OF_TURN, eot_detected) + client.on(AgentServerMessageType.ADD_SEGMENT, add_segments) + + # HEADER + if SHOW_LOG: + print() + print() + print("---") + + # Connect + try: + await client.connect() + except Exception: + pytest.skip("Failed to connect to server") + + # Check we are connected + assert client._is_connected + + # Individual payloads + await send_audio_file(client, sample.path) + + # FOOTER + if SHOW_LOG: + print("---") + print() + print() + + # Close session + await client.disconnect() + assert not client._is_connected + + # Debug count + print(f"EOT count: {eot_count}") + print(f"Segment transcribed: {len(segment_transcribed)}") + + # Check the length of the results + assert len(segment_transcribed) == len(sample.segments) + + # Validate (if we have expected results) + for idx, result in enumerate(segment_transcribed): + assert result.lower() == sample.segments[idx].text.lower()