diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index 6394b2644b..3ea49796fc 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -258,4 +258,3 @@ class BaseTrainingConfig(TrainerConfig): num_loader_workers: int = 0 num_eval_loader_workers: int = 0 use_noise_augment: bool = False - use_language_weighted_sampler: bool = False diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index f43c646473..a9b56ed497 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -220,6 +220,18 @@ class BaseTTSConfig(BaseTrainingConfig): eval_split_size (float): If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + use_speaker_weighted_sampler (bool): + Enable / Disable the batch balancer by speaker. Defaults to ```False```. + + speaker_weighted_sampler_alpha (float): + Number that control the influence of the speaker sampler weights. Defaults to ```1.0```. + + use_language_weighted_sampler (bool): + Enable / Disable the batch balancer by language. Defaults to ```False```. + + language_weighted_sampler_alpha (float): + Number that control the influence of the language sampler weights. Defaults to ```1.0```. """ audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) @@ -262,3 +274,8 @@ class BaseTTSConfig(BaseTrainingConfig): # evaluation eval_split_max_size: int = None eval_split_size: float = 0.01 + # weighted samplers + use_speaker_weighted_sampler: bool = False + speaker_weighted_sampler_alpha: float = 1.0 + use_language_weighted_sampler: bool = False + language_weighted_sampler_alpha: float = 1.0 diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 4e54b94704..222f851970 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -7,14 +7,15 @@ from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset -from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler +from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_balancer_weights from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from torch.utils.data.sampler import WeightedRandomSampler # pylint: skip-file @@ -232,6 +233,36 @@ def format_batch(self, batch: Dict) -> Dict: "language_ids": language_ids, } + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + weights = None + data_items = dataset.samples + + if getattr(config, "use_language_weighted_sampler", False): + alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) + print(" > Using Language weighted sampler with alpha:", alpha) + weights = get_language_balancer_weights(data_items) * alpha + + if getattr(config, "use_speaker_weighted_sampler", False): + alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) + print(" > Using Speaker weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_speaker_balancer_weights(data_items) * alpha + else: + weights = get_speaker_balancer_weights(data_items) * alpha + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + + return sampler + def get_data_loader( self, config: Coqpit, @@ -300,25 +331,8 @@ def get_data_loader( # sort input sequences from short to long dataset.preprocess_samples() - # sampler for DDP - sampler = DistributedSampler(dataset) if num_gpus > 1 else None - - # Weighted samplers - # TODO: make this DDP amenable - assert not ( - num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) - ), "language_weighted_sampler is not supported with DistributedSampler" - assert not ( - num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False) - ), "speaker_weighted_sampler is not supported with DistributedSampler" - - if sampler is None: - if getattr(config, "use_language_weighted_sampler", False): - print(" > Using Language weighted sampler") - sampler = get_language_weighted_sampler(dataset.samples) - elif getattr(config, "use_speaker_weighted_sampler", False): - print(" > Using Language weighted sampler") - sampler = get_speaker_weighted_sampler(dataset.samples) + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) loader = DataLoader( dataset, diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index a43e081c86..6aa30dfe6c 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -13,7 +13,6 @@ from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.configs.shared_configs import CharactersConfig @@ -24,8 +23,8 @@ from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask -from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler +from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer @@ -1354,31 +1353,15 @@ def get_data_loader( # sort input sequences from short to long dataset.preprocess_samples() - # sampler for DDP - sampler = DistributedSampler(dataset) if num_gpus > 1 else None - - # Weighted samplers - # TODO: make this DDP amenable - assert not ( - num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) - ), "language_weighted_sampler is not supported with DistributedSampler" - assert not ( - num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False) - ), "speaker_weighted_sampler is not supported with DistributedSampler" - - if sampler is None: - if getattr(config, "use_language_weighted_sampler", False): - print(" > Using Language weighted sampler") - sampler = get_language_weighted_sampler(dataset.samples) - elif getattr(config, "use_speaker_weighted_sampler", False): - print(" > Using Language weighted sampler") - sampler = get_speaker_weighted_sampler(dataset.samples) + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, shuffle=False, # shuffle is done in the dataset. drop_last=False, # setting this False might cause issues in AMP training. + sampler=sampler, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 19708c13eb..7decabb078 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -6,7 +6,6 @@ import numpy as np import torch from coqpit import Coqpit -from torch.utils.data.sampler import WeightedRandomSampler from TTS.config import check_config_and_model_args @@ -128,11 +127,14 @@ def _set_file_path(path): return None -def get_language_weighted_sampler(items: list): +def get_language_balancer_weights(items: list): language_names = np.array([item["language"] for item in items]) unique_language_names = np.unique(language_names).tolist() language_ids = [unique_language_names.index(l) for l in language_names] language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) weight_language = 1.0 / language_count - dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double() - return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) + # get weight for each sample + dataset_samples_weight = np.array([weight_language[l] for l in language_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 99d653e685..078ce3f1dd 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -7,7 +7,6 @@ import numpy as np import torch from coqpit import Coqpit -from torch.utils.data.sampler import WeightedRandomSampler from TTS.config import get_from_config_or_model_args_with_default, load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model @@ -449,11 +448,13 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, return speaker_manager -def get_speaker_weighted_sampler(items: list): +def get_speaker_balancer_weights(items: list): speaker_names = np.array([item["speaker_name"] for item in items]) unique_speaker_names = np.unique(speaker_names).tolist() speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) weight_speaker = 1.0 / speaker_count - dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double() - return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) + dataset_samples_weight = np.array([weight_speaker[l] for l in speaker_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 497a3fb58f..12152fb812 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -1,10 +1,13 @@ import functools +import unittest + import torch from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.utils.languages import get_language_weighted_sampler +from TTS.tts.utils.languages import get_language_balancer_weights +from TTS.tts.utils.speakers import get_speaker_balancer_weights # Fixing random state to avoid random fails torch.manual_seed(0) @@ -25,34 +28,57 @@ language="pt-br", ) -# Adding the EN samples twice to create an unbalanced dataset +# Adding the EN samples twice to create a language unbalanced dataset train_samples, eval_samples = load_tts_samples( [dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True ) +# gerenate a speaker unbalanced dataset +for i, sample in enumerate(train_samples): + if i < 5: + sample["speaker_name"] = "ljspeech-0" + else: + sample["speaker_name"] = "ljspeech-1" + def is_balanced(lang_1, lang_2): return 0.85 < lang_1 / lang_2 < 1.2 -random_sampler = torch.utils.data.RandomSampler(train_samples) -ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) -en, pt = 0, 0 -for index in ids: - if train_samples[index]["language"] == "en": - en += 1 - else: - pt += 1 +class TestSamplers(unittest.TestCase): + def test_language_random_sampler(self): # pylint: disable=no-self-use + random_sampler = torch.utils.data.RandomSampler(train_samples) + ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) + en, pt = 0, 0 + for index in ids: + if train_samples[index]["language"] == "en": + en += 1 + else: + pt += 1 -assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced" + assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced" -weighted_sampler = get_language_weighted_sampler(train_samples) -ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) -en, pt = 0, 0 -for index in ids: - if train_samples[index]["language"] == "en": - en += 1 - else: - pt += 1 + def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use + weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples)) + ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) + en, pt = 0, 0 + for index in ids: + if train_samples[index]["language"] == "en": + en += 1 + else: + pt += 1 + + assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced" + + def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use + + weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples)) + ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) + spk1, spk2 = 0, 0 + for index in ids: + if train_samples[index]["speaker_name"] == "ljspeech-0": + spk1 += 1 + else: + spk2 += 1 -assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced" + assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced" diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py index a8e2020e35..e12661a506 100644 --- a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py +++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py @@ -45,7 +45,7 @@ ["Be a voice, not an echo.", "ljspeech-0", None, "en"], ["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"], ], - datasets=[dataset_config_en, dataset_config_pt], + datasets=[dataset_config_en, dataset_config_en, dataset_config_en, dataset_config_pt], ) # set audio config config.audio.do_trim_silence = True @@ -71,8 +71,11 @@ config.model_args.use_sdp = True config.use_sdp = True -# deactivate language sampler -config.use_language_weighted_sampler = False +# activate language and speaker samplers +config.use_language_weighted_sampler = True +config.language_weighted_sampler_alpha = 10 +config.use_speaker_weighted_sampler = True +config.speaker_weighted_sampler_alpha = 5 config.save_json(config_path)