Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion modules/aux_decoder/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import torch.nn as nn

from modules.commons.common_layers import AdamWCov1d


class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Expand Down Expand Up @@ -71,7 +73,7 @@ def __init__(
layer_scale_init_value=1e-6, drop_out=dropout_rate
) for _ in range(num_layers)
)
self.outconv = nn.Conv1d(
self.outconv = AdamWCov1d(
num_channels, out_dims, kernel_size,
stride=1, padding=(kernel_size - 1) // 2
)
Expand Down
4 changes: 2 additions & 2 deletions modules/backbones/lynxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F

from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose
from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose, AdamWCov1d
from modules.commons.common_layers import KaimingNormalConv1d as Conv1d
from utils.hparams import hparams

Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio
]
)
self.norm = nn.LayerNorm(num_channels)
self.output_projection = Conv1d(num_channels, in_dims * n_feats, kernel_size=1)
self.output_projection = AdamWCov1d(num_channels, in_dims * n_feats, kernel_size=1)
self.strong_cond = strong_cond
nn.init.zeros_(self.output_projection.weight)

Expand Down
4 changes: 2 additions & 2 deletions modules/backbones/lynxnet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F

from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, ATanGLU, Transpose
from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, ATanGLU, Transpose, AdamWLinear
from utils.hparams import hparams


Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio
]
)
self.norm = nn.LayerNorm(num_channels)
self.output_projection = nn.Linear(num_channels, in_dims * n_feats)
self.output_projection = AdamWLinear(num_channels, in_dims * n_feats)
nn.init.kaiming_normal_(self.input_projection.weight)
nn.init.kaiming_normal_(self.conditioner_projection.weight)
nn.init.zeros_(self.output_projection.weight)
Expand Down
4 changes: 2 additions & 2 deletions modules/backbones/wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn
import torch.nn.functional as F

from modules.commons.common_layers import SinusoidalPosEmb
from modules.commons.common_layers import SinusoidalPosEmb, AdamWCov1d
from modules.commons.common_layers import KaimingNormalConv1d as Conv1d
from utils.hparams import hparams

Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=20, num_channels=256, dilatio
for i in range(num_layers)
])
self.skip_projection = Conv1d(num_channels, num_channels, 1)
self.output_projection = Conv1d(num_channels, in_dims * n_feats, 1)
self.output_projection = AdamWCov1d(num_channels, in_dims * n_feats, 1)
nn.init.zeros_(self.output_projection.weight)

def forward(self, spec, diffusion_step, cond):
Expand Down
21 changes: 21 additions & 0 deletions modules/commons/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ def __init__(
nn.init.constant_(self.weight[padding_idx], 0)


class AdamWLinear(torch.nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
*args,
bias: bool = True,
**kwargs
):
super().__init__(in_features, out_features, *args, bias=bias, **kwargs)
nn.init.xavier_uniform_(self.weight)
if bias:
nn.init.constant_(self.bias, 0.)


class XavierUniformInitLinear(torch.nn.Linear):
def __init__(
self,
Expand Down Expand Up @@ -160,6 +175,12 @@ def forward(self, x):
return out * torch.atan(gate)


class AdamWCov1d(torch.nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
nn.init.kaiming_normal_(self.weight)


class KaimingNormalConv1d(torch.nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
11 changes: 6 additions & 5 deletions modules/fastspeech/acoustic_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
NormalInitEmbedding as Embedding,
XavierUniformInitLinear as Linear,
SinusoidalPosEmb,
AdamWLinear,
)
from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur, StretchRegulator
from utils.hparams import hparams
Expand All @@ -32,7 +33,7 @@ def __init__(self, vocab_size):
)
self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True)

self.dur_embed = Linear(1, hparams['hidden_size'])
self.dur_embed = AdamWLinear(1, hparams['hidden_size'])
self.encoder = FastSpeech2Encoder(
hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'],
ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'],
Expand All @@ -41,7 +42,7 @@ def __init__(self, vocab_size):
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True)
)

self.pitch_embed = Linear(1, hparams['hidden_size'])
self.pitch_embed = AdamWLinear(1, hparams['hidden_size'])
self.variance_embed_list = []
self.use_energy_embed = hparams.get('use_energy_embed', False)
self.use_breathiness_embed = hparams.get('use_breathiness_embed', False)
Expand All @@ -59,7 +60,7 @@ def __init__(self, vocab_size):
self.use_variance_embeds = len(self.variance_embed_list) > 0
if self.use_variance_embeds:
self.variance_embeds = nn.ModuleDict({
v_name: Linear(1, hparams['hidden_size'])
v_name: AdamWLinear(1, hparams['hidden_size'])
for v_name in self.variance_embed_list
})

Expand All @@ -85,11 +86,11 @@ def __init__(self, vocab_size):

self.use_key_shift_embed = hparams.get('use_key_shift_embed', False)
if self.use_key_shift_embed:
self.key_shift_embed = Linear(1, hparams['hidden_size'])
self.key_shift_embed = AdamWLinear(1, hparams['hidden_size'])

self.use_speed_embed = hparams.get('use_speed_embed', False)
if self.use_speed_embed:
self.speed_embed = Linear(1, hparams['hidden_size'])
self.speed_embed = AdamWLinear(1, hparams['hidden_size'])

