Skip to content

Commit

Permalink
added some modules, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
p0p4k committed Aug 31, 2023
1 parent e65cab7 commit 865d2ea
Show file tree
Hide file tree
Showing 6 changed files with 1,087 additions and 128 deletions.
181 changes: 181 additions & 0 deletions TTS/tts/configs/vits2_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from dataclasses import dataclass, field
from typing import List

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.vits2 import Vits2Args, Vits2AudioConfig


@dataclass
class Vits2Config(BaseTTSConfig):
"""Defines parameters for VITS2 End2End TTS model.
Args:
model (str):
Model name. Do not change unless you know what you are doing.
model_args (Vits2Args):
Model architecture arguments. Defaults to `Vits2Args()`.
audio (Vits2AudioConfig):
Audio processing configuration. Defaults to `Vits2AudioConfig()`.
grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
lr_gen (float):
Initial learning rate for the generator. Defaults to 0.0002.
lr_disc (float):
Initial learning rate for the discriminator. Defaults to 0.0002.
lr_scheduler_gen (str):
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.
lr_scheduler_gen_params (dict):
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
lr_scheduler_disc (str):
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.
lr_scheduler_disc_params (dict):
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
scheduler_after_epoch (bool):
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
optimizer (str):
Name of the optimizer to use with both the generator and the discriminator networks. One of the
`torch.optim.*`. Defaults to `AdamW`.
kl_loss_alpha (float):
Loss weight for KL loss. Defaults to 1.0.
disc_loss_alpha (float):
Loss weight for the discriminator loss. Defaults to 1.0.
gen_loss_alpha (float):
Loss weight for the generator loss. Defaults to 1.0.
feat_loss_alpha (float):
Loss weight for the feature matching loss. Defaults to 1.0.
mel_loss_alpha (float):
Loss weight for the mel loss. Defaults to 45.0.
return_wav (bool):
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
use_weighted_sampler (bool):
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
weighted_sampler_attrs (dict):
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
by overweighting `root_path` by 2.0. Defaults to `{}`.
weighted_sampler_multipliers (dict):
Weight each unique value of a key returned by the formatter for weighted sampling.
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
add_blank (bool):
If true, a blank token is added in between every character. Defaults to `True`.
test_sentences (List[List]):
List of sentences with speaker and language information to be used for testing.
language_ids_file (str):
Path to the language ids file.
use_language_embedding (bool):
If true, language embedding is used. Defaults to `False`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
Example:
>>> from TTS.tts.configs.vits2_config import Vits2Config
>>> config = Vits2Config()
"""

model: str = "vits2"
# model specific params
model_args: Vits2Args = field(default_factory=Vits2Args)
audio: Vits2AudioConfig = field(default_factory=Vits2AudioConfig)

# optimizer
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
lr_gen: float = 0.0002
lr_disc: float = 0.0002
lr_dur: float = 0.0002

lr_scheduler_gen: str = "ExponentialLR"
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
lr_scheduler_disc: str = "ExponentialLR"
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
lr_scheduler_dur: str = "ExponentialLR"
lr_scheduler_dur_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})

scheduler_after_epoch: bool = True
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})

# loss params
kl_loss_alpha: float = 1.0
disc_loss_alpha: float = 1.0
gen_loss_alpha: float = 1.0
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
speaker_encoder_loss_alpha: float = 1.0

# data loader params
return_wav: bool = True
compute_linear_spec: bool = True

# sampler params
use_weighted_sampler: bool = False # TODO: move it to the base config
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})

# overrides
r: int = 1 # DO NOT CHANGE
add_blank: bool = True

# testing
test_sentences: List[List] = field(
default_factory=lambda: [
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
["Be a voice, not an echo."],
["I'm sorry Dave. I'm afraid I can't do that."],
["This cake is great. It's so delicious and moist."],
["Prior to November 22, 1963."],
]
)

# multi-speaker settings
# use speaker embedding layer
num_speakers: int = 0
use_speaker_embedding: bool = False
speakers_file: str = None
speaker_embedding_channels: int = 256
language_ids_file: str = None
use_language_embedding: bool = False

# use d-vectors
use_d_vector_file: bool = False
d_vector_file: List[str] = None
d_vector_dim: int = None

def __post_init__(self):
for key, val in self.model_args.items():
if hasattr(self, key):
self[key] = val
33 changes: 33 additions & 0 deletions TTS/tts/layers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,39 @@ def forward(
return_dict["loss"] = loss
return return_dict

class Vits2DurationLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
self.disc_loss_alpha = c.disc_loss_alpha

@staticmethod
def discriminator_loss(scores_real, scores_fake):
loss = 0
real_losses = []
fake_losses = []
for dr, dg in zip(scores_real, scores_fake):
dr = dr.float()
dg = dg.float()
real_loss = torch.mean((1 - dr) ** 2)
fake_loss = torch.mean(dg**2)
loss += real_loss + fake_loss
real_losses.append(real_loss.item())
fake_losses.append(fake_loss.item())
return loss, real_losses, fake_losses

def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0
return_dict = {}
loss_disc, loss_disc_real, _ = self.discriminator_loss(
scores_real=scores_disc_real, scores_fake=scores_disc_fake
)
return_dict["loss_dur_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + return_dict["loss_dur_disc"]
return_dict["loss"] = loss

for i, ldr in enumerate(loss_disc_real):
return_dict[f"loss_dur_disc_real_{i}"] = ldr
return return_dict

class VitsDiscriminatorLoss(nn.Module):
def __init__(self, c: Coqpit):
Expand Down
84 changes: 84 additions & 0 deletions TTS/tts/layers/vits2/duration_discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
from torch import nn
from TTS.tts.layers.generic.normalization import LayerNorm2


class DurationDiscriminator(nn.Module): #vits2
"""VITS-2 Duration Discriminator.
::
dur_r, dur_hat -> DurationDiscriminator() -> output_probs
Args:
in_channels (int): number of input channels.
filter_channels (int): number of filter channels.
kernel_size (int): kernel size of the convolutional layers.
p_dropout (float): dropout probability.
gin_channels (int): number of global conditioning channels.
Unused for now.
Returns:
List[Tensor]: list of discriminator scores. Real, Predicted/Generated.
"""
# TODO : not using "spk conditioning" for now according to the paper.
# Can be a better discriminator if we use it.
def __init__(
self,
in_channels,
filter_channels,
kernel_size,
p_dropout,
gin_channels=0
):
super().__init__()

self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels

self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.dur_proj = nn.Conv1d(1, filter_channels, 1)

self.pre_out_conv_1 = nn.Conv1d(2*filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.pre_out_norm_1 = LayerNorm2(filter_channels)
self.pre_out_conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
self.pre_out_norm_2 = LayerNorm2(filter_channels)

# if gin_channels != 0:
# self.cond = nn.Conv1d(gin_channels, in_channels, 1)

self.output_layer = nn.Sequential(
nn.Linear(filter_channels, 1),
nn.Sigmoid()
)

def forward_probability(self, x, x_mask, dur, g=None):
dur = self.dur_proj(dur)
x = torch.cat([x, dur], dim=1)
x = self.pre_out_conv_1(x * x_mask)
x = self.pre_out_conv_2(x * x_mask)
x = x * x_mask
x = x.transpose(1, 2)
output_prob = self.output_layer(x)
return output_prob

def forward(self, x, x_mask, dur_r, dur_hat, g=None):
x = torch.detach(x)
# if g is not None:
# g = torch.detach(g)
# x = x + self.cond(g)
x = self.conv_1(x * x_mask)
# x = self.drop(x)
x = self.conv_2(x * x_mask)
# x = self.drop(x)

output_probs = []
for dur in [dur_r, dur_hat]:
output_prob = self.forward_probability(x, x_mask, dur, g)
output_probs.append(output_prob)

return output_probs
Loading

0 comments on commit 865d2ea

Please sign in to comment.