Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ max_beta: 0.02
enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
rope_theta: 8000
rotary_dim: 192
use_stretch_embed: true
use_variance_scaling: true
rel_pos: true
Expand Down
2 changes: 2 additions & 0 deletions configs/templates/config_acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ diffusion_type: reflow
enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
rope_theta: 8000
rotary_dim: 64
use_stretch_embed: true
use_variance_scaling: true
use_shallow_diffusion: true
Expand Down
13 changes: 11 additions & 2 deletions configs/templates/config_variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,14 @@ tension_logit_max: 10.0
enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
rope_theta: 8000
rotary_dim: 64
use_stretch_embed: false
use_variance_scaling: true
hidden_size: 384
dur_prediction_args:
arch: resnet
hidden_size: 256
arch: fs2
hidden_size: 512
dropout: 0.1
num_layers: 5
kernel_size: 3
Expand All @@ -80,11 +82,18 @@ dur_prediction_args:
lambda_pdur_loss: 0.3
lambda_wdur_loss: 1.0
lambda_sdur_loss: 3.0
use_sdp: true
sdp_ratio: 0.2
sdp_n_chans: 192
lambda_sdp_loss: 0.005
lambda_sdp_reg_loss: 0.1
sdp_reg_warmup_steps: 16000

use_melody_encoder: true
melody_encoder_args:
hidden_size: 128
enc_layers: 4
rotary_dim: 64
use_glide_embed: false
glide_types: [up, down]
glide_embed_scale: 11.313708498984760 # sqrt(128)
Expand Down
9 changes: 9 additions & 0 deletions configs/variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ predict_tension: false
enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
rope_theta: 8000
rotary_dim: 192
use_stretch_embed: false
use_variance_scaling: true
rel_pos: true
Expand All @@ -53,11 +55,18 @@ dur_prediction_args:
lambda_pdur_loss: 0.3
lambda_wdur_loss: 1.0
lambda_sdur_loss: 3.0
use_sdp: true
sdp_ratio: 0.2
sdp_n_chans: 192
lambda_sdp_loss: 0.005
lambda_sdp_reg_loss: 0.1
sdp_reg_warmup_steps: 16000

use_melody_encoder: true
melody_encoder_args:
hidden_size: 128
enc_layers: 4
rotary_dim: 64
use_glide_embed: false
glide_types: [up, down]
glide_embed_scale: 11.313708498984760 # sqrt(128)
Expand Down
5 changes: 3 additions & 2 deletions deployment/modules/fastspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ def forward_encoder_phoneme(self, tokens, ph_dur, languages=None):
def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None):
midi_embed = self.midi_embed(ph_midi)
dur_cond = encoder_out + midi_embed
sdp_cond = dur_cond
if hparams['use_spk_id'] and spk_embed is not None:
dur_cond += spk_embed
ph_dur = self.dur_predictor(dur_cond, x_masks=x_masks)
dur_cond = dur_cond + spk_embed
ph_dur, _, _ = self.dur_predictor(dur_cond, x_masks=x_masks, sdp_cond=sdp_cond, spk_embed=spk_embed)
return ph_dur

def view_as_encoder(self):
Expand Down
2 changes: 1 addition & 1 deletion inference/ds_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def forward_model(self, sample):
else:
ph_spk_mix_embed = spk_mix_embed = None

dur_pred, pitch_pred, variance_pred = self.model(
dur_pred, pitch_pred, variance_pred, _, _ = self.model(
txt_tokens, languages=sample.get('languages'),
midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, mel2ph=mel2ph,
note_midi=note_midi, note_rest=note_rest, note_dur=note_dur, note_glide=note_glide, mel2note=mel2note,
Expand Down
8 changes: 6 additions & 2 deletions modules/commons/rotary_embedding_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ def __init__(
dim,
theta=10000,
max_seq_len=8192,
interleaved: bool = True
interleaved: bool = True,
rotary_dim: int = None
):
super().__init__()
self.interleaved = interleaved
self.cached_freqs_seq_len = max_seq_len
inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.rotary_dim = rotary_dim if rotary_dim is not None else dim
if self.rotary_dim > dim:
raise ValueError(f"rotary_dim ({self.rotary_dim}) cannot be larger than dim ({dim})")
inv_freq = 1. / (theta ** (torch.arange(0, self.rotary_dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq, persistent=False)
self.register_buffer('cached_freqs', self._precompute_cache(max_seq_len), persistent=False)

Expand Down
3 changes: 2 additions & 1 deletion modules/fastspeech/acoustic_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self, vocab_size):
ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'],
dropout=hparams['dropout'], num_heads=hparams['num_heads'],
use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False),
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True)
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True),
rope_theta=hparams.get('rope_theta', 10000), rotary_dim=hparams.get('rotary_dim', None),
)

