Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alphas to control language and speaker balancer #1216

Merged
merged 11 commits into from
Mar 10, 2022
16 changes: 16 additions & 0 deletions TTS/config/shared_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,27 @@ class BaseTrainingConfig(TrainerConfig):

num_eval_loader_workers (int):
Number of workers for evaluation time dataloader.

use_speaker_weighted_sampler (bool):
Enable / Disable the batch balancer by speaker. Defaults to ```False```.

speaker_weighted_sampler_alpha (float):
Number that control the influence of the speaker sampler weights. Defaults to ```1.0```.

use_language_weighted_sampler (bool):
Enable / Disable the batch balancer by language. Defaults to ```False```.

language_weighted_sampler_alpha (float):
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
erogol marked this conversation as resolved.
Show resolved Hide resolved
"""

model: str = None
# dataloading
num_loader_workers: int = 0
num_eval_loader_workers: int = 0
use_noise_augment: bool = False
# weighted samplers
use_speaker_weighted_sampler: bool = False
speaker_weighted_sampler_alpha: float = 1.0
use_language_weighted_sampler: bool = False
language_weighted_sampler_alpha: float = 1.0
58 changes: 36 additions & 22 deletions TTS/tts/models/base_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper

from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
erogol marked this conversation as resolved.
Show resolved Hide resolved
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_balancer_weights
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from torch.utils.data.sampler import WeightedRandomSampler

# pylint: skip-file

Expand Down Expand Up @@ -232,6 +233,36 @@ def format_batch(self, batch: Dict) -> Dict:
"language_ids": language_ids,
}

def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
weights = None
data_items = dataset.samples

if getattr(config, "use_language_weighted_sampler", False):
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
print(" > Using Language weighted sampler with alpha:", alpha)
weights = get_language_balancer_weights(data_items) * alpha

if getattr(config, "use_speaker_weighted_sampler", False):
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
print(" > Using Speaker weighted sampler with alpha:", alpha)
if weights is not None:
weights += get_speaker_balancer_weights(data_items) * alpha
else:
weights = get_speaker_balancer_weights(data_items) * alpha

if weights is not None:
sampler = WeightedRandomSampler(weights, len(weights))
erogol marked this conversation as resolved.
Show resolved Hide resolved
else:
sampler = None

# sampler for DDP
if sampler is None:
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
else: # If a sampler is already defined use this sampler and DDP sampler together
sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler

return sampler

def get_data_loader(
self,
config: Coqpit,
Expand Down Expand Up @@ -300,25 +331,8 @@ def get_data_loader(
# sort input sequences from short to long
dataset.preprocess_samples()

# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None

# Weighted samplers
# TODO: make this DDP amenable
assert not (
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
), "language_weighted_sampler is not supported with DistributedSampler"
assert not (
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
), "speaker_weighted_sampler is not supported with DistributedSampler"

if sampler is None:
if getattr(config, "use_language_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_language_weighted_sampler(dataset.samples)
elif getattr(config, "use_speaker_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.samples)
# get samplers
sampler = self.get_sampler(config, dataset, num_gpus)

loader = DataLoader(
dataset,
Expand Down
27 changes: 5 additions & 22 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler

from TTS.tts.configs.shared_configs import CharactersConfig
Expand All @@ -24,8 +23,8 @@
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
Expand Down Expand Up @@ -1354,31 +1353,15 @@ def get_data_loader(
# sort input sequences from short to long
dataset.preprocess_samples()

# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None

# Weighted samplers
# TODO: make this DDP amenable
assert not (
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
), "language_weighted_sampler is not supported with DistributedSampler"
assert not (
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
), "speaker_weighted_sampler is not supported with DistributedSampler"

if sampler is None:
if getattr(config, "use_language_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_language_weighted_sampler(dataset.samples)
elif getattr(config, "use_speaker_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.samples)
# get samplers
sampler = self.get_sampler(config, dataset, num_gpus)

loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=False, # shuffle is done in the dataset.
drop_last=False, # setting this False might cause issues in AMP training.
sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
Expand Down
10 changes: 6 additions & 4 deletions TTS/tts/utils/languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import torch
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler

from TTS.config import check_config_and_model_args

Expand Down Expand Up @@ -128,11 +127,14 @@ def _set_file_path(path):
return None


def get_language_weighted_sampler(items: list):
def get_language_balancer_weights(items: list):
language_names = np.array([item["language"] for item in items])
unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
weight_language = 1.0 / language_count
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
# get weight for each sample
dataset_samples_weight = np.array([weight_language[l] for l in language_ids])
# normalize
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
return torch.from_numpy(dataset_samples_weight).float()
9 changes: 5 additions & 4 deletions TTS/tts/utils/speakers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import torch
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler

from TTS.config import get_from_config_or_model_args_with_default, load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
Expand Down Expand Up @@ -449,11 +448,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
return speaker_manager


def get_speaker_weighted_sampler(items: list):
def get_speaker_balancer_weights(items: list):
speaker_names = np.array([item["speaker_name"] for item in items])
unique_speaker_names = np.unique(speaker_names).tolist()
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
weight_speaker = 1.0 / speaker_count
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
dataset_samples_weight = np.array([weight_speaker[l] for l in speaker_ids])
# normalize
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
return torch.from_numpy(dataset_samples_weight).float()
66 changes: 46 additions & 20 deletions tests/data_tests/test_samplers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import functools

import unittest

import torch

Edresson marked this conversation as resolved.
Show resolved Hide resolved
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_weighted_sampler
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights

# Fixing random state to avoid random fails
torch.manual_seed(0)
Expand All @@ -25,34 +28,57 @@
language="pt-br",
)

# Adding the EN samples twice to create an unbalanced dataset
# Adding the EN samples twice to create a language unbalanced dataset
train_samples, eval_samples = load_tts_samples(
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
)

# gerenate a speaker unbalanced dataset
for i, sample in enumerate(train_samples):
if i < 5:
sample["speaker_name"] = "ljspeech-0"
else:
sample["speaker_name"] = "ljspeech-1"


def is_balanced(lang_1, lang_2):
return 0.85 < lang_1 / lang_2 < 1.2


random_sampler = torch.utils.data.RandomSampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1
class TestSamplers(unittest.TestCase):
def test_language_random_sampler(self): # pylint: disable=no-self-use
random_sampler = torch.utils.data.RandomSampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1

assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"

weighted_sampler = get_language_weighted_sampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1
def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples))
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0
for index in ids:
if train_samples[index]["language"] == "en":
en += 1
else:
pt += 1

assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced"

def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use

weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples))
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
spk1, spk2 = 0, 0
for index in ids:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1

assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced"
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
9 changes: 6 additions & 3 deletions tests/tts_tests/test_vits_multilingual_train-d_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
["Be a voice, not an echo.", "ljspeech-0", None, "en"],
["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"],
],
datasets=[dataset_config_en, dataset_config_pt],
datasets=[dataset_config_en, dataset_config_en, dataset_config_en, dataset_config_pt],
)
# set audio config
config.audio.do_trim_silence = True
Expand All @@ -71,8 +71,11 @@
config.model_args.use_sdp = True
config.use_sdp = True

# deactivate language sampler
config.use_language_weighted_sampler = False
# activate language and speaker samplers
config.use_language_weighted_sampler = True
config.language_weighted_sampler_alpha = 10
config.use_speaker_weighted_sampler = True
config.speaker_weighted_sampler_alpha = 5

config.save_json(config_path)

Expand Down