From 46101eca7953615c2e31dee84f18e6cb4beafc04 Mon Sep 17 00:00:00 2001 From: mtenpow <3721+mtenpow@users.noreply.github.com> Date: Thu, 29 Feb 2024 13:16:31 -0800 Subject: [PATCH 1/5] Introduce configurable congestion control. The primary motivation for this (as of 2024/02/28) is increase availability for customers using PlayHT On-Prem appliance by adding quick retries in response to RESOURCE_EXHAUSTED errors. This change allows customers to turn on one of an enumerated set of congestion control algorithms. We've implemented just one for now, STATIC_MAR_2024, which retries at most twice with a 50ms backoff between attempts. This is a dead simple congestion control algorithm with static constants; it leaves a lot to be desired. We should iterate on these algorithms in the future. The CongestionCtrl enum was added so that algorithms can be added without requiring customers to change their code much. --- pyht/async_client.py | 56 ++++++++++++++++++++++++++---------- pyht/client.py | 67 ++++++++++++++++++++++++++++++++++++-------- 2 files changed, 96 insertions(+), 27 deletions(-) diff --git a/pyht/async_client.py b/pyht/async_client.py index 0b1adb1..c43d236 100644 --- a/pyht/async_client.py +++ b/pyht/async_client.py @@ -1,4 +1,6 @@ from __future__ import annotations + +import logging from typing import Any, AsyncGenerator, AsyncIterable, AsyncIterator, Coroutine import asyncio @@ -15,7 +17,7 @@ from grpc import ssl_channel_credentials, StatusCode -from .client import TTSOptions +from .client import TTSOptions, CongestionCtrl from .lease import Lease, LeaseFactory from .protos import api_pb2, api_pb2_grpc from .utils import ensure_sentence_end, normalize, split_text, SENTENCE_END_REGEX @@ -51,6 +53,7 @@ class AdvancedOptions: fallback_enabled: bool = False auto_refresh_lease: bool = True disable_lease_disk_cache: bool = False + congestion_ctrl: CongestionCtrl = CongestionCtrl.OFF def __init__( self, @@ -207,26 +210,49 @@ async def tts( text = ensure_sentence_end(text) request = api_pb2.TtsRequest(params=options.tts_params(text, voice_engine), lease=lease_data) - try: - stub = api_pb2_grpc.TtsStub(self._rpc[1]) - stream: TtsUnaryStream = stub.Tts(request) - if context is not None: - context.assign(stream) - async for response in stream: - yield response.data - except grpc.RpcError as e: - error_code = getattr(e, "code")() - if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE} or self._fallback_rpc is None: - raise + + retries = 0 + max_retries = 0 + backoff = 0 + if self._advanced.congestion_ctrl == CongestionCtrl.STATIC_MAR_2023: + max_retries = 2 + backoff = 0.05 + + while True: try: - stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1]) + stub = api_pb2_grpc.TtsStub(self._rpc[1]) stream: TtsUnaryStream = stub.Tts(request) if context is not None: context.assign(stream) async for response in stream: yield response.data - except grpc.RpcError as fallback_e: - raise fallback_e from e + except grpc.RpcError as e: + error_code = getattr(e, "code")() + logging.debug(f"Error: {error_code}") + if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE}: + raise + + if retries < max_retries: + retries += 1 + logging.debug(f"Retrying in {backoff} ms ({retries} attempts so far)... ({error_code})") + if backoff > 0: + await asyncio.sleep(backoff) + continue + + if self._fallback_rpc is None: + raise + + logging.info(f"Falling back to {self._fallback_rpc[0]}... ({error_code})") + try: + stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1]) + stream: TtsUnaryStream = stub.Tts(request) + if context is not None: + context.assign(stream) + async for response in stream: + yield response.data + break + except grpc.RpcError as fallback_e: + raise fallback_e from e def get_stream_pair( self, diff --git a/pyht/client.py b/pyht/client.py index 87d8ac4..129ad3c 100644 --- a/pyht/client.py +++ b/pyht/client.py @@ -1,5 +1,8 @@ from __future__ import annotations +import logging +import time +from enum import Enum from typing import Generator, Iterable, Iterator, List, Tuple from dataclasses import dataclass @@ -59,6 +62,19 @@ def tts_params(self, text: list[str], voice_engine: str | None) -> api_pb2.TtsPa return params +class CongestionCtrl(Enum): + """ + Enumerates a streaming congestion control algorithms, used to optimize the rate at which text is sent to PlayHT. + """ + + # The client will not do any congestion control. Text will be sent to PlayHT as fast as possible. + OFF = 0 + + # The client will optimize for minimizing the number of physical resources required to handle a single stream. + # If you're using PlayHT On-Prem, you should use this {@link CongestionCtrl} algorithm. + STATIC_MAR_2023 = 1 + + class Client: LEASE_DATA: bytes | None = None LEASE_CACHE_PATH: str = os.path.join(tempfile.gettempdir(), 'playht.temporary.lease') @@ -72,6 +88,7 @@ class AdvancedOptions: fallback_enabled: bool = False auto_refresh_lease: bool = True disable_lease_disk_cache: bool = False + congestion_ctrl: CongestionCtrl = CongestionCtrl.OFF def __init__( self, @@ -227,22 +244,48 @@ def tts( text = ensure_sentence_end(text) request = api_pb2.TtsRequest(params=options.tts_params(text, voice_engine), lease=lease_data) - try: - stub = api_pb2_grpc.TtsStub(self._rpc[1]) - response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse] - for item in response: - yield item.data - except grpc.RpcError as e: - error_code = getattr(e, "code")() - if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE} or self._fallback_rpc is None: - raise + + retries = 0 + max_retries = 0 + backoff = 0 + if self._advanced.congestion_ctrl == CongestionCtrl.STATIC_MAR_2023: + max_retries = 2 + backoff = 0.05 + + while True: try: - stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1]) + stub = api_pb2_grpc.TtsStub(self._rpc[1]) response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse] for item in response: yield item.data - except grpc.RpcError as fallback_e: - raise fallback_e from e + break + except grpc.RpcError as e: + error_code = getattr(e, "code")() + logging.debug(f"Error: {error_code}") + if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE}: + raise + + if retries < max_retries: + retries += 1 + logging.debug(f"Retrying in {backoff} ms ({retries} attempts so far)... ({error_code})") + print(f"Retrying in {backoff} ms ({retries} attempts so far)... ({error_code})") + if backoff > 0: + time.sleep(backoff) + continue + + if self._fallback_rpc is None: + raise + + logging.info(f"Falling back to {self._fallback_rpc[0]}... ({error_code})") + print(f"Falling back to {self._fallback_rpc[0]}... ({error_code})") + try: + stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1]) + response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse] + for item in response: + yield item.data + break + except grpc.RpcError as fallback_e: + raise fallback_e from e def get_stream_pair( self, From b56f1bc5d2b053fdd442fa59ef4f173e1e98a62d Mon Sep 17 00:00:00 2001 From: mtenpow <3721+mtenpow@users.noreply.github.com> Date: Thu, 29 Feb 2024 13:20:38 -0800 Subject: [PATCH 2/5] Improve docs. --- pyht/client.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pyht/client.py b/pyht/client.py index 129ad3c..9262984 100644 --- a/pyht/client.py +++ b/pyht/client.py @@ -67,11 +67,15 @@ class CongestionCtrl(Enum): Enumerates a streaming congestion control algorithms, used to optimize the rate at which text is sent to PlayHT. """ - # The client will not do any congestion control. Text will be sent to PlayHT as fast as possible. + # The client will not do any congestion control. OFF = 0 - # The client will optimize for minimizing the number of physical resources required to handle a single stream. - # If you're using PlayHT On-Prem, you should use this {@link CongestionCtrl} algorithm. + # The client will retry requests to the primary address up to two times with a 50ms backoff between attempts. + # + # Then it will fall back to the fallback address (if one is configured). No retry attempts will be made + # against the fallback address. + # + # If you're using PlayHT On-Prem, you should probably be using this congestion control algorithm. STATIC_MAR_2023 = 1 From 266f296ae0b065d2583060c0e1fda6bd3af25f8b Mon Sep 17 00:00:00 2001 From: mtenpow <3721+mtenpow@users.noreply.github.com> Date: Thu, 29 Feb 2024 13:22:18 -0800 Subject: [PATCH 3/5] Improve docs. --- pyht/client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyht/client.py b/pyht/client.py index 9262984..d8ae4ad 100644 --- a/pyht/client.py +++ b/pyht/client.py @@ -271,8 +271,8 @@ def tts( if retries < max_retries: retries += 1 + # It's a poor customer experience to show internal details about retries, so we only debug log here. logging.debug(f"Retrying in {backoff} ms ({retries} attempts so far)... ({error_code})") - print(f"Retrying in {backoff} ms ({retries} attempts so far)... ({error_code})") if backoff > 0: time.sleep(backoff) continue @@ -280,8 +280,9 @@ def tts( if self._fallback_rpc is None: raise + # We log fallbacks to give customers an extra signal that they should scale up their on-prem appliance + # (e.g. by paying for more GPU quota) logging.info(f"Falling back to {self._fallback_rpc[0]}... ({error_code})") - print(f"Falling back to {self._fallback_rpc[0]}... ({error_code})") try: stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1]) response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse] From 2a1d55ebf72d00a345c88bcdb07c0034dc2612ac Mon Sep 17 00:00:00 2001 From: mtenpow <3721+mtenpow@users.noreply.github.com> Date: Thu, 29 Feb 2024 16:08:30 -0800 Subject: [PATCH 4/5] Address code review comments --- pyht/async_client.py | 14 ++++++-------- pyht/client.py | 14 ++++++-------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/pyht/async_client.py b/pyht/async_client.py index c43d236..0b954d5 100644 --- a/pyht/async_client.py +++ b/pyht/async_client.py @@ -211,14 +211,13 @@ async def tts( request = api_pb2.TtsRequest(params=options.tts_params(text, voice_engine), lease=lease_data) - retries = 0 - max_retries = 0 + max_attempts = 1 backoff = 0 if self._advanced.congestion_ctrl == CongestionCtrl.STATIC_MAR_2023: - max_retries = 2 + max_attempts = 3 backoff = 0.05 - while True: + for attempt in range(1, max_attempts + 1): try: stub = api_pb2_grpc.TtsStub(self._rpc[1]) stream: TtsUnaryStream = stub.Tts(request) @@ -232,9 +231,8 @@ async def tts( if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE}: raise - if retries < max_retries: - retries += 1 - logging.debug(f"Retrying in {backoff} ms ({retries} attempts so far)... ({error_code})") + if attempt < max_attempts: + logging.debug(f"Retrying in {backoff * 1000} sec ({attempt} attempts so far)... ({error_code})") if backoff > 0: await asyncio.sleep(backoff) continue @@ -242,7 +240,7 @@ async def tts( if self._fallback_rpc is None: raise - logging.info(f"Falling back to {self._fallback_rpc[0]}... ({error_code})") + logging.info(f"Falling back to {self._fallback_rpc[0]} because {self._rpc[0]} threw: {error_code}") try: stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1]) stream: TtsUnaryStream = stub.Tts(request) diff --git a/pyht/client.py b/pyht/client.py index d8ae4ad..ca5a43a 100644 --- a/pyht/client.py +++ b/pyht/client.py @@ -249,14 +249,13 @@ def tts( request = api_pb2.TtsRequest(params=options.tts_params(text, voice_engine), lease=lease_data) - retries = 0 - max_retries = 0 + max_attempts = 1 backoff = 0 if self._advanced.congestion_ctrl == CongestionCtrl.STATIC_MAR_2023: - max_retries = 2 + max_attempts = 3 backoff = 0.05 - while True: + for attempt in range(1, max_attempts + 1): try: stub = api_pb2_grpc.TtsStub(self._rpc[1]) response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse] @@ -269,10 +268,9 @@ def tts( if error_code not in {StatusCode.RESOURCE_EXHAUSTED, StatusCode.UNAVAILABLE}: raise - if retries < max_retries: - retries += 1 + if attempt < max_attempts: # It's a poor customer experience to show internal details about retries, so we only debug log here. - logging.debug(f"Retrying in {backoff} ms ({retries} attempts so far)... ({error_code})") + logging.debug(f"Retrying in {backoff * 1000} ms ({attempt} attempts so far)... ({error_code})") if backoff > 0: time.sleep(backoff) continue @@ -282,7 +280,7 @@ def tts( # We log fallbacks to give customers an extra signal that they should scale up their on-prem appliance # (e.g. by paying for more GPU quota) - logging.info(f"Falling back to {self._fallback_rpc[0]}... ({error_code})") + logging.info(f"Falling back to {self._fallback_rpc[0]} because {self._rpc[0]} threw: {error_code}") try: stub = api_pb2_grpc.TtsStub(self._fallback_rpc[1]) response = stub.Tts(request) # type: Iterable[api_pb2.TtsResponse] From 5929415820f788eac92c399791bf7ce2d44f7e72 Mon Sep 17 00:00:00 2001 From: mtenpow <3721+mtenpow@users.noreply.github.com> Date: Thu, 29 Feb 2024 16:09:54 -0800 Subject: [PATCH 5/5] Fix loop breaking --- pyht/async_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyht/async_client.py b/pyht/async_client.py index 0b954d5..6374325 100644 --- a/pyht/async_client.py +++ b/pyht/async_client.py @@ -225,6 +225,7 @@ async def tts( context.assign(stream) async for response in stream: yield response.data + break except grpc.RpcError as e: error_code = getattr(e, "code")() logging.debug(f"Error: {error_code}")