self.pitch_embed = AdamWLinear(1, hparams['hidden_size'])
Expand Down
126 changes: 107 additions & 19 deletions modules/fastspeech/tts_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from modules.commons.rotary_embedding_torch import RotaryEmbedding
from modules.commons.common_layers import SinusoidalPositionalEmbedding, EncSALayer, AdamWLinear
from modules.commons.espnet_positional_embedding import RelPositionalEncoding
from modules.sdp.sdp import StochasticDurationPredictor

DEFAULT_MAX_SOURCE_POSITIONS = 2000
DEFAULT_MAX_TARGET_POSITIONS = 2000
Expand Down Expand Up @@ -54,6 +55,9 @@ class DurationPredictor(torch.nn.Module):
"""Duration predictor module.
This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
It combines a deterministic Duration Predictor (based on FastSpeech/ResNet) and
an optional Stochastic Duration Predictor (SDP, based on Normalizing Flows like VITS/Bert-VITS2)
to improve generation variance.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
Note:
Expand All @@ -62,7 +66,9 @@ class DurationPredictor(torch.nn.Module):
"""

def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3,
dropout_rate=0.1, offset=1.0, dur_loss_type='mse', arch='resnet'):
dropout_rate=0.1, offset=1.0, dur_loss_type='mse', arch='fs2',
sdp_ratio=0.2, sdp_n_chans=192, gin_channels=0
):
"""Initialize duration predictor module.
Args:
in_dims (int): Input dimension.
Expand All @@ -71,6 +77,12 @@ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3,
kernel_size (int, optional): Kernel size of convolutional layers.
dropout_rate (float, optional): Dropout rate.
offset (float, optional): Offset value to avoid nan in log domain.
dur_loss_type (str, optional): Loss type ('mse', 'huber', etc.).
arch (str, optional): Architecture type ('resnet' or standard CNN).
use_sdp (bool, optional): Whether to use Stochastic Duration Predictor.
sdp_ratio (float, optional): Interpolation ratio between SDP and DP outputs during inference.
sdp_n_chans (int, optional): Hidden channels for SDP.
gin_channels (int, optional): Speaker embedding channels for SDP conditioning.
"""
super(DurationPredictor, self).__init__()
self.offset = offset
Expand Down Expand Up @@ -99,6 +111,7 @@ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3,
self.res_conv = nn.Conv1d(in_dims, n_chans, 1)
else:
self.res_conv = None

self.loss_type = dur_loss_type
if self.loss_type in ['mse', 'huber']:
self.out_dims = 1
Expand All @@ -112,6 +125,21 @@ def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3,
raise NotImplementedError()
self.linear = AdamWLinear(n_chans, self.out_dims)

self.use_sdp = use_sdp
self.sdp_ratio = sdp_ratio

if self.use_sdp:
self.sdp = StochasticDurationPredictor(
in_channels=in_dims,
filter_channels=sdp_n_chans,
kernel_size=3,
p_dropout=0.5,
n_flows=4,
gin_channels=gin_channels
)
else:
self.sdp = None

def out2dur(self, xs):
if self.loss_type in ['mse', 'huber']:
# NOTE: calculate loss in log domain
Expand All @@ -122,33 +150,86 @@ def out2dur(self, xs):
raise NotImplementedError()
return dur