self.use_spk_id = hparams['use_spk_id']
if self.use_spk_id:
Expand Down
4 changes: 2 additions & 2 deletions modules/fastspeech/tts_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
from torch.nn import functional as F
from modules.commons.rotary_embedding_torch import RotaryEmbedding
from modules.commons.common_layers import SinusoidalPositionalEmbedding, EncSALayer
from modules.commons.common_layers import SinusoidalPositionalEmbedding, EncSALayer, AdamWLinear
from modules.commons.espnet_positional_embedding import RelPositionalEncoding

DEFAULT_MAX_SOURCE_POSITIONS = 2000
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3,
# self.crf = CRF(out_dims, batch_first=True)
else:
raise NotImplementedError()
self.linear = torch.nn.Linear(n_chans, self.out_dims)
self.linear = AdamWLinear(n_chans, self.out_dims)

def out2dur(self, xs):
if self.loss_type in ['mse', 'huber']:
Expand Down
9 changes: 5 additions & 4 deletions modules/fastspeech/variance_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from modules.commons.common_layers import (
NormalInitEmbedding as Embedding,
XavierUniformInitLinear as Linear,
AdamWLinear,
)
from modules.fastspeech.tts_modules import FastSpeech2Encoder, DurationPredictor
from utils.hparams import hparams
Expand All @@ -24,9 +25,9 @@ def __init__(self, vocab_size):

if self.predict_dur:
self.onset_embed = Embedding(2, hparams['hidden_size'])
self.word_dur_embed = Linear(1, hparams['hidden_size'])
self.word_dur_embed = AdamWLinear(1, hparams['hidden_size'])
else:
self.ph_dur_embed = Linear(1, hparams['hidden_size'])
self.ph_dur_embed = AdamWLinear(1, hparams['hidden_size'])

self.encoder = FastSpeech2Encoder(
hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'],
Expand Down Expand Up @@ -112,8 +113,8 @@ def get_hparam(key):
# MIDI inputs
hidden_size = get_hparam('hidden_size')
self.use_variance_scaling = hparams.get('use_variance_scaling', False)
self.note_midi_embed = Linear(1, hidden_size)
self.note_dur_embed = Linear(1, hidden_size)
self.note_midi_embed = AdamWLinear(1, hidden_size)
self.note_dur_embed = AdamWLinear(1, hidden_size)

# ornament inputs
self.use_glide_embed = hparams['use_glide_embed']
Expand Down
16 changes: 13 additions & 3 deletions modules/optimizer/muon.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -6,6 +7,8 @@
from typing import List
from .chained_optimizer import ChainedOptimizer, OptimizerSpec

from modules.commons.common_layers import AdamWLinear, AdamWCov1d


def get_bf16_support_map():
bf16_support_map = {}
Expand Down Expand Up @@ -129,13 +132,20 @@ def get_params_for_muon(model) -> List[Parameter]:
Returns:
A list of parameters that should be optimized with muon.
"""
excluded_module_classes = (nn.Embedding, AdamWLinear, AdamWCov1d)
muon_params = []
for module in model.modules():
# BFS through all submodules and exclude parameters from certain module types
queue = collections.deque([model])
while queue:
module = queue.popleft()
if isinstance(module, excluded_module_classes):
continue
for param in module.parameters(recurse=False):
if not param.requires_grad:
continue
if not isinstance(module, nn.Embedding) and param.ndim >= 2:
if param.ndim >= 2:
muon_params.append(param)
queue.extend(list(module.children()))
return muon_params


Expand All @@ -150,4 +160,4 @@ def __init__(self, model, lr=0.0005, weight_decay=0.0, muon_args={}, adamw_args=
callback = lambda p, spec_idx: print(
f"Adding param {p.shape} to optimizer{spec_idx} {str(specs[spec_idx].class_type)}"
)
super().__init__(model.parameters(), specs, lr=lr, weight_decay=weight_decay, optimizer_selection_callback=callback)
super().__init__(model.parameters(), specs, lr=lr, weight_decay=weight_decay, optimizer_selection_callback=callback)
10 changes: 5 additions & 5 deletions modules/toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from modules.commons.common_layers import (
XavierUniformInitLinear as Linear,
NormalInitEmbedding as Embedding,
SinusoidalPosEmb
SinusoidalPosEmb, AdamWLinear,
)
from modules.core import (
GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion,
Expand Down Expand Up @@ -160,9 +160,9 @@ def __init__(self, vocab_size):
self.use_melody_encoder = hparams.get('use_melody_encoder', False)
if self.use_melody_encoder:
self.melody_encoder = MelodyEncoder(enc_hparams=hparams['melody_encoder_args'])
self.delta_pitch_embed = Linear(1, hparams['hidden_size'])
self.delta_pitch_embed = AdamWLinear(1, hparams['hidden_size'])
else:
self.base_pitch_embed = Linear(1, hparams['hidden_size'])
self.base_pitch_embed = AdamWLinear(1, hparams['hidden_size'])

self.pitch_retake_embed = Embedding(2, hparams['hidden_size'])
pitch_hparams = hparams['pitch_prediction_args']
Expand Down Expand Up @@ -195,9 +195,9 @@ def __init__(self, vocab_size):
raise ValueError(f"Invalid diffusion type: {self.diffusion_type}")

if self.predict_variances:
self.pitch_embed = Linear(1, hparams['hidden_size'])
self.pitch_embed = AdamWLinear(1, hparams['hidden_size'])
self.variance_embeds = nn.ModuleDict({
v_name: Linear(1, hparams['hidden_size'])
v_name: AdamWLinear(1, hparams['hidden_size'])
for v_name in self.variance_prediction_list
})

Expand Down