From 8cf0ec41051a40b16e93f63055157917a3a3355e Mon Sep 17 00:00:00 2001 From: Anurag Tomer Date: Tue, 13 Jan 2026 11:04:36 +0530 Subject: [PATCH 1/4] Adding ChatterBox changes --- WORKSPACE | 2 +- riva/clients/tts/riva_tts_client.cc | 7 +++++++ riva/clients/tts/riva_tts_perf_client.cc | 17 ++++++++++++++--- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index a91c61f..5acf18c 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -77,7 +77,7 @@ grpc_extra_deps() git_repository( name = "nvriva_common", remote = "https://github.com/atomer-nvidia/common.git", - commit = "3085c065085d15a284b37847470fe0182c9a6c67" + commit = "84337dbb94e4dc1b5cb081f1988a365e67895cd0" ) http_archive( diff --git a/riva/clients/tts/riva_tts_client.cc b/riva/clients/tts/riva_tts_client.cc index 446fcc6..5df8919 100644 --- a/riva/clients/tts/riva_tts_client.cc +++ b/riva/clients/tts/riva_tts_client.cc @@ -55,6 +55,7 @@ DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot a DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation"); DEFINE_uint64(max_grpc_message_size, MAX_GRPC_MESSAGE_SIZE, "Max GRPC message size"); DEFINE_double(speed, 1.0, "Speed of generated audio, ranges between 0.5-2.0"); +DEFINE_double(exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0"); static const std::string LC_enUS = "en-US"; @@ -120,6 +121,7 @@ main(int argc, char** argv) str_usage << " --timeout_ms= " << std::endl; str_usage << " --max_grpc_message_size= " << std::endl; str_usage << " --speed= " << std::endl; + str_usage << " --exaggeration_factor= " << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -225,6 +227,11 @@ main(int argc, char** argv) if (not FLAGS_online and not FLAGS_zero_shot_transcript.empty()) { zero_shot_data->set_transcript(FLAGS_zero_shot_transcript); } + if (FLAGS_exaggeration_factor < 0.0 || FLAGS_exaggeration_factor > 2.0) { + LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl; + return -1; + } + zero_shot_data->set_exaggeration_factor(FLAGS_exaggeration_factor); } // Send text content using Synthesize(). diff --git a/riva/clients/tts/riva_tts_perf_client.cc b/riva/clients/tts/riva_tts_perf_client.cc index 8125049..a1d71e1 100644 --- a/riva/clients/tts/riva_tts_perf_client.cc +++ b/riva/clients/tts/riva_tts_perf_client.cc @@ -64,6 +64,7 @@ DEFINE_int32(zero_shot_quality, 20, "Required quality of output audio, ranges be DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone custom words"); DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt."); DEFINE_double(speed, 1.0, "Speed of generated audio, ranges between 0.5-2.0"); +DEFINE_double(exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0"); static const std::string LC_enUS = "en-US"; @@ -115,7 +116,7 @@ synthesizeBatch( std::unique_ptr tts, std::string text, std::string language, uint32_t rate, std::string voice_name, std::string filepath, std::string zero_shot_prompt_filename, int32_t zero_shot_quality, std::string custom_dictionary, - std::string zero_shot_transcript, double speed) + std::string zero_shot_transcript, double speed, double exaggeration_factor) { // Parse command line arguments. nr_tts::SynthesizeSpeechRequest request; @@ -168,6 +169,11 @@ synthesizeBatch( if (not FLAGS_zero_shot_transcript.empty()) { zero_shot_data->set_transcript(FLAGS_zero_shot_transcript); } + if (exaggeration_factor < 0.0 || exaggeration_factor > 2.0) { + LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl; + return -1; + } + zero_shot_data->set_exaggeration_factor(exaggeration_factor); } // Send text content using Synthesize(). @@ -211,7 +217,7 @@ synthesizeOnline( std::unique_ptr tts, std::string text, std::string language, uint32_t rate, std::string voice_name, double* time_to_first_chunk, std::vector* time_to_next_chunk, size_t* num_samples, std::string filepath, - std::string zero_shot_prompt_filename, int32_t zero_shot_quality, double speed) + std::string zero_shot_prompt_filename, int32_t zero_shot_quality, double speed, double exaggeration_factor) { nr_tts::SynthesizeSpeechRequest request; request.set_text(text); @@ -260,6 +266,11 @@ synthesizeOnline( } zero_shot_data->set_sample_rate_hz(zero_shot_sample_rate); zero_shot_data->set_quality(zero_shot_quality); + if (exaggeration_factor < 0.0 || exaggeration_factor > 2.0) { + LOG(ERROR) << "Exaggeration factor must be between 0.0 and 2.0" << std::endl; + return; + } + zero_shot_data->set_exaggeration_factor(exaggeration_factor); } @@ -367,7 +378,7 @@ main(int argc, char** argv) str_usage << " --zero_shot_transcript=" << std::endl; str_usage << " --custom_dictionary= " << std::endl; str_usage << " --speed= " << std::endl; - + str_usage << " --exaggeration_factor= " << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); From c2b418f6e0696e863560d767ee33e5b288cc2f5b Mon Sep 17 00:00:00 2001 From: Anurag Tomer Date: Wed, 14 Jan 2026 13:43:36 +0530 Subject: [PATCH 2/4] Adding ChatterBox changes --- WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/WORKSPACE b/WORKSPACE index 5acf18c..4ff14ab 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -77,7 +77,7 @@ grpc_extra_deps() git_repository( name = "nvriva_common", remote = "https://github.com/atomer-nvidia/common.git", - commit = "84337dbb94e4dc1b5cb081f1988a365e67895cd0" + commit = "60e67e8ba30eac99d8cfb30275b03b76b6562a29" ) http_archive( From 4047ee5b2afbf805c4f8d685a886c7c31408bfd9 Mon Sep 17 00:00:00 2001 From: Anurag Tomer Date: Wed, 14 Jan 2026 14:26:47 +0530 Subject: [PATCH 3/4] Removing speed --- riva/clients/tts/riva_tts_client.cc | 7 ------- riva/clients/tts/riva_tts_perf_client.cc | 20 ++++---------------- 2 files changed, 4 insertions(+), 23 deletions(-) diff --git a/riva/clients/tts/riva_tts_client.cc b/riva/clients/tts/riva_tts_client.cc index 5df8919..a9ef3e9 100644 --- a/riva/clients/tts/riva_tts_client.cc +++ b/riva/clients/tts/riva_tts_client.cc @@ -54,7 +54,6 @@ DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt."); DEFINE_uint64(timeout_ms, 10000, "Timeout for GRPC channel creation"); DEFINE_uint64(max_grpc_message_size, MAX_GRPC_MESSAGE_SIZE, "Max GRPC message size"); -DEFINE_double(speed, 1.0, "Speed of generated audio, ranges between 0.5-2.0"); DEFINE_double(exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0"); static const std::string LC_enUS = "en-US"; @@ -120,7 +119,6 @@ main(int argc, char** argv) str_usage << " --custom_dictionary= " << std::endl; str_usage << " --timeout_ms= " << std::endl; str_usage << " --max_grpc_message_size= " << std::endl; - str_usage << " --speed= " << std::endl; str_usage << " --exaggeration_factor= " << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -190,11 +188,6 @@ main(int argc, char** argv) request.set_sample_rate_hz(rate); request.set_voice_name(FLAGS_voice_name); - if (FLAGS_speed < 0.5 || FLAGS_speed > 2.0) { - LOG(ERROR) << "Speed must be between 0.5 and 2.0" << std::endl; - return -1; - } - request.set_speed(FLAGS_speed); if (not FLAGS_zero_shot_audio_prompt.empty()) { auto zero_shot_data = request.mutable_zero_shot_data(); std::vector> audio_prompt; diff --git a/riva/clients/tts/riva_tts_perf_client.cc b/riva/clients/tts/riva_tts_perf_client.cc index a1d71e1..65a6a3a 100644 --- a/riva/clients/tts/riva_tts_perf_client.cc +++ b/riva/clients/tts/riva_tts_perf_client.cc @@ -63,7 +63,6 @@ DEFINE_string( DEFINE_int32(zero_shot_quality, 20, "Required quality of output audio, ranges between 1-40."); DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone custom words"); DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt."); -DEFINE_double(speed, 1.0, "Speed of generated audio, ranges between 0.5-2.0"); DEFINE_double(exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0"); static const std::string LC_enUS = "en-US"; @@ -116,7 +115,7 @@ synthesizeBatch( std::unique_ptr tts, std::string text, std::string language, uint32_t rate, std::string voice_name, std::string filepath, std::string zero_shot_prompt_filename, int32_t zero_shot_quality, std::string custom_dictionary, - std::string zero_shot_transcript, double speed, double exaggeration_factor) + std::string zero_shot_transcript, double exaggeration_factor) { // Parse command line arguments. nr_tts::SynthesizeSpeechRequest request; @@ -124,11 +123,6 @@ synthesizeBatch( request.set_language_code(language); request.set_sample_rate_hz(rate); request.set_voice_name(voice_name); - if (speed < 0.5 || speed > 2.0) { - LOG(ERROR) << "Speed must be between 0.5 and 2.0" << std::endl; - return -1; - } - request.set_speed(speed); if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { request.set_encoding(nr::LINEAR_PCM); } else if (FLAGS_audio_encoding == "opus") { @@ -217,18 +211,13 @@ synthesizeOnline( std::unique_ptr tts, std::string text, std::string language, uint32_t rate, std::string voice_name, double* time_to_first_chunk, std::vector* time_to_next_chunk, size_t* num_samples, std::string filepath, - std::string zero_shot_prompt_filename, int32_t zero_shot_quality, double speed, double exaggeration_factor) + std::string zero_shot_prompt_filename, int32_t zero_shot_quality, double exaggeration_factor) { nr_tts::SynthesizeSpeechRequest request; request.set_text(text); request.set_language_code(language); request.set_sample_rate_hz(rate); request.set_voice_name(voice_name); - if (speed < 0.5 || speed > 2.0) { - LOG(ERROR) << "Speed must be between 0.5 and 2.0" << std::endl; - return; - } - request.set_speed(speed); auto ae = nr::AudioEncoding::ENCODING_UNSPECIFIED; if (FLAGS_audio_encoding.empty() || FLAGS_audio_encoding == "pcm") { ae = nr::LINEAR_PCM; @@ -377,7 +366,6 @@ main(int argc, char** argv) str_usage << " --zero_shot_quality=" << std::endl; str_usage << " --zero_shot_transcript=" << std::endl; str_usage << " --custom_dictionary= " << std::endl; - str_usage << " --speed= " << std::endl; str_usage << " --exaggeration_factor= " << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -505,7 +493,7 @@ main(int argc, char** argv) std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, &time_to_first_chunk, time_to_next_chunk, &num_samples, std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, - FLAGS_zero_shot_quality, FLAGS_speed); + FLAGS_zero_shot_quality); latencies_first_chunk[i]->push_back(time_to_first_chunk); latencies_next_chunks[i]->insert( latencies_next_chunks[i]->end(), time_to_next_chunk->begin(), @@ -581,7 +569,7 @@ main(int argc, char** argv) int32_t num_samples = synthesizeBatch( std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, - FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript, FLAGS_speed); + FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript); results_num_samples[i]->push_back(num_samples); } })); From 8f4e8b41e6db185ee78ff25fbe216e62668a5f9a Mon Sep 17 00:00:00 2001 From: Anurag Tomer Date: Wed, 14 Jan 2026 18:54:43 +0530 Subject: [PATCH 4/4] fix: missing parameters --- riva/clients/tts/riva_tts_perf_client.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/riva/clients/tts/riva_tts_perf_client.cc b/riva/clients/tts/riva_tts_perf_client.cc index 65a6a3a..adf867e 100644 --- a/riva/clients/tts/riva_tts_perf_client.cc +++ b/riva/clients/tts/riva_tts_perf_client.cc @@ -63,7 +63,8 @@ DEFINE_string( DEFINE_int32(zero_shot_quality, 20, "Required quality of output audio, ranges between 1-40."); DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone custom words"); DEFINE_string(zero_shot_transcript, "", "Transcript corresponding to Zero shot audio prompt."); -DEFINE_double(exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0"); +DEFINE_double( + exaggeration_factor, 1.0, "Exaggeration factor for generated audio, ranges between 0.0-2.0"); static const std::string LC_enUS = "en-US"; @@ -493,7 +494,7 @@ main(int argc, char** argv) std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, &time_to_first_chunk, time_to_next_chunk, &num_samples, std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, - FLAGS_zero_shot_quality); + FLAGS_zero_shot_quality, FLAGS_exaggeration_factor); latencies_first_chunk[i]->push_back(time_to_first_chunk); latencies_next_chunks[i]->insert( latencies_next_chunks[i]->end(), time_to_next_chunk->begin(), @@ -569,7 +570,8 @@ main(int argc, char** argv) int32_t num_samples = synthesizeBatch( std::move(tts), sentences[i][s].second, FLAGS_language, rate, FLAGS_voice_name, std::to_string(sentences[i][s].first) + ".wav", FLAGS_zero_shot_audio_prompt, - FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript); + FLAGS_zero_shot_quality, FLAGS_custom_dictionary, FLAGS_zero_shot_transcript, + FLAGS_exaggeration_factor); results_num_samples[i]->push_back(num_samples); } }));