From 3de9aecb2bab049ba2e8b228e3b57c7bc23cf544 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Sun, 4 Jan 2026 16:51:12 +0800 Subject: [PATCH 1/2] fix "input Linear" and "output Linear" optim by Muon fix "input Linear" and "output Linear" optim by Muon --- modules/aux_decoder/convnext.py | 4 +++- modules/backbones/lynxnet.py | 4 ++-- modules/backbones/lynxnet2.py | 4 ++-- modules/backbones/wavenet.py | 4 ++-- modules/commons/common_layers.py | 21 +++++++++++++++++++++ modules/fastspeech/acoustic_encoder.py | 11 ++++++----- modules/fastspeech/tts_modules.py | 4 ++-- modules/fastspeech/variance_encoder.py | 9 +++++---- modules/optimizer/muon.py | 12 +++++++++++- modules/toplevel.py | 10 +++++----- 10 files changed, 59 insertions(+), 24 deletions(-) diff --git a/modules/aux_decoder/convnext.py b/modules/aux_decoder/convnext.py index a03959ddf..2b6ef1a80 100644 --- a/modules/aux_decoder/convnext.py +++ b/modules/aux_decoder/convnext.py @@ -3,6 +3,8 @@ import torch import torch.nn as nn +from modules.commons.common_layers import AdamWCov1d + class ConvNeXtBlock(nn.Module): """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. @@ -71,7 +73,7 @@ def __init__( layer_scale_init_value=1e-6, drop_out=dropout_rate ) for _ in range(num_layers) ) - self.outconv = nn.Conv1d( + self.outconv = AdamWCov1d( num_channels, out_dims, kernel_size, stride=1, padding=(kernel_size - 1) // 2 ) diff --git a/modules/backbones/lynxnet.py b/modules/backbones/lynxnet.py index 9f5c6a383..766dc960f 100644 --- a/modules/backbones/lynxnet.py +++ b/modules/backbones/lynxnet.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose +from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose, AdamWCov1d from modules.commons.common_layers import KaimingNormalConv1d as Conv1d from utils.hparams import hparams @@ -106,7 +106,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio ] ) self.norm = nn.LayerNorm(num_channels) - self.output_projection = Conv1d(num_channels, in_dims * n_feats, kernel_size=1) + self.output_projection = AdamWCov1d(num_channels, in_dims * n_feats, kernel_size=1) self.strong_cond = strong_cond nn.init.zeros_(self.output_projection.weight) diff --git a/modules/backbones/lynxnet2.py b/modules/backbones/lynxnet2.py index 76e8580b5..864ba3779 100644 --- a/modules/backbones/lynxnet2.py +++ b/modules/backbones/lynxnet2.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, ATanGLU, Transpose +from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, ATanGLU, Transpose, AdamWLinear from utils.hparams import hparams @@ -72,7 +72,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio ] ) self.norm = nn.LayerNorm(num_channels) - self.output_projection = nn.Linear(num_channels, in_dims * n_feats) + self.output_projection = AdamWLinear(num_channels, in_dims * n_feats) nn.init.kaiming_normal_(self.input_projection.weight) nn.init.kaiming_normal_(self.conditioner_projection.weight) nn.init.zeros_(self.output_projection.weight) diff --git a/modules/backbones/wavenet.py b/modules/backbones/wavenet.py index 2cbff961d..1baedbfa3 100644 --- a/modules/backbones/wavenet.py +++ b/modules/backbones/wavenet.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from modules.commons.common_layers import SinusoidalPosEmb +from modules.commons.common_layers import SinusoidalPosEmb, AdamWCov1d from modules.commons.common_layers import KaimingNormalConv1d as Conv1d from utils.hparams import hparams @@ -64,7 +64,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=20, num_channels=256, dilatio for i in range(num_layers) ]) self.skip_projection = Conv1d(num_channels, num_channels, 1) - self.output_projection = Conv1d(num_channels, in_dims * n_feats, 1) + self.output_projection = AdamWCov1d(num_channels, in_dims * n_feats, 1) nn.init.zeros_(self.output_projection.weight) def forward(self, spec, diffusion_step, cond): diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index 0012b99c3..cb5aa72ac 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -26,6 +26,21 @@ def __init__( nn.init.constant_(self.weight[padding_idx], 0) +class AdamWLinear(torch.nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + *args, + bias: bool = True, + **kwargs + ): + super().__init__(in_features, out_features, *args, bias=bias, **kwargs) + nn.init.xavier_uniform_(self.weight) + if bias: + nn.init.constant_(self.bias, 0.) + + class XavierUniformInitLinear(torch.nn.Linear): def __init__( self, @@ -160,6 +175,12 @@ def forward(self, x): return out * torch.atan(gate) +class AdamWCov1d(torch.nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + nn.init.kaiming_normal_(self.weight) + + class KaimingNormalConv1d(torch.nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index 868d383fd..f75ab2d5d 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -6,6 +6,7 @@ NormalInitEmbedding as Embedding, XavierUniformInitLinear as Linear, SinusoidalPosEmb, + AdamWLinear, ) from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur, StretchRegulator from utils.hparams import hparams @@ -32,7 +33,7 @@ def __init__(self, vocab_size): ) self.stretch_embed_rnn = nn.GRU(hparams['hidden_size'], hparams['hidden_size'], 1, batch_first=True) - self.dur_embed = Linear(1, hparams['hidden_size']) + self.dur_embed = AdamWLinear(1, hparams['hidden_size']) self.encoder = FastSpeech2Encoder( hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'], ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], @@ -41,7 +42,7 @@ def __init__(self, vocab_size): use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) ) - self.pitch_embed = Linear(1, hparams['hidden_size']) + self.pitch_embed = AdamWLinear(1, hparams['hidden_size']) self.variance_embed_list = [] self.use_energy_embed = hparams.get('use_energy_embed', False) self.use_breathiness_embed = hparams.get('use_breathiness_embed', False) @@ -59,7 +60,7 @@ def __init__(self, vocab_size): self.use_variance_embeds = len(self.variance_embed_list) > 0 if self.use_variance_embeds: self.variance_embeds = nn.ModuleDict({ - v_name: Linear(1, hparams['hidden_size']) + v_name: AdamWLinear(1, hparams['hidden_size']) for v_name in self.variance_embed_list }) @@ -85,11 +86,11 @@ def __init__(self, vocab_size): self.use_key_shift_embed = hparams.get('use_key_shift_embed', False) if self.use_key_shift_embed: - self.key_shift_embed = Linear(1, hparams['hidden_size']) + self.key_shift_embed = AdamWLinear(1, hparams['hidden_size']) self.use_speed_embed = hparams.get('use_speed_embed', False) if self.use_speed_embed: - self.speed_embed = Linear(1, hparams['hidden_size']) + self.speed_embed = AdamWLinear(1, hparams['hidden_size']) self.use_spk_id = hparams['use_spk_id'] if self.use_spk_id: diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index cc840aed3..57b360115 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch.nn import functional as F from modules.commons.rotary_embedding_torch import RotaryEmbedding -from modules.commons.common_layers import SinusoidalPositionalEmbedding, EncSALayer +from modules.commons.common_layers import SinusoidalPositionalEmbedding, EncSALayer, AdamWLinear from modules.commons.espnet_positional_embedding import RelPositionalEncoding DEFAULT_MAX_SOURCE_POSITIONS = 2000 @@ -110,7 +110,7 @@ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3, # self.crf = CRF(out_dims, batch_first=True) else: raise NotImplementedError() - self.linear = torch.nn.Linear(n_chans, self.out_dims) + self.linear = AdamWLinear(n_chans, self.out_dims) def out2dur(self, xs): if self.loss_type in ['mse', 'huber']: diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index ba6994c1e..712964846 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -5,6 +5,7 @@ from modules.commons.common_layers import ( NormalInitEmbedding as Embedding, XavierUniformInitLinear as Linear, + AdamWLinear, ) from modules.fastspeech.tts_modules import FastSpeech2Encoder, DurationPredictor from utils.hparams import hparams @@ -24,9 +25,9 @@ def __init__(self, vocab_size): if self.predict_dur: self.onset_embed = Embedding(2, hparams['hidden_size']) - self.word_dur_embed = Linear(1, hparams['hidden_size']) + self.word_dur_embed = AdamWLinear(1, hparams['hidden_size']) else: - self.ph_dur_embed = Linear(1, hparams['hidden_size']) + self.ph_dur_embed = AdamWLinear(1, hparams['hidden_size']) self.encoder = FastSpeech2Encoder( hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'], @@ -112,8 +113,8 @@ def get_hparam(key): # MIDI inputs hidden_size = get_hparam('hidden_size') self.use_variance_scaling = hparams.get('use_variance_scaling', False) - self.note_midi_embed = Linear(1, hidden_size) - self.note_dur_embed = Linear(1, hidden_size) + self.note_midi_embed = AdamWLinear(1, hidden_size) + self.note_dur_embed = AdamWLinear(1, hidden_size) # ornament inputs self.use_glide_embed = hparams['use_glide_embed'] diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index caf3a45e4..39c233b1c 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -1,3 +1,4 @@ +import collections import torch import torch.nn as nn import torch.nn.functional as F @@ -6,6 +7,8 @@ from typing import List from .chained_optimizer import ChainedOptimizer, OptimizerSpec +from modules.commons.common_layers import AdamWLinear, AdamWCov1d + def get_bf16_support_map(): bf16_support_map = {} @@ -129,13 +132,20 @@ def get_params_for_muon(model) -> List[Parameter]: Returns: A list of parameters that should be optimized with muon. """ + excluded_module_classes = (AdamWLinear, AdamWCov1d) muon_params = [] - for module in model.modules(): + # BFS through all submodules and exclude parameters from certain module types + queue = collections.deque([model]) + while queue: + module = queue.popleft() + if isinstance(module, excluded_module_classes): + continue for param in module.parameters(recurse=False): if not param.requires_grad: continue if not isinstance(module, nn.Embedding) and param.ndim >= 2: muon_params.append(param) + queue.extend(list(module.children())) return muon_params diff --git a/modules/toplevel.py b/modules/toplevel.py index 3c3129665..bc8029af3 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -11,7 +11,7 @@ from modules.commons.common_layers import ( XavierUniformInitLinear as Linear, NormalInitEmbedding as Embedding, - SinusoidalPosEmb + SinusoidalPosEmb, AdamWLinear, ) from modules.core import ( GaussianDiffusion, PitchDiffusion, MultiVarianceDiffusion, @@ -160,9 +160,9 @@ def __init__(self, vocab_size): self.use_melody_encoder = hparams.get('use_melody_encoder', False) if self.use_melody_encoder: self.melody_encoder = MelodyEncoder(enc_hparams=hparams['melody_encoder_args']) - self.delta_pitch_embed = Linear(1, hparams['hidden_size']) + self.delta_pitch_embed = AdamWLinear(1, hparams['hidden_size']) else: - self.base_pitch_embed = Linear(1, hparams['hidden_size']) + self.base_pitch_embed = AdamWLinear(1, hparams['hidden_size']) self.pitch_retake_embed = Embedding(2, hparams['hidden_size']) pitch_hparams = hparams['pitch_prediction_args'] @@ -195,9 +195,9 @@ def __init__(self, vocab_size): raise ValueError(f"Invalid diffusion type: {self.diffusion_type}") if self.predict_variances: - self.pitch_embed = Linear(1, hparams['hidden_size']) + self.pitch_embed = AdamWLinear(1, hparams['hidden_size']) self.variance_embeds = nn.ModuleDict({ - v_name: Linear(1, hparams['hidden_size']) + v_name: AdamWLinear(1, hparams['hidden_size']) for v_name in self.variance_prediction_list }) From a6f0a3efa55e24456b751340e0e1af0b5a4fb8ea Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Thu, 15 Jan 2026 13:58:03 +0800 Subject: [PATCH 2/2] Update excluded module classes in muon.py --- modules/optimizer/muon.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 39c233b1c..678f22e65 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -132,7 +132,7 @@ def get_params_for_muon(model) -> List[Parameter]: Returns: A list of parameters that should be optimized with muon. """ - excluded_module_classes = (AdamWLinear, AdamWCov1d) + excluded_module_classes = (nn.Embedding, AdamWLinear, AdamWCov1d) muon_params = [] # BFS through all submodules and exclude parameters from certain module types queue = collections.deque([model]) @@ -143,7 +143,7 @@ def get_params_for_muon(model) -> List[Parameter]: for param in module.parameters(recurse=False): if not param.requires_grad: continue - if not isinstance(module, nn.Embedding) and param.ndim >= 2: + if param.ndim >= 2: muon_params.append(param) queue.extend(list(module.children())) return muon_params @@ -160,4 +160,4 @@ def __init__(self, model, lr=0.0005, weight_decay=0.0, muon_args={}, adamw_args= callback = lambda p, spec_idx: print( f"Adding param {p.shape} to optimizer{spec_idx} {str(specs[spec_idx].class_type)}" ) - super().__init__(model.parameters(), specs, lr=lr, weight_decay=weight_decay, optimizer_selection_callback=callback) \ No newline at end of file + super().__init__(model.parameters(), specs, lr=lr, weight_decay=weight_decay, optimizer_selection_callback=callback)