Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions pyht/async_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations

import logging
from typing import Any, AsyncGenerator, AsyncIterable, AsyncIterator, Coroutine

import asyncio
Expand All @@ -15,7 +17,7 @@
from grpc import ssl_channel_credentials, StatusCode


from .client import TTSOptions
from .client import TTSOptions, CongestionCtrl
from .lease import Lease, LeaseFactory
from .protos import api_pb2, api_pb2_grpc
from .utils import ensure_sentence_end, normalize, split_text, SENTENCE_END_REGEX
Expand Down Expand Up @@ -51,6 +53,7 @@ class AdvancedOptions:
fallback_enabled: bool = False
auto_refresh_lease: bool = True
disable_lease_disk_cache: bool = False
congestion_ctrl: CongestionCtrl = CongestionCtrl.OFF

def __init__(
self,
Expand Down Expand Up @@ -207,26 +210,48 @@ async def tts(
text = ensure_sentence_end(text)

request = api_pb2.TtsRequest(params=options.tts_params(text, voice_engine), lease=lease_data)
try:
stub = api_pb2_grpc.TtsStub(self._rpc[1])
stream: TtsUnaryStream = stub.Tts(request)
if context is not None:
context.assign(stream)
async for response in stream:
yield response.data
except grpc.RpcError as e:
error_code = getattr(e, "code")()
if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE} or self._fallback_rpc is None:
raise

max_attempts = 1
backoff = 0
if self._advanced.congestion_ctrl == CongestionCtrl.STATIC_MAR_2023:
max_attempts = 3
backoff = 0.05

for attempt in range(1, max_attempts + 1):
try:
stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1])
stub = api_pb2_grpc.TtsStub(self._rpc[1])
stream: TtsUnaryStream = stub.Tts(request)
if context is not None:
context.assign(stream)
async for response in stream:
yield response.data
except grpc.RpcError as fallback_e:
raise fallback_e from e
break
except grpc.RpcError as e:
error_code = getattr(e, "code")()
logging.debug(f"Error: {error_code}")
if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE}:
raise

if attempt < max_attempts:
logging.debug(f"Retrying in {backoff * 1000} sec ({attempt} attempts so far)... ({error_code})")
if backoff > 0:
await asyncio.sleep(backoff)
continue

if self._fallback_rpc is None:
raise

logging.info(f"Falling back to {self._fallback_rpc[0]} because {self._rpc[0]} threw: {error_code}")
try:
stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1])
stream: TtsUnaryStream = stub.Tts(request)
if context is not None:
context.assign(stream)
async for response in stream:
yield response.data
break
except grpc.RpcError as fallback_e:
raise fallback_e from e

def get_stream_pair(
self,
Expand Down
70 changes: 58 additions & 12 deletions pyht/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import logging
import time
from enum import Enum
from typing import Generator, Iterable, Iterator, List, Tuple

from dataclasses import dataclass
Expand Down Expand Up @@ -59,6 +62,23 @@ def tts_params(self, text: list[str], voice_engine: str | None) -> api_pb2.TtsPa
return params


class CongestionCtrl(Enum):
"""
Enumerates a streaming congestion control algorithms, used to optimize the rate at which text is sent to PlayHT.
"""

# The client will not do any congestion control.
OFF = 0

# The client will retry requests to the primary address up to two times with a 50ms backoff between attempts.
#
# Then it will fall back to the fallback address (if one is configured). No retry attempts will be made
# against the fallback address.
#
# If you're using PlayHT On-Prem, you should probably be using this congestion control algorithm.
STATIC_MAR_2023 = 1


class Client:
LEASE_DATA: bytes | None = None
LEASE_CACHE_PATH: str = os.path.join(tempfile.gettempdir(), 'playht.temporary.lease')
Expand All @@ -72,6 +92,7 @@ class AdvancedOptions:
fallback_enabled: bool = False
auto_refresh_lease: bool = True
disable_lease_disk_cache: bool = False
congestion_ctrl: CongestionCtrl = CongestionCtrl.OFF

def __init__(
self,
Expand Down Expand Up @@ -227,22 +248,47 @@ def tts(
text = ensure_sentence_end(text)

request = api_pb2.TtsRequest(params=options.tts_params(text, voice_engine), lease=lease_data)
try:
stub = api_pb2_grpc.TtsStub(self._rpc[1])
response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse]
for item in response:
yield item.data
except grpc.RpcError as e:
error_code = getattr(e, "code")()
if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE} or self._fallback_rpc is None:
raise

max_attempts = 1
backoff = 0
if self._advanced.congestion_ctrl == CongestionCtrl.STATIC_MAR_2023:
max_attempts = 3
backoff = 0.05

for attempt in range(1, max_attempts + 1):
try:
stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1])
stub = api_pb2_grpc.TtsStub(self._rpc[1])
response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse]
for item in response:
yield item.data
except grpc.RpcError as fallback_e:
raise fallback_e from e
break
except grpc.RpcError as e:
error_code = getattr(e, "code")()
logging.debug(f"Error: {error_code}")
if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE}:
raise

if attempt < max_attempts:
# It's a poor customer experience to show internal details about retries, so we only debug log here.
logging.debug(f"Retrying in {backoff * 1000} ms ({attempt} attempts so far)... ({error_code})")
if backoff > 0:
time.sleep(backoff)
continue

if self._fallback_rpc is None:
raise

# We log fallbacks to give customers an extra signal that they should scale up their on-prem appliance
# (e.g. by paying for more GPU quota)
logging.info(f"Falling back to {self._fallback_rpc[0]} because {self._rpc[0]} threw: {error_code}")
try:
stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1])
response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse]
for item in response:
yield item.data
break
except grpc.RpcError as fallback_e:
raise fallback_e from e

def get_stream_pair(
self,
Expand Down