def forward(self, xs, x_masks=None, infer=True):
def forward(self, xs, x_masks=None, infer=True, ph_dur=None, sdp_cond=None, spk_embed=None):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of input sequences (B, Tmax, idim).
x_masks (BoolTensor, optional): Batch of masks indicating padded part (B, Tmax).
infer (bool): Whether inference
ph_dur (Tensor, optional): Ground truth phoneme duration [B, Tmax]. Needed for SDP training.
sdp_cond (Tensor, optional): Conditioning sequence for SDP [B, Tmax, idim].
spk_embed (Tensor, optional): Speaker embedding [B, gin_channels].

Returns:
(train) FloatTensor, (infer) LongTensor: Batch of predicted durations in linear domain (B, Tmax).
tuple:
- dur_pred (Tensor): Final predicted linear durations [B, Tmax].
- loss_sdp (Tensor or int): SDP negative log-likelihood flow loss (0 if not using SDP or inferring).
- sdp_pred (Tensor or None): SDP reverse prediction for MSE regression computing.
"""
xs = xs.transpose(1, -1) # (B, idim, Tmax)
masks = 1 - x_masks.float()
masks_ = masks[:, None, :]
non_pad_mask = 1.0 - x_masks.float() # [B, Tmax]
non_pad_mask_1d = non_pad_mask.unsqueeze(1) # [B, 1, Tmax]
non_pad_mask_2d = non_pad_mask.unsqueeze(2) # [B, Tmax, 1]
g = spk_embed.transpose(1, -1) if spk_embed is not None else None # [B, gin_chans, 1]
dp_xs = xs.transpose(1, -1) # [B, Tmax, idim] -> [B, idim, Tmax]
for idx, f in enumerate(self.conv):
if self.use_resnet:
residual = self.res_conv(xs) if idx == 0 and self.res_conv is not None else xs
xs = residual + f(xs)
residual = self.res_conv(dp_xs) if (idx == 0 and self.res_conv is not None) else dp_xs
dp_xs = residual + f(dp_xs)
else:
xs = f(xs)
dp_xs = f(dp_xs)

if x_masks is not None:
xs = xs * masks_
xs = self.linear(xs.transpose(1, -1)) # [B, T, C]
xs = xs * masks[:, :, None] # (B, T, C)
dp_xs = dp_xs * non_pad_mask_1d

dp_xs = self.linear(dp_xs.transpose(1, -1)) # [B, idim, Tmax] -> [B, Tmax, C]
dp_xs = dp_xs * non_pad_mask_2d # Mask padded areas [B, Tmax, C]

dur_pred = self.out2dur(dp_xs) # Convert to linear domain [B, Tmax]

loss_sdp = 0.0
sdp_pred = None

if self.use_sdp and sdp_cond is not None:
sdp_cond_t = sdp_cond.transpose(1, -1) # [B, idim, Tmax]

if not infer:
nll_loss = self.sdp(
x=sdp_cond_t,
x_mask=non_pad_mask_1d,
w=ph_dur,
g=g,
reverse=False,
noise_scale=1.0
)

l_length_sdp = nll_loss / torch.sum(non_pad_mask_1d)
loss_sdp = torch.sum(l_length_sdp.float())

logw_sdp = self.sdp(
x=sdp_cond_t,
x_mask=non_pad_mask_1d,
g=g,
reverse=True,
noise_scale=1.0
)
sdp_pred = torch.ceil(self.out2dur(logw_sdp.transpose(1, -1) * non_pad_mask_2d))

else:
logw_sdp = self.sdp(
x=sdp_cond_t,
x_mask=non_pad_mask_1d,
g=g,
reverse=True,
noise_scale=0.8
)
sdp_pred = torch.ceil(self.out2dur(logw_sdp.transpose(1, -1) * non_pad_mask_2d))

dur_pred = (sdp_pred * self.sdp_ratio) + (dur_pred * (1.0 - self.sdp_ratio))

dur_pred = self.out2dur(xs)
if infer:
dur_pred = dur_pred.clamp(min=0.) # avoid negative value
return dur_pred
return dur_pred.clamp(min=0.0), None, None
else:
return dur_pred, loss_sdp, sdp_pred


class VariancePredictor(torch.nn.Module):
Expand Down Expand Up @@ -369,19 +450,26 @@ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
class FastSpeech2Encoder(nn.Module):
def __init__(self, hidden_size, num_layers,
ffn_kernel_size=9, ffn_act='gelu',
dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True):
dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True,
use_rope=False, rope_interleaved=True, rope_theta=10000, rotary_dim=None):
super().__init__()
self.num_layers = num_layers
embed_dim = self.hidden_size = hidden_size
self.dropout = dropout
self.use_pos_embed = use_pos_embed
if use_pos_embed and use_rope:
if embed_dim % (num_heads * 2) != 0:
rotary_dim = rotary_dim if rotary_dim is not None else embed_dim
if rotary_dim % (num_heads * 2) != 0:
raise ValueError(
"RoPE requires the hidden size to be multiple of "
f"num_heads * 2 = {num_heads * 2}, but got {embed_dim}."
f"num_heads * 2 = {num_heads * 2}, but got {rotary_dim}."
)
rotary_embed = RotaryEmbedding(dim=embed_dim // num_heads, interleaved=rope_interleaved)
rotary_embed = RotaryEmbedding(
dim=embed_dim // num_heads,
theta=rope_theta,
interleaved=rope_interleaved,
rotary_dim=rotary_dim
)
else:
rotary_embed = None
self.layers = nn.ModuleList([
Expand Down
26 changes: 19 additions & 7 deletions modules/fastspeech/variance_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(self, vocab_size):
ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'],
dropout=hparams['dropout'], num_heads=hparams['num_heads'],
use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False),
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True)
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True),
rope_theta=hparams.get('rope_theta', 10000), rotary_dim=hparams.get('rotary_dim', None),
)

dur_hparams = hparams['dur_prediction_args']
Expand All @@ -48,7 +49,11 @@ def __init__(self, vocab_size):
kernel_size=dur_hparams['kernel_size'],
offset=dur_hparams['log_offset'],
dur_loss_type=dur_hparams['loss_type'],
arch=dur_hparams['arch']
arch=dur_hparams['arch'],
use_sdp = hparams.get('use_sdp', False),
sdp_ratio = hparams.get('sdp_ratio', 0.2),
sdp_n_chans = hparams.get('sdp_n_chans', 192),
gin_channels = hparams['hidden_size'] if hparams['use_spk_id'] else 0,
)

def forward(
Expand Down Expand Up @@ -91,16 +96,22 @@ def forward(
extra_embed += lang_embed
encoder_out = self.encoder(txt_embed, extra_embed, txt_tokens == 0)

sdp_loss = None
sdp_pred = None
if self.predict_dur:
midi_embed = self.midi_embed(midi) # => [B, T_ph, H]
dur_cond = encoder_out + midi_embed
sdp_cond = dur_cond
if spk_embed is not None:
dur_cond += spk_embed
ph_dur_pred = self.dur_predictor(dur_cond, x_masks=txt_tokens == PAD_INDEX, infer=infer)
dur_cond = dur_cond + spk_embed
g = spk_embed
else:
g = None
ph_dur_pred, sdp_loss, sdp_pred = self.dur_predictor(dur_cond, x_masks=txt_tokens == PAD_INDEX, infer=infer, infer=infer, ph_dur=ph_dur, sdp_cond=sdp_cond, spk_embed=g)

return encoder_out, ph_dur_pred
return encoder_out, ph_dur_pred, sdp_loss, sdp_pred
else:
return encoder_out, None
return encoder_out, None, sdp_loss, sdp_pred


class MelodyEncoder(nn.Module):
Expand Down Expand Up @@ -128,7 +139,8 @@ def get_hparam(key):
ffn_kernel_size=get_hparam('enc_ffn_kernel_size'), ffn_act=get_hparam('ffn_act'),
dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'),
use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos'),
use_rope=get_hparam('use_rope'), rope_interleaved=hparams.get('rope_interleaved', True)
use_rope=get_hparam('use_rope'), rope_interleaved=hparams.get('rope_interleaved', True),
rope_theta=hparams.get('rope_theta', 10000), rotary_dim=get_hparam('rotary_dim'),
)
self.out_proj = Linear(hidden_size, hparams['hidden_size'])

Expand Down
Loading