From 7cadcbc2932097de73e753ae5d6444b476bbb596 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Wed, 25 Mar 2026 22:35:47 +0800 Subject: [PATCH 1/2] update SDP and Partial RoPE --- configs/acoustic.yaml | 2 + configs/templates/config_acoustic.yaml | 2 + configs/templates/config_variance.yaml | 13 +- configs/variance.yaml | 9 + deployment/modules/fastspeech2.py | 5 +- inference/ds_variance.py | 2 +- modules/commons/rotary_embedding_torch.py | 8 +- modules/fastspeech/acoustic_encoder.py | 3 +- modules/fastspeech/tts_modules.py | 126 ++++++++-- modules/fastspeech/variance_encoder.py | 26 +- modules/optimizer/muon.py | 3 +- modules/sdp/sdp.py | 283 ++++++++++++++++++++++ modules/sdp/transforms.py | 214 ++++++++++++++++ modules/toplevel.py | 8 +- training/variance_task.py | 23 +- 15 files changed, 687 insertions(+), 40 deletions(-) create mode 100644 modules/sdp/sdp.py create mode 100644 modules/sdp/transforms.py diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 935d6e16..0ae23adc 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -65,6 +65,8 @@ max_beta: 0.02 enc_ffn_kernel_size: 3 use_rope: true rope_interleaved: false +rope_theta: 8000 +rotary_dim: 192 use_stretch_embed: true use_variance_scaling: true rel_pos: true diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 9d63028f..9af2582e 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -72,6 +72,8 @@ diffusion_type: reflow enc_ffn_kernel_size: 3 use_rope: true rope_interleaved: false +rope_theta: 8000 +rotary_dim: 192 use_stretch_embed: true use_variance_scaling: true use_shallow_diffusion: true diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 40f4c532..990591e7 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -66,12 +66,14 @@ tension_logit_max: 10.0 enc_ffn_kernel_size: 3 use_rope: true rope_interleaved: false +rope_theta: 8000 +rotary_dim: 192 use_stretch_embed: false use_variance_scaling: true hidden_size: 384 dur_prediction_args: - arch: resnet - hidden_size: 256 + arch: fs2 + hidden_size: 512 dropout: 0.1 num_layers: 5 kernel_size: 3 @@ -80,11 +82,18 @@ dur_prediction_args: lambda_pdur_loss: 0.3 lambda_wdur_loss: 1.0 lambda_sdur_loss: 3.0 +use_sdp: true +sdp_ratio: 0.2 +sdp_n_chans: 192 +lambda_sdp_loss: 0.005 +lambda_sdp_reg_loss: 0.1 +sdp_reg_warmup_steps: 16000 use_melody_encoder: true melody_encoder_args: hidden_size: 128 enc_layers: 4 + rotary_dim: 64 use_glide_embed: false glide_types: [up, down] glide_embed_scale: 11.313708498984760 # sqrt(128) diff --git a/configs/variance.yaml b/configs/variance.yaml index a819c1c4..e4a2dcc2 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -37,6 +37,8 @@ predict_tension: false enc_ffn_kernel_size: 3 use_rope: true rope_interleaved: false +rope_theta: 8000 +rotary_dim: 192 use_stretch_embed: false use_variance_scaling: true rel_pos: true @@ -53,11 +55,18 @@ dur_prediction_args: lambda_pdur_loss: 0.3 lambda_wdur_loss: 1.0 lambda_sdur_loss: 3.0 +use_sdp: true +sdp_ratio: 0.2 +sdp_n_chans: 192 +lambda_sdp_loss: 0.005 +lambda_sdp_reg_loss: 0.1 +sdp_reg_warmup_steps: 16000 use_melody_encoder: true melody_encoder_args: hidden_size: 128 enc_layers: 4 + rotary_dim: 64 use_glide_embed: false glide_types: [up, down] glide_embed_scale: 11.313708498984760 # sqrt(128) diff --git a/deployment/modules/fastspeech2.py b/deployment/modules/fastspeech2.py index b22590c2..58cb307e 100644 --- a/deployment/modules/fastspeech2.py +++ b/deployment/modules/fastspeech2.py @@ -193,9 +193,10 @@ def forward_encoder_phoneme(self, tokens, ph_dur, languages=None): def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None): midi_embed = self.midi_embed(ph_midi) dur_cond = encoder_out + midi_embed + sdp_cond = dur_cond if hparams['use_spk_id'] and spk_embed is not None: - dur_cond += spk_embed - ph_dur = self.dur_predictor(dur_cond, x_masks=x_masks) + dur_cond = dur_cond + spk_embed + ph_dur, _, _ = self.dur_predictor(dur_cond, x_masks=x_masks, sdp_cond=sdp_cond, spk_embed=spk_embed) return ph_dur def view_as_encoder(self): diff --git a/inference/ds_variance.py b/inference/ds_variance.py index da3d6e94..f3a977b3 100644 --- a/inference/ds_variance.py +++ b/inference/ds_variance.py @@ -327,7 +327,7 @@ def forward_model(self, sample): else: ph_spk_mix_embed = spk_mix_embed = None - dur_pred, pitch_pred, variance_pred = self.model( + dur_pred, pitch_pred, variance_pred, _, _ = self.model( txt_tokens, languages=sample.get('languages'), midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, mel2ph=mel2ph, note_midi=note_midi, note_rest=note_rest, note_dur=note_dur, note_glide=note_glide, mel2note=mel2note, diff --git a/modules/commons/rotary_embedding_torch.py b/modules/commons/rotary_embedding_torch.py index 1a1fa193..bf039e4e 100644 --- a/modules/commons/rotary_embedding_torch.py +++ b/modules/commons/rotary_embedding_torch.py @@ -33,12 +33,16 @@ def __init__( dim, theta=10000, max_seq_len=8192, - interleaved: bool = True + interleaved: bool = True, + rotary_dim: int = None ): super().__init__() self.interleaved = interleaved self.cached_freqs_seq_len = max_seq_len - inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) + self.rotary_dim = rotary_dim if rotary_dim is not None else dim + if self.rotary_dim > dim: + raise ValueError(f"rotary_dim ({self.rotary_dim}) cannot be larger than dim ({dim})") + inv_freq = 1. / (theta ** (torch.arange(0, self.rotary_dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq, persistent=False) self.register_buffer('cached_freqs', self._precompute_cache(max_seq_len), persistent=False) diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index f75ab2d5..13232bd7 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -39,7 +39,8 @@ def __init__(self, vocab_size): ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], dropout=hparams['dropout'], num_heads=hparams['num_heads'], use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False), - use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) + use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True), + rope_theta=hparams.get('rope_theta', 10000), rotary_dim=hparams.get('rotary_dim', None), ) self.pitch_embed = AdamWLinear(1, hparams['hidden_size']) diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index daf85127..fa1ee3ce 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -6,6 +6,7 @@ from modules.commons.rotary_embedding_torch import RotaryEmbedding from modules.commons.common_layers import SinusoidalPositionalEmbedding, EncSALayer, AdamWLinear from modules.commons.espnet_positional_embedding import RelPositionalEncoding +from modules.sdp.sdp import StochasticDurationPredictor DEFAULT_MAX_SOURCE_POSITIONS = 2000 DEFAULT_MAX_TARGET_POSITIONS = 2000 @@ -54,6 +55,9 @@ class DurationPredictor(torch.nn.Module): """Duration predictor module. This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder. + It combines a deterministic Duration Predictor (based on FastSpeech/ResNet) and + an optional Stochastic Duration Predictor (SDP, based on Normalizing Flows like VITS/Bert-VITS2) + to improve generation variance. .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: https://arxiv.org/pdf/1905.09263.pdf Note: @@ -62,7 +66,9 @@ class DurationPredictor(torch.nn.Module): """ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3, - dropout_rate=0.1, offset=1.0, dur_loss_type='mse', arch='resnet'): + dropout_rate=0.1, offset=1.0, dur_loss_type='mse', arch='fs2', + sdp_ratio=0.2, sdp_n_chans=192, gin_channels=0 + ): """Initialize duration predictor module. Args: in_dims (int): Input dimension. @@ -71,6 +77,12 @@ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3, kernel_size (int, optional): Kernel size of convolutional layers. dropout_rate (float, optional): Dropout rate. offset (float, optional): Offset value to avoid nan in log domain. + dur_loss_type (str, optional): Loss type ('mse', 'huber', etc.). + arch (str, optional): Architecture type ('resnet' or standard CNN). + use_sdp (bool, optional): Whether to use Stochastic Duration Predictor. + sdp_ratio (float, optional): Interpolation ratio between SDP and DP outputs during inference. + sdp_n_chans (int, optional): Hidden channels for SDP. + gin_channels (int, optional): Speaker embedding channels for SDP conditioning. """ super(DurationPredictor, self).__init__() self.offset = offset @@ -99,6 +111,7 @@ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3, self.res_conv = nn.Conv1d(in_dims, n_chans, 1) else: self.res_conv = None + self.loss_type = dur_loss_type if self.loss_type in ['mse', 'huber']: self.out_dims = 1 @@ -112,6 +125,21 @@ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3, raise NotImplementedError() self.linear = AdamWLinear(n_chans, self.out_dims) + self.use_sdp = use_sdp + self.sdp_ratio = sdp_ratio + + if self.use_sdp: + self.sdp = StochasticDurationPredictor( + in_channels=in_dims, + filter_channels=sdp_n_chans, + kernel_size=3, + p_dropout=0.5, + n_flows=4, + gin_channels=gin_channels + ) + else: + self.sdp = None + def out2dur(self, xs): if self.loss_type in ['mse', 'huber']: # NOTE: calculate loss in log domain @@ -122,33 +150,86 @@ def out2dur(self, xs): raise NotImplementedError() return dur - def forward(self, xs, x_masks=None, infer=True): + def forward(self, xs, x_masks=None, infer=True, ph_dur=None, sdp_cond=None, spk_embed=None): """Calculate forward propagation. Args: xs (Tensor): Batch of input sequences (B, Tmax, idim). x_masks (BoolTensor, optional): Batch of masks indicating padded part (B, Tmax). infer (bool): Whether inference + ph_dur (Tensor, optional): Ground truth phoneme duration [B, Tmax]. Needed for SDP training. + sdp_cond (Tensor, optional): Conditioning sequence for SDP [B, Tmax, idim]. + spk_embed (Tensor, optional): Speaker embedding [B, gin_channels]. + Returns: - (train) FloatTensor, (infer) LongTensor: Batch of predicted durations in linear domain (B, Tmax). + tuple: + - dur_pred (Tensor): Final predicted linear durations [B, Tmax]. + - loss_sdp (Tensor or int): SDP negative log-likelihood flow loss (0 if not using SDP or inferring). + - sdp_pred (Tensor or None): SDP reverse prediction for MSE regression computing. """ - xs = xs.transpose(1, -1) # (B, idim, Tmax) - masks = 1 - x_masks.float() - masks_ = masks[:, None, :] + non_pad_mask = 1.0 - x_masks.float() # [B, Tmax] + non_pad_mask_1d = non_pad_mask.unsqueeze(1) # [B, 1, Tmax] + non_pad_mask_2d = non_pad_mask.unsqueeze(2) # [B, Tmax, 1] + g = spk_embed.transpose(1, -1) if spk_embed is not None else None # [B, gin_chans, 1] + dp_xs = xs.transpose(1, -1) # [B, Tmax, idim] -> [B, idim, Tmax] for idx, f in enumerate(self.conv): if self.use_resnet: - residual = self.res_conv(xs) if idx == 0 and self.res_conv is not None else xs - xs = residual + f(xs) + residual = self.res_conv(dp_xs) if (idx == 0 and self.res_conv is not None) else dp_xs + dp_xs = residual + f(dp_xs) else: - xs = f(xs) + dp_xs = f(dp_xs) + if x_masks is not None: - xs = xs * masks_ - xs = self.linear(xs.transpose(1, -1)) # [B, T, C] - xs = xs * masks[:, :, None] # (B, T, C) + dp_xs = dp_xs * non_pad_mask_1d + + dp_xs = self.linear(dp_xs.transpose(1, -1)) # [B, idim, Tmax] -> [B, Tmax, C] + dp_xs = dp_xs * non_pad_mask_2d # Mask padded areas [B, Tmax, C] + + dur_pred = self.out2dur(dp_xs) # Convert to linear domain [B, Tmax] + + loss_sdp = 0.0 + sdp_pred = None + + if self.use_sdp and sdp_cond is not None: + sdp_cond_t = sdp_cond.transpose(1, -1) # [B, idim, Tmax] + + if not infer: + nll_loss = self.sdp( + x=sdp_cond_t, + x_mask=non_pad_mask_1d, + w=ph_dur, + g=g, + reverse=False, + noise_scale=1.0 + ) + + l_length_sdp = nll_loss / torch.sum(non_pad_mask_1d) + loss_sdp = torch.sum(l_length_sdp.float()) + + logw_sdp = self.sdp( + x=sdp_cond_t, + x_mask=non_pad_mask_1d, + g=g, + reverse=True, + noise_scale=1.0 + ) + sdp_pred = torch.ceil(self.out2dur(logw_sdp.transpose(1, -1) * non_pad_mask_2d)) + + else: + logw_sdp = self.sdp( + x=sdp_cond_t, + x_mask=non_pad_mask_1d, + g=g, + reverse=True, + noise_scale=0.8 + ) + sdp_pred = torch.ceil(self.out2dur(logw_sdp.transpose(1, -1) * non_pad_mask_2d)) + + dur_pred = (sdp_pred * self.sdp_ratio) + (dur_pred * (1.0 - self.sdp_ratio)) - dur_pred = self.out2dur(xs) if infer: - dur_pred = dur_pred.clamp(min=0.) # avoid negative value - return dur_pred + return dur_pred.clamp(min=0.0), None, None + else: + return dur_pred, loss_sdp, sdp_pred class VariancePredictor(torch.nn.Module): @@ -369,19 +450,26 @@ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None): class FastSpeech2Encoder(nn.Module): def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, ffn_act='gelu', - dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True): + dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, + use_rope=False, rope_interleaved=True, rope_theta=10000, rotary_dim=None): super().__init__() self.num_layers = num_layers embed_dim = self.hidden_size = hidden_size self.dropout = dropout self.use_pos_embed = use_pos_embed if use_pos_embed and use_rope: - if embed_dim % (num_heads * 2) != 0: + rotary_dim = rotary_dim if rotary_dim is not None else embed_dim + if rotary_dim % (num_heads * 2) != 0: raise ValueError( "RoPE requires the hidden size to be multiple of " - f"num_heads * 2 = {num_heads * 2}, but got {embed_dim}." + f"num_heads * 2 = {num_heads * 2}, but got {rotary_dim}." ) - rotary_embed = RotaryEmbedding(dim=embed_dim // num_heads, interleaved=rope_interleaved) + rotary_embed = RotaryEmbedding( + dim=embed_dim // num_heads, + theta=rope_theta, + interleaved=rope_interleaved, + rotary_dim=rotary_dim + ) else: rotary_embed = None self.layers = nn.ModuleList([ diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index 71296484..c60cac1c 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -34,7 +34,8 @@ def __init__(self, vocab_size): ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], dropout=hparams['dropout'], num_heads=hparams['num_heads'], use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False), - use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) + use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True), + rope_theta=hparams.get('rope_theta', 10000), rotary_dim=hparams.get('rotary_dim', None), ) dur_hparams = hparams['dur_prediction_args'] @@ -48,7 +49,11 @@ def __init__(self, vocab_size): kernel_size=dur_hparams['kernel_size'], offset=dur_hparams['log_offset'], dur_loss_type=dur_hparams['loss_type'], - arch=dur_hparams['arch'] + arch=dur_hparams['arch'], + use_sdp = hparams.get('use_sdp', False), + sdp_ratio = hparams.get('sdp_ratio', 0.2), + sdp_n_chans = hparams.get('sdp_n_chans', 192), + gin_channels = hparams['hidden_size'] if hparams['use_spk_id'] else 0, ) def forward( @@ -91,16 +96,22 @@ def forward( extra_embed += lang_embed encoder_out = self.encoder(txt_embed, extra_embed, txt_tokens == 0) + sdp_loss = None + sdp_pred = None if self.predict_dur: midi_embed = self.midi_embed(midi) # => [B, T_ph, H] dur_cond = encoder_out + midi_embed + sdp_cond = dur_cond if spk_embed is not None: - dur_cond += spk_embed - ph_dur_pred = self.dur_predictor(dur_cond, x_masks=txt_tokens == PAD_INDEX, infer=infer) + dur_cond = dur_cond + spk_embed + g = spk_embed + else: + g = None + ph_dur_pred, sdp_loss, sdp_pred = self.dur_predictor(dur_cond, x_masks=txt_tokens == PAD_INDEX, infer=infer, infer=infer, ph_dur=ph_dur, sdp_cond=sdp_cond, spk_embed=g) - return encoder_out, ph_dur_pred + return encoder_out, ph_dur_pred, sdp_loss, sdp_pred else: - return encoder_out, None + return encoder_out, None, sdp_loss, sdp_pred class MelodyEncoder(nn.Module): @@ -128,7 +139,8 @@ def get_hparam(key): ffn_kernel_size=get_hparam('enc_ffn_kernel_size'), ffn_act=get_hparam('ffn_act'), dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'), use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos'), - use_rope=get_hparam('use_rope'), rope_interleaved=hparams.get('rope_interleaved', True) + use_rope=get_hparam('use_rope'), rope_interleaved=hparams.get('rope_interleaved', True), + rope_theta=hparams.get('rope_theta', 10000), rotary_dim=get_hparam('rotary_dim'), ) self.out_proj = Linear(hidden_size, hparams['hidden_size']) diff --git a/modules/optimizer/muon.py b/modules/optimizer/muon.py index 678f22e6..ea4c450b 100644 --- a/modules/optimizer/muon.py +++ b/modules/optimizer/muon.py @@ -8,6 +8,7 @@ from .chained_optimizer import ChainedOptimizer, OptimizerSpec from modules.commons.common_layers import AdamWLinear, AdamWCov1d +from modules.fastspeech.tts_modules import DurationPredictor def get_bf16_support_map(): @@ -132,7 +133,7 @@ def get_params_for_muon(model) -> List[Parameter]: Returns: A list of parameters that should be optimized with muon. """ - excluded_module_classes = (nn.Embedding, AdamWLinear, AdamWCov1d) + excluded_module_classes = (nn.Embedding, AdamWLinear, AdamWCov1d, DurationPredictor) muon_params = [] # BFS through all submodules and exclude parameters from certain module types queue = collections.deque([model]) diff --git a/modules/sdp/sdp.py b/modules/sdp/sdp.py new file mode 100644 index 00000000..81376055 --- /dev/null +++ b/modules/sdp/sdp.py @@ -0,0 +1,283 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from modules.sdp.transforms import piecewise_rational_quadratic_transform + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 0) + 1.0) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = (torch.exp(x) - 1.0) * x_mask + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ConvFlow(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + n_layers, + num_bins=16, + tail_bound=5.0, + ): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) + self.proj = nn.Conv1d( + filter_channels, self.half_channels * (num_bins * 3 - 1), 1 + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( + self.filter_channels + ) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class DDSConv(nn.Module): + """ + Dilated and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class StochasticDurationPredictor(nn.Module): + def __init__( + self, + in_channels, + filter_channels=192, + kernel_size=3, + p_dropout=0.5, + n_flows=4, + gin_channels=0, + ): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = Log() + self.flows = nn.ModuleList() + self.flows.append(ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append( + ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) + self.flows.append(Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + self.post_flows = nn.ModuleList() + self.post_flows.append(ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append( + ConvFlow(2, filter_channels, kernel_size, n_layers=3) + ) + self.post_flows.append(Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = DDSConv( + filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout + ) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + if w is not None: + w = w.float()[:, None, :] + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + # h_w = self.post_pre(w) + h_w = self.post_pre(torch.log(w + 1.0)) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = ( + torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) + * x_mask + ) + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) + - logdet_tot_q + ) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = ( + torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) + - logdet_tot + ) + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = ( + torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) + * noise_scale + ) + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw diff --git a/modules/sdp/transforms.py b/modules/sdp/transforms.py new file mode 100644 index 00000000..d842485f --- /dev/null +++ b/modules/sdp/transforms.py @@ -0,0 +1,214 @@ +import torch +from torch.nn import functional as F + +import numpy as np + + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + if not torch.jit.is_tracing() and not torch.onnx.is_in_onnx_export(): + if (discriminant < 0).any(): + min_val = torch.min(discriminant).item() + raise RuntimeError(f"Flow crush: The discriminant yields a negative number (minimum value: {min_val}).") + + discriminant = torch.clamp_min(discriminant, 0.0) + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/modules/toplevel.py b/modules/toplevel.py index bc8029af..c749608f 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -249,7 +249,7 @@ def forward( else: ph_spk_embed = spk_embed = None - encoder_out, dur_pred_out = self.fs2( + encoder_out, dur_pred_out, spd_loss, sdp_pred = self.fs2( txt_tokens, midi=midi, ph2word=ph2word, ph_dur=ph_dur, word_dur=word_dur, spk_embed=ph_spk_embed, languages=languages, @@ -257,7 +257,7 @@ def forward( ) if not self.predict_pitch and not self.predict_variances: - return dur_pred_out, None, ({} if infer else None) + return dur_pred_out, None, ({} if infer else None), spd_loss, sdp_pred if mel2ph is None and word_dur is not None: # inference from file dur_pred_align = self.rr(dur_pred_out, ph2word, word_dur) @@ -339,7 +339,7 @@ def forward( pitch_pred_out = None if not self.predict_variances: - return dur_pred_out, pitch_pred_out, ({} if infer else None) + return dur_pred_out, pitch_pred_out, ({} if infer else None), spd_loss, sdp_pred if pitch is None: pitch = base_pitch + pitch_pred_out @@ -364,4 +364,4 @@ def forward( else: variances_pred_out = variance_outputs - return dur_pred_out, pitch_pred_out, variances_pred_out + return dur_pred_out, pitch_pred_out, variances_pred_out, spd_loss, sdp_pred diff --git a/training/variance_task.py b/training/variance_task.py index 646d9540..bd159b9e 100644 --- a/training/variance_task.py +++ b/training/variance_task.py @@ -93,6 +93,7 @@ def __init__(self): self.predict_dur = hparams['predict_dur'] if self.predict_dur: self.lambda_dur_loss = hparams['lambda_dur_loss'] + self.use_sdp = hparams.get('use_sdp', False) self.predict_pitch = hparams['predict_pitch'] if self.predict_pitch: @@ -134,6 +135,18 @@ def build_losses_and_metrics(self): self.register_validation_loss('dur_loss') self.register_validation_metric('rhythm_corr', RhythmCorrectness(tolerance=0.05)) self.register_validation_metric('ph_dur_acc', PhonemeDurationAccuracy(tolerance=0.2)) + if self.use_sdp: + self.dur_sdp_loss = DurationLoss( + offset=dur_hparams['log_offset'], + loss_type=dur_hparams['loss_type'], + lambda_pdur=dur_hparams['lambda_pdur_loss'], + lambda_wdur=dur_hparams['lambda_wdur_loss'], + lambda_sdur=dur_hparams['lambda_sdur_loss'] + ) + self.register_validation_loss('dur_sdp_loss') + self.lambda_sdp_loss = hparams.get('lambda_sdp_loss', 0.005) + self.sdp_flow_loss = torch.nn.Identity() + self.register_validation_loss('sdp_flow_loss') if self.predict_pitch: if self.diffusion_type == 'ddpm': self.pitch_loss = DiffusionLoss(loss_type=hparams['main_loss_type']) @@ -207,7 +220,7 @@ def run_model(self, sample, infer=False): spk_id=spk_ids, infer=infer ) - dur_pred, pitch_pred, variances_pred = output + dur_pred, pitch_pred, variances_pred, sdp_loss, sdp_pred = output if infer: if dur_pred is not None: dur_pred = dur_pred.round().long() @@ -216,6 +229,14 @@ def run_model(self, sample, infer=False): losses = {} if dur_pred is not None: losses['dur_loss'] = self.lambda_dur_loss * self.dur_loss(dur_pred, ph_dur, ph2word=ph2word) + if self.use_sdp: + losses['sdp_flow_loss'] = self.sdp_flow_loss(sdp_loss) * self.lambda_sdp_loss + lambda_sdp_reg_base = hparams.get('lambda_sdp_reg_loss', 0.1) + warmup_steps = hparams.get('sdp_reg_warmup_steps', 16000) + step = getattr(self, 'global_step', 1) + anneal_weight = min(1.0, step / warmup_steps) if warmup_steps > 0 else 1.0 + current_sdp_reg_weight = lambda_sdp_reg_base * anneal_weight + losses['dur_sdp_loss'] = current_sdp_reg_weight * self.dur_sdp_loss(sdp_pred, ph_dur, ph2word=ph2word) non_padding = (mel2ph > 0).unsqueeze(-1) if mel2ph is not None else None if pitch_pred is not None: if self.diffusion_type == 'ddpm': From 6502e3f63c45cf1e98b4c8893ff4db29cdecbbc6 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Wed, 25 Mar 2026 22:55:53 +0800 Subject: [PATCH 2/2] Partial RoPE --- configs/templates/config_acoustic.yaml | 2 +- configs/templates/config_variance.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 9af2582e..14b3085f 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -73,7 +73,7 @@ enc_ffn_kernel_size: 3 use_rope: true rope_interleaved: false rope_theta: 8000 -rotary_dim: 192 +rotary_dim: 64 use_stretch_embed: true use_variance_scaling: true use_shallow_diffusion: true diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 990591e7..e4c7c960 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -67,7 +67,7 @@ enc_ffn_kernel_size: 3 use_rope: true rope_interleaved: false rope_theta: 8000 -rotary_dim: 192 +rotary_dim: 64 use_stretch_embed: false use_variance_scaling: true hidden_size: 384