diff --git a/pyht/async_client.py b/pyht/async_client.py index 0b1adb1..6374325 100644 --- a/pyht/async_client.py +++ b/pyht/async_client.py @@ -1,4 +1,6 @@ from __future__ import annotations + +import logging from typing import Any, AsyncGenerator, AsyncIterable, AsyncIterator, Coroutine import asyncio @@ -15,7 +17,7 @@ from grpc import ssl_channel_credentials, StatusCode -from .client import TTSOptions +from .client import TTSOptions, CongestionCtrl from .lease import Lease, LeaseFactory from .protos import api_pb2, api_pb2_grpc from .utils import ensure_sentence_end, normalize, split_text, SENTENCE_END_REGEX @@ -51,6 +53,7 @@ class AdvancedOptions: fallback_enabled: bool = False auto_refresh_lease: bool = True disable_lease_disk_cache: bool = False + congestion_ctrl: CongestionCtrl = CongestionCtrl.OFF def __init__( self, @@ -207,26 +210,48 @@ async def tts( text = ensure_sentence_end(text) request = api_pb2.TtsRequest(params=options.tts_params(text, voice_engine), lease=lease_data) - try: - stub = api_pb2_grpc.TtsStub(self._rpc[1]) - stream: TtsUnaryStream = stub.Tts(request) - if context is not None: - context.assign(stream) - async for response in stream: - yield response.data - except grpc.RpcError as e: - error_code = getattr(e, "code")() - if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE} or self._fallback_rpc is None: - raise + + max_attempts = 1 + backoff = 0 + if self._advanced.congestion_ctrl == CongestionCtrl.STATIC_MAR_2023: + max_attempts = 3 + backoff = 0.05 + + for attempt in range(1, max_attempts + 1): try: - stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1]) + stub = api_pb2_grpc.TtsStub(self._rpc[1]) stream: TtsUnaryStream = stub.Tts(request) if context is not None: context.assign(stream) async for response in stream: yield response.data - except grpc.RpcError as fallback_e: - raise fallback_e from e + break + except grpc.RpcError as e: + error_code = getattr(e, "code")() + logging.debug(f"Error: {error_code}") + if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE}: + raise + + if attempt < max_attempts: + logging.debug(f"Retrying in {backoff * 1000} sec ({attempt} attempts so far)... ({error_code})") + if backoff > 0: + await asyncio.sleep(backoff) + continue + + if self._fallback_rpc is None: + raise + + logging.info(f"Falling back to {self._fallback_rpc[0]} because {self._rpc[0]} threw: {error_code}") + try: + stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1]) + stream: TtsUnaryStream = stub.Tts(request) + if context is not None: + context.assign(stream) + async for response in stream: + yield response.data + break + except grpc.RpcError as fallback_e: + raise fallback_e from e def get_stream_pair( self, diff --git a/pyht/client.py b/pyht/client.py index 87d8ac4..ca5a43a 100644 --- a/pyht/client.py +++ b/pyht/client.py @@ -1,5 +1,8 @@ from __future__ import annotations +import logging +import time +from enum import Enum from typing import Generator, Iterable, Iterator, List, Tuple from dataclasses import dataclass @@ -59,6 +62,23 @@ def tts_params(self, text: list[str], voice_engine: str | None) -> api_pb2.TtsPa return params +class CongestionCtrl(Enum): + """ + Enumerates a streaming congestion control algorithms, used to optimize the rate at which text is sent to PlayHT. + """ + + # The client will not do any congestion control. + OFF = 0 + + # The client will retry requests to the primary address up to two times with a 50ms backoff between attempts. + # + # Then it will fall back to the fallback address (if one is configured). No retry attempts will be made + # against the fallback address. + # + # If you're using PlayHT On-Prem, you should probably be using this congestion control algorithm. + STATIC_MAR_2023 = 1 + + class Client: LEASE_DATA: bytes | None = None LEASE_CACHE_PATH: str = os.path.join(tempfile.gettempdir(), 'playht.temporary.lease') @@ -72,6 +92,7 @@ class AdvancedOptions: fallback_enabled: bool = False auto_refresh_lease: bool = True disable_lease_disk_cache: bool = False + congestion_ctrl: CongestionCtrl = CongestionCtrl.OFF def __init__( self, @@ -227,22 +248,47 @@ def tts( text = ensure_sentence_end(text) request = api_pb2.TtsRequest(params=options.tts_params(text, voice_engine), lease=lease_data) - try: - stub = api_pb2_grpc.TtsStub(self._rpc[1]) - response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse] - for item in response: - yield item.data - except grpc.RpcError as e: - error_code = getattr(e, "code")() - if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE} or self._fallback_rpc is None: - raise + + max_attempts = 1 + backoff = 0 + if self._advanced.congestion_ctrl == CongestionCtrl.STATIC_MAR_2023: + max_attempts = 3 + backoff = 0.05 + + for attempt in range(1, max_attempts + 1): try: - stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1]) + stub = api_pb2_grpc.TtsStub(self._rpc[1]) response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse] for item in response: yield item.data - except grpc.RpcError as fallback_e: - raise fallback_e from e + break + except grpc.RpcError as e: + error_code = getattr(e, "code")() + logging.debug(f"Error: {error_code}") + if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE}: + raise + + if attempt < max_attempts: + # It's a poor customer experience to show internal details about retries, so we only debug log here. + logging.debug(f"Retrying in {backoff * 1000} ms ({attempt} attempts so far)... ({error_code})") + if backoff > 0: + time.sleep(backoff) + continue + + if self._fallback_rpc is None: + raise + + # We log fallbacks to give customers an extra signal that they should scale up their on-prem appliance + # (e.g. by paying for more GPU quota) + logging.info(f"Falling back to {self._fallback_rpc[0]} because {self._rpc[0]} threw: {error_code}") + try: + stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1]) + response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse] + for item in response: + yield item.data + break + except grpc.RpcError as fallback_e: + raise fallback_e from e def get_stream_pair( self,