Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ def main(

# Get the streams
if use_http:
voice_engine = "Play3.0-mini-http"
voice_engine = "Play3.0-mini"
protocol = "http"
elif use_ws:
voice_engine = "Play3.0-mini-ws"
voice_engine = "Play3.0-mini"
protocol = "ws"
else:
voice_engine = "PlayHT2.0"
in_stream, out_stream = client.get_stream_pair(options, voice_engine=voice_engine)
voice_engine = "PlayHT2.0-turbo"
protocol = "grpc"
in_stream, out_stream = client.get_stream_pair(options, voice_engine=voice_engine, protocol=protocol)

# Start a player thread.
audio_thread = threading.Thread(None, save_audio, args=(out_stream,))
Expand Down Expand Up @@ -115,12 +118,15 @@ async def async_main(

# Get the streams
if use_http:
voice_engine = "Play3.0-mini-http"
voice_engine = "Play3.0-mini"
protocol = "http"
elif use_ws:
voice_engine = "Play3.0-mini-ws"
voice_engine = "Play3.0-mini"
protocol = "ws"
else:
voice_engine = "PlayHT2.0-turbo"
in_stream, out_stream = client.get_stream_pair(options, voice_engine=voice_engine)
protocol = "grpc"
in_stream, out_stream = client.get_stream_pair(options, voice_engine=voice_engine, protocol=protocol)

audio_task = asyncio.create_task(async_save_audio(out_stream))

Expand Down
8 changes: 7 additions & 1 deletion pyht/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class AdvancedOptions:
congestion_ctrl: CongestionCtrl = CongestionCtrl.OFF
metrics_buffer_size: int = 1000
remove_ssml_tags: bool = False
interruptible_ws: bool = False

# gRPC (PlayHT2.0-turbo, Play3.0-mini-grpc)
grpc_addr: Optional[str] = None
Expand Down Expand Up @@ -443,6 +444,11 @@ async def _tts_ws(
start = time.perf_counter()
await self.ensure_inference_coordinates()

if self._advanced.interruptible_ws:
query_params = "&_ws_mode=interrupt"
else:
query_params = ""

text = prepare_text(text, self._advanced.remove_ssml_tags)
assert self._inference_coordinates is not None, "No connection"
metrics.append("text", str(text)).append("endpoint",
Expand All @@ -452,7 +458,7 @@ async def _tts_ws(
for attempt in range(1, self._max_attempts + 1):
try:
assert self._inference_coordinates is not None, "No connection"
ws_address = self._inference_coordinates[voice_engine]["websocket_url"]
ws_address = self._inference_coordinates[voice_engine]["websocket_url"] + query_params
if self._ws is None:
self._ws = await connect(ws_address)
self._ws_requests_sent = 0
Expand Down
10 changes: 8 additions & 2 deletions pyht/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ class AdvancedOptions:
congestion_ctrl: CongestionCtrl = CongestionCtrl.OFF
metrics_buffer_size: int = 1000
remove_ssml_tags: bool = False
interruptible_ws: bool = False

# gRPC (PlayHT2.0-turbo and Play3.0-mini-grpc)
grpc_addr: Optional[str] = None
Expand Down Expand Up @@ -712,6 +713,11 @@ def _tts_ws(
start = time.perf_counter()
self.ensure_inference_coordinates()

if self._advanced.interruptible_ws:
query_params = "&_ws_mode=interrupt"
else:
query_params = ""

text = prepare_text(text, self._advanced.remove_ssml_tags)
assert self._inference_coordinates is not None, "No connection"
metrics.append("text", str(text)).append("endpoint",
Expand All @@ -721,7 +727,7 @@ def _tts_ws(
for attempt in range(1, self._max_attempts + 1):
try:
assert self._inference_coordinates is not None, "No connection"
ws_address = self._inference_coordinates[voice_engine]["websocket_url"]
ws_address = self._inference_coordinates[voice_engine]["websocket_url"] + query_params
if self._ws is None:
self._ws = connect(ws_address)
self._ws_requests_sent = 0
Expand Down Expand Up @@ -749,7 +755,7 @@ def _tts_ws(
request_id = msg["request_id"]
elif self._ws_responses_received > self._ws_requests_sent:
raise Exception("Received more responses than requests")
elif msg["type"] == "end" and msg["request_id"] == request_id:
elif (msg["type"] == "end" or msg["type"] == "interrupt") and msg["request_id"] == request_id:
break
else:
continue
Expand Down