From cbe8cd33d8712c12cd9a378a76a3c39422e1b018 Mon Sep 17 00:00:00 2001 From: Emmanuel Schmidbauer Date: Tue, 22 Jul 2025 12:56:06 -0400 Subject: [PATCH] package --- .gitignore | 3 + pyproject.toml | 54 ++++ src/dmospeech2/__init__.py | 0 src/{ => dmospeech2}/ctcmodel.py | 77 ++--- src/{ => dmospeech2}/demo.ipynb | 0 .../discriminator_conformer.py | 101 +++--- src/{ => dmospeech2}/dmd_trainer.py | 147 ++++----- src/{ => dmospeech2}/duration_predictor.py | 31 +- src/{ => dmospeech2}/duration_trainer.py | 92 +++--- .../duration_trainer_with_prompt.py | 75 ++--- src/{ => dmospeech2}/ecapa_tdnn.py | 38 ++- src/{ => dmospeech2}/grpo_duration_trainer.py | 239 +++++++------- src/{ => dmospeech2}/guidance_model.py | 306 +++++++++--------- src/{ => dmospeech2}/infer.py | 214 ++++++------ src/{ => dmospeech2}/unimodel.py | 139 ++++---- src/f5_tts/__init__.py | 0 src/f5_tts/api.py | 13 +- .../data}/Emilia_ZH_EN_pinyin/vocab.txt | 0 ...brispeech_pc_test_clean_cross_sentence.lst | 0 src/f5_tts/eval/ecapa_tdnn.py | 1 - src/f5_tts/eval/eval_infer_batch.py | 10 +- .../eval/eval_librispeech_test_clean.py | 2 - src/f5_tts/eval/eval_seedtts_testset.py | 2 - src/f5_tts/infer/infer_cli.py | 23 +- src/f5_tts/infer/infer_gradio.py | 15 +- src/f5_tts/infer/speech_edit.py | 5 +- src/f5_tts/infer/utils_infer.py | 3 - src/f5_tts/model/__init__.py | 7 +- src/f5_tts/model/backbones/dit.py | 17 +- src/f5_tts/model/backbones/mmdit.py | 28 +- src/f5_tts/model/backbones/unett.py | 19 +- src/f5_tts/model/cfm.py | 10 +- src/f5_tts/model/dataset.py | 4 +- src/f5_tts/model/modules.py | 1 - src/f5_tts/model/trainer.py | 3 +- src/f5_tts/model/utils.py | 15 +- src/f5_tts/model_new/__init__.py | 1 - src/f5_tts/model_new/backbones/dit.py | 14 +- src/f5_tts/model_new/backbones/mmdit.py | 13 +- src/f5_tts/model_new/backbones/unett.py | 16 +- src/f5_tts/model_new/cfm.py | 12 +- src/f5_tts/model_new/modules.py | 3 +- src/f5_tts/model_new/trainer.py | 4 +- src/f5_tts/model_new/utils.py | 1 - src/f5_tts/runtime/triton_trtllm/benchmark.py | 1 - .../runtime/triton_trtllm/patch/__init__.py | 18 +- .../triton_trtllm/patch/f5tts/model.py | 4 +- .../triton_trtllm/patch/f5tts/modules.py | 26 +- .../triton_trtllm/scripts/conv_stft.py | 1 - .../scripts/export_vocoder_to_onnx.py | 1 - src/f5_tts/scripts/count_params_gflops.py | 2 - src/f5_tts/socket_client.py | 1 - src/f5_tts/socket_server.py | 11 +- src/f5_tts/train/datasets/prepare_csv_wavs.py | 2 - src/f5_tts/train/datasets/prepare_emilia.py | 2 - .../train/datasets/prepare_emilia_v2.py | 1 - src/f5_tts/train/datasets/prepare_libritts.py | 1 - src/f5_tts/train/datasets/prepare_ljspeech.py | 1 - .../train/datasets/prepare_wenetspeech4tts.py | 1 - src/f5_tts/train/finetune_cli.py | 1 - src/f5_tts/train/finetune_gradio.py | 1 - src/f5_tts/train/train.py | 1 - 62 files changed, 825 insertions(+), 1009 deletions(-) create mode 100644 .gitignore create mode 100644 pyproject.toml create mode 100644 src/dmospeech2/__init__.py rename src/{ => dmospeech2}/ctcmodel.py (85%) rename src/{ => dmospeech2}/demo.ipynb (100%) rename src/{ => dmospeech2}/discriminator_conformer.py (78%) rename src/{ => dmospeech2}/dmd_trainer.py (90%) rename src/{ => dmospeech2}/duration_predictor.py (85%) rename src/{ => dmospeech2}/duration_trainer.py (92%) rename src/{ => dmospeech2}/duration_trainer_with_prompt.py (92%) rename src/{ => dmospeech2}/ecapa_tdnn.py (94%) rename src/{ => dmospeech2}/grpo_duration_trainer.py (90%) rename src/{ => dmospeech2}/guidance_model.py (85%) rename src/{ => dmospeech2}/infer.py (91%) rename src/{ => dmospeech2}/unimodel.py (87%) create mode 100644 src/f5_tts/__init__.py rename {data => src/f5_tts/data}/Emilia_ZH_EN_pinyin/vocab.txt (100%) rename {data => src/f5_tts/data}/librispeech_pc_test_clean_cross_sentence.lst (100%) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..61f7eca --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +notes.txt +.venv +venv diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ee1abb1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,54 @@ +[build-system] +requires = ["setuptools>=77.0.3", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "dmospeech2" +version = "0.1.0" +description = "DMOSpeech 2 - Reinforcement learning for duration prediction in speech synthesis" +readme = "README.md" +requires-python = ">=3.9" +license = {file = "LICENSE"} +authors = [{name = "Yinghao Aaron Li", email = "71044569+yl4579@users.noreply.github.com"}] +dependencies = [ + "accelerate>=0.33.0", + "bitsandbytes>0.37.0", + "cached_path", + "click", + "datasets", + "ema_pytorch>=0.5.2", + "gradio>=3.45.2", + "hydra-core>=1.3.0", + "jieba", + "librosa", + "matplotlib", + "numpy<=1.26.4", + "pydantic<=2.10.6", + "pydub", + "pypinyin", + "safetensors", + "soundfile", + "tomli", + "torch>=2.0.0", + "torchaudio>=2.0.0", + "torchdiffeq", + "tqdm>=4.65.0", + "transformers", + "transformers_stream_generator", + "unidecode", + "vocos", + "wandb", + "x_transformers>=1.31.14", +] + +[tool.setuptools] +include-package-data = true + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-data] +"f5_tts" = [ + "data/Emilia_ZH_EN_pinyin/vocab.txt", + "data/librispeech_pc_test_clean_cross_sentence.lst", +] diff --git a/src/dmospeech2/__init__.py b/src/dmospeech2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ctcmodel.py b/src/dmospeech2/ctcmodel.py similarity index 85% rename from src/ctcmodel.py rename to src/dmospeech2/ctcmodel.py index 0bc89d8..508cf3f 100644 --- a/src/ctcmodel.py +++ b/src/dmospeech2/ctcmodel.py @@ -1,27 +1,13 @@ -from torch import nn -import torch import copy - from pathlib import Path -from torchaudio.models import Conformer +import torch +from torch import nn +from torchaudio.models import Conformer -from f5_tts.model.utils import default -from f5_tts.model.utils import exists -from f5_tts.model.utils import list_str_to_idx -from f5_tts.model.utils import list_str_to_tensor -from f5_tts.model.utils import lens_to_mask -from f5_tts.model.utils import mask_from_frac_lengths - +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, -) class ResBlock(nn.Module): def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2): @@ -31,7 +17,6 @@ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2): self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p) for i in range(n_conv)]) - def forward(self, x): for block in self.blocks: res = x @@ -55,26 +40,25 @@ def _get_conv(self, hidden_dim, dilation, dropout_p=0.2): class ConformerCTC(nn.Module): def __init__(self, vocab_size, - mel_dim=100, - num_heads=8, - d_hid=512, + mel_dim=100, + num_heads=8, + d_hid=512, nlayers=6): super().__init__() - + self.mel_proj = nn.Conv1d(mel_dim, d_hid, kernel_size=3, padding=1) - + self.d_hid = d_hid - + self.resblock1 = nn.Sequential( - ResBlock(d_hid), - nn.GroupNorm(num_groups=1, num_channels=d_hid) - ) - + ResBlock(d_hid), + nn.GroupNorm(num_groups=1, num_channels=d_hid) + ) + self.resblock2 = nn.Sequential( - ResBlock(d_hid), - nn.GroupNorm(num_groups=1, num_channels=d_hid) - ) - + ResBlock(d_hid), + nn.GroupNorm(num_groups=1, num_channels=d_hid) + ) self.conf_pre = torch.nn.ModuleList( [Conformer( @@ -85,9 +69,9 @@ def __init__(self, depthwise_conv_kernel_size=15, use_group_norm=True,) for _ in range(nlayers // 2) - ] + ] ) - + self.conf_after = torch.nn.ModuleList( [Conformer( input_dim=d_hid, @@ -97,14 +81,13 @@ def __init__(self, depthwise_conv_kernel_size=7, use_group_norm=True,) for _ in range(nlayers // 2) - ] + ] ) - self.out = nn.Linear(d_hid, 1 + vocab_size) # 1 for blank + self.out = nn.Linear(d_hid, 1 + vocab_size) # 1 for blank self.ctc_loss = nn.CTCLoss(blank=vocab_size, zero_infinity=True).cuda() - def forward(self, latent, text=None, text_lens=None): layers = [] @@ -147,9 +130,8 @@ def forward(self, latent, text=None, text_lens=None): if __name__ == "__main__": from f5_tts.model.utils import get_tokenizer - bsz = 16 - + tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" @@ -158,15 +140,15 @@ def forward(self, latent, text=None, text_lens=None): else: tokenizer_path = dataset_name vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) - + model = ConformerCTC(vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6).cuda() - + text = ["hello world"] * bsz lens = torch.randint(1, 1000, (bsz,)).cuda() inp = torch.randn(bsz, lens.max(), 80).cuda() - + batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device - + # handle text as string text_lens = torch.tensor([len(t) for t in text], device=device) if isinstance(text, list): @@ -198,7 +180,6 @@ def forward(self, latent, text=None, text_lens=None): char_vocab_map = list(vocab_char_map.keys()) - for batch in best_path: decoded_sequence = [] previous_token = None @@ -216,6 +197,6 @@ def forward(self, latent, text=None, text_lens=None): gt_texts = [] for i in range(text_lens.size(0)): gt_texts.append(''.join([char_vocab_map[token] for token in text[i, :text_lens[i]]])) - + print(decoded_texts) - print(gt_texts) \ No newline at end of file + print(gt_texts) diff --git a/src/demo.ipynb b/src/dmospeech2/demo.ipynb similarity index 100% rename from src/demo.ipynb rename to src/dmospeech2/demo.ipynb diff --git a/src/discriminator_conformer.py b/src/dmospeech2/discriminator_conformer.py similarity index 78% rename from src/discriminator_conformer.py rename to src/dmospeech2/discriminator_conformer.py index 058e151..9b5e349 100644 --- a/src/discriminator_conformer.py +++ b/src/dmospeech2/discriminator_conformer.py @@ -2,21 +2,17 @@ from __future__ import annotations +from pathlib import Path + import torch import torch.nn as nn import torch.nn.functional as F import torchaudio.transforms as trans -from pathlib import Path from torchaudio.models import Conformer -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, -) +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) + class ResBlock(nn.Module): def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2): @@ -26,7 +22,6 @@ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2): self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p) for i in range(n_conv)]) - def forward(self, x): for block in self.blocks: res = x @@ -46,36 +41,37 @@ def _get_conv(self, hidden_dim, dilation, dropout_p=0.2): ] return nn.Sequential(*layers) + class ConformerDiscirminator(nn.Module): def __init__(self, input_dim, channels=512, num_layers=3, num_heads=8, depthwise_conv_kernel_size=15, use_group_norm=True): super().__init__() - + self.input_layer = nn.Conv1d(input_dim, channels, kernel_size=3, padding=1) self.resblock1 = nn.Sequential( - ResBlock(channels), - nn.GroupNorm(num_groups=1, num_channels=channels) - ) - + ResBlock(channels), + nn.GroupNorm(num_groups=1, num_channels=channels) + ) + self.resblock2 = nn.Sequential( - ResBlock(channels), - nn.GroupNorm(num_groups=1, num_channels=channels) - ) - - self.conformer1 = Conformer(**{"input_dim": channels, - "num_heads": num_heads, - "ffn_dim": channels * 2, - "num_layers": 1, - "depthwise_conv_kernel_size": depthwise_conv_kernel_size // 2, - "use_group_norm": use_group_norm}) - - self.conformer2 = Conformer(**{"input_dim": channels, - "num_heads": num_heads, - "ffn_dim": channels * 2, - "num_layers": num_layers - 1, - "depthwise_conv_kernel_size": depthwise_conv_kernel_size, - "use_group_norm": use_group_norm}) - + ResBlock(channels), + nn.GroupNorm(num_groups=1, num_channels=channels) + ) + + self.conformer1 = Conformer(**{"input_dim": channels, + "num_heads": num_heads, + "ffn_dim": channels * 2, + "num_layers": 1, + "depthwise_conv_kernel_size": depthwise_conv_kernel_size // 2, + "use_group_norm": use_group_norm}) + + self.conformer2 = Conformer(**{"input_dim": channels, + "num_heads": num_heads, + "ffn_dim": channels * 2, + "num_layers": num_layers - 1, + "depthwise_conv_kernel_size": depthwise_conv_kernel_size, + "use_group_norm": use_group_norm}) + self.linear = nn.Conv1d(channels, 1, kernel_size=1) def forward(self, x): @@ -89,7 +85,7 @@ def forward(self, x): x = nn.functional.avg_pool1d(x, 2) x = self.resblock2(x) x = nn.functional.avg_pool1d(x, 2) - + # Transpose to (B, T, C) for the conformer. x = x.transpose(1, 2) batch_size, time_steps, _ = x.shape @@ -107,12 +103,13 @@ def forward(self, x): return out + if __name__ == "__main__": - from f5_tts.model.utils import get_tokenizer from f5_tts.model import DiT + from f5_tts.model.utils import get_tokenizer bsz = 2 - + tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" @@ -121,8 +118,7 @@ def forward(self, x): else: tokenizer_path = dataset_name vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) - - + fake_unet = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=80) fake_unet = fake_unet.cuda() @@ -130,11 +126,11 @@ def forward(self, x): text = ["hello world"] * bsz lens = torch.randint(1, 1000, (bsz,)).cuda() inp = torch.randn(bsz, lens.max(), 80).cuda() - + batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device - + # handle text as string if isinstance(text, list): if exists(vocab_char_map): @@ -149,11 +145,11 @@ def forward(self, x): mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch frac_lengths_mask = (0.7, 1.0) - + # get a random span to mask out for training conditionally frac_lengths = torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) - + if exists(mask): rand_span_mask &= mask @@ -163,16 +159,16 @@ def forward(self, x): x1 = inp x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) - + phi = (1 - t) * x0 + t * x1 flow = x1 - x0 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) layers = fake_unet( - x=phi, + x=phi, cond=cond, - text=text, - time=time, + text=text, + time=time, drop_audio_cond=False, drop_text=False, classify_mode=True @@ -181,20 +177,19 @@ def forward(self, x): # layers = torch.stack(layers, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2) # print(layers.shape) - from ctcmodel import ConformerCTC + from dmospeech2.ctcmodel import ConformerCTC ctcmodel = ConformerCTC(vocab_size=vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6).cuda() real_out, layer = ctcmodel(inp) - layer = layer[-3:] # only use the last 3 layers + layer = layer[-3:] # only use the last 3 layers layer = [F.interpolate(l, mode='nearest', scale_factor=4).transpose(-1, -2) for l in layer] if layer[0].size(1) < layers[0].size(1): layer = [F.pad(l, (0, 0, 0, layers[0].size(1) - l.size(1))) for l in layer] - + layers = layer + layers - model = ConformerDiscirminator(input_dim=23 * 1024 + 3 * 512, - channels=512 - ) - + model = ConformerDiscirminator(input_dim=23 * 1024 + 3 * 512, + channels=512 + ) model = model.cuda() print(model) diff --git a/src/dmd_trainer.py b/src/dmospeech2/dmd_trainer.py similarity index 90% rename from src/dmd_trainer.py rename to src/dmospeech2/dmd_trainer.py index cf7f94f..fb430da 100644 --- a/src/dmd_trainer.py +++ b/src/dmospeech2/dmd_trainer.py @@ -1,28 +1,26 @@ from __future__ import annotations -import os import gc -from tqdm import tqdm -import wandb +import math +import os import torch import torch.nn as nn -from torch.optim import AdamW -from torch.utils.data import DataLoader, Dataset, SequentialSampler -from torch.optim.lr_scheduler import LinearLR, SequentialLR - +import wandb from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs +from torch.optim import AdamW +from torch.optim.lr_scheduler import LinearLR, SequentialLR +from torch.utils.data import DataLoader, Dataset, SequentialSampler +from tqdm import tqdm +from dmospeech2.unimodel import UniModel -from unimodel import UniModel from f5_tts.model import CFM -from f5_tts.model.utils import exists, default from f5_tts.model.dataset import DynamicBatchSampler, collate_fn - +from f5_tts.model.utils import default, exists # trainer -import math class RunningStats: def __init__(self): @@ -49,7 +47,6 @@ def std(self): return math.sqrt(self.variance) - class Trainer: def __init__( self, @@ -74,7 +71,7 @@ def __init__( accelerate_kwargs: dict = dict(), bnb_optimizer: bool = False, scale: float = 1.0, - + # training parameters for DMDSpeech num_student_step: int = 1, gen_update_ratio: int = 5, @@ -142,25 +139,25 @@ def __init__( self.noise_scheduler = noise_scheduler self.duration_predictor = duration_predictor - + self.log_step = log_step - self.gen_update_ratio = gen_update_ratio # number of generator updates per guidance (fake score function and discriminator) update - self.lambda_discriminator_loss = lambda_discriminator_loss # weight for discriminator loss (L_adv) - self.lambda_generator_loss = lambda_generator_loss # weight for generator loss (L_adv) - self.lambda_ctc_loss = lambda_ctc_loss # weight for ctc loss - self.lambda_sim_loss = lambda_sim_loss # weight for similarity loss - + self.gen_update_ratio = gen_update_ratio # number of generator updates per guidance (fake score function and discriminator) update + self.lambda_discriminator_loss = lambda_discriminator_loss # weight for discriminator loss (L_adv) + self.lambda_generator_loss = lambda_generator_loss # weight for generator loss (L_adv) + self.lambda_ctc_loss = lambda_ctc_loss # weight for ctc loss + self.lambda_sim_loss = lambda_sim_loss # weight for similarity loss + # create distillation schedule for student model self.student_steps = ( - torch.linspace(0.0, 1.0, num_student_step + 1)[:-1]) - - self.GAN = model.guidance_model.gen_cls_loss # whether to use GAN training - self.num_GAN = num_GAN # number of steps before adversarial training - self.num_D = num_D # number of steps to train the discriminator before adversarial training - self.num_ctc = num_ctc # number of steps before CTC training - self.num_sim = num_sim # number of steps before similarity training - self.num_simu = num_simu # number of steps before using simulated data + torch.linspace(0.0, 1.0, num_student_step + 1)[:-1]) + + self.GAN = model.guidance_model.gen_cls_loss # whether to use GAN training + self.num_GAN = num_GAN # number of steps before adversarial training + self.num_D = num_D # number of steps to train the discriminator before adversarial training + self.num_ctc = num_ctc # number of steps before CTC training + self.num_sim = num_sim # number of steps before similarity training + self.num_simu = num_simu # number of steps before using simulated data # Assuming `self.model.fake_unet.parameters()` and `self.model.guidance_model.parameters()` are accessible if bnb_optimizer: @@ -176,7 +173,6 @@ def __init__( self.generator_norm = RunningStats() self.guidance_norm = RunningStats() - @property def is_main(self): return self.accelerator.is_main_process @@ -232,7 +228,6 @@ def load_checkpoint(self): del checkpoint gc.collect() return step - def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int = None, vocoder: nn.Module = None): if exists(resumable_with_seed): @@ -274,12 +269,12 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int warmup_steps = ( self.num_warmup_updates * self.accelerator.num_processes ) - + # consider a fixed warmup steps while using accelerate multi-gpu ddp # otherwise by default with split_batches=False, warmup steps change with num_processes total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps decay_steps = total_steps - warmup_steps - + warmup_scheduler_generator = LinearLR(self.optimizer_generator, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps)) decay_scheduler_generator = LinearLR(self.optimizer_generator, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps // (self.gen_update_ratio * self.grad_accumulation_steps)) self.scheduler_generator = SequentialLR(self.optimizer_generator, schedulers=[warmup_scheduler_generator, decay_scheduler_generator], milestones=[warmup_steps // (self.gen_update_ratio * self.grad_accumulation_steps)]) @@ -307,7 +302,7 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int if exists(resumable_with_seed) and epoch == skipped_epoch: progress_bar = tqdm( skipped_dataloader, - desc=f"Epoch {epoch+1}/{self.epochs}", + desc=f"Epoch {epoch + 1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process, initial=skipped_batch, @@ -316,36 +311,36 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int else: progress_bar = tqdm( train_dataloader, - desc=f"Epoch {epoch+1}/{self.epochs}", + desc=f"Epoch {epoch + 1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process, ) for batch in progress_bar: update_generator = global_step % self.gen_update_ratio == 0 - + with self.accelerator.accumulate(self.model): metrics = {} text_inputs = batch["text"] mel_spec = batch["mel"].permute(0, 2, 1) mel_lengths = batch["mel_lengths"] - + mel_spec = mel_spec / self.scale - - guidance_loss_dict, guidance_log_dict = self.model(inp=mel_spec, - text=text_inputs, - lens=mel_lengths, - student_steps=self.student_steps, - update_generator=False, - use_simulated=global_step >= self.num_simu, - ) + + guidance_loss_dict, guidance_log_dict = self.model(inp=mel_spec, + text=text_inputs, + lens=mel_lengths, + student_steps=self.student_steps, + update_generator=False, + use_simulated=global_step >= self.num_simu, + ) # if self.GAN and update_generator: # # only add discriminator loss if GAN is enabled and generator is being updated # guidance_cls_loss = guidance_loss_dict["guidance_cls_loss"] * (self.lambda_discriminator_loss if global_step >= self.num_GAN and update_generator else 0) # metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"] # self.accelerator.backward(guidance_cls_loss, retain_graph=True) - + # if self.max_grad_norm > 0 and self.accelerator.sync_gradients: # metrics['grad_norm_guidance'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) @@ -360,7 +355,7 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int metrics['loss/discriminator_loss'] = guidance_loss_dict["guidance_cls_loss"] guidance_loss += guidance_cls_loss - + self.accelerator.backward(guidance_loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: @@ -376,20 +371,19 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int # elif self.guidance_norm.count >= 100: # self.guidance_norm.update(metrics['grad_norm_guidance']) - self.optimizer_guidance.step() self.scheduler_guidance.step() self.optimizer_guidance.zero_grad() self.optimizer_generator.zero_grad() # zero out the generator's gradient as well - + if update_generator: - generator_loss_dict, generator_log_dict = self.model(inp=mel_spec, - text=text_inputs, - lens=mel_lengths, - student_steps=self.student_steps, - update_generator=True, - use_simulated=global_step >= self.num_ctc, - ) + generator_loss_dict, generator_log_dict = self.model(inp=mel_spec, + text=text_inputs, + lens=mel_lengths, + student_steps=self.student_steps, + update_generator=True, + use_simulated=global_step >= self.num_ctc, + ) # if self.GAN: # gen_cls_loss = generator_loss_dict["gen_cls_loss"] * (self.lambda_generator_loss if global_step >= (self.num_GAN + self.num_D) and update_generator else 0) # metrics["loss/gen_cls_loss"] = generator_loss_dict["gen_cls_loss"] @@ -402,7 +396,7 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int generator_loss = 0 generator_loss += generator_loss_dict["loss_dm"] if "loss_mse" in generator_loss_dict: - generator_loss += generator_loss_dict["loss_mse"] + generator_loss += generator_loss_dict["loss_mse"] generator_loss += generator_loss_dict["loss_ctc"] * (self.lambda_ctc_loss if global_step >= self.num_ctc else 0) generator_loss += generator_loss_dict["loss_sim"] * (self.lambda_sim_loss if global_step >= self.num_sim else 0) generator_loss += generator_loss_dict["loss_kl"] * (self.lambda_ctc_loss if global_step >= self.num_ctc else 0) @@ -416,7 +410,7 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int metrics['loss/similarity_loss'] = generator_loss_dict["loss_sim"] metrics['loss/generator_loss'] = generator_loss - + if "loss_mse" in generator_loss_dict and generator_loss_dict["loss_mse"] != 0: metrics['loss/mse_loss'] = generator_loss_dict["loss_mse"] if "loss_kl" in generator_loss_dict and generator_loss_dict["loss_kl"] != 0: @@ -427,7 +421,7 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int if self.max_grad_norm > 0 and self.accelerator.sync_gradients: metrics['grad_norm_generator'] = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) # self.generator_norm.update(metrics['grad_norm_generator']) - + # if metrics['grad_norm_generator'] > self.generator_norm.mean + 15 * self.generator_norm.std: # self.optimizer_generator.zero_grad() # self.optimizer_guidance.zero_grad() @@ -440,16 +434,14 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int self.optimizer_generator.zero_grad() self.optimizer_guidance.zero_grad() # zero out the guidance's gradient as well - global_step += 1 if self.accelerator.is_local_main_process: self.accelerator.log({**metrics, "lr_generator": self.scheduler_generator.get_last_lr()[0], "lr_guidance": self.scheduler_guidance.get_last_lr()[0], - } - , step=global_step) - + }, step=global_step) + if global_step % self.log_step == 0 and self.accelerator.is_local_main_process and vocoder is not None: # log the first batch of the epoch with torch.no_grad(): @@ -468,7 +460,7 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int sample_rate=24000, caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy()) ) - + generator_cond = generator_log_dict['generator_cond'][0].unsqueeze(0).permute(0, 2, 1) * self.scale generator_cond = vocoder.decode(generator_cond.float().cpu()) generator_cond = wandb.Audio( @@ -476,7 +468,7 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int sample_rate=24000, caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy()) ) - + ground_truth = generator_log_dict['ground_truth'][0].unsqueeze(0).permute(0, 2, 1) * self.scale ground_truth = vocoder.decode(ground_truth.float().cpu()) ground_truth = wandb.Audio( @@ -484,7 +476,7 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int sample_rate=24000, caption="time: " + str(generator_log_dict['time'][0].float().cpu().numpy()) ) - + dmtrain_noisy_inp = generator_log_dict['dmtrain_noisy_inp'][0].unsqueeze(0).permute(0, 2, 1) * self.scale dmtrain_noisy_inp = vocoder.decode(dmtrain_noisy_inp.float().cpu()) dmtrain_noisy_inp = wandb.Audio( @@ -492,7 +484,7 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int sample_rate=24000, caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy()) ) - + dmtrain_pred_real_image = generator_log_dict['dmtrain_pred_real_image'][0].unsqueeze(0).permute(0, 2, 1) * self.scale dmtrain_pred_real_image = vocoder.decode(dmtrain_pred_real_image.float().cpu()) dmtrain_pred_real_image = wandb.Audio( @@ -500,7 +492,7 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int sample_rate=24000, caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy()) ) - + dmtrain_pred_fake_image = generator_log_dict['dmtrain_pred_fake_image'][0].unsqueeze(0).permute(0, 2, 1) * self.scale dmtrain_pred_fake_image = vocoder.decode(dmtrain_pred_fake_image.float().cpu()) dmtrain_pred_fake_image = wandb.Audio( @@ -508,17 +500,16 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int sample_rate=24000, caption="dmtrain_time: " + str(generator_log_dict['dmtrain_time'][0].float().cpu().numpy()) ) - - - self.accelerator.log({"noisy_input": generator_input, + + self.accelerator.log({"noisy_input": generator_input, "output": generator_output, - "cond": generator_cond, - "ground_truth": ground_truth, - "dmtrain_noisy_inp": dmtrain_noisy_inp, - "dmtrain_pred_real_image": dmtrain_pred_real_image, - "dmtrain_pred_fake_image": dmtrain_pred_fake_image, - - }, step=global_step) + "cond": generator_cond, + "ground_truth": ground_truth, + "dmtrain_noisy_inp": dmtrain_noisy_inp, + "dmtrain_pred_real_image": dmtrain_pred_real_image, + "dmtrain_pred_fake_image": dmtrain_pred_fake_image, + + }, step=global_step) progress_bar.set_postfix(step=str(global_step), metrics=metrics) @@ -531,5 +522,3 @@ def train(self, train_dataset: Dataset, num_workers=64, resumable_with_seed: int self.save_checkpoint(global_step, last=True) self.accelerator.end_training() - - diff --git a/src/duration_predictor.py b/src/dmospeech2/duration_predictor.py similarity index 85% rename from src/duration_predictor.py rename to src/dmospeech2/duration_predictor.py index 03c73fc..a8decc4 100644 --- a/src/duration_predictor.py +++ b/src/dmospeech2/duration_predictor.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn -# from tts_encode import tts_encode def calculate_remaining_lengths(mel_lengths): B = mel_lengths.shape[0] @@ -33,42 +32,42 @@ def forward(self, x): class SpeechLengthPredictor(nn.Module): - def __init__(self, - vocab_size=2545, n_mel=100, hidden_dim=256, - n_text_layer=4, n_cross_layer=4, n_head=8, - output_dim=1, - ): + def __init__(self, + vocab_size=2545, n_mel=100, hidden_dim=256, + n_text_layer=4, n_cross_layer=4, n_head=8, + output_dim=1, + ): super().__init__() - + # Text Encoder: Embedding + Transformer Layers - self.text_embedder = nn.Embedding(vocab_size+1, hidden_dim, padding_idx=vocab_size) + self.text_embedder = nn.Embedding(vocab_size + 1, hidden_dim, padding_idx=vocab_size) self.text_pe = PositionalEncoding(hidden_dim) encoder_layer = nn.TransformerEncoderLayer( - d_model=hidden_dim, nhead=n_head, dim_feedforward=hidden_dim*2, batch_first=True + d_model=hidden_dim, nhead=n_head, dim_feedforward=hidden_dim * 2, batch_first=True ) self.text_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_text_layer) - + # Mel Spectrogram Embedder self.mel_embedder = nn.Linear(n_mel, hidden_dim) self.mel_pe = PositionalEncoding(hidden_dim) # Transformer Decoder Layers with Cross-Attention in Every Layer decoder_layer = nn.TransformerDecoderLayer( - d_model=hidden_dim, nhead=n_head, dim_feedforward=hidden_dim*2, batch_first=True + d_model=hidden_dim, nhead=n_head, dim_feedforward=hidden_dim * 2, batch_first=True ) self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_cross_layer) - + # Final Classification Layer self.predictor = nn.Linear(hidden_dim, output_dim) def forward(self, text_ids, mel): # Encode text text_embedded = self.text_pe(self.text_embedder(text_ids)) - text_features = self.text_encoder(text_embedded) # (B, L_text, D) - + text_features = self.text_encoder(text_embedded) # (B, L_text, D) + # Encode Mel spectrogram mel_features = self.mel_pe(self.mel_embedder(mel)) # (B, L_mel, D) - + # Causal Masking for Decoder seq_len = mel_features.size(1) causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(mel.device) @@ -78,7 +77,7 @@ def forward(self, text_ids, mel): # Transformer Decoder with Cross-Attention in Each Layer decoder_out = self.decoder(mel_features, text_features, tgt_mask=causal_mask) - + # Length Prediction length_logits = self.predictor(decoder_out).squeeze(-1) return length_logits diff --git a/src/duration_trainer.py b/src/dmospeech2/duration_trainer.py similarity index 92% rename from src/duration_trainer.py rename to src/dmospeech2/duration_trainer.py index 3757bbf..df49e0c 100644 --- a/src/duration_trainer.py +++ b/src/dmospeech2/duration_trainer.py @@ -1,47 +1,38 @@ from __future__ import annotations import gc -import os - import math +import os import torch +import torch.nn.functional as F import torchaudio import wandb from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs +from duration_predictor import calculate_remaining_lengths from ema_pytorch import EMA from torch.optim import AdamW from torch.optim.lr_scheduler import LinearLR, SequentialLR -from torch.utils.data import DataLoader, Dataset, SequentialSampler, Subset # <-- Added Subset import +from torch.utils.data import (DataLoader, Dataset, # <-- Added Subset import + SequentialSampler, Subset) from tqdm import tqdm -import torch.nn.functional as F - from f5_tts.model import CFM -from f5_tts.model.dataset import collate_fn, DynamicBatchSampler -from f5_tts.model.utils import default, exists - -from duration_predictor import calculate_remaining_lengths +from f5_tts.model.dataset import DynamicBatchSampler, collate_fn +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) # trainer -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, -) SAMPLE_RATE = 24_000 def masked_l1_loss(est_lengths, tar_lengths): - first_zero_idx = (tar_lengths == 0).int().argmax(dim=1) + first_zero_idx = (tar_lengths == 0).int().argmax(dim=1) B, L = tar_lengths.shape - range_tensor = torch.arange(L, device=tar_lengths.device).expand(B, L) + range_tensor = torch.arange(L, device=tar_lengths.device).expand(B, L) mask = range_tensor <= first_zero_idx[:, None] # Include the first 0 loss = F.l1_loss(est_lengths, tar_lengths, reduction='none') # (B, L) loss = loss * mask # Zero out ignored positions @@ -55,8 +46,8 @@ def masked_cross_entropy_loss(est_length_logits, tar_length_labels): range_tensor = torch.arange(L, device=tar_length_labels.device).expand(B, L) mask = range_tensor <= first_zero_idx[:, None] # Include the first 0 loss = F.cross_entropy( - est_length_logits.reshape(-1, est_length_logits.size(-1)), - tar_length_labels.reshape(-1), + est_length_logits.reshape(-1, est_length_logits.size(-1)), + tar_length_labels.reshape(-1), reduction='none' ).reshape(B, L) loss = loss * mask @@ -165,14 +156,14 @@ def __init__( else: self.optimizer = AdamW(model.parameters(), lr=learning_rate) self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) - + @property def is_main(self): return self.accelerator.is_main_process def save_checkpoint(self, step, last=False): self.accelerator.wait_for_everyone() - if self.is_main: + if self.is_main: checkpoint = dict( model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(), @@ -229,7 +220,7 @@ def load_checkpoint(self): } self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) step = 0 - + del checkpoint gc.collect() @@ -237,7 +228,6 @@ def load_checkpoint(self): return step - def validate(self, valid_dataloader, global_step): """ Runs evaluation on the validation set, computes the average loss, @@ -251,12 +241,12 @@ def validate(self, valid_dataloader, global_step): with torch.no_grad(): for batch in valid_dataloader: # Inputs - mel = batch['mel'].permute(0, 2, 1) # (B, L_mel, D) + mel = batch['mel'].permute(0, 2, 1) # (B, L_mel, D) text = batch['text'] if self.process_token_to_id: text_ids = list_str_to_idx(text, self.vocab_char_map).to(mel.device) - text_ids = text_ids.masked_fill(text_ids==-1, self.vocab_size) + text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size) else: text_ids = text @@ -274,7 +264,7 @@ def validate(self, valid_dataloader, global_step): elif self.loss_fn == 'CE': tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + .clamp(min=0, max=self.n_class - 1) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_labels = torch.argmax(est_length_logtis, dim=-1) loss = masked_cross_entropy_loss( @@ -287,7 +277,7 @@ def validate(self, valid_dataloader, global_step): elif self.loss_fn == 'L1_and_CE': tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + .clamp(min=0, max=self.n_class - 1) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_1hots = F.gumbel_softmax( est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1 @@ -324,15 +314,14 @@ def validate(self, valid_dataloader, global_step): { f"valid_loss": avg_valid_loss, f"valid_sec_error": avg_valid_sec_error - }, + }, step=global_step ) - - self.model.train() + self.model.train() def train(self, train_dataset: Dataset, valid_dataset: Dataset, - num_workers=64, resumable_with_seed: int = None): + num_workers=64, resumable_with_seed: int = None): if exists(resumable_with_seed): generator = torch.Generator() generator.manual_seed(resumable_with_seed) @@ -386,13 +375,13 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, valid_dataset, collate_fn=collate_fn, num_workers=num_workers, - pin_memory=True, + pin_memory=True, persistent_workers=True, batch_sampler=batch_sampler, ) else: raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") - + # accelerator.prepare() dispatches batches to devices; # which means the length of dataloader calculated before, should consider the number of devices warmup_steps = ( @@ -427,7 +416,7 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, if exists(resumable_with_seed) and epoch == skipped_epoch: progress_bar = tqdm( skipped_dataloader, - desc=f"Epoch {epoch+1}/{self.epochs}", + desc=f"Epoch {epoch + 1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process, initial=skipped_batch, @@ -436,7 +425,7 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, else: progress_bar = tqdm( train_dataloader, - desc=f"Epoch {epoch+1}/{self.epochs}", + desc=f"Epoch {epoch + 1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process, ) @@ -444,12 +433,12 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, for batch in progress_bar: with self.accelerator.accumulate(self.model): # Inputs - mel = batch['mel'].permute(0, 2, 1) # (B, L_mel, D) + mel = batch['mel'].permute(0, 2, 1) # (B, L_mel, D) text = batch['text'] if self.process_token_to_id: text_ids = list_str_to_idx(text, self.vocab_char_map).to(mel.device) - text_ids = text_ids.masked_fill(text_ids==-1, self.vocab_size) + text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size) else: text_ids = text @@ -469,15 +458,15 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, sec_error = frame_error * 256 / 24000 log_dict = { - 'loss': loss.item(), - 'loss_L1': loss.item(), + 'loss': loss.item(), + 'loss_L1': loss.item(), 'sec_error': sec_error.item(), 'lr': self.scheduler.get_last_lr()[0] - } + } elif self.loss_fn == 'CE': tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + .clamp(min=0, max=self.n_class - 1) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_labels = torch.argmax(est_length_logtis, dim=-1) loss = masked_cross_entropy_loss( @@ -491,15 +480,15 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, sec_error = frame_error * 256 / 24000 log_dict = { - 'loss': loss.item(), - 'loss_CE': loss.item(), + 'loss': loss.item(), + 'loss_CE': loss.item(), 'sec_error': sec_error.item(), 'lr': self.scheduler.get_last_lr()[0] - } + } elif self.loss_fn == 'L1_and_CE': tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + .clamp(min=0, max=self.n_class - 1) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_1hots = F.gumbel_softmax( est_length_logtis, tau=self.gumbel_tau, hard=True, dim=-1 @@ -513,7 +502,7 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, est_length_logits=est_length_logtis, tar_length_labels=tar_length_labels ) - loss_L1 = masked_l1_loss( + loss_L1 = masked_l1_loss( est_lengths=est_lengths, tar_lengths=tar_lengths ) @@ -524,9 +513,9 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, sec_error = frame_error * 256 / 24000 log_dict = { - 'loss': loss.item(), - 'loss_L1': loss_L1.item(), - 'loss_CE': loss_CE.item(), + 'loss': loss.item(), + 'loss_L1': loss_L1.item(), + 'loss_CE': loss_CE.item(), 'sec_error': sec_error.item(), 'lr': self.scheduler.get_last_lr()[0] } @@ -534,7 +523,6 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, else: raise NotImplementedError(self.loss_fn) - self.accelerator.backward(loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: diff --git a/src/duration_trainer_with_prompt.py b/src/dmospeech2/duration_trainer_with_prompt.py similarity index 92% rename from src/duration_trainer_with_prompt.py rename to src/dmospeech2/duration_trainer_with_prompt.py index cb96457..751e13c 100644 --- a/src/duration_trainer_with_prompt.py +++ b/src/dmospeech2/duration_trainer_with_prompt.py @@ -1,11 +1,11 @@ from __future__ import annotations import gc -import os - import math +import os import torch +import torch.nn.functional as F import torchaudio import wandb from accelerate import Accelerator @@ -13,25 +13,17 @@ from ema_pytorch import EMA from torch.optim import AdamW from torch.optim.lr_scheduler import LinearLR, SequentialLR -from torch.utils.data import DataLoader, Dataset, SequentialSampler, Subset # <-- Added Subset import +from torch.utils.data import (DataLoader, Dataset, # <-- Added Subset import + SequentialSampler, Subset) from tqdm import tqdm -import torch.nn.functional as F - from f5_tts.model import CFM -from f5_tts.model.dataset import collate_fn, DynamicBatchSampler -from f5_tts.model.utils import default, exists +from f5_tts.model.dataset import DynamicBatchSampler, collate_fn +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) # trainer -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, -) SAMPLE_RATE = 24_000 @@ -138,14 +130,14 @@ def __init__( else: self.optimizer = AdamW(model.parameters(), lr=learning_rate) self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) - + @property def is_main(self): return self.accelerator.is_main_process def save_checkpoint(self, step, last=False): self.accelerator.wait_for_everyone() - if self.is_main: + if self.is_main: checkpoint = dict( model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(), @@ -202,7 +194,7 @@ def load_checkpoint(self): } self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) step = 0 - + del checkpoint gc.collect() @@ -210,7 +202,6 @@ def load_checkpoint(self): return step - def validate(self, valid_dataloader, global_step): """ Runs evaluation on the validation set, computes the average loss, @@ -226,29 +217,29 @@ def validate(self, valid_dataloader, global_step): for batch in valid_dataloader: # Inputs - prompt_mel = batch['pmt_mel_specs'].permute(0, 2, 1) # (B, L_mel, D) + prompt_mel = batch['pmt_mel_specs'].permute(0, 2, 1) # (B, L_mel, D) prompt_text = batch['pmt_text'] text = batch['text'] target_ids = list_str_to_idx(text, self.vocab_char_map).to(prompt_mel.device) - target_ids = target_ids.masked_fill(target_ids==-1, vocab_size) + target_ids = target_ids.masked_fill(target_ids == -1, vocab_size) prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to(prompt_mel.device) - prompt_ids = prompt_ids.masked_fill(prompt_ids==-1, vocab_size) + prompt_ids = prompt_ids.masked_fill(prompt_ids == -1, vocab_size) # Targets tar_lengths = batch['mel_lengths'] # Forward - predictions = SLP(target_ids=target_ids, prompt_ids=prompt_ids, prompt_mel=prompt_mel) # (B, C) + predictions = SLP(target_ids=target_ids, prompt_ids=prompt_ids, prompt_mel=prompt_mel) # (B, C) if self.loss_fn == 'CE': tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + .clamp(min=0, max=self.n_class - 1) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_labels = torch.argmax(est_length_logtis, dim=-1) loss = F.cross_entropy(est_length_logtis, tar_length_labels) - + est_lengths = est_length_labels * self.n_frame_per_class frame_error = (est_lengths.float() - tar_lengths.float()).abs().mean() sec_error = frame_error * 256 / 24000 @@ -265,15 +256,14 @@ def validate(self, valid_dataloader, global_step): { f"valid_loss": avg_valid_loss, f"valid_sec_error": avg_valid_sec_error - }, + }, step=global_step ) - - self.model.train() + self.model.train() def train(self, train_dataset: Dataset, valid_dataset: Dataset, - num_workers=64, resumable_with_seed: int = None): + num_workers=64, resumable_with_seed: int = None): if exists(resumable_with_seed): generator = torch.Generator() generator.manual_seed(resumable_with_seed) @@ -327,13 +317,13 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, valid_dataset, collate_fn=collate_fn, num_workers=num_workers, - pin_memory=True, + pin_memory=True, persistent_workers=True, batch_sampler=batch_sampler, ) else: raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") - + # accelerator.prepare() dispatches batches to devices; # which means the length of dataloader calculated before, should consider the number of devices warmup_steps = ( @@ -368,7 +358,7 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, if exists(resumable_with_seed) and epoch == skipped_epoch: progress_bar = tqdm( skipped_dataloader, - desc=f"Epoch {epoch+1}/{self.epochs}", + desc=f"Epoch {epoch + 1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process, initial=skipped_batch, @@ -377,7 +367,7 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, else: progress_bar = tqdm( train_dataloader, - desc=f"Epoch {epoch+1}/{self.epochs}", + desc=f"Epoch {epoch + 1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process, ) @@ -385,45 +375,44 @@ def train(self, train_dataset: Dataset, valid_dataset: Dataset, for batch in progress_bar: with self.accelerator.accumulate(self.model): # Inputs - prompt_mel = batch['pmt_mel_specs'].permute(0, 2, 1) # (B, L_mel, D) + prompt_mel = batch['pmt_mel_specs'].permute(0, 2, 1) # (B, L_mel, D) prompt_text = batch['pmt_text'] text = batch['text'] target_ids = list_str_to_idx(text, self.vocab_char_map).to(prompt_mel.device) - target_ids = target_ids.masked_fill(target_ids==-1, vocab_size) + target_ids = target_ids.masked_fill(target_ids == -1, vocab_size) prompt_ids = list_str_to_idx(prompt_text, self.vocab_char_map).to(prompt_mel.device) - prompt_ids = prompt_ids.masked_fill(prompt_ids==-1, vocab_size) + prompt_ids = prompt_ids.masked_fill(prompt_ids == -1, vocab_size) # Targets tar_lengths = batch['mel_lengths'] # Forward - predictions = SLP(target_ids=target_ids, prompt_ids=prompt_ids, prompt_mel=prompt_mel) # (B, C) + predictions = SLP(target_ids=target_ids, prompt_ids=prompt_ids, prompt_mel=prompt_mel) # (B, C) if self.loss_fn == 'CE': tar_length_labels = (tar_lengths // self.n_frame_per_class) \ - .clamp(min=0, max=self.n_class-1) # [0, 1, ..., n_class-1] + .clamp(min=0, max=self.n_class - 1) # [0, 1, ..., n_class-1] est_length_logtis = predictions est_length_labels = torch.argmax(est_length_logtis, dim=-1) loss = F.cross_entropy(est_length_logtis, tar_length_labels) - + with torch.no_grad(): est_lengths = est_length_labels * self.n_frame_per_class frame_error = (est_lengths.float() - tar_lengths.float()).abs().mean() sec_error = frame_error * 256 / 24000 log_dict = { - 'loss': loss.item(), - 'loss_CE': loss.item(), + 'loss': loss.item(), + 'loss_CE': loss.item(), 'sec_error': sec_error.item(), 'lr': self.scheduler.get_last_lr()[0] - } + } else: raise NotImplementedError(self.loss_fn) - self.accelerator.backward(loss) if self.max_grad_norm > 0 and self.accelerator.sync_gradients: diff --git a/src/ecapa_tdnn.py b/src/dmospeech2/ecapa_tdnn.py similarity index 94% rename from src/ecapa_tdnn.py rename to src/dmospeech2/ecapa_tdnn.py index b55aaf2..3be2ea0 100644 --- a/src/ecapa_tdnn.py +++ b/src/dmospeech2/ecapa_tdnn.py @@ -1,12 +1,13 @@ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN +# from ctcmodel_nopool import ConformerCTC as ConformerCTCNoPool +from pathlib import Path + import torch import torch.nn as nn import torch.nn.functional as F import torchaudio.transforms as trans -from ctcmodel import ConformerCTC -# from ctcmodel_nopool import ConformerCTC as ConformerCTCNoPool -from pathlib import Path +from dmospeech2.ctcmodel import ConformerCTC ''' Res2Conv1d + BatchNorm1d + ReLU ''' @@ -87,6 +88,7 @@ def forward(self, x): ''' SE-Res2Block of the ECAPA-TDNN architecture. ''' + class SE_Res2Block(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim): super().__init__() @@ -119,6 +121,7 @@ def forward(self, x): ''' Attentive weighted mean and standard deviation pooling. ''' + class AttentiveStatsPool(nn.Module): def __init__(self, in_dim, attention_channels=128, global_context_att=False): super().__init__() @@ -151,13 +154,13 @@ def forward(self, x): class ECAPA_TDNN(nn.Module): - def __init__(self, channels=512, emb_dim=512, - global_context_att=False, use_fp16=True, - ctc_cls=ConformerCTC, - ctc_path='/data4/F5TTS/ckpts/F5TTS_norm_ASR_vocos_pinyin_Emilia_ZH_EN/model_last.pt', - ctc_args={'vocab_size': 2545, 'mel_dim': 100, 'num_heads': 8, 'd_hid': 512, 'nlayers': 6}, - ctc_no_grad=False - ): + def __init__(self, channels=512, emb_dim=512, + global_context_att=False, use_fp16=True, + ctc_cls=ConformerCTC, + ctc_path='/data4/F5TTS/ckpts/F5TTS_norm_ASR_vocos_pinyin_Emilia_ZH_EN/model_last.pt', + ctc_args={'vocab_size': 2545, 'mel_dim': 100, 'num_heads': 8, 'd_hid': 512, 'nlayers': 6}, + ctc_no_grad=False + ): super().__init__() if ctc_path != None: ctc_path = Path(ctc_path) @@ -170,7 +173,7 @@ def __init__(self, channels=512, emb_dim=512, self.ctc_model = model self.ctc_model.out.requires_grad_(False) - + if ctc_cls == ConformerCTC: self.feat_num = ctc_args['nlayers'] + 2 + 1 # elif ctc_cls == ConformerCTCNoPool: @@ -180,7 +183,7 @@ def __init__(self, channels=512, emb_dim=512, feat_dim = ctc_args['d_hid'] self.emb_dim = emb_dim - + self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) self.instance_norm = nn.InstanceNorm1d(feat_dim) @@ -208,19 +211,19 @@ def __init__(self, channels=512, emb_dim=512, self.ctc_no_grad = ctc_no_grad print('ctc_no_grad: ', self.ctc_no_grad) - def forward(self, latent, input_lengths, return_asr=False): + def forward(self, latent, input_lengths, return_asr=False): if self.ctc_no_grad: with torch.no_grad(): asr, h = self.ctc_model(latent, input_lengths) else: asr, h = self.ctc_model(latent, input_lengths) - + x = torch.stack(h, dim=0) norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) x = (norm_weights * x).sum(dim=0) x = x + 1e-6 # x = torch.transpose(x, 1, 2) + 1e-6 - + x = self.instance_norm(x) # x = torch.transpose(x, 1, 2) @@ -238,9 +241,10 @@ def forward(self, latent, input_lengths, return_asr=False): return out, asr return out + if __name__ == "__main__": - from diffspeech.ldm.model import DiT from diffspeech.data.collate import get_mask_from_lengths + from diffspeech.ldm.model import DiT from diffspeech.tools.text.vocab import IPA bsz = 3 @@ -265,4 +269,4 @@ def forward(self, latent, input_lengths, return_asr=False): emb = model(latent, latent_mask.sum(axis=-1)) - print(emb.shape) \ No newline at end of file + print(emb.shape) diff --git a/src/grpo_duration_trainer.py b/src/dmospeech2/grpo_duration_trainer.py similarity index 90% rename from src/grpo_duration_trainer.py rename to src/dmospeech2/grpo_duration_trainer.py index 34b98c2..574a4ed 100644 --- a/src/grpo_duration_trainer.py +++ b/src/dmospeech2/grpo_duration_trainer.py @@ -1,25 +1,24 @@ -import os +import copy import gc +import io import json +import os import random import time -import io -import copy -from typing import List, Dict, Any, Optional, Callable, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F +import wandb +from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs from torch.optim import AdamW from torch.optim.lr_scheduler import LinearLR, SequentialLR from torch.utils.data import DataLoader, Dataset, SequentialSampler, Subset from tqdm import tqdm -from accelerate import Accelerator -from accelerate.utils import DistributedDataParallelKwargs -import wandb - -from f5_tts.model.dataset import collate_fn, DynamicBatchSampler +from f5_tts.model.dataset import DynamicBatchSampler, collate_fn from f5_tts.model.utils import list_str_to_idx # torch.autograd.set_detect_anomaly(True) @@ -33,16 +32,16 @@ def safe_sample(logits, temperature=1.0): """ # Apply temperature scaling scaled_logits = logits / temperature - + # Compute categorical distribution probs = F.softmax(scaled_logits, dim=-1) - + # Sample from the distribution once per batch element samples = torch.multinomial(probs, num_samples=1) # (B, 1) - + # Convert to one-hot encoding one_hot_samples = torch.zeros_like(probs).scatter_(1, samples, 1) - + return one_hot_samples @@ -51,51 +50,52 @@ class GRPODurationTrainer: Trainer class that implements GRPO (Generative Reinforcement Learning from Preference Optimization) for a duration predictor in text-to-speech synthesis. """ + def __init__( self, model, # Duration predictor model inference_fn, # Function to generate speech reward_fn, # Function to compute rewards from generated speech - + vocab_size: int, # Size of the vocabulary vocab_char_map: dict, # Mapping from characters to token IDs # Duration model parameters n_class: int = 301, # Number of duration classes - n_frame_per_class: int = 10, # Number of frames per class + n_frame_per_class: int = 10, # Number of frames per class gumbel_tau: int = 0.7, - + # GRPO parameters beta: float = 0.04, # KL regularization weight clip_param: float = 0.2, # PPO clip parameter num_pre_samples: int = 8, # Number of samples per prompt - compute_gen_logps: bool = True, # Whether to compute generation log probabilities - + compute_gen_logps: bool = True, # Whether to compute generation log probabilities + # Training parameters learning_rate: float = 5e-6, num_warmup_updates: int = 10000, save_per_updates: int = 10000, checkpoint_path: Optional[str] = None, all_steps: int = 100000, # Total training steps - + # Batch parameters batch_size: int = 8, batch_size_type: str = "sample", max_samples: int = 16, grad_accumulation_steps: int = 2, max_grad_norm: float = 1.0, - + # Logging parameters logger: Optional[str] = "wandb", wandb_project: str = "tts-duration-grpo", wandb_run_name: str = "grpo_run", wandb_resume_id: Optional[str] = None, - + accelerate_kwargs: dict = dict(), ): # Initialize accelerator for distributed training ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) - + if logger == "wandb" and not wandb.api.api_key: logger = None print(f"Using logger: {logger}") @@ -138,20 +138,20 @@ def __init__( # Store model, inference function, and reward function self.model = model - + # Create reference model (frozen clone of the initial model) self.ref_model = copy.deepcopy(model) for param in self.ref_model.parameters(): param.requires_grad = False self.ref_model.eval() - + # prepare inference_fn self.inference_fn = inference_fn self.inference_fn.scale = self.inference_fn.scale.to(self.accelerator.device) self.inference_fn.tts_model = self.inference_fn.tts_model.to(self.accelerator.device) # prepare reward_fn self.reward_fn = reward_fn - + # Store vocabulary and mapping self.vocab_size = vocab_size self.vocab_char_map = vocab_char_map @@ -160,47 +160,47 @@ def __init__( self.n_class = n_class self.n_frame_per_class = n_frame_per_class self.gumbel_tau = gumbel_tau - + # Store GRPO parameters self.beta = beta self.clip_param = clip_param self.num_pre_samples = num_pre_samples self.compute_gen_logps = compute_gen_logps - + # Store training parameters self.learning_rate = learning_rate self.num_warmup_updates: int = num_warmup_updates self.save_per_updates = save_per_updates self.checkpoint_path = checkpoint_path or f"ckpts/{wandb_run_name}" self.all_steps = all_steps - + # Store batch parameters self.batch_size = batch_size self.batch_size_type = batch_size_type self.max_samples = max_samples self.grad_accumulation_steps = grad_accumulation_steps self.max_grad_norm = max_grad_norm - + # Initialize optimizer self.optimizer = AdamW(model.parameters(), lr=learning_rate) - + # Prepare model and optimizer with accelerator self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) self.ref_model = self.accelerator.prepare(self.ref_model) - self.reward_fn, self.inference_fn = self.accelerator.prepare(self.reward_fn, self.inference_fn) - + self.reward_fn, self.inference_fn = self.accelerator.prepare(self.reward_fn, self.inference_fn) + # GRPO batch queue self.batch_queue = [] - + # Store distributed rank self.rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 self.device = f'cuda:{self.rank}' - + @property def is_main(self): return self.accelerator.is_main_process - + def save_checkpoint(self, step, last=False): """Save model and optimizer state""" self.accelerator.wait_for_everyone() @@ -217,7 +217,7 @@ def save_checkpoint(self, step, last=False): self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt") else: self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt") - + def load_checkpoint(self): """Load latest checkpoint if available""" if ( @@ -238,8 +238,8 @@ def load_checkpoint(self): print(f'Loading checkpoint: {latest_checkpoint}') checkpoint = torch.load( - f"{self.checkpoint_path}/{latest_checkpoint}", - weights_only=True, + f"{self.checkpoint_path}/{latest_checkpoint}", + weights_only=True, map_location="cpu" ) @@ -252,13 +252,13 @@ def load_checkpoint(self): else: self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"]) step = 0 - + del checkpoint gc.collect() - + print(f'Successfully loaded checkpoint at step {step}') return step - + @torch.no_grad() def get_ref_logps(self, text_ids, mel, sampled_classes): """ @@ -268,31 +268,31 @@ def get_ref_logps(self, text_ids, mel, sampled_classes): K = self.num_pre_samples with torch.no_grad(): ref_logits = self.ref_model(text_ids=text_ids, mel=mel)[:, -1, :] - ref_logits = ref_logits.unsqueeze(1).repeat(1, K, 1).view(B*K, -1) + ref_logits = ref_logits.unsqueeze(1).repeat(1, K, 1).view(B * K, -1) ref_log_probs = F.log_softmax(ref_logits, dim=-1) ref_logps = torch.gather( - ref_log_probs, - dim=-1, + ref_log_probs, + dim=-1, index=sampled_classes.unsqueeze(-1) ).squeeze(-1) return ref_logps - + @torch.no_grad() def generate_duration_samples(self, batch_inputs): """ Generate multiple duration predictions from the model for each input and evaluate them using the inference function and reward model - + Args: batch_inputs: Dictionary with text, prompt audio, etc. - + Returns: Dictionary with duration samples, rewards, and reference logits """ if self.rank == 0: print("Generating duration samples...") - + # all_logits = [] all_text_ids = [] all_mels = [] @@ -306,7 +306,7 @@ def generate_duration_samples(self, batch_inputs): # Fetch batch inputs # prompt_mel = batch_inputs['mel'].permute(0, 2, 1).to(self.device) - prompt_mel = batch_inputs['mel'].permute(0, 2, 1) # (B, T, 100) + prompt_mel = batch_inputs['mel'].permute(0, 2, 1) # (B, T, 100) prompt_text = batch_inputs['text'] batch_size = prompt_mel.shape[0] @@ -315,11 +315,11 @@ def generate_duration_samples(self, batch_inputs): target_text = batch_inputs['target_text'] target_text_lengths = torch.LongTensor([len(t) for t in target_text]).to(prompt_mel.device) try: - full_text = [prompt+[' ']+target for prompt, target in zip(prompt_text, target_text)] + full_text = [prompt + [' '] + target for prompt, target in zip(prompt_text, target_text)] except: target_text = [batch_inputs['text'][-1]] + batch_inputs['text'][:-1] target_text_lengths = batch_inputs['text_lengths'].clone().roll(1, 0) - full_text = [prompt+[' ']+target for prompt, target in zip(prompt_text, target_text)] + full_text = [prompt + [' '] + target for prompt, target in zip(prompt_text, target_text)] # Goes to reward model target_text_ids = list_str_to_idx(target_text, self.vocab_char_map).to(self.accelerator.device) # to device, the dataloader only gives list @@ -329,7 +329,7 @@ def generate_duration_samples(self, batch_inputs): # Deepcopy to separate text_ids for SLP and TTS slp_text_ids = full_text_ids.detach().clone() - slp_text_ids = slp_text_ids.masked_fill(slp_text_ids==-1, self.vocab_size) # (B, L) + slp_text_ids = slp_text_ids.masked_fill(slp_text_ids == -1, self.vocab_size) # (B, L) # Pre-compute duration logits K = self.num_pre_samples @@ -340,40 +340,40 @@ def generate_duration_samples(self, batch_inputs): # Run model once for B inputs old_logits = self.model( - text_ids=slp_text_ids, # (B, L) + text_ids=slp_text_ids, # (B, L) mel=prompt_mel # (B, T, 100) )[:, -1, :] # (B, n_class) # Repeat each result K times along batch dimension - old_logits = old_logits.unsqueeze(1).repeat(1, K, 1) # (B, K, n_class) + old_logits = old_logits.unsqueeze(1).repeat(1, K, 1) # (B, K, n_class) # logits_nograd = logits_grad.detach().clone().view(B, K, -1) # (B, K, n_class) for _full_text_ids, _target_text_ids, _target_text_lengths, \ - _prompt_mel, _old_logits in zip( - full_text_ids, target_text_ids, target_text_lengths, - prompt_mel, old_logits - ): + _prompt_mel, _old_logits in zip( + full_text_ids, target_text_ids, target_text_lengths, + prompt_mel, old_logits + ): duration_sample = F.gumbel_softmax(_old_logits, tau=self.gumbel_tau, hard=True, dim=-1) duration2frames = torch.arange(self.n_class).float().to(self.accelerator.device) * self.n_frame_per_class - est_frames = (duration_sample * duration2frames).sum(-1) # (K, ) + est_frames = (duration_sample * duration2frames).sum(-1) # (K, ) # Compute log probabilities of the samples sampled_classes = duration_sample.argmax(dim=-1) log_probs = F.log_softmax(_old_logits, dim=-1) gen_logps = torch.gather( - log_probs, - dim=-1, + log_probs, + dim=-1, index=sampled_classes.unsqueeze(-1) ).squeeze(-1) # Shape: [K, n_class] - + # Generate speech using the sampled durations sampled_rewards = [] for i in range(K): cur_duration = est_frames[i] if cur_duration == 0: - cur_duration = cur_duration + 50 # prevent 0 duration + cur_duration = cur_duration + 50 # prevent 0 duration infer_full_text_ids = _full_text_ids.unsqueeze(0) infer_prompt_mel = _prompt_mel.unsqueeze(0) cur_duration = cur_duration.unsqueeze(0) @@ -382,13 +382,13 @@ def generate_duration_samples(self, batch_inputs): with torch.inference_mode(): try: _est_mel = self.inference_fn( - full_text_ids=infer_full_text_ids, - prompt_mel=infer_prompt_mel, - target_duration=cur_duration, + full_text_ids=infer_full_text_ids, + prompt_mel=infer_prompt_mel, + target_duration=cur_duration, teacher_steps=0 ) - _est_mel = _est_mel.permute(0, 2, 1) # (1, T, 100) - + _est_mel = _est_mel.permute(0, 2, 1) # (1, T, 100) + loss_dict = self.reward_fn( prompt_mel=infer_prompt_mel, est_mel=_est_mel, @@ -396,7 +396,7 @@ def generate_duration_samples(self, batch_inputs): target_text_length=infer_target_text_lengths ) # #TODO reweight the loss for reward - reward_sim = loss_dict['loss_sim'] # 0 to 1 + reward_sim = loss_dict['loss_sim'] # 0 to 1 reward_ctc = loss_dict['loss_ctc'] reward = -(reward_ctc + reward_sim * 3) all_ctc_loss.append(reward_ctc) @@ -405,9 +405,9 @@ def generate_duration_samples(self, batch_inputs): if self.rank == 0: print(f"Error in speech synthesis: {e}") reward = torch.tensor(-1.0).to(cur_duration.device) - + sampled_rewards.append(reward) - # list with length of K + # list with length of K sampled_rewards = torch.stack(sampled_rewards) # (K, ) # Normalize rewards if (sampled_rewards.max() - sampled_rewards.min()).item() > 1e-6: @@ -421,7 +421,7 @@ def generate_duration_samples(self, batch_inputs): all_durations.append(est_frames) all_gen_logps.append(gen_logps) all_rewards.extend(sampled_rewards) # list with length of B*K - + # Concatenate all data # logits = torch.cat(all_logits, dim=0) # text_ids = torch.cat(all_text_ids, dim=0) @@ -433,7 +433,7 @@ def generate_duration_samples(self, batch_inputs): ctc_losses = torch.stack(all_ctc_loss) sv_losses = torch.stack(all_sv_loss) - + if self.is_main: self.accelerator.log({ "ctc_loss": ctc_losses.mean().item(), @@ -459,30 +459,30 @@ def generate_duration_samples(self, batch_inputs): "sampled_classes": sampled_classes, "durations": durations, } - + if self.compute_gen_logps: batch_outputs["gen_logps"] = gen_logps - + if self.rank == 0: print(f"Generated {len(rewards)} samples with reward min/mean/max: {rewards.min().item():.4f}/{rewards.mean().item():.4f}/{rewards.max().item():.4f}") - + return batch_outputs - + def GRPO_step(self, batch): """ Perform a GRPO update step - + Args: batch: Dictionary with inputs, rewards, reference logits, etc. - + Returns: Loss value """ # Extract batch data # NOTE: why .unsqueeze(1) ??? - rewards = batch['rewards'] #.unsqueeze(1) + rewards = batch['rewards'] # .unsqueeze(1) ref_logps = batch['refs'] # (B) - sampled_classes = batch['sampled_classes'] # (B) + sampled_classes = batch['sampled_classes'] # (B) prompt_mel = batch['prompt_mel'] text_ids = batch['text_ids'] @@ -491,23 +491,23 @@ def GRPO_step(self, batch): B, T, _ = prompt_mel.shape _, L = text_ids.shape cur_logits = self.model( - text_ids=text_ids, # (B, L) + text_ids=text_ids, # (B, L) mel=prompt_mel # (B, T, 100) )[:, -1, :] - cur_logits = cur_logits.unsqueeze(1).repeat(1, K, 1).view(B*K, -1) + cur_logits = cur_logits.unsqueeze(1).repeat(1, K, 1).view(B * K, -1) # Compute current log probabilities for sampled actions log_probs = F.log_softmax(cur_logits, dim=-1) cur_logps = torch.gather( - log_probs, - dim=-1, + log_probs, + dim=-1, index=sampled_classes.unsqueeze(-1) ).squeeze(-1) # (B) # KL divergence computation (same as in Qwen2.5 code) # KL = exp(ref - cur) - (ref - cur) - 1 - kl_div = torch.exp(ref_logps - cur_logps) - (ref_logps - cur_logps) - 1 # (B) - + kl_div = torch.exp(ref_logps - cur_logps) - (ref_logps - cur_logps) - 1 # (B) + # Compute probability ratio for PPO if "gen_logps" in batch: gen_logps = batch['gen_logps'] @@ -517,30 +517,30 @@ def GRPO_step(self, batch): else: # Simplification if gen_logps not available loss = torch.exp(cur_logps - cur_logps.detach()) * rewards - + # Final GRPO loss with KL regularization - loss = -(loss - self.beta * kl_div) # (B) + loss = -(loss - self.beta * kl_div) # (B) loss = loss.mean() - + return loss - + def get_batch(self): """Get a batch from the queue or return None if empty""" if not self.batch_queue: return None return self.batch_queue.pop(0) - + def generate_mode(self, num_batches=5): """ Generate samples and add them to the batch queue - + Args: dataset: Dataset to sample from num_batches: Number of batches to generate """ if self.rank == 0: print("Entering generate mode...") - + tic = time.time() for _ in range(num_batches): try: @@ -559,14 +559,14 @@ def generate_mode(self, num_batches=5): continue # Add batch to queue self.batch_queue.append(batch_outputs) - + if self.rank == 0: print(f"Exiting generate mode: {time.time() - tic:.3f}s") - + def train(self, train_dataset, valid_dataset=None, num_workers=64, resumable_with_seed=666): """ Train the model using GRPO - + Args: train_dataset: Training dataset valid_dataset: Validation dataset (optional) @@ -597,7 +597,7 @@ def train(self, train_dataset, valid_dataset=None, num_workers=64, resumable_wit self.train_iterator = iter(self.train_dataloader) self.valid_iterator = iter(self.valid_dataloader) - + elif self.batch_size_type == "frame": self.accelerator.even_batches = False @@ -623,11 +623,11 @@ def train(self, train_dataset, valid_dataset=None, num_workers=64, resumable_wit valid_dataset, collate_fn=collate_fn, num_workers=num_workers, - pin_memory=True, + pin_memory=True, persistent_workers=True, batch_sampler=batch_sampler, ) - + self.train_dataloader, self.valid_dataloader = self.accelerator.prepare(self.train_dataloader, self.valid_dataloader) self.train_iterator = iter(self.train_dataloader) @@ -635,53 +635,52 @@ def train(self, train_dataset, valid_dataset=None, num_workers=64, resumable_wit else: raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}") - # Setup schedulers warmup_steps = self.num_warmup_updates * self.accelerator.num_processes total_steps = self.all_steps decay_steps = total_steps - warmup_steps - + warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps) decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps) - + self.scheduler = SequentialLR( - self.optimizer, - schedulers=[warmup_scheduler, decay_scheduler], + self.optimizer, + schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps] ) - + self.scheduler = self.accelerator.prepare(self.scheduler) - + # Load checkpoint if available start_step = self.load_checkpoint() self.global_step = start_step - + # Generate initial batches self.generate_mode() - + # Training loop progress = range(1, self.all_steps + 1) - + # Skip steps that are already done progress = [step for step in progress if step > start_step] if self.is_main: progress = tqdm(progress, desc="Training", unit="step") - + for step in progress: # Get batch from queue or generate more batch = self.get_batch() while batch is None: self.generate_mode() batch = self.get_batch() - + # GRPO update with self.accelerator.accumulate(self.model): loss = self.GRPO_step(batch) # for param in self.model.parameters(): - # custom_loss = loss + 0 * param.sum() + # custom_loss = loss + 0 * param.sum() self.accelerator.backward(loss) - + if self.max_grad_norm > 0 and self.accelerator.sync_gradients: total_norm = self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) else: @@ -693,7 +692,7 @@ def train(self, train_dataset, valid_dataset=None, num_workers=64, resumable_wit ]), 2 ) - + self.accelerator.log({ "grad_norm": total_norm.item() }, step=self.global_step) @@ -701,9 +700,9 @@ def train(self, train_dataset, valid_dataset=None, num_workers=64, resumable_wit self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() - + self.global_step += 1 - + # Log metrics if self.is_main: self.accelerator.log({ @@ -717,13 +716,13 @@ def train(self, train_dataset, valid_dataset=None, num_workers=64, resumable_wit loss=f"{loss.item():.4f}", lr=f"{self.scheduler.get_last_lr()[0]:.8f}" ) - + # Save checkpoint if self.global_step % self.save_per_updates == 0: self.save_checkpoint(self.global_step) - + # Optional validation logic could be added here - + # Save final checkpoint self.save_checkpoint(self.global_step, last=True) self.accelerator.end_training() diff --git a/src/guidance_model.py b/src/dmospeech2/guidance_model.py similarity index 85% rename from src/guidance_model.py rename to src/dmospeech2/guidance_model.py index cf064fe..3c7c9a7 100644 --- a/src/guidance_model.py +++ b/src/dmospeech2/guidance_model.py @@ -8,28 +8,22 @@ """ from __future__ import annotations -from typing import Callable + from random import random -import numpy as np +from typing import Callable +import numpy as np import torch -from torch import nn import torch.nn.functional as F +from dmospeech2.ctcmodel import ConformerCTC +from dmospeech2.discriminator_conformer import ConformerDiscirminator +from dmospeech2.ecapa_tdnn import ECAPA_TDNN +from torch import nn from f5_tts.model import DiT +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, -) - -from discriminator_conformer import ConformerDiscirminator -from ctcmodel import ConformerCTC -from ecapa_tdnn import ECAPA_TDNN class NoOpContext: def __enter__(self): @@ -38,62 +32,64 @@ def __enter__(self): def __exit__(self, *args): pass -def predict_flow(transformer, # flow model - x, # noisy input - cond, # mask (prompt mask + length mask) - text, # text input - time, # time step - second_time=None, - cfg_strength=1.0 -): + +def predict_flow(transformer, # flow model + x, # noisy input + cond, # mask (prompt mask + length mask) + text, # text input + time, # time step + second_time=None, + cfg_strength=1.0 + ): pred = transformer( - x=x, - cond=cond, - text=text, time=time, + x=x, + cond=cond, + text=text, time=time, second_time=second_time, - drop_audio_cond=False, + drop_audio_cond=False, drop_text=False ) - + if cfg_strength < 1e-5: return pred - + null_pred = transformer( - x=x, - cond=cond, - text=text, time=time, - second_time=second_time, - drop_audio_cond=True, - drop_text=True + x=x, + cond=cond, + text=text, time=time, + second_time=second_time, + drop_audio_cond=True, + drop_text=True ) return pred + (pred - null_pred) * cfg_strength + def _kl_dist_func(x, y): log_probs = F.log_softmax(x, dim=2) - target_probs = F.log_softmax(y, dim=2) + target_probs = F.log_softmax(y, dim=2) return torch.nn.functional.kl_div(log_probs, target_probs, reduction="batchmean", log_target=True) class Guidance(nn.Module): - def __init__(self, - real_unet: DiT, # teacher flow model - fake_unet: DiT, # student flow model - - use_fp16: bool = True, - real_guidance_scale: float = 0.0, - fake_guidance_scale: float = 0.0, - gen_cls_loss: bool = False, - - sv_path_en: str = "", - sv_path_zh: str = "", - ctc_path: str = "", - sway_coeff: float = 0.0, - scale: float = 1.0, - ): + def __init__(self, + real_unet: DiT, # teacher flow model + fake_unet: DiT, # student flow model + + use_fp16: bool = True, + real_guidance_scale: float = 0.0, + fake_guidance_scale: float = 0.0, + gen_cls_loss: bool = False, + + sv_path_en: str = "", + sv_path_zh: str = "", + ctc_path: str = "", + sway_coeff: float = 0.0, + scale: float = 1.0, + ): super().__init__() self.vocab_size = real_unet.vocab_size - + if ctc_path != "": model = ConformerCTC(vocab_size=real_unet.vocab_size, mel_dim=real_unet.mel_dim, num_heads=8, d_hid=512, nlayers=6) self.ctc_model = model.eval() @@ -113,43 +109,43 @@ def __init__(self, self.sv_model_zh.load_state_dict(torch.load(sv_path_zh, weights_only=True, map_location='cpu')['model_state_dict']) self.scale = scale - + self.real_unet = real_unet - self.real_unet.requires_grad_(False) # no update on the teacher model + self.real_unet.requires_grad_(False) # no update on the teacher model self.fake_unet = fake_unet - self.fake_unet.requires_grad_(True) # update the student model - - self.real_guidance_scale = real_guidance_scale + self.fake_unet.requires_grad_(True) # update the student model + + self.real_guidance_scale = real_guidance_scale self.fake_guidance_scale = fake_guidance_scale - + assert self.fake_guidance_scale == 0, "no guidance for fake" self.use_fp16 = use_fp16 - self.gen_cls_loss = gen_cls_loss - + self.gen_cls_loss = gen_cls_loss + self.sway_coeff = sway_coeff - + if self.gen_cls_loss: self.cls_pred_branch = ConformerDiscirminator( - input_dim=(self.fake_unet.depth + 1) * self.fake_unet.dim + 3 * 512, # 3 is the number of layers from the CTC model + input_dim=(self.fake_unet.depth + 1) * self.fake_unet.dim + 3 * 512, # 3 is the number of layers from the CTC model num_layers=3, channels=self.fake_unet.dim // 2, ) self.cls_pred_branch.requires_grad_(True) - + self.network_context_manager = torch.autocast(device_type="cuda", dtype=torch.float16) if self.use_fp16 else NoOpContext() + from torch.utils.data import DataLoader, Dataset, SequentialSampler + from f5_tts.model.dataset import (DynamicBatchSampler, collate_fn, + load_dataset) from f5_tts.model.utils import get_tokenizer - from torch.utils.data import DataLoader, Dataset, SequentialSampler - from f5_tts.model.dataset import load_dataset - from f5_tts.model.dataset import DynamicBatchSampler, collate_fn bsz = 16 - + tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" @@ -161,10 +157,8 @@ def __init__(self, self.vocab_char_map = vocab_char_map - - def compute_distribution_matching_loss( - self, + self, inp: float["b n d"] | float["b nw"], # mel or raw wave, ground truth latent text: int["b nt"] | list[str], # text input @@ -183,12 +177,12 @@ def compute_distribution_matching_loss( The code is adapted from F5-TTS but conceptualized per DMD: L_DMD encourages p_theta to match p_data via the difference between teacher and student predictions. """ - + original_inp = inp - + with torch.no_grad(): batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device - + # mel is x1 x1 = inp @@ -197,67 +191,67 @@ def compute_distribution_matching_loss( # time step time = torch.rand((batch,), dtype=dtype, device=device) - + # get flow t = time.unsqueeze(-1).unsqueeze(-1) # t = t + self.sway_coeff * (torch.cos(torch.pi / 2 * t) - 1 + t) sigma_t, alpha_t = (1 - t), t - phi = (1 - t) * x0 + t * x1 # noisy x - flow = x1 - x0 # flow target - + phi = (1 - t) * x0 + t * x1 # noisy x + flow = x1 - x0 # flow target + # only predict what is within the random mask span for infilling cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) - - # run at full precision as autocast and no_grad doesn't work well together + + # run at full precision as autocast and no_grad doesn't work well together with self.network_context_manager: pred_fake = predict_flow( - self.fake_unet, - phi, - cond, # mask (prompt mask + length mask) - text, # text input - time, # time step + self.fake_unet, + phi, + cond, # mask (prompt mask + length mask) + text, # text input + time, # time step second_time=second_time, cfg_strength=self.fake_guidance_scale ) - # pred = (x1 - x0), thus phi + (1-t) * pred = (1 - t) * x0 + t * x1 + (1 - t) * (x1 - x0) = (1 - t) * x1 + t * x1 = x1 + # pred = (x1 - x0), thus phi + (1-t) * pred = (1 - t) * x0 + t * x1 + (1 - t) * (x1 - x0) = (1 - t) * x1 + t * x1 = x1 pred_fake_image = phi + (1 - t) * pred_fake pred_fake_image[~rand_span_mask] = inp[~rand_span_mask] - + with self.network_context_manager: pred_real = predict_flow( self.real_unet, phi, cond, text, time, cfg_strength=self.real_guidance_scale ) - + pred_real_image = phi + (1 - t) * pred_real pred_real_image[~rand_span_mask] = inp[~rand_span_mask] p_real = (inp - pred_real_image) p_fake = (inp - pred_fake_image) - - grad = (p_real - p_fake) / torch.abs(p_real).mean(dim=[1, 2], keepdim=True) + + grad = (p_real - p_fake) / torch.abs(p_real).mean(dim=[1, 2], keepdim=True) grad = torch.nan_to_num(grad) - + # grad = grad / sigma_t # pred_fake - pred_real # grad = grad * (1 + sigma_t / alpha_t) - + # grad = grad / (1 + sigma_t / alpha_t) # noise # grad = grad / sigma_t # score difference # grad = grad * alpha_t # grad = grad * (sigma_t ** 2 / alpha_t) - + # grad = grad * (alpha_t + sigma_t ** 2 / alpha_t) - + # The DMD loss: MSE to move student distribution closer to teacher distribution # Only optimize over the masked region - loss = 0.5 * F.mse_loss(original_inp.float(), (original_inp-grad).detach().float(), reduction="none") * rand_span_mask.unsqueeze( - -1 - ) + loss = 0.5 * F.mse_loss(original_inp.float(), (original_inp - grad).detach().float(), reduction="none") * rand_span_mask.unsqueeze( + -1 + ) loss = loss.sum() / (rand_span_mask.sum() * grad.size(-1)) - + loss_dict = { - "loss_dm": loss + "loss_dm": loss } dm_log_dict = { @@ -270,8 +264,7 @@ def compute_distribution_matching_loss( } return loss_dict, dm_log_dict - - + def compute_ctc_sv_loss( self, real_inp: torch.Tensor, # real data latent @@ -295,11 +288,11 @@ def compute_ctc_sv_loss( with torch.no_grad(): real_out, real_layers, ctc_loss_test = self.ctc_model(real_inp * self.scale, text, text_lens) real_logits = real_out.log_softmax(dim=2) - # emb_real = self.sv_model(real_inp * self.scale) # snippet from prompt region - + # emb_real = self.sv_model(real_inp * self.scale) # snippet from prompt region + fake_logits = out.log_softmax(dim=2) kl_loss = F.kl_div(fake_logits, real_logits, reduction="mean", log_target=True) - + # For SV: # Extract speaker embeddings from real (prompt) and fake: # emb_fake = self.sv_model(fake_inp * self.scale) @@ -325,7 +318,7 @@ def compute_ctc_sv_loss( random_start = np.random.randint(0, mel_length - mel_len) else: random_start = np.random.randint(prompt_start, prompt_end - mel_len) - + chunks_fake.append(fake_inp[bib, random_start:random_start + mel_len, :]) chunks_real.append(real_inp[bib, :mel_len, :]) @@ -352,21 +345,19 @@ def compute_ctc_sv_loss( "loss_sim": sv_loss }, layer, real_layers - - def compute_loss_fake( self, - inp: torch.Tensor, # student generator output + inp: torch.Tensor, # student generator output text: torch.Tensor | list[str], rand_span_mask: torch.Tensor, second_time: torch.Tensor | None = None, ): """ Compute flow loss for the fake flow model, which is trained to estimate the flow (score) of the student distribution. - + This is the same as L_diff in the paper. """ - + # Similar to distribution matching, but only train fake to predict flow directly batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device @@ -383,26 +374,26 @@ def compute_loss_fake( x1 = inp x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) - + phi = (1 - t) * x0 + t * x1 flow = x1 - x0 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) with self.network_context_manager: pred = self.fake_unet( - x=phi, + x=phi, cond=cond, - text=text, - time=time, + text=text, + time=time, second_time=second_time, - drop_audio_cond=False, - drop_text=False # make sure the cfg=1 + drop_audio_cond=False, + drop_text=False # make sure the cfg=1 ) # Compute MSE between predicted flow and actual flow, masked by rand_span_mask loss = F.mse_loss(pred, flow, reduction="none") loss = loss[rand_span_mask].mean() - + loss_dict = { "loss_fake_mean": loss } @@ -416,7 +407,7 @@ def compute_loss_fake( def compute_cls_logits( self, - inp: torch.Tensor, # student generator output + inp: torch.Tensor, # student generator output layer: torch.Tensor, text: torch.Tensor, rand_span_mask: torch.Tensor, @@ -425,9 +416,9 @@ def compute_cls_logits( ): ''' Compute adversarial loss logits for the generator. - + This is used to compute L_adv in the paper. - + ''' context_no_grad = torch.no_grad if guidance else NoOpContext @@ -438,7 +429,7 @@ def compute_cls_logits( # For classification, we need some representation: # We'll mimic the logic from compute_loss_fake - + batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device if isinstance(text, list): if exists(self.vocab_char_map): @@ -453,26 +444,26 @@ def compute_cls_logits( x1 = inp x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) - + phi = (1 - t) * x0 + t * x1 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) with self.network_context_manager: layers = self.fake_unet( - x=phi, + x=phi, cond=cond, - text=text, - time=time, + text=text, + time=time, second_time=second_time, - drop_audio_cond=False, - drop_text=False, # make sure the cfg=1 + drop_audio_cond=False, + drop_text=False, # make sure the cfg=1 classify_mode=True ) # layers = torch.stack(layers, dim=0) if guidance: layers = [layer.detach() for layer in layers] - layer = layer[-3:] # only use the last 3 layers + layer = layer[-3:] # only use the last 3 layers layer = [l.transpose(-1, -2) for l in layer] # layer = [F.interpolate(l, mode='nearest', scale_factor=4).transpose(-1, -2) for l in layer] if layer[0].size(1) < layers[0].size(1): @@ -484,10 +475,9 @@ def compute_cls_logits( return logits, layers - def compute_generator_cls_loss( self, - inp: torch.Tensor, # student generator output + inp: torch.Tensor, # student generator output layer: torch.Tensor, real_layers: torch.Tensor, text: torch.Tensor, @@ -499,7 +489,7 @@ def compute_generator_cls_loss( ''' Compute the adversarial loss for the generator. ''' - + # Compute classification loss for generator: if not self.gen_cls_loss: return {"gen_cls_loss": 0} @@ -509,7 +499,7 @@ def compute_generator_cls_loss( loss = ((1 - logits) ** 2).mean() return {"gen_cls_loss": loss, "loss_mse": 0} - + def compute_guidance_cls_loss( self, fake_inp: torch.Tensor, @@ -536,7 +526,7 @@ def compute_guidance_cls_loss( with torch.no_grad(): # get layers from CTC model _, layer = self.ctc_model(real_inp * self.scale) - + logits_real, _ = self.compute_cls_logits(real_inp.detach(), layer, text, rand_span_mask, second_time, guidance=True) loss_real = ((1 - logits_real)**2).mean() @@ -560,19 +550,19 @@ def generator_forward( text_normalized: torch.Tensor, text_normalized_lens: torch.Tensor, rand_span_mask: torch.Tensor, - real_data: dict | None = None, # ground truth data (primarily prompt) to compute SV loss + real_data: dict | None = None, # ground truth data (primarily prompt) to compute SV loss second_time: torch.Tensor | None = None, mse_loss: bool = False, ): ''' Forward pass for the generator. - + This function computes the loss for the generator, which includes: - Distribution matching loss (L_DMD) - Adversarial generator loss (L_adv(G; D)) - CTC/SV loss (L_ctc + L_sv) ''' - + # 1. Compute DM loss dm_loss_dict, dm_log_dict = self.compute_distribution_matching_loss(inp, text, rand_span_mask=rand_span_mask, second_time=second_time) @@ -587,12 +577,11 @@ def generator_forward( # 3. Compute optional classification loss if self.gen_cls_loss: cls_loss_dict = self.compute_generator_cls_loss(inp, layer, real_layers, text, - rand_span_mask=rand_span_mask, - second_time=second_time, - mse_inp = real_data["inp"] if real_data is not None else None, - mse_loss = mse_loss, - ) - + rand_span_mask=rand_span_mask, + second_time=second_time, + mse_inp=real_data["inp"] if real_data is not None else None, + mse_loss=mse_loss, + ) loss_dict = {**dm_loss_dict, **cls_loss_dict, **ctc_sv_loss_dict} log_dict = {**dm_log_dict} @@ -610,13 +599,13 @@ def guidance_forward( ): ''' Forward pass for the guidnce module (discriminator + fake flow function). - + This function computes the loss for the guidance module, which includes: - Flow matching loss (L_diff) - Adversarial discrminator loss (L_adv(D; G)) - + ''' - + # Compute fake loss (like epsilon prediction loss in Guidance) fake_loss_dict, fake_log_dict = self.compute_loss_fake(fake_inp, text, rand_span_mask=rand_span_mask, second_time=second_time) @@ -630,7 +619,7 @@ def guidance_forward( log_dict = {**fake_log_dict, **cls_log_dict} return loss_dict, log_dict - + def forward( self, generator_turn=False, @@ -665,14 +654,11 @@ def forward( return loss_dict, log_dict - - if __name__ == "__main__": from f5_tts.model.utils import get_tokenizer - bsz = 16 - + tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" @@ -681,25 +667,24 @@ def forward( else: tokenizer_path = dataset_name vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) - - + real_unet = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=100) fake_unet = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=100) - - guidance = Guidance(real_unet, + + guidance = Guidance(real_unet, fake_unet, real_guidance_scale=1.0, fake_guidance_scale=0.0, use_fp16=True, - gen_cls_loss=True, + gen_cls_loss=True, ).cuda() - + text = ["hello world"] * bsz lens = torch.randint(1, 1000, (bsz,)).cuda() inp = torch.randn(bsz, lens.max(), 80).cuda() - + batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device - + # handle text as string if isinstance(text, list): if exists(vocab_char_map): @@ -714,14 +699,14 @@ def forward( mask = lens_to_mask(lens, length=seq_len) # useless here, as collate_fn will pad to max length in batch frac_lengths_mask = (0.7, 1.0) - + # get a random span to mask out for training conditionally frac_lengths = torch.zeros((batch,), device=device).float().uniform_(*frac_lengths_mask) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) - + if exists(mask): rand_span_mask &= mask - + # Construct data dicts for generator and guidance phases # For flow, `real_data` can just be the ground truth if available; here we simulate it real_data_dict = { @@ -742,7 +727,6 @@ def forward( "real_data": real_data_dict } - # Generator forward pass loss_dict, log_dict = guidance(generator_turn=True, generator_data_dict=generator_data_dict) print("Generator turn losses:", loss_dict) diff --git a/src/infer.py b/src/dmospeech2/infer.py similarity index 91% rename from src/infer.py rename to src/dmospeech2/infer.py index 1726fd8..f104cfc 100644 --- a/src/infer.py +++ b/src/dmospeech2/infer.py @@ -1,33 +1,25 @@ -import os import torch -import torchaudio import torch.nn.functional as F +import torchaudio +from dmospeech2.duration_predictor import SpeechLengthPredictor from torch.nn.utils.rnn import pad_sequence from torchdiffeq import odeint -from safetensors.torch import load_file -import IPython.display as ipd +# Import custom modules +from dmospeech2.unimodel import UniModel +from f5_tts.infer.utils_infer import (convert_char_to_pinyin, load_vocoder, + preprocess_ref_audio_text, speed, + target_rms, transcribe) # Import F5-TTS modules -from f5_tts.model import CFM, UNetT, DiT +from f5_tts.model import CFM, DiT, UNetT from f5_tts.model.modules import MelSpec -from f5_tts.model.utils import ( - default, exists, list_str_to_idx, list_str_to_tensor, - lens_to_mask, mask_from_frac_lengths, get_tokenizer -) -from f5_tts.infer.utils_infer import ( - load_vocoder, preprocess_ref_audio_text, chunk_text, - convert_char_to_pinyin, transcribe, target_rms, - target_sample_rate, hop_length, speed -) - -# Import custom modules -from unimodel import UniModel -from duration_predictor import SpeechLengthPredictor +from f5_tts.model.utils import (exists, get_tokenizer, lens_to_mask, + list_str_to_idx, list_str_to_tensor) class DMOInference: """F5-TTS Inference wrapper class for easy text-to-speech generation.""" - + def __init__( self, student_checkpoint_path="", @@ -39,7 +31,7 @@ def __init__( ): """ Initialize F5-TTS inference model. - + Args: student_checkpoint_path: Path to student model checkpoint duration_predictor_path: Path to duration predictor checkpoint @@ -49,13 +41,11 @@ def __init__( dataset_name: Dataset name for tokenizer cuda_device_id: CUDA device ID to use """ - + self.device = device self.model_type = model_type self.tokenizer = tokenizer self.dataset_name = dataset_name - - # Model parameters self.target_sample_rate = 24000 self.n_mel_channels = 100 self.hop_length = 256 @@ -63,23 +53,20 @@ def __init__( self.fake_guidance_scale = 0 self.gen_cls_loss = False self.num_student_step = 4 - - # Initialize components self._setup_tokenizer() self._setup_models(student_checkpoint_path) self._setup_mel_spec() self._setup_vocoder() self._setup_duration_predictor(duration_predictor_path) - + def _setup_tokenizer(self): """Setup tokenizer and vocabulary.""" if self.tokenizer == "custom": tokenizer_path = self.tokenizer_path else: tokenizer_path = self.dataset_name - self.vocab_char_map, self.vocab_size = get_tokenizer(tokenizer_path, self.tokenizer) - + def _setup_models(self, student_checkpoint_path): """Initialize teacher and student models.""" # Model configuration @@ -91,11 +78,10 @@ def _setup_models(self, student_checkpoint_path): model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) else: raise ValueError(f"Unknown model type: {self.model_type}") - + # Initialize UniModel (student) self.model = UniModel( - model_cls(**model_cfg, text_num_embeds=self.vocab_size, mel_dim=self.n_mel_channels, - second_time=self.num_student_step > 1), + model_cls(**model_cfg, text_num_embeds=self.vocab_size, mel_dim=self.n_mel_channels, second_time=self.num_student_step > 1), checkpoint_path="", vocab_char_map=self.vocab_char_map, frac_lengths_mask=(0.5, 0.9), @@ -104,17 +90,17 @@ def _setup_models(self, student_checkpoint_path): gen_cls_loss=self.gen_cls_loss, sway_coeff=0, ) - + # Load student checkpoint checkpoint = torch.load(student_checkpoint_path, map_location='cpu') self.model.load_state_dict(checkpoint['model_state_dict'], strict=False) - + # Setup generator and teacher self.generator = self.model.feedforward_model.to(self.device) self.teacher = self.model.guidance_model.real_unet.to(self.device) - + self.scale = checkpoint['scale'] - + def _setup_mel_spec(self): """Initialize mel spectrogram module.""" mel_spec_kwargs = dict( @@ -123,12 +109,12 @@ def _setup_mel_spec(self): hop_length=self.hop_length, ) self.mel_spec = MelSpec(**mel_spec_kwargs) - + def _setup_vocoder(self): """Initialize vocoder.""" self.vocos = load_vocoder(is_local=False, local_path="") self.vocos = self.vocos.to(self.device) - + def _setup_duration_predictor(self, checkpoint_path): """Initialize duration predictor.""" self.wav2mel = MelSpec( @@ -139,7 +125,7 @@ def _setup_duration_predictor(self, checkpoint_path): n_fft=1024, mel_spec_type='vocos' ).to(self.device) - + self.SLP = SpeechLengthPredictor( vocab_size=2545, n_mel=100, @@ -149,14 +135,14 @@ def _setup_duration_predictor(self, checkpoint_path): n_head=8, output_dim=301 ).to(self.device) - + self.SLP.eval() self.SLP.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model_state_dict']) - + def predict_duration(self, pmt_wav_path, tar_text, pmt_text, dp_softmax_range=0.7, temperature=0): """ Predict duration for target text based on prompt audio. - + Args: pmt_wav_path: Path to prompt audio tar_text: Target text to generate @@ -173,41 +159,41 @@ def predict_duration(self, pmt_wav_path, tar_text, pmt_text, dp_softmax_range=0. if pmt_wav.size(0) > 1: pmt_wav = pmt_wav[0].unsqueeze(0) pmt_wav = pmt_wav.to(self.device) - + pmt_mel = self.wav2mel(pmt_wav).permute(0, 2, 1) tar_tokens = self._convert_to_pinyin(list(tar_text)) pmt_tokens = self._convert_to_pinyin(list(pmt_text)) - + # Calculate duration ref_text_len = len(pmt_tokens) gen_text_len = len(tar_tokens) ref_audio_len = pmt_mel.size(1) duration = int(ref_audio_len / ref_text_len * gen_text_len / speed) duration = duration // 10 - + min_duration = max(int(duration * dp_softmax_range), 0) max_duration = min(int(duration * (1 + dp_softmax_range)), 301) - + all_tokens = pmt_tokens + [' '] + tar_tokens - + text_ids = list_str_to_idx([all_tokens], self.vocab_char_map).to(self.device) text_ids = text_ids.masked_fill(text_ids == -1, self.vocab_size) - + with torch.no_grad(): predictions = self.SLP(text_ids=text_ids, mel=pmt_mel) predictions = predictions[:, -1, :] predictions[:, :min_duration] = float('-inf') predictions[:, max_duration:] = float('-inf') - + if temperature == 0: est_label = predictions.argmax(-1)[..., -1].item() * 10 else: probs = torch.softmax(predictions / temperature, dim=-1) sampled_idx = torch.multinomial(probs.squeeze(0), num_samples=1) # Remove the -1 index est_label = sampled_idx.item() * 10 - + return est_label - + def _convert_to_pinyin(self, char_list): """Convert character list to pinyin.""" result = [] @@ -216,7 +202,7 @@ def _convert_to_pinyin(self, char_list): while result[0] == ' ' and len(result) > 1: result = result[1:] return result - + def generate( self, gen_text, @@ -235,7 +221,7 @@ def generate( ): """ Generate speech from text using teacher-student distillation. - + Args: gen_text: Text to generate audio_path: Path to prompt audio @@ -250,39 +236,39 @@ def generate( cfg_strength: Classifier-free guidance strength sway_coefficient: Sway sampling coefficient verbose: Output sampling steps - + Returns: Generated audio waveform """ if prompt_text is None: prompt_text = transcribe(audio_path) - + # Predict duration if not provided if duration is None: duration = self.predict_duration(audio_path, gen_text, prompt_text, dp_softmax_range, temperature) - + # Preprocess audio and text ref_audio, ref_text = preprocess_ref_audio_text(audio_path, prompt_text) audio, sr = torchaudio.load(ref_audio) - + if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) - + # Normalize audio rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms - + if sr != self.target_sample_rate: resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) audio = resampler(audio) - + audio = audio.to(self.device) - + # Prepare text text_list = [ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) - + # Calculate durations ref_audio_len = audio.shape[-1] // self.hop_length if duration is None: @@ -291,7 +277,7 @@ def generate( duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) else: duration = ref_audio_len + duration - + if verbose: print('audio:', audio.shape) print('text:', final_text_list) @@ -303,7 +289,7 @@ def generate( cond, text, step_cond, cond_mask, max_duration, duration_tensor = self._prepare_inputs( audio, final_text_list, duration ) - + # Teacher-student sampling if teacher_steps > 0 and student_start_step > 0: if verbose: @@ -314,18 +300,18 @@ def generate( ) else: x1 = step_cond - + if verbose: print('Start student sampling...') # Student sampling x1 = self._student_sampling(x1, cond, text, student_start_step, verbose, sway_coefficient) - + # Decode to audio mel = x1.permute(0, 2, 1) * self.scale generated_wave = self.vocos.decode(mel[..., cond_mask.sum():]) - + return generated_wave.cpu().numpy().squeeze() - + def generate_teacher_only( self, gen_text, @@ -339,7 +325,7 @@ def generate_teacher_only( ): """ Generate speech using teacher model only (no student distillation). - + Args: gen_text: Text to generate audio_path: Path to prompt audio @@ -349,39 +335,39 @@ def generate_teacher_only( eta: Stochasticity control (0=DDIM, 1=DDPM) cfg_strength: Classifier-free guidance strength sway_coefficient: Sway sampling coefficient - + Returns: Generated audio waveform """ if prompt_text is None: prompt_text = transcribe(audio_path) - + # Predict duration if not provided if duration is None: duration = self.predict_duration(audio_path, gen_text, prompt_text) - + # Preprocess audio and text ref_audio, ref_text = preprocess_ref_audio_text(audio_path, prompt_text) audio, sr = torchaudio.load(ref_audio) - + if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) - + # Normalize audio rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms - + if sr != self.target_sample_rate: resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate) audio = resampler(audio) - + audio = audio.to(self.device) - + # Prepare text text_list = [ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) - + # Calculate durations ref_audio_len = audio.shape[-1] // self.hop_length if duration is None: @@ -390,44 +376,44 @@ def generate_teacher_only( duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) else: duration = ref_audio_len + duration - + # Run inference with torch.inference_mode(): cond, text, step_cond, cond_mask, max_duration = self._prepare_inputs( audio, final_text_list, duration ) - + # Teacher-only sampling x1 = self._teacher_sampling( step_cond, text, cond_mask, max_duration, duration, teacher_steps, 1.0, eta, cfg_strength, sway_coefficient # stopping_time=1.0 for full sampling ) - + # Decode to audio mel = x1.permute(0, 2, 1) * self.scale generated_wave = self.vocos.decode(mel[..., cond_mask.sum():]) - + return generated_wave - + def _prepare_inputs(self, audio, text_list, duration): """Prepare inputs for generation.""" lens = None max_duration_limit = 4096 - + cond = audio text = text_list - + if cond.ndim == 2: cond = self.mel_spec(cond) cond = cond.permute(0, 2, 1) assert cond.shape[-1] == 100 - + cond = cond / self.scale batch, cond_seq_len, device = *cond.shape[:2], cond.device - + if not exists(lens): lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) - + # Process text if isinstance(text, list): if exists(self.vocab_char_map): @@ -435,40 +421,40 @@ def _prepare_inputs(self, audio, text_list, duration): else: text = list_str_to_tensor(text).to(device) assert text.shape[0] == batch - + if exists(text): text_lens = (text != -1).sum(dim=-1) lens = torch.maximum(text_lens, lens) - + # Process duration cond_mask = lens_to_mask(lens) - + if isinstance(duration, int): duration = torch.full((batch,), duration, device=device, dtype=torch.long) - + duration = torch.maximum(lens + 1, duration) duration = duration.clamp(max=max_duration_limit) max_duration = duration.amax() - + # Pad conditioning cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False) cond_mask = cond_mask.unsqueeze(-1) step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) - + return cond, text, step_cond, cond_mask, max_duration, duration - + def _teacher_sampling(self, step_cond, text, cond_mask, max_duration, duration, - teacher_steps, teacher_stopping_time, eta, cfg_strength, verbose, sway_sampling_coef = -1): + teacher_steps, teacher_stopping_time, eta, cfg_strength, verbose, sway_sampling_coef=-1): """Perform teacher model sampling.""" device = step_cond.device - + # Pre-generate noise sequence for stochastic sampling noise_seq = None if eta > 0: - noise_seq = [torch.randn(1, max_duration, 100, device=device) - for _ in range(teacher_steps)] - + noise_seq = [torch.randn(1, max_duration, 100, device=device) + for _ in range(teacher_steps)] + def fn(t, x): with torch.inference_mode(): with torch.autocast(device_type="cuda", dtype=torch.float16): @@ -476,20 +462,20 @@ def fn(t, x): print(f'current t: {t}') step_frac = 1.0 - t.item() step_idx = min(int(step_frac * len(noise_seq)), len(noise_seq) - 1) if noise_seq else 0 - + # Predict flow pred = self.teacher( x=x, cond=step_cond, text=text, time=t, mask=None, drop_audio_cond=False, drop_text=False ) - + if cfg_strength > 1e-5: null_pred = self.teacher( x=x, cond=step_cond, text=text, time=t, mask=None, drop_audio_cond=True, drop_text=True ) pred = pred + (pred - null_pred) * cfg_strength - + # Add stochasticity if eta > 0 if eta > 0 and noise_seq is not None: alpha_t = 1.0 - t.item() @@ -501,23 +487,23 @@ def fn(t, x): return pred + noise_scale * noise_seq[step_idx] else: return pred - + # Initialize noise y0 = [] for dur in duration: y0.append(torch.randn(dur, 100, device=device, dtype=step_cond.dtype)) y0 = pad_sequence(y0, padding_value=0, batch_first=True) - + # Setup time steps t = torch.linspace(0, 1, teacher_steps + 1, device=device, dtype=step_cond.dtype) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) t = t[:(t > teacher_stopping_time).float().argmax() + 2] t = t[:-1] - + # Solve ODE trajectory = odeint(fn, y0, t, method="euler") - + if teacher_stopping_time < 1.0: # If early stopping, compute final step pred = fn(t[-1], trajectory[-1]) @@ -525,20 +511,20 @@ def fn(t, x): return test_out else: return trajectory[-1] - - def _student_sampling(self, x1, cond, text, student_start_step, verbose, sway_coeff = -1): + + def _student_sampling(self, x1, cond, text, student_start_step, verbose, sway_coeff=-1): """Perform student model sampling.""" steps = torch.Tensor([0, 0.25, 0.5, 0.75]) steps = steps + sway_coeff * (torch.cos(torch.pi / 2 * steps) - 1 + steps) steps = steps[student_start_step:] - + for step in steps: time = torch.Tensor([step]).to(x1.device) - + x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) phi = (1 - t) * x0 + t * x1 - + if verbose: print(f'current step: {step}') with torch.no_grad(): @@ -550,10 +536,10 @@ def _student_sampling(self, x1, cond, text, student_start_step, verbose, sway_co drop_audio_cond=False, drop_text=False ) - + # Predicted mel spectrogram output = phi + (1 - t) * pred - + x1 = output - + return x1 diff --git a/src/unimodel.py b/src/dmospeech2/unimodel.py similarity index 87% rename from src/unimodel.py rename to src/dmospeech2/unimodel.py index afdb142..3eaa93c 100644 --- a/src/unimodel.py +++ b/src/dmospeech2/unimodel.py @@ -1,43 +1,37 @@ from __future__ import annotations -from typing import Callable -from random import random import contextlib - -from torch import nn -import torch import copy import os +from pathlib import Path +from random import random +from typing import Callable + +import torch +from dmospeech2.guidance_model import Guidance +from torch import nn from f5_tts.model import DiT, UNetT -from pathlib import Path -from guidance_model import Guidance -from f5_tts.model.utils import ( - default, - exists, - list_str_to_idx, - list_str_to_tensor, - lens_to_mask, - mask_from_frac_lengths, - sample_consecutive_steps, - sample_from_list, -) +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths, + sample_consecutive_steps, sample_from_list) + class UniModel(nn.Module): - def __init__(self, - model: DiT, # teacher model (dit model) + def __init__(self, + model: DiT, # teacher model (dit model) checkpoint_path: str = "", second_time: bool = True, use_fp16: bool = True, - real_guidance_scale: float = 2.0, - fake_guidance_scale: float = 0.0, + real_guidance_scale: float = 2.0, + fake_guidance_scale: float = 0.0, gen_cls_loss: bool = False, sway_coeff: float = -1.0, vocab_char_map: dict[str, int] | None = None, frac_lengths_mask: tuple[float, float] = (0.7, 1.0)): - + super().__init__() - + if checkpoint_path != "": if "model_last.pt" in os.listdir(checkpoint_path): latest_checkpoint = "model_last.pt" @@ -74,12 +68,12 @@ def __init__(self, model.load_state_dict(filtered_state_dict, strict=False) else: self.scale = 1.0 - + real_unet = copy.deepcopy(model) real_unet.time_embed2 = None - + fake_unet = copy.deepcopy(model) - + # Instantiate Guidance, which internally uses real_unet and fake_unet initialized from the teacher self.guidance_model = Guidance( real_unet=real_unet, @@ -90,15 +84,15 @@ def __init__(self, gen_cls_loss=gen_cls_loss, sway_coeff=sway_coeff, ) - - self.feedforward_model = copy.deepcopy(model) # initialize the student model + + self.feedforward_model = copy.deepcopy(model) # initialize the student model self.feedforward_model.requires_grad_(True) self.feedforward_model.time_embed2 = None self.vocab_char_map = vocab_char_map self.frac_lengths_mask = frac_lengths_mask - - self.second_time = second_time # fake_unet.time_embed2 is not None + + self.second_time = second_time # fake_unet.time_embed2 is not None def forward(self, inp: float["b n d"], # mel @@ -107,7 +101,7 @@ def forward(self, lens: int["b"] | None = None, student_steps: list[int] = [0, 0.25, 0.5, 0.75], update_generator: bool = False, - ): + ): """ Forward pass that routes to either generator_forward or guidance_forward in the Guidance class, depending on the arguments. @@ -126,7 +120,7 @@ def forward(self, "rand_span_mask": Tensor (B, N) - boolean mask "real_data": dict with keys like: "inp", "text", "rand_span_mask" - + Returns: -------- loss_dict: dict[str, Tensor] @@ -134,7 +128,7 @@ def forward(self, log_dict: dict[str, Tensor or float] Dictionary of logging tensors or values. """ - + batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device # handle text as string @@ -159,10 +153,9 @@ def forward(self, frac_lengths = torch.zeros((batch,), device=device).float().uniform_(*self.frac_lengths_mask) rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) - + if exists(mask): rand_span_mask &= mask - # # use generated output from previous step as input with torch.no_grad(): @@ -171,41 +164,41 @@ def forward(self, t = p_time.unsqueeze(-1).unsqueeze(-1) phi = (1 - t) * x0 + t * x1 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) - + pred = self.feedforward_model( - x=phi, + x=phi, cond=cond, - text=text, - time=p_time, - drop_audio_cond=False, - drop_text=False # make sure the cfg=1 - ) # flow prediction - + text=text, + time=p_time, + drop_audio_cond=False, + drop_text=False # make sure the cfg=1 + ) # flow prediction + # predicted mel spectrogram - output = phi + (1 - t) * pred + output = phi + (1 - t) * pred output[~rand_span_mask] = inp[~rand_span_mask] - + # forward diffusion x1 = output x0 = torch.randn_like(x1) t = time.unsqueeze(-1).unsqueeze(-1) phi = (1 - t) * x0 + t * x1 cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) - + with torch.no_grad() if not update_generator else contextlib.nullcontext(): pred = self.feedforward_model( - x=phi, + x=phi, cond=cond, - text=text, - time=time, - drop_audio_cond=False, - drop_text=False # make sure no cfg is used + text=text, + time=time, + drop_audio_cond=False, + drop_text=False # make sure no cfg is used ) - + # predicted mel spectrogram output = phi + (1 - t) * pred output[~rand_span_mask] = inp[~rand_span_mask] - + if update_generator: generator_data_dict = { "inp": output, @@ -219,7 +212,7 @@ def forward(self, "rand_span_mask": rand_span_mask } } - + # avoid any side effects of gradient accumulation # self.guidance_model.requires_grad_(False) # self.feedforward_model.requires_grad_(True) @@ -229,13 +222,13 @@ def forward(self, generator_data_dict=generator_data_dict, guidance_data_dict=None ) - + generator_log_dict['ground_truth'] = x1 generator_log_dict['generator_input'] = phi generator_log_dict['generator_output'] = output generator_log_dict['generator_cond'] = cond generator_log_dict['time'] = time - + return generator_loss_dict, generator_log_dict else: guidance_data_dict = { @@ -249,7 +242,7 @@ def forward(self, "rand_span_mask": rand_span_mask } } - + # avoid any side effects of gradient accumulation # self.feedforward_model.requires_grad_(False) # self.guidance_model.requires_grad_(True) @@ -260,21 +253,22 @@ def forward(self, guidance_data_dict=guidance_data_dict ) # self.feedforward_model.requires_grad_(True) - + return guidance_loss_dict, guidance_log_dict - + # return guidance_loss_dict, guidance_log_dict, generator_loss_dict, generator_log_dict - + if __name__ == "__main__": - - from f5_tts.model.utils import get_tokenizer + from torch.utils.data import DataLoader, Dataset, SequentialSampler - from f5_tts.model.dataset import load_dataset - from f5_tts.model.dataset import DynamicBatchSampler, collate_fn + + from f5_tts.model.dataset import (DynamicBatchSampler, collate_fn, + load_dataset) + from f5_tts.model.utils import get_tokenizer bsz = 16 - + tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) dataset_name = "Emilia_ZH_EN" @@ -285,20 +279,19 @@ def forward(self, vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) dit = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4, text_num_embeds=vocab_size, mel_dim=100) - - model = UniModel(dit, + + model = UniModel(dit, checkpoint_path="/data4/F5TTS/ckpts/F5TTS_Base_norm_flow_8GPU_vocos_pinyin_Emilia_ZH_EN", gen_cls_loss=True, vocab_char_map=vocab_char_map, frac_lengths_mask=(0.7, 1.0) ).cuda() - + # batch = next(iter(train_dataloader)) # torch.save(batch, "batch.pt") batch = torch.load("batch.pt") - inp, text, lens = batch["mel"].permute(0, 2, 1).cuda(), batch["text"], batch["mel_lengths"].cuda() + inp, text, lens = batch["mel"].permute(0, 2, 1).cuda(), batch["text"], batch["mel_lengths"].cuda() - # text = ["hello world"] * bsz # lens = torch.randint(1, 1000, (bsz,)).cuda() # inp = torch.randn(bsz, lens.max(), 100).cuda() @@ -308,10 +301,10 @@ def forward(self, guidance_loss_dict, guidance_log_dict = model(inp, text, lens=lens, update_generator=False, student_steps=(torch.linspace(0.0, 1.0, num_student_step + 1)[:-1])) generator_loss_dict, generator_log_dict = model(inp, text, lens=lens, update_generator=True, student_steps=(torch.linspace(0.0, 1.0, num_student_step + 1)[:-1])) - + print(guidance_loss_dict) print(generator_loss_dict) - + guidance_loss = 0 guidance_loss += guidance_loss_dict["loss_fake_mean"] guidance_loss += guidance_loss_dict["guidance_cls_loss"] @@ -324,4 +317,4 @@ def forward(self, generator_loss += generator_loss_dict["loss_mse"] guidance_loss.backward() - generator_loss.backward() \ No newline at end of file + generator_loss.backward() diff --git a/src/f5_tts/__init__.py b/src/f5_tts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/f5_tts/api.py b/src/f5_tts/api.py index d73ee1b..c68b4c6 100644 --- a/src/f5_tts/api.py +++ b/src/f5_tts/api.py @@ -8,15 +8,10 @@ from hydra.utils import get_class from omegaconf import OmegaConf -from f5_tts.infer.utils_infer import ( - infer_process, - load_model, - load_vocoder, - preprocess_ref_audio_text, - remove_silence_for_generated_wav, - save_spectrogram, - transcribe, -) +from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder, + preprocess_ref_audio_text, + remove_silence_for_generated_wav, + save_spectrogram, transcribe) from f5_tts.model.utils import seed_everything diff --git a/data/Emilia_ZH_EN_pinyin/vocab.txt b/src/f5_tts/data/Emilia_ZH_EN_pinyin/vocab.txt similarity index 100% rename from data/Emilia_ZH_EN_pinyin/vocab.txt rename to src/f5_tts/data/Emilia_ZH_EN_pinyin/vocab.txt diff --git a/data/librispeech_pc_test_clean_cross_sentence.lst b/src/f5_tts/data/librispeech_pc_test_clean_cross_sentence.lst similarity index 100% rename from data/librispeech_pc_test_clean_cross_sentence.lst rename to src/f5_tts/data/librispeech_pc_test_clean_cross_sentence.lst diff --git a/src/f5_tts/eval/ecapa_tdnn.py b/src/f5_tts/eval/ecapa_tdnn.py index f0e4c9c..7cabf61 100644 --- a/src/f5_tts/eval/ecapa_tdnn.py +++ b/src/f5_tts/eval/ecapa_tdnn.py @@ -9,7 +9,6 @@ import torch.nn as nn import torch.nn.functional as F - """ Res2Conv1d + BatchNorm1d + ReLU """ diff --git a/src/f5_tts/eval/eval_infer_batch.py b/src/f5_tts/eval/eval_infer_batch.py index cea5b7a..b5cb3af 100644 --- a/src/f5_tts/eval/eval_infer_batch.py +++ b/src/f5_tts/eval/eval_infer_batch.py @@ -1,7 +1,6 @@ import os import sys - sys.path.append(os.getcwd()) import argparse @@ -15,16 +14,13 @@ from omegaconf import OmegaConf from tqdm import tqdm -from f5_tts.eval.utils_eval import ( - get_inference_prompt, - get_librispeech_test_clean_metainfo, - get_seedtts_testset_metainfo, -) +from f5_tts.eval.utils_eval import (get_inference_prompt, + get_librispeech_test_clean_metainfo, + get_seedtts_testset_metainfo) from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder from f5_tts.model import CFM from f5_tts.model.utils import get_tokenizer - accelerator = Accelerator() device = f"cuda:{accelerator.process_index}" diff --git a/src/f5_tts/eval/eval_librispeech_test_clean.py b/src/f5_tts/eval/eval_librispeech_test_clean.py index 0fef801..381b45f 100644 --- a/src/f5_tts/eval/eval_librispeech_test_clean.py +++ b/src/f5_tts/eval/eval_librispeech_test_clean.py @@ -5,7 +5,6 @@ import os import sys - sys.path.append(os.getcwd()) import multiprocessing as mp @@ -15,7 +14,6 @@ from f5_tts.eval.utils_eval import get_librispeech_test, run_asr_wer, run_sim - rel_path = str(files("f5_tts").joinpath("../../")) diff --git a/src/f5_tts/eval/eval_seedtts_testset.py b/src/f5_tts/eval/eval_seedtts_testset.py index 158a3dd..24128b6 100644 --- a/src/f5_tts/eval/eval_seedtts_testset.py +++ b/src/f5_tts/eval/eval_seedtts_testset.py @@ -5,7 +5,6 @@ import os import sys - sys.path.append(os.getcwd()) import multiprocessing as mp @@ -15,7 +14,6 @@ from f5_tts.eval.utils_eval import get_seed_tts_test, run_asr_wer, run_sim - rel_path = str(files("f5_tts").joinpath("../../")) diff --git a/src/f5_tts/infer/infer_cli.py b/src/f5_tts/infer/infer_cli.py index 7d51170..ff2b3c3 100644 --- a/src/f5_tts/infer/infer_cli.py +++ b/src/f5_tts/infer/infer_cli.py @@ -14,23 +14,12 @@ from omegaconf import OmegaConf from unidecode import unidecode -from f5_tts.infer.utils_infer import ( - cfg_strength, - cross_fade_duration, - device, - fix_duration, - infer_process, - load_model, - load_vocoder, - mel_spec_type, - nfe_step, - preprocess_ref_audio_text, - remove_silence_for_generated_wav, - speed, - sway_sampling_coef, - target_rms, -) - +from f5_tts.infer.utils_infer import (cfg_strength, cross_fade_duration, + device, fix_duration, infer_process, + load_model, load_vocoder, mel_spec_type, + nfe_step, preprocess_ref_audio_text, + remove_silence_for_generated_wav, speed, + sway_sampling_coef, target_rms) parser = argparse.ArgumentParser( prog="python3 infer-cli.py", diff --git a/src/f5_tts/infer/infer_gradio.py b/src/f5_tts/infer/infer_gradio.py index f4c3aef..f867e2d 100644 --- a/src/f5_tts/infer/infer_gradio.py +++ b/src/f5_tts/infer/infer_gradio.py @@ -19,7 +19,6 @@ from cached_path import cached_path from transformers import AutoModelForCausalLM, AutoTokenizer - try: import spaces @@ -35,18 +34,12 @@ def gpu_decorator(func): return func -from f5_tts.infer.utils_infer import ( - infer_process, - load_model, - load_vocoder, - preprocess_ref_audio_text, - remove_silence_for_generated_wav, - save_spectrogram, - tempfile_kwargs, -) +from f5_tts.infer.utils_infer import (infer_process, load_model, load_vocoder, + preprocess_ref_audio_text, + remove_silence_for_generated_wav, + save_spectrogram, tempfile_kwargs) from f5_tts.model import DiT, UNetT - DEFAULT_TTS_MODEL = "F5-TTS_v1" tts_model_choice = DEFAULT_TTS_MODEL diff --git a/src/f5_tts/infer/speech_edit.py b/src/f5_tts/infer/speech_edit.py index fdeda9f..374400c 100644 --- a/src/f5_tts/infer/speech_edit.py +++ b/src/f5_tts/infer/speech_edit.py @@ -1,6 +1,5 @@ import os - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility from importlib.resources import files @@ -12,11 +11,11 @@ from hydra.utils import get_class from omegaconf import OmegaConf -from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram +from f5_tts.infer.utils_infer import (load_checkpoint, load_vocoder, + save_spectrogram) from f5_tts.model import CFM from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer - device = ( "cuda" if torch.cuda.is_available() diff --git a/src/f5_tts/infer/utils_infer.py b/src/f5_tts/infer/utils_infer.py index a1f3111..711aa5c 100644 --- a/src/f5_tts/infer/utils_infer.py +++ b/src/f5_tts/infer/utils_infer.py @@ -4,7 +4,6 @@ import sys from concurrent.futures import ThreadPoolExecutor - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../third_party/BigVGAN/") @@ -15,7 +14,6 @@ import matplotlib - matplotlib.use("Agg") import matplotlib.pylab as plt @@ -31,7 +29,6 @@ from f5_tts.model import CFM from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer - _ref_audio_cache = {} _ref_text_cache = {} diff --git a/src/f5_tts/model/__init__.py b/src/f5_tts/model/__init__.py index 59cf691..4e11f92 100644 --- a/src/f5_tts/model/__init__.py +++ b/src/f5_tts/model/__init__.py @@ -1,10 +1,7 @@ -from f5_tts.model.cfm import CFM - -from f5_tts.model.backbones.unett import UNetT from f5_tts.model.backbones.dit import DiT from f5_tts.model.backbones.mmdit import MMDiT - +from f5_tts.model.backbones.unett import UNetT +from f5_tts.model.cfm import CFM from f5_tts.model.trainer import Trainer - __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"] diff --git a/src/f5_tts/model/backbones/dit.py b/src/f5_tts/model/backbones/dit.py index 70d4cdb..f96c99d 100644 --- a/src/f5_tts/model/backbones/dit.py +++ b/src/f5_tts/model/backbones/dit.py @@ -10,21 +10,14 @@ from __future__ import annotations import torch -from torch import nn import torch.nn.functional as F - +from torch import nn from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model.modules import ( - TimestepEmbedding, - ConvNeXtV2Block, - ConvPositionEmbedding, - DiTBlock, - AdaLayerNormZero_Final, - precompute_freqs_cis, - get_pos_embed_indices, -) - +from f5_tts.model.modules import (AdaLayerNormZero_Final, ConvNeXtV2Block, + ConvPositionEmbedding, DiTBlock, + TimestepEmbedding, get_pos_embed_indices, + precompute_freqs_cis) # Text embedding diff --git a/src/f5_tts/model/backbones/mmdit.py b/src/f5_tts/model/backbones/mmdit.py index d3f9a91..69314fb 100644 --- a/src/f5_tts/model/backbones/mmdit.py +++ b/src/f5_tts/model/backbones/mmdit.py @@ -10,29 +10,17 @@ from __future__ import annotations import torch -from torch import nn import torch.nn.functional as F - +from torch import nn from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model.modules import ( - TimestepEmbedding, - ConvPositionEmbedding, - MMDiTBlock, - DiTBlock, - AdaLayerNormZero_Final, - precompute_freqs_cis, - get_pos_embed_indices, -) - -from f5_tts.model.utils import ( - default, - exists, - lens_to_mask, - list_str_to_idx, - list_str_to_tensor, - mask_from_frac_lengths, -) +from f5_tts.model.modules import (AdaLayerNormZero_Final, + ConvPositionEmbedding, DiTBlock, MMDiTBlock, + TimestepEmbedding, get_pos_embed_indices, + precompute_freqs_cis) +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) + # text embedding diff --git a/src/f5_tts/model/backbones/unett.py b/src/f5_tts/model/backbones/unett.py index acf649a..47f0b6f 100644 --- a/src/f5_tts/model/backbones/unett.py +++ b/src/f5_tts/model/backbones/unett.py @@ -8,26 +8,19 @@ """ from __future__ import annotations + from typing import Literal import torch -from torch import nn import torch.nn.functional as F - +from torch import nn from x_transformers import RMSNorm from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model.modules import ( - TimestepEmbedding, - ConvNeXtV2Block, - ConvPositionEmbedding, - Attention, - AttnProcessor, - FeedForward, - precompute_freqs_cis, - get_pos_embed_indices, -) - +from f5_tts.model.modules import (Attention, AttnProcessor, ConvNeXtV2Block, + ConvPositionEmbedding, FeedForward, + TimestepEmbedding, get_pos_embed_indices, + precompute_freqs_cis) # Text embedding diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index 9de0f3e..ebe52fe 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -19,14 +19,8 @@ from torchdiffeq import odeint from f5_tts.model.modules import MelSpec -from f5_tts.model.utils import ( - default, - exists, - lens_to_mask, - list_str_to_idx, - list_str_to_tensor, - mask_from_frac_lengths, -) +from f5_tts.model.utils import (default, exists, lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) class CFM(nn.Module): diff --git a/src/f5_tts/model/dataset.py b/src/f5_tts/model/dataset.py index e17b854..593dc43 100644 --- a/src/f5_tts/model/dataset.py +++ b/src/f5_tts/model/dataset.py @@ -1,6 +1,6 @@ -import re import json import random +import re from importlib.resources import files import torch @@ -16,8 +16,6 @@ from f5_tts.model.utils import default - - def get_speaker_id(path): parts = path.split('/') speaker_id = parts[-3] diff --git a/src/f5_tts/model/modules.py b/src/f5_tts/model/modules.py index 62507a4..2aeb1ae 100644 --- a/src/f5_tts/model/modules.py +++ b/src/f5_tts/model/modules.py @@ -19,7 +19,6 @@ from torch import nn from x_transformers.x_transformers import apply_rotary_pos_emb - # raw wav to mel spec diff --git a/src/f5_tts/model/trainer.py b/src/f5_tts/model/trainer.py index 63c74a0..cf771c0 100644 --- a/src/f5_tts/model/trainer.py +++ b/src/f5_tts/model/trainer.py @@ -215,7 +215,8 @@ def load_checkpoint(self): def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): if self.log_samples: - from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef + from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder, + nfe_step, sway_sampling_coef) vocoder = load_vocoder( vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index 0aec09a..13af1a2 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -5,13 +5,11 @@ from collections import defaultdict from importlib.resources import files +import jieba import torch +from pypinyin import Style, lazy_pinyin from torch.nn.utils.rnn import pad_sequence -import jieba -from pypinyin import lazy_pinyin, Style - - # seed everything @@ -109,7 +107,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): - if use "byte", set to 256 (unicode byte range) """ if tokenizer in ["pinyin", "char"]: - tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt") + tokenizer_path = os.path.join(files("f5_tts").joinpath("data"), f"{dataset_name}_{tokenizer}/vocab.txt") with open(tokenizer_path, "r", encoding="utf-8") as f: vocab_char_map = {} for i, char in enumerate(f): @@ -131,9 +129,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"): return vocab_char_map, vocab_size - # convert char to pinyin - jieba.initialize() print("Word segmentation module jieba initialized.\n") @@ -184,7 +180,7 @@ def is_chinese(c): def repetition_found(text, length=2, tolerance=10): pattern_count = defaultdict(int) for i in range(len(text) - length + 1): - pattern = text[i : i + length] + pattern = text[i: i + length] pattern_count[pattern] += 1 for pattern, count in pattern_count.items(): if count > tolerance: @@ -224,7 +220,7 @@ def load_checkpoint(model, ckpt_path, device, use_ema=True): def sample_consecutive_steps(float_list): idx = torch.randint(0, len(float_list), size=(1,)) next_idx = idx - 1 - + if next_idx < 0: next_idx = 0 else: @@ -247,4 +243,3 @@ def sample_from_list(float_list, N): random_samples = float_tensor[random_indices] return random_samples - diff --git a/src/f5_tts/model_new/__init__.py b/src/f5_tts/model_new/__init__.py index d7d5fbe..7c6b8a5 100644 --- a/src/f5_tts/model_new/__init__.py +++ b/src/f5_tts/model_new/__init__.py @@ -4,5 +4,4 @@ from f5_tts.model_new.cfm import CFM from f5_tts.model_new.trainer import Trainer - __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"] diff --git a/src/f5_tts/model_new/backbones/dit.py b/src/f5_tts/model_new/backbones/dit.py index d20434c..c7d31af 100644 --- a/src/f5_tts/model_new/backbones/dit.py +++ b/src/f5_tts/model_new/backbones/dit.py @@ -14,16 +14,10 @@ from torch import nn from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model_new.modules import ( - AdaLayerNorm_Final, - ConvNeXtV2Block, - ConvPositionEmbedding, - DiTBlock, - TimestepEmbedding, - get_pos_embed_indices, - precompute_freqs_cis, -) - +from f5_tts.model_new.modules import (AdaLayerNorm_Final, ConvNeXtV2Block, + ConvPositionEmbedding, DiTBlock, + TimestepEmbedding, get_pos_embed_indices, + precompute_freqs_cis) # Text embedding diff --git a/src/f5_tts/model_new/backbones/mmdit.py b/src/f5_tts/model_new/backbones/mmdit.py index 0a785d2..769dc3a 100644 --- a/src/f5_tts/model_new/backbones/mmdit.py +++ b/src/f5_tts/model_new/backbones/mmdit.py @@ -13,15 +13,10 @@ from torch import nn from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model_new.modules import ( - AdaLayerNorm_Final, - ConvPositionEmbedding, - MMDiTBlock, - TimestepEmbedding, - get_pos_embed_indices, - precompute_freqs_cis, -) - +from f5_tts.model_new.modules import (AdaLayerNorm_Final, + ConvPositionEmbedding, MMDiTBlock, + TimestepEmbedding, get_pos_embed_indices, + precompute_freqs_cis) # text embedding diff --git a/src/f5_tts/model_new/backbones/unett.py b/src/f5_tts/model_new/backbones/unett.py index f9fda55..cf82b0e 100644 --- a/src/f5_tts/model_new/backbones/unett.py +++ b/src/f5_tts/model_new/backbones/unett.py @@ -17,17 +17,11 @@ from x_transformers import RMSNorm from x_transformers.x_transformers import RotaryEmbedding -from f5_tts.model_new.modules import ( - Attention, - AttnProcessor, - ConvNeXtV2Block, - ConvPositionEmbedding, - FeedForward, - TimestepEmbedding, - get_pos_embed_indices, - precompute_freqs_cis, -) - +from f5_tts.model_new.modules import (Attention, AttnProcessor, + ConvNeXtV2Block, ConvPositionEmbedding, + FeedForward, TimestepEmbedding, + get_pos_embed_indices, + precompute_freqs_cis) # Text embedding diff --git a/src/f5_tts/model_new/cfm.py b/src/f5_tts/model_new/cfm.py index a803faf..5e6a09c 100644 --- a/src/f5_tts/model_new/cfm.py +++ b/src/f5_tts/model_new/cfm.py @@ -19,15 +19,9 @@ from torchdiffeq import odeint from f5_tts.model_new.modules import MelSpec -from f5_tts.model_new.utils import ( - default, - exists, - get_epss_timesteps, - lens_to_mask, - list_str_to_idx, - list_str_to_tensor, - mask_from_frac_lengths, -) +from f5_tts.model_new.utils import (default, exists, get_epss_timesteps, + lens_to_mask, list_str_to_idx, + list_str_to_tensor, mask_from_frac_lengths) class CFM(nn.Module): diff --git a/src/f5_tts/model_new/modules.py b/src/f5_tts/model_new/modules.py index 655a3b6..4eb1989 100644 --- a/src/f5_tts/model_new/modules.py +++ b/src/f5_tts/model_new/modules.py @@ -22,7 +22,6 @@ from f5_tts.model_new.utils import is_package_available - # raw wav to mel spec @@ -435,8 +434,8 @@ def forward( # Attention processor if is_package_available("flash_attn"): + from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import pad_input, unpad_input - from flash_attn import flash_attn_varlen_func, flash_attn_func class AttnProcessor: diff --git a/src/f5_tts/model_new/trainer.py b/src/f5_tts/model_new/trainer.py index 45cbc53..5941899 100644 --- a/src/f5_tts/model_new/trainer.py +++ b/src/f5_tts/model_new/trainer.py @@ -19,7 +19,6 @@ from f5_tts.model.dataset import DynamicBatchSampler, collate_fn from f5_tts.model.utils import default, exists - # trainer @@ -261,7 +260,8 @@ def load_checkpoint(self): def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None): if self.log_samples: - from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef + from f5_tts.infer.utils_infer import (cfg_strength, load_vocoder, + nfe_step, sway_sampling_coef) vocoder = load_vocoder( vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path diff --git a/src/f5_tts/model_new/utils.py b/src/f5_tts/model_new/utils.py index c5c3829..8802921 100644 --- a/src/f5_tts/model_new/utils.py +++ b/src/f5_tts/model_new/utils.py @@ -10,7 +10,6 @@ from pypinyin import Style, lazy_pinyin from torch.nn.utils.rnn import pad_sequence - # seed everything diff --git a/src/f5_tts/runtime/triton_trtllm/benchmark.py b/src/f5_tts/runtime/triton_trtllm/benchmark.py index cb054ec..9f39445 100644 --- a/src/f5_tts/runtime/triton_trtllm/benchmark.py +++ b/src/f5_tts/runtime/triton_trtllm/benchmark.py @@ -51,7 +51,6 @@ from tqdm import tqdm from vocos import Vocos - torch.manual_seed(0) diff --git a/src/f5_tts/runtime/triton_trtllm/patch/__init__.py b/src/f5_tts/runtime/triton_trtllm/patch/__init__.py index ab19e9f..dc4aedf 100644 --- a/src/f5_tts/runtime/triton_trtllm/patch/__init__.py +++ b/src/f5_tts/runtime/triton_trtllm/patch/__init__.py @@ -13,14 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from .baichuan.model import BaichuanForCausalLM -from .bert.model import ( - BertForQuestionAnswering, - BertForSequenceClassification, - BertModel, - RobertaForQuestionAnswering, - RobertaForSequenceClassification, - RobertaModel, -) +from .bert.model import (BertForQuestionAnswering, + BertForSequenceClassification, BertModel, + RobertaForQuestionAnswering, + RobertaForSequenceClassification, RobertaModel) from .bloom.model import BloomForCausalLM, BloomModel from .chatglm.config import ChatGLMConfig from .chatglm.model import ChatGLMForCausalLM, ChatGLMModel @@ -51,17 +47,17 @@ from .medusa.config import MedusaConfig from .medusa.model import MedusaForCausalLm from .mllama.model import MLLaMAModel -from .modeling_utils import PretrainedConfig, PretrainedModel, SpeculativeDecodingMode +from .modeling_utils import (PretrainedConfig, PretrainedModel, + SpeculativeDecodingMode) from .mpt.model import MPTForCausalLM, MPTModel from .nemotron_nas.model import DeciLMForCausalLM from .opt.model import OPTForCausalLM, OPTModel -from .phi.model import PhiForCausalLM, PhiModel from .phi3.model import Phi3ForCausalLM, Phi3Model +from .phi.model import PhiForCausalLM, PhiModel from .qwen.model import QWenForCausalLM from .recurrentgemma.model import RecurrentGemmaForCausalLM from .redrafter.model import ReDrafterForCausalLM - __all__ = [ "BertModel", "BertForQuestionAnswering", diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py index 2f8007f..be4c7c1 100644 --- a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py @@ -13,8 +13,8 @@ from ...module import Module, ModuleList from ...plugin import current_all_reduce_helper from ..modeling_utils import PretrainedConfig, PretrainedModel -from .modules import AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, TimestepEmbedding - +from .modules import (AdaLayerNormZero_Final, ConvPositionEmbedding, DiTBlock, + TimestepEmbedding) current_file_path = os.path.abspath(__file__) parent_dir = os.path.dirname(current_file_path) diff --git a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py index 2121d28..558458e 100644 --- a/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py +++ b/src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py @@ -9,28 +9,10 @@ from tensorrt_llm._common import default_net from ..._utils import str_dtype_to_trt, trt_dtype_to_np -from ...functional import ( - Tensor, - bert_attention, - cast, - chunk, - concat, - constant, - expand, - expand_dims, - expand_dims_like, - expand_mask, - gelu, - matmul, - permute, - shape, - silu, - slice, - softmax, - squeeze, - unsqueeze, - view, -) +from ...functional import (Tensor, bert_attention, cast, chunk, concat, + constant, expand, expand_dims, expand_dims_like, + expand_mask, gelu, matmul, permute, shape, silu, + slice, softmax, squeeze, unsqueeze, view) from ...layers import ColumnLinear, Conv1d, LayerNorm, Linear, Mish, RowLinear from ...module import Module diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py b/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py index 993e472..563ba84 100644 --- a/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py +++ b/src/f5_tts/runtime/triton_trtllm/scripts/conv_stft.py @@ -40,7 +40,6 @@ import torch.nn.functional as F from scipy.signal import check_COLA, get_window - support_clp_op = None if th.__version__ >= "1.7.0": from torch.fft import rfft as fft diff --git a/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py index 6743aec..946916e 100644 --- a/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py +++ b/src/f5_tts/runtime/triton_trtllm/scripts/export_vocoder_to_onnx.py @@ -20,7 +20,6 @@ from huggingface_hub import hf_hub_download from vocos import Vocos - opset_version = 17 diff --git a/src/f5_tts/scripts/count_params_gflops.py b/src/f5_tts/scripts/count_params_gflops.py index d706388..5c0fee3 100644 --- a/src/f5_tts/scripts/count_params_gflops.py +++ b/src/f5_tts/scripts/count_params_gflops.py @@ -1,7 +1,6 @@ import os import sys - sys.path.append(os.getcwd()) import thop @@ -9,7 +8,6 @@ from f5_tts.model import CFM, DiT - """ ~155M """ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4) # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4) diff --git a/src/f5_tts/socket_client.py b/src/f5_tts/socket_client.py index c47ad44..55459b2 100644 --- a/src/f5_tts/socket_client.py +++ b/src/f5_tts/socket_client.py @@ -6,7 +6,6 @@ import numpy as np import pyaudio - logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/src/f5_tts/socket_server.py b/src/f5_tts/socket_server.py index 3fd780a..5698b48 100644 --- a/src/f5_tts/socket_server.py +++ b/src/f5_tts/socket_server.py @@ -16,14 +16,9 @@ from hydra.utils import get_class from omegaconf import OmegaConf -from f5_tts.infer.utils_infer import ( - chunk_text, - infer_batch_process, - load_model, - load_vocoder, - preprocess_ref_audio_text, -) - +from f5_tts.infer.utils_infer import (chunk_text, infer_batch_process, + load_model, load_vocoder, + preprocess_ref_audio_text) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/src/f5_tts/train/datasets/prepare_csv_wavs.py b/src/f5_tts/train/datasets/prepare_csv_wavs.py index 26ad6f8..2c738fe 100644 --- a/src/f5_tts/train/datasets/prepare_csv_wavs.py +++ b/src/f5_tts/train/datasets/prepare_csv_wavs.py @@ -7,7 +7,6 @@ import sys from contextlib import contextmanager - sys.path.append(os.getcwd()) import argparse @@ -22,7 +21,6 @@ from f5_tts.model.utils import convert_char_to_pinyin - PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt") diff --git a/src/f5_tts/train/datasets/prepare_emilia.py b/src/f5_tts/train/datasets/prepare_emilia.py index 4c4a771..bf7d1bb 100644 --- a/src/f5_tts/train/datasets/prepare_emilia.py +++ b/src/f5_tts/train/datasets/prepare_emilia.py @@ -7,7 +7,6 @@ import os import sys - sys.path.append(os.getcwd()) import json @@ -20,7 +19,6 @@ from f5_tts.model.utils import convert_char_to_pinyin, repetition_found - out_zh = { "ZH_B00041_S06226", "ZH_B00042_S09204", diff --git a/src/f5_tts/train/datasets/prepare_emilia_v2.py b/src/f5_tts/train/datasets/prepare_emilia_v2.py index 50322c0..d7d9e80 100644 --- a/src/f5_tts/train/datasets/prepare_emilia_v2.py +++ b/src/f5_tts/train/datasets/prepare_emilia_v2.py @@ -12,7 +12,6 @@ from f5_tts.model.utils import repetition_found - # Define filters for exclusion out_en = set() en_filters = ["ا", "い", "て"] diff --git a/src/f5_tts/train/datasets/prepare_libritts.py b/src/f5_tts/train/datasets/prepare_libritts.py index a892dd6..20ef774 100644 --- a/src/f5_tts/train/datasets/prepare_libritts.py +++ b/src/f5_tts/train/datasets/prepare_libritts.py @@ -1,7 +1,6 @@ import os import sys - sys.path.append(os.getcwd()) import json diff --git a/src/f5_tts/train/datasets/prepare_ljspeech.py b/src/f5_tts/train/datasets/prepare_ljspeech.py index 9f64b0a..1298bf8 100644 --- a/src/f5_tts/train/datasets/prepare_ljspeech.py +++ b/src/f5_tts/train/datasets/prepare_ljspeech.py @@ -1,7 +1,6 @@ import os import sys - sys.path.append(os.getcwd()) import json diff --git a/src/f5_tts/train/datasets/prepare_wenetspeech4tts.py b/src/f5_tts/train/datasets/prepare_wenetspeech4tts.py index 6498421..49ca1b3 100644 --- a/src/f5_tts/train/datasets/prepare_wenetspeech4tts.py +++ b/src/f5_tts/train/datasets/prepare_wenetspeech4tts.py @@ -4,7 +4,6 @@ import os import sys - sys.path.append(os.getcwd()) import json diff --git a/src/f5_tts/train/finetune_cli.py b/src/f5_tts/train/finetune_cli.py index cdf42a9..465021a 100644 --- a/src/f5_tts/train/finetune_cli.py +++ b/src/f5_tts/train/finetune_cli.py @@ -9,7 +9,6 @@ from f5_tts.model.dataset import load_dataset from f5_tts.model.utils import get_tokenizer - # -------------------------- Dataset Settings --------------------------- # target_sample_rate = 24000 n_mel_channels = 100 diff --git a/src/f5_tts/train/finetune_gradio.py b/src/f5_tts/train/finetune_gradio.py index eee2a3f..8acd18e 100644 --- a/src/f5_tts/train/finetune_gradio.py +++ b/src/f5_tts/train/finetune_gradio.py @@ -32,7 +32,6 @@ from f5_tts.infer.utils_infer import transcribe from f5_tts.model.utils import convert_char_to_pinyin - training_process = None system = platform.system() python_executable = sys.executable or "python" diff --git a/src/f5_tts/train/train.py b/src/f5_tts/train/train.py index b948ab1..a935e36 100644 --- a/src/f5_tts/train/train.py +++ b/src/f5_tts/train/train.py @@ -10,7 +10,6 @@ from f5_tts.model.dataset import load_dataset from f5_tts.model.utils import get_tokenizer - os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)