diff --git a/packages/playht/src/api/APISettingsStore.ts b/packages/playht/src/api/APISettingsStore.ts index 66d5fb1..a66a915 100644 --- a/packages/playht/src/api/APISettingsStore.ts +++ b/packages/playht/src/api/APISettingsStore.ts @@ -20,6 +20,7 @@ export class APISettingsStore { apiKey: settings.apiKey, customAddr: settings.customAddr, fallbackEnabled: settings.fallbackEnabled, + congestionCtrl: settings.congestionCtrl, }); APISettingsStore._instance = this; diff --git a/packages/playht/src/api/apiCommon.ts b/packages/playht/src/api/apiCommon.ts index 5c88966..f7d2617 100644 --- a/packages/playht/src/api/apiCommon.ts +++ b/packages/playht/src/api/apiCommon.ts @@ -1,13 +1,13 @@ import type { - SpeechOptions, - SpeechStreamOptions, - SpeechOutput, - OutputQuality, Emotion, - VoiceEngine, + OutputFormat, + OutputQuality, PlayHT10OutputStreamFormat, PlayHT20OutputStreamFormat, - OutputFormat, + SpeechOptions, + SpeechOutput, + SpeechStreamOptions, + VoiceEngine, } from '..'; import { PassThrough, Readable, Writable } from 'node:stream'; import { APISettingsStore } from './APISettingsStore'; @@ -17,6 +17,7 @@ import { generateV2Speech } from './generateV2Speech'; import { generateV2Stream } from './generateV2Stream'; import { textStreamToSentences } from './textStreamToSentences'; import { generateGRpcStream } from './generateGRpcStream'; +import { CongestionController } from './congestionCtrl'; export type V1ApiOptions = { narrationStyle?: string; @@ -198,8 +199,8 @@ async function audioStreamFromSentences( writableStream: NodeJS.WritableStream, options?: SpeechStreamOptions, ) { - // Create a stream for promises - const promiseStream = new Readable({ + // Create a stream of audio chunk promises -- each corresponding to a sentence + const audioChunkStream = new Readable({ objectMode: true, read() {}, }); @@ -226,23 +227,45 @@ async function audioStreamFromSentences( writableStream.end(); } + const congestionController = new CongestionController(APISettingsStore.getSettings().congestionCtrl ?? 'Off'); + // For each sentence in the stream, add a task to the queue + let sentenceIdx = 0; sentencesStream.on('data', async (data) => { const sentence = data.toString(); - const generatePromise = (async () => { - return await internalGenerateStreamFromString(sentence, options); - })(); - promiseStream.push(generatePromise); + /** + * NOTE: + * + * If the congestion control algorithm is set to "Off", + * then this {@link CongestionController#enqueue} method will invoke the task immediately; + * thereby generating the audio chunk for this sentence immediately. + * + * @see CongestionController + * @see CongestionCtrl + */ + congestionController.enqueue(() => { + const nextAudioChunk = (async () => { + return await internalGenerateStreamFromString(sentence, options); + })(); + audioChunkStream.push(nextAudioChunk); + }, `createAudioChunk#${sentenceIdx}`); + + sentenceIdx++; }); sentencesStream.on('end', async () => { - promiseStream.push(null); + /** + * NOTE: if the congestion control algorithm is set to "Off", then this enqueue method will simply invoke the task immediately. + */ + congestionController.enqueue(() => { + audioChunkStream.push(null); + }, 'endAudioChunks'); }); sentencesStream.on('error', onError); - // Read from the promiseStream and await for each promise in order + // Await each audio chunk in order, and write the raw audio to the output audio stream const writeAudio = new Writable({ objectMode: true, write: async (generatePromise, _, callback) => { @@ -252,8 +275,29 @@ async function audioStreamFromSentences( onError(); return; } + const completion = { + headersRemaining: 0, + gotAudio: false, + }; + // NOTE: The cast below is to avoid a cyclic dependency warning from "yarn verify" + switch ((<{ outputFormat: string }>options).outputFormat) { + case 'wav': + completion.headersRemaining = 1; + break; + case 'mp3': + completion.headersRemaining = 1; + break; + default: + break; + } await new Promise((resolve) => { resultStream.on('data', (chunk: Buffer) => { + if (completion.headersRemaining > 0) { + completion.headersRemaining -= 1; + } else if (!completion.gotAudio) { + completion.gotAudio = true; + congestionController.audioRecvd(); + } writableStream.write(chunk); }); @@ -272,9 +316,9 @@ async function audioStreamFromSentences( writeAudio.on('error', onError); - promiseStream.on('error', onError); + audioChunkStream.on('error', onError); - promiseStream.on('end', () => { + audioChunkStream.on('end', () => { setTimeout( () => writeAudio.on('finish', () => { @@ -284,5 +328,5 @@ async function audioStreamFromSentences( ); }); - promiseStream.pipe(writeAudio); + audioChunkStream.pipe(writeAudio); } diff --git a/packages/playht/src/api/congestionCtrl.ts b/packages/playht/src/api/congestionCtrl.ts new file mode 100644 index 0000000..c553e41 --- /dev/null +++ b/packages/playht/src/api/congestionCtrl.ts @@ -0,0 +1,113 @@ +/** + * Enumerates a streaming congestion control algorithms, used to optimize the rate at which text is sent to PlayHT. + */ +export type CongestionCtrl = + /** + * The client will not do any congestion control. Text will be sent to PlayHT as fast as possible. + */ + | 'Off' + + /** + * The client will optimize for minimizing the number of physical resources required to handle a single stream. + * + * If you're using PlayHT On-Prem, you should use this {@link CongestionCtrl} algorithm. + */ + | 'StaticMar2024'; + +/** + * Responsible for optimizing the rate at which text is sent to the underlying API endpoint, according to the + * specified {@link CongestionCtrl} algorithm. {@link CongestionController} is essentially a task queue + * that throttles the parallelism of, and delay between, task execution. + * + * The primary motivation for this (as of 2024/02/28) is to protect PlayHT On-Prem appliances + * from being inundated with a burst of text-to-speech requests that it can't satisfy. Prior to the introduction + * of {@link CongestionController} (and more generally {@link CongestionCtrl}), the client would split + * a text stream into two text chunks (referred to as "sentences") and send them to the API client (i.e. gRPC client) + * simultaneously. This would routinely overload on-prem appliances that operate without a lot of GPU capacity headroom[1]. + * + * The result would be that most requests that clients sent would immediately result in a gRPC error 8: RESOURCE_EXHAUSTED; + * and therefore, a bad customer experience. {@link CongestionController}, if configured with "StaticMar2024", + * will now delay sending subsequent text chunks (i.e. sentences) to the gRPC client until audio for the preceding text + * chunk has started streaming. + * + * The current {@link CongestionCtrl} algorithm ("StaticMar2024") is very simple and leaves a lot to + * be desired. We should iterate on these algorithms. The {@link CongestionCtrl} enum was added so that algorithms + * can be added without requiring customers to change their code much. + * + * [1] Customers tend to be very cost sensitive regarding expensive GPU capacity, and therefore want to keep their appliances + * running near 100% utilization. + * + * --mtp@2024/02/28 + * + * This class is largely inert if the specified {@link CongestionCtrl} is "Off". + */ +export class CongestionController { + algo: CongestionCtrl; + taskQ: Array = []; + inflight = 0; + parallelism: number; + postChunkBackoff: number; + + constructor(algo: CongestionCtrl) { + this.algo = algo; + switch (algo) { + case 'Off': + this.parallelism = Infinity; + this.postChunkBackoff = 0; + break; + case 'StaticMar2024': + this.parallelism = 1; + this.postChunkBackoff = 50; + break; + default: + throw new Error(`Unrecognized congestion control algorithm: ${algo}`); + } + } + + enqueue(task: () => void, name: string) { + // if congestion control is turned off - just execute the task immediately + if (this.algo == 'Off') { + task(); + return; + } + + this.taskQ.push(new Task(task, name)); + this.maybeDoMore(); + } + + private maybeDoMore() { + // if congestion control is turned off - there's nothing to do here because all tasks were executed immediately + if (this.algo == 'Off') return; + + while (this.inflight < this.parallelism && this.taskQ.length > 0) { + const task = this.taskQ.shift()!; + this.inflight++; + //console.debug(`[PlayHT SDK] Started congestion control task: ${task.name}. inflight=${this.inflight}`); + task.fn(); + } + } + + audioRecvd() { + // if congestion control is turned off - there's nothing to do here because all tasks were executed immediately + if (this.algo == 'Off') return; + + this.inflight = Math.max(this.inflight - 1, 0); + //console.debug('[PlayHT SDK] Congestion control received audio'); + setTimeout(() => { + this.maybeDoMore(); + }, this.postChunkBackoff); + } +} + +/** + * NOTE: + * + * {@link #name} is currently unused, but exists so that we can log task names during development. + * Without {@link #name}, it's hard to understand which tasks were executed and in which order. + */ +class Task { + constructor( + public fn: () => void, + public name: string, + ) {} +} diff --git a/packages/playht/src/grpc-client/client.ts b/packages/playht/src/grpc-client/client.ts index afb89a4..584323a 100644 --- a/packages/playht/src/grpc-client/client.ts +++ b/packages/playht/src/grpc-client/client.ts @@ -1,9 +1,10 @@ -import { credentials, Client as GrpcClient } from '@grpc/grpc-js'; +import { Client as GrpcClient, credentials } from '@grpc/grpc-js'; import fetch from 'cross-fetch'; import apiProto from './protos/api'; import { Lease } from './lease'; import { ReadableStream } from './readable-stream'; import { TTSStreamSource } from './tts-stream-source'; +import { CongestionCtrl } from './congestion-ctrl'; export type TTSParams = apiProto.playht.v1.ITtsParams; export const Quality = apiProto.playht.v1.Quality; @@ -35,6 +36,11 @@ export interface ClientOptions { * (configured with "customAddr" above) to the standard PlayHT address. */ fallbackEnabled?: boolean; + + /** + * @see CongestionCtrl + */ + congestionCtrl?: CongestionCtrl; } const USE_INSECURE_CONNECTION = false; @@ -259,7 +265,8 @@ export class Client { rpcClient = isPremium ? this.premiumRpc!.client : this.rpc!.client; fallbackClient = undefined; } - const stream = new ReadableStream(new TTSStreamSource(request, rpcClient, fallbackClient)); + const congestionCtrl = this.options.congestionCtrl ?? 'Off'; + const stream = new ReadableStream(new TTSStreamSource(request, rpcClient, fallbackClient, congestionCtrl)); // fix for TypeScript not DOM types not including Symbol.asyncIterator in ReadableStream return stream as unknown as AsyncIterable & ReadableStream; } diff --git a/packages/playht/src/grpc-client/congestion-ctrl.ts b/packages/playht/src/grpc-client/congestion-ctrl.ts new file mode 100644 index 0000000..4f84684 --- /dev/null +++ b/packages/playht/src/grpc-client/congestion-ctrl.ts @@ -0,0 +1,15 @@ +/** + * Enumerates a streaming congestion control algorithms, used to optimize the rate at which text is sent to PlayHT. + */ +export type CongestionCtrl = + /** + * The client will not do any congestion control. Text will be sent to PlayHT as fast as possible. + */ + | 'Off' + + /** + * The client will optimize for minimizing the number of physical resources required to handle a single stream. + * + * If you're using PlayHT On-Prem, you should use this {@link CongestionCtrl} algorithm. + */ + | 'StaticMar2024'; diff --git a/packages/playht/src/grpc-client/tts-stream-source.ts b/packages/playht/src/grpc-client/tts-stream-source.ts index 370fd4d..b69c3bd 100644 --- a/packages/playht/src/grpc-client/tts-stream-source.ts +++ b/packages/playht/src/grpc-client/tts-stream-source.ts @@ -1,16 +1,46 @@ import type * as grpc from '@grpc/grpc-js'; import * as apiProto from './protos/api'; +import { CongestionCtrl } from './congestion-ctrl'; export class TTSStreamSource implements UnderlyingByteSource { private stream?: grpc.ClientReadableStream; readonly type = 'bytes'; private retryable = true; + private retries = 0; + private maxRetries = 0; + private backoff = 0; constructor( private readonly request: apiProto.playht.v1.ITtsRequest, private readonly rpcClient: grpc.Client, private readonly fallbackClient?: grpc.Client, - ) {} + private readonly congestionCtrl?: CongestionCtrl, + ) { + if (congestionCtrl != undefined) { + switch (congestionCtrl) { + case 'Off': + this.maxRetries = 0; + this.backoff = 0; + break; + case 'StaticMar2024': + /** + * NOTE: + * + * The values below were experimentally chosen. + * + * The experiments were not rigorous and certainly leave a lot to be desired. We should tune them over time. + * We might end up with something dynamic inspired by additive-increase-multiplicative-decrease. + * + * --mtp@2024/02/28 + */ + this.maxRetries = 2; + this.backoff = 50; + break; + default: + throw new Error(`Unrecognized congestion control algorithm: ${congestionCtrl}`); + } + } + } start(controller: ReadableByteStreamController) { this.startAndMaybeFallback(controller, this.rpcClient, this.fallbackClient); @@ -67,14 +97,27 @@ export class TTSStreamSource implements UnderlyingByteSource { }); this.stream.on('error', (err) => { // if we get an error while this stream source is still retryable (i.e. we haven't started streaming data back and haven't canceled) - // then we can fallback if there is a fallback rpc client - if (this.retryable && fallbackClient) { - console.warn(`[PlayHT SDK] Falling back...`, fallbackClient.getChannel().getTarget(), err.message); - this.end(); - // start again with the fallback client and the primary client - // we won't specify a second order fallback client - so if this client fails, this stream will fail - this.startAndMaybeFallback(controller, fallbackClient, undefined); - return; + // then we can retry or fall back (if there is a fallback rpc client) + if (this.retryable) { + if (this.retries < this.maxRetries) { + this.end(); + this.retries++; + // NOTE: It's a poor customer experience to show internal details about retries -- so we don't log here. + //console.debug(`[PlayHT SDK] Retrying in ${this.backoff} ms ... (${this.retries} attempts so far)`, err.message); + // retry with the same primary and fallback client + setTimeout(() => { + this.startAndMaybeFallback(controller, client, fallbackClient); + }, this.backoff); + } else if (fallbackClient) { + // NOTE: We log fallbacks to give customers a signal that they should scale up their on-prem appliance (e.g. by paying for more GPU quota) + console.warn(`[PlayHT SDK] Falling back to ${fallbackClient.getChannel().getTarget()} ...`, err.message); + this.end(); + // start again with the fallback client and the primary client + // we won't specify a second order fallback client - so if this client fails, this stream will fail + // we also won't reset the number of retries - so we'll try at most once with the fallback client + this.startAndMaybeFallback(controller, fallbackClient, undefined); + return; + } } // if we reach here - we couldn't fallback and therefore this stream has failed diff --git a/packages/playht/src/index.ts b/packages/playht/src/index.ts index c05cbef..98ee22a 100644 --- a/packages/playht/src/index.ts +++ b/packages/playht/src/index.ts @@ -2,6 +2,7 @@ import { APISettingsStore } from './api/APISettingsStore'; import { commonGenerateSpeech, commonGenerateStream } from './api/apiCommon'; import { commonGetAllVoices } from './api/commonGetAllVoices'; import { commonInstantClone, internalDeleteClone } from './api/instantCloneInternal'; +import { CongestionCtrl } from './api/congestionCtrl'; /** * Type representing the various voice engines that can be used for speech synthesis. @@ -397,6 +398,16 @@ export type APISettingsInput = { * (configured with "customAddr" above) to the standard PlayHT address. */ fallbackEnabled?: boolean; + + /** + * If specified, the client will use the specified {@link CongestionCtrl} algorithm to optimize + * the rate at which it sends text to PlayHT. + * + * If you're using PlayHT On-Prem, you should set this to "StaticMar2024". + * + * @see CongestionCtrl + */ + congestionCtrl?: CongestionCtrl; }; /**