diff --git a/demo/main.py b/demo/main.py index 21b9f8f..576d54f 100644 --- a/demo/main.py +++ b/demo/main.py @@ -43,12 +43,15 @@ def main( # Get the streams if use_http: - voice_engine = "Play3.0-mini-http" + voice_engine = "Play3.0-mini" + protocol = "http" elif use_ws: - voice_engine = "Play3.0-mini-ws" + voice_engine = "Play3.0-mini" + protocol = "ws" else: - voice_engine = "PlayHT2.0" - in_stream, out_stream = client.get_stream_pair(options, voice_engine=voice_engine) + voice_engine = "PlayHT2.0-turbo" + protocol = "grpc" + in_stream, out_stream = client.get_stream_pair(options, voice_engine=voice_engine, protocol=protocol) # Start a player thread. audio_thread = threading.Thread(None, save_audio, args=(out_stream,)) @@ -115,12 +118,15 @@ async def async_main( # Get the streams if use_http: - voice_engine = "Play3.0-mini-http" + voice_engine = "Play3.0-mini" + protocol = "http" elif use_ws: - voice_engine = "Play3.0-mini-ws" + voice_engine = "Play3.0-mini" + protocol = "ws" else: voice_engine = "PlayHT2.0-turbo" - in_stream, out_stream = client.get_stream_pair(options, voice_engine=voice_engine) + protocol = "grpc" + in_stream, out_stream = client.get_stream_pair(options, voice_engine=voice_engine, protocol=protocol) audio_task = asyncio.create_task(async_save_audio(out_stream)) diff --git a/pyht/async_client.py b/pyht/async_client.py index 419f0f5..07cb29a 100644 --- a/pyht/async_client.py +++ b/pyht/async_client.py @@ -57,6 +57,7 @@ class AdvancedOptions: congestion_ctrl: CongestionCtrl = CongestionCtrl.OFF metrics_buffer_size: int = 1000 remove_ssml_tags: bool = False + interruptible_ws: bool = False # gRPC (PlayHT2.0-turbo, Play3.0-mini-grpc) grpc_addr: Optional[str] = None @@ -443,6 +444,11 @@ async def _tts_ws( start = time.perf_counter() await self.ensure_inference_coordinates() + if self._advanced.interruptible_ws: + query_params = "&_ws_mode=interrupt" + else: + query_params = "" + text = prepare_text(text, self._advanced.remove_ssml_tags) assert self._inference_coordinates is not None, "No connection" metrics.append("text", str(text)).append("endpoint", @@ -452,7 +458,7 @@ async def _tts_ws( for attempt in range(1, self._max_attempts + 1): try: assert self._inference_coordinates is not None, "No connection" - ws_address = self._inference_coordinates[voice_engine]["websocket_url"] + ws_address = self._inference_coordinates[voice_engine]["websocket_url"] + query_params if self._ws is None: self._ws = await connect(ws_address) self._ws_requests_sent = 0 diff --git a/pyht/client.py b/pyht/client.py index b5eb08e..d0cfd0b 100644 --- a/pyht/client.py +++ b/pyht/client.py @@ -336,6 +336,7 @@ class AdvancedOptions: congestion_ctrl: CongestionCtrl = CongestionCtrl.OFF metrics_buffer_size: int = 1000 remove_ssml_tags: bool = False + interruptible_ws: bool = False # gRPC (PlayHT2.0-turbo and Play3.0-mini-grpc) grpc_addr: Optional[str] = None @@ -712,6 +713,11 @@ def _tts_ws( start = time.perf_counter() self.ensure_inference_coordinates() + if self._advanced.interruptible_ws: + query_params = "&_ws_mode=interrupt" + else: + query_params = "" + text = prepare_text(text, self._advanced.remove_ssml_tags) assert self._inference_coordinates is not None, "No connection" metrics.append("text", str(text)).append("endpoint", @@ -721,7 +727,7 @@ def _tts_ws( for attempt in range(1, self._max_attempts + 1): try: assert self._inference_coordinates is not None, "No connection" - ws_address = self._inference_coordinates[voice_engine]["websocket_url"] + ws_address = self._inference_coordinates[voice_engine]["websocket_url"] + query_params if self._ws is None: self._ws = connect(ws_address) self._ws_requests_sent = 0 @@ -749,7 +755,7 @@ def _tts_ws( request_id = msg["request_id"] elif self._ws_responses_received > self._ws_requests_sent: raise Exception("Received more responses than requests") - elif msg["type"] == "end" and msg["request_id"] == request_id: + elif (msg["type"] == "end" or msg["type"] == "interrupt") and msg["request_id"] == request_id: break else: continue