From f5ecbf33cc8ff03de5e3473df5baca1b4e849741 Mon Sep 17 00:00:00 2001 From: Edresson Date: Tue, 8 Feb 2022 12:28:22 -0300 Subject: [PATCH 01/11] Add alphas to control language and speaker balancer --- TTS/config/shared_configs.py | 4 +++ TTS/tts/models/base_tts.py | 25 ++++++++++++----- TTS/tts/utils/languages.py | 8 ++++-- TTS/tts/utils/speakers.py | 7 +++-- tests/data_tests/test_samplers.py | 28 +++++++++++++++++-- .../test_vits_multilingual_train-d_vectors.py | 7 +++-- 6 files changed, 61 insertions(+), 18 deletions(-) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index 6394b2644b..72b5733076 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -258,4 +258,8 @@ class BaseTrainingConfig(TrainerConfig): num_loader_workers: int = 0 num_eval_loader_workers: int = 0 use_noise_augment: bool = False + # weighted samplers + use_speaker_weighted_sampler: bool = False + speaker_weighted_sampler_alpha: float = 1.0 use_language_weighted_sampler: bool = False + language_weighted_sampler_alpha: float = 1.0 diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 4e54b94704..2ff9d339d0 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -11,10 +11,11 @@ from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset -from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler +from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_balancer_weights from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from torch.utils.data.sampler import WeightedRandomSampler # pylint: skip-file @@ -313,12 +314,22 @@ def get_data_loader( ), "speaker_weighted_sampler is not supported with DistributedSampler" if sampler is None: + weights = None if getattr(config, "use_language_weighted_sampler", False): - print(" > Using Language weighted sampler") - sampler = get_language_weighted_sampler(dataset.samples) - elif getattr(config, "use_speaker_weighted_sampler", False): - print(" > Using Language weighted sampler") - sampler = get_speaker_weighted_sampler(dataset.samples) + alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) + print(" > Using Language weighted sampler with alpha:", alpha) + weights = get_language_balancer_weights(dataset.items) * alpha + + if getattr(config, "use_speaker_weighted_sampler", False): + alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) + print(" > Using Speaker weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_speaker_balancer_weights(dataset.items) * alpha + else: + weights = get_speaker_balancer_weights(dataset.items) * alpha + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) loader = DataLoader( dataset, diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 19708c13eb..d8d4c70b79 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -6,7 +6,6 @@ import numpy as np import torch from coqpit import Coqpit -from torch.utils.data.sampler import WeightedRandomSampler from TTS.config import check_config_and_model_args @@ -134,5 +133,8 @@ def get_language_weighted_sampler(items: list): language_ids = [unique_language_names.index(l) for l in language_names] language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) weight_language = 1.0 / language_count - dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double() - return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) + # get weight for each sample + dataset_samples_weight = np.array([weight_language[l] for l in language_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).double() diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 99d653e685..46aa3b9b72 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -7,7 +7,6 @@ import numpy as np import torch from coqpit import Coqpit -from torch.utils.data.sampler import WeightedRandomSampler from TTS.config import get_from_config_or_model_args_with_default, load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model @@ -455,5 +454,7 @@ def get_speaker_weighted_sampler(items: list): speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) weight_speaker = 1.0 / speaker_count - dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double() - return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) + dataset_samples_weight = np.array([weight_speaker[l] for l in speaker_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).double() diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 497a3fb58f..80331eacd6 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -4,7 +4,8 @@ from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples -from TTS.tts.utils.languages import get_language_weighted_sampler +from TTS.tts.utils.languages import get_language_balancer_weights +from TTS.tts.utils.speakers import get_speaker_balancer_weights # Fixing random state to avoid random fails torch.manual_seed(0) @@ -46,7 +47,8 @@ def is_balanced(lang_1, lang_2): assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced" -weighted_sampler = get_language_weighted_sampler(train_samples) + +weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples)) ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) en, pt = 0, 0 for index in ids: @@ -55,4 +57,24 @@ def is_balanced(lang_1, lang_2): else: pt += 1 -assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced" +assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced" + +# test speaker weighted sampler + +# gerenate a speaker unbalanced dataset +for i in range(0, len(train_samples)): + if i < 5: + train_samples[i][2] = "ljspeech-0" + else: + train_samples[i][2] = "ljspeech-1" + +weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples)) +ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) +spk1, spk2 = 0, 0 +for index in ids: + if train_samples[index][2] == "ljspeech-0": + spk1 += 1 + else: + spk2 += 1 + +assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced" diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py index a8e2020e35..cb292c1d3d 100644 --- a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py +++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py @@ -45,7 +45,7 @@ ["Be a voice, not an echo.", "ljspeech-0", None, "en"], ["Be a voice, not an echo.", "ljspeech-1", None, "pt-br"], ], - datasets=[dataset_config_en, dataset_config_pt], + datasets=[dataset_config_en, dataset_config_en, dataset_config_en, dataset_config_pt], ) # set audio config config.audio.do_trim_silence = True @@ -72,7 +72,10 @@ config.use_sdp = True # deactivate language sampler -config.use_language_weighted_sampler = False +config.use_language_weighted_sampler = True +config.language_weighted_sampler_alpha = 10 +config.use_speaker_weighted_sampler = True +config.speaker_weighted_sampler_alpha = 1 config.save_json(config_path) From 7245444697704ef15211d4f1d525d4606cbf7ada Mon Sep 17 00:00:00 2001 From: Edresson Date: Thu, 10 Feb 2022 14:19:30 -0300 Subject: [PATCH 02/11] Add docs for speaker and language samplers --- TTS/config/shared_configs.py | 12 ++++++++++++ tests/data_tests/test_samplers.py | 6 +++--- .../test_vits_multilingual_train-d_vectors.py | 4 ++-- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index 72b5733076..f746e26875 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -251,6 +251,18 @@ class BaseTrainingConfig(TrainerConfig): num_eval_loader_workers (int): Number of workers for evaluation time dataloader. + + use_speaker_weighted_sampler (bool): + Enable / Disable the batch balancer by speaker. Defaults to ```False```. + + speaker_weighted_sampler_alpha (float): + Number that control the influence of the speaker sampler weights. Defaults to ```1.0```. + + use_language_weighted_sampler (bool): + Enable / Disable the batch balancer by language. Defaults to ```False```. + + language_weighted_sampler_alpha (float): + Number that control the influence of the language sampler weights. Defaults to ```1.0```. """ model: str = None diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 80331eacd6..044c0de917 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -62,11 +62,11 @@ def is_balanced(lang_1, lang_2): # test speaker weighted sampler # gerenate a speaker unbalanced dataset -for i in range(0, len(train_samples)): +for i, sample in enumerate(train_samples): if i < 5: - train_samples[i][2] = "ljspeech-0" + sample[2] = "ljspeech-0" else: - train_samples[i][2] = "ljspeech-1" + sample[2] = "ljspeech-1" weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples)) ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py index cb292c1d3d..e12661a506 100644 --- a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py +++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py @@ -71,11 +71,11 @@ config.model_args.use_sdp = True config.use_sdp = True -# deactivate language sampler +# activate language and speaker samplers config.use_language_weighted_sampler = True config.language_weighted_sampler_alpha = 10 config.use_speaker_weighted_sampler = True -config.speaker_weighted_sampler_alpha = 1 +config.speaker_weighted_sampler_alpha = 5 config.save_json(config_path) From 4c3c19806a4156b127a791350cd6304c2cd7a449 Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 11 Feb 2022 08:47:38 -0300 Subject: [PATCH 03/11] Change the Samplers weights to float for save memory --- TTS/tts/utils/languages.py | 2 +- TTS/tts/utils/speakers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index d8d4c70b79..3016f05b02 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -137,4 +137,4 @@ def get_language_weighted_sampler(items: list): dataset_samples_weight = np.array([weight_language[l] for l in language_ids]) # normalize dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) - return torch.from_numpy(dataset_samples_weight).double() + return torch.from_numpy(dataset_samples_weight).float() diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 46aa3b9b72..e4fc568b3e 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -457,4 +457,4 @@ def get_speaker_weighted_sampler(items: list): dataset_samples_weight = np.array([weight_speaker[l] for l in speaker_ids]) # normalize dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) - return torch.from_numpy(dataset_samples_weight).double() + return torch.from_numpy(dataset_samples_weight).float() From ebfa2b9b0a1f91718309f3dc29d38f3fed8e2467 Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 11 Feb 2022 09:31:34 -0300 Subject: [PATCH 04/11] Change the test_samplers to unittest format --- tests/data_tests/test_samplers.py | 84 ++++++++++++++++--------------- 1 file changed, 44 insertions(+), 40 deletions(-) diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 044c0de917..c9ce89a43a 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -1,5 +1,7 @@ import functools +import unittest + import torch from TTS.config.shared_configs import BaseDatasetConfig @@ -26,41 +28,11 @@ language="pt-br", ) -# Adding the EN samples twice to create an unbalanced dataset +# Adding the EN samples twice to create a language unbalanced dataset train_samples, eval_samples = load_tts_samples( [dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True ) - -def is_balanced(lang_1, lang_2): - return 0.85 < lang_1 / lang_2 < 1.2 - - -random_sampler = torch.utils.data.RandomSampler(train_samples) -ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) -en, pt = 0, 0 -for index in ids: - if train_samples[index]["language"] == "en": - en += 1 - else: - pt += 1 - -assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced" - - -weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples)) -ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) -en, pt = 0, 0 -for index in ids: - if train_samples[index]["language"] == "en": - en += 1 - else: - pt += 1 - -assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced" - -# test speaker weighted sampler - # gerenate a speaker unbalanced dataset for i, sample in enumerate(train_samples): if i < 5: @@ -68,13 +40,45 @@ def is_balanced(lang_1, lang_2): else: sample[2] = "ljspeech-1" -weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples)) -ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) -spk1, spk2 = 0, 0 -for index in ids: - if train_samples[index][2] == "ljspeech-0": - spk1 += 1 - else: - spk2 += 1 -assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced" +def is_balanced(lang_1, lang_2): + return 0.85 < lang_1 / lang_2 < 1.2 + + +class TestSamplers(unittest.TestCase): + def test_language_random_sampler(self): # pylint: disable=no-self-use + random_sampler = torch.utils.data.RandomSampler(train_samples) + ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) + en, pt = 0, 0 + for index in ids: + if train_samples[index][3] == "en": + en += 1 + else: + pt += 1 + + assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced" + + def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use + weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples)) + ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) + en, pt = 0, 0 + for index in ids: + if train_samples[index][3] == "en": + en += 1 + else: + pt += 1 + + assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced" + + def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use + + weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples)) + ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) + spk1, spk2 = 0, 0 + for index in ids: + if train_samples[index][2] == "ljspeech-0": + spk1 += 1 + else: + spk2 += 1 + + assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced" From 9af8073faeebe5c9d1c39c4ff5ad0c7f76a75d7c Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 11 Feb 2022 12:07:12 -0300 Subject: [PATCH 05/11] Add get_sampler method in BaseTTS --- TTS/tts/models/base_tts.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 2ff9d339d0..1d230273f8 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -233,6 +233,26 @@ def format_batch(self, batch: Dict) -> Dict: "language_ids": language_ids, } + def get_sampler(self, config: Coqpit, data_items: List, sampler: bool = None): + weights = None + if getattr(config, "use_language_weighted_sampler", False): + alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) + print(" > Using Language weighted sampler with alpha:", alpha) + weights = get_language_balancer_weights(data_items) * alpha + + if getattr(config, "use_speaker_weighted_sampler", False): + alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) + print(" > Using Speaker weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_speaker_balancer_weights(data_items) * alpha + else: + weights = get_speaker_balancer_weights(data_items) * alpha + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + + return sampler + def get_data_loader( self, config: Coqpit, @@ -313,23 +333,7 @@ def get_data_loader( num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False) ), "speaker_weighted_sampler is not supported with DistributedSampler" - if sampler is None: - weights = None - if getattr(config, "use_language_weighted_sampler", False): - alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) - print(" > Using Language weighted sampler with alpha:", alpha) - weights = get_language_balancer_weights(dataset.items) * alpha - - if getattr(config, "use_speaker_weighted_sampler", False): - alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) - print(" > Using Speaker weighted sampler with alpha:", alpha) - if weights is not None: - weights += get_speaker_balancer_weights(dataset.items) * alpha - else: - weights = get_speaker_balancer_weights(dataset.items) * alpha - - if weights is not None: - sampler = WeightedRandomSampler(weights, len(weights)) + sampler = self.get_sampler(config, dataset.items, sampler) if sampler is None else sampler loader = DataLoader( dataset, From 77f527a1d748b21c7878ac6f29620f387f3c3568 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 21 Feb 2022 14:34:45 +0000 Subject: [PATCH 06/11] Fix rebase issues --- TTS/tts/utils/languages.py | 2 +- TTS/tts/utils/speakers.py | 2 +- tests/data_tests/test_samplers.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 3016f05b02..7decabb078 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -127,7 +127,7 @@ def _set_file_path(path): return None -def get_language_weighted_sampler(items: list): +def get_language_balancer_weights(items: list): language_names = np.array([item["language"] for item in items]) unique_language_names = np.unique(language_names).tolist() language_ids = [unique_language_names.index(l) for l in language_names] diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index e4fc568b3e..078ce3f1dd 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -448,7 +448,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, return speaker_manager -def get_speaker_weighted_sampler(items: list): +def get_speaker_balancer_weights(items: list): speaker_names = np.array([item["speaker_name"] for item in items]) unique_speaker_names = np.unique(speaker_names).tolist() speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index c9ce89a43a..12152fb812 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -36,9 +36,9 @@ # gerenate a speaker unbalanced dataset for i, sample in enumerate(train_samples): if i < 5: - sample[2] = "ljspeech-0" + sample["speaker_name"] = "ljspeech-0" else: - sample[2] = "ljspeech-1" + sample["speaker_name"] = "ljspeech-1" def is_balanced(lang_1, lang_2): @@ -51,7 +51,7 @@ def test_language_random_sampler(self): # pylint: disable=no-self-use ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) en, pt = 0, 0 for index in ids: - if train_samples[index][3] == "en": + if train_samples[index]["language"] == "en": en += 1 else: pt += 1 @@ -63,7 +63,7 @@ def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) en, pt = 0, 0 for index in ids: - if train_samples[index][3] == "en": + if train_samples[index]["language"] == "en": en += 1 else: pt += 1 @@ -76,7 +76,7 @@ def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) spk1, spk2 = 0, 0 for index in ids: - if train_samples[index][2] == "ljspeech-0": + if train_samples[index]["speaker_name"] == "ljspeech-0": spk1 += 1 else: spk2 += 1 From c34f65f6c9512cfe1d61f4ba7171d14e5f0306b8 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 22 Feb 2022 21:16:24 +0000 Subject: [PATCH 07/11] Add language and speaker samplers support for DDP training --- TTS/tts/models/base_tts.py | 29 ++++++++-------- TTS/utils/samplers.py | 68 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 15 deletions(-) create mode 100644 TTS/utils/samplers.py diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 1d230273f8..e7dbae8e2a 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -7,7 +7,7 @@ from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler +from TTS.utils.samplers import DistributedSampler, DistributedSamplerWithSampler from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset @@ -233,8 +233,10 @@ def format_batch(self, batch: Dict) -> Dict: "language_ids": language_ids, } - def get_sampler(self, config: Coqpit, data_items: List, sampler: bool = None): + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): weights = None + data_items = dataset.items + if getattr(config, "use_language_weighted_sampler", False): alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) print(" > Using Language weighted sampler with alpha:", alpha) @@ -250,6 +252,14 @@ def get_sampler(self, config: Coqpit, data_items: List, sampler: bool = None): if weights is not None: sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWithSampler(sampler) if num_gpus > 1 else sampler return sampler @@ -321,19 +331,8 @@ def get_data_loader( # sort input sequences from short to long dataset.preprocess_samples() - # sampler for DDP - sampler = DistributedSampler(dataset) if num_gpus > 1 else None - - # Weighted samplers - # TODO: make this DDP amenable - assert not ( - num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) - ), "language_weighted_sampler is not supported with DistributedSampler" - assert not ( - num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False) - ), "speaker_weighted_sampler is not supported with DistributedSampler" - - sampler = self.get_sampler(config, dataset.items, sampler) if sampler is None else sampler + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) loader = DataLoader( dataset, diff --git a/TTS/utils/samplers.py b/TTS/utils/samplers.py new file mode 100644 index 0000000000..5ecd7fb04a --- /dev/null +++ b/TTS/utils/samplers.py @@ -0,0 +1,68 @@ +import torch +from torch.utils.data.distributed import DistributedSampler + +class DistributedSamplerWithSampler(DistributedSampler): + """ Wrapper over Sampler for distributed training. It allows you to use any sampler in distributed mode. + It is especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. In such a case, each + process can pass a torch.utils.data.DistributedSampler instance as a torch.utils.data.DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note: + Dataset is assumed to be of constant size. + + Args: + sampler: Sampler used for subsampling. + num_replicas (int, optional): Number of processes participating in distributed training. By default, + world_size is retrieved from the current distributed group. + rank (int, optional): Rank of the current process within num_replicas. By default, rank is retrieved + from the current distributed group. + shuffle (bool, optional): If True, sampler will shuffle the indices. Default: True. + seed (int, optional): random seed used to shuffle the sampler if shuffle=True. This number should be + identical across all processes in the distributed group. Default: 0. + + Reference: https://github.com/pytorch/pytorch/issues/23430 + + """ + + def __init__( + self, + sampler, + num_replicas: int = None, + rank: int = None, + shuffle: bool = True, + seed: int = 0 + ): + super().__init__( + sampler, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + seed=seed + ) + + def __iter__(self): + indices = list(self.dataset)[:self.total_size] + + # Add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size , f"{len(indices)} != {self.total_size}" + + # Subsample + offset = self.num_samples * self.rank + indices = indices[offset : offset + self.num_samples] + assert len(indices) == self.num_samples, f"{len(indices)} != {self.num_samples}" + + return iter(indices) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + if hasattr(self.dataset, 'set_epoch'): + self.dataset.set_epoch(epoch) + elif hasattr(self.dataset, 'generator'): + self.dataset.generator = torch.Generator().manual_seed(self.seed + epoch) + + def state_dict(self): + return self.dataset.state_dict() + + def load_state_dict(self, state_dict): + self.dataset.load_state_dict(state_dict) From 11aa2fc407501bbd9d0aa60f417d1bfb467845da Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 23 Feb 2022 13:56:50 +0000 Subject: [PATCH 08/11] Rename distributed sampler wrapper --- TTS/tts/models/base_tts.py | 4 ++-- TTS/utils/samplers.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index e7dbae8e2a..32b253b223 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -7,7 +7,7 @@ from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader -from TTS.utils.samplers import DistributedSampler, DistributedSamplerWithSampler +from TTS.utils.samplers import DistributedSampler, DistributedSamplerWrapper from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset @@ -259,7 +259,7 @@ def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): if sampler is None: sampler = DistributedSampler(dataset) if num_gpus > 1 else None else: # If a sampler is already defined use this sampler and DDP sampler together - sampler = DistributedSamplerWithSampler(sampler) if num_gpus > 1 else sampler + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler return sampler diff --git a/TTS/utils/samplers.py b/TTS/utils/samplers.py index 5ecd7fb04a..560a26d8d5 100644 --- a/TTS/utils/samplers.py +++ b/TTS/utils/samplers.py @@ -1,7 +1,7 @@ import torch from torch.utils.data.distributed import DistributedSampler -class DistributedSamplerWithSampler(DistributedSampler): +class DistributedSamplerWrapper(DistributedSampler): """ Wrapper over Sampler for distributed training. It allows you to use any sampler in distributed mode. It is especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. In such a case, each process can pass a torch.utils.data.DistributedSampler instance as a torch.utils.data.DataLoader sampler, From d212cd92212a615e592bea579592c4cce0448556 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 7 Mar 2022 16:03:13 -0300 Subject: [PATCH 09/11] Remove the DistributedSamplerWrapper and use the one from Trainer --- TTS/tts/models/base_tts.py | 2 +- TTS/utils/samplers.py | 68 -------------------------------------- 2 files changed, 1 insertion(+), 69 deletions(-) delete mode 100644 TTS/utils/samplers.py diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 32b253b223..4b557c9b7b 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -7,7 +7,7 @@ from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader -from TTS.utils.samplers import DistributedSampler, DistributedSamplerWrapper +from trainer.torch import DistributedSampler, DistributedSamplerWrapper from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset diff --git a/TTS/utils/samplers.py b/TTS/utils/samplers.py deleted file mode 100644 index 560a26d8d5..0000000000 --- a/TTS/utils/samplers.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -from torch.utils.data.distributed import DistributedSampler - -class DistributedSamplerWrapper(DistributedSampler): - """ Wrapper over Sampler for distributed training. It allows you to use any sampler in distributed mode. - It is especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. In such a case, each - process can pass a torch.utils.data.DistributedSampler instance as a torch.utils.data.DataLoader sampler, - and load a subset of the original dataset that is exclusive to it. - - .. note: - Dataset is assumed to be of constant size. - - Args: - sampler: Sampler used for subsampling. - num_replicas (int, optional): Number of processes participating in distributed training. By default, - world_size is retrieved from the current distributed group. - rank (int, optional): Rank of the current process within num_replicas. By default, rank is retrieved - from the current distributed group. - shuffle (bool, optional): If True, sampler will shuffle the indices. Default: True. - seed (int, optional): random seed used to shuffle the sampler if shuffle=True. This number should be - identical across all processes in the distributed group. Default: 0. - - Reference: https://github.com/pytorch/pytorch/issues/23430 - - """ - - def __init__( - self, - sampler, - num_replicas: int = None, - rank: int = None, - shuffle: bool = True, - seed: int = 0 - ): - super().__init__( - sampler, - num_replicas=num_replicas, - rank=rank, - shuffle=shuffle, - seed=seed - ) - - def __iter__(self): - indices = list(self.dataset)[:self.total_size] - - # Add extra samples to make it evenly divisible - indices += indices[:(self.total_size - len(indices))] - assert len(indices) == self.total_size , f"{len(indices)} != {self.total_size}" - - # Subsample - offset = self.num_samples * self.rank - indices = indices[offset : offset + self.num_samples] - assert len(indices) == self.num_samples, f"{len(indices)} != {self.num_samples}" - - return iter(indices) - - def set_epoch(self, epoch): - super().set_epoch(epoch) - if hasattr(self.dataset, 'set_epoch'): - self.dataset.set_epoch(epoch) - elif hasattr(self.dataset, 'generator'): - self.dataset.generator = torch.Generator().manual_seed(self.seed + epoch) - - def state_dict(self): - return self.dataset.state_dict() - - def load_state_dict(self, state_dict): - self.dataset.load_state_dict(state_dict) From 0898f6ed92e04dda18edafa5d9fbd4bab9a77a3b Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 7 Mar 2022 16:44:05 -0300 Subject: [PATCH 10/11] Bugfix after rebase --- TTS/tts/models/base_tts.py | 2 +- TTS/tts/models/vits.py | 27 +++++---------------------- 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 4b557c9b7b..222f851970 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -235,7 +235,7 @@ def format_batch(self, batch: Dict) -> Dict: def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): weights = None - data_items = dataset.items + data_items = dataset.samples if getattr(config, "use_language_weighted_sampler", False): alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index a43e081c86..6aa30dfe6c 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -13,7 +13,6 @@ from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.configs.shared_configs import CharactersConfig @@ -24,8 +23,8 @@ from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask -from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_weighted_sampler +from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer @@ -1354,31 +1353,15 @@ def get_data_loader( # sort input sequences from short to long dataset.preprocess_samples() - # sampler for DDP - sampler = DistributedSampler(dataset) if num_gpus > 1 else None - - # Weighted samplers - # TODO: make this DDP amenable - assert not ( - num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) - ), "language_weighted_sampler is not supported with DistributedSampler" - assert not ( - num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False) - ), "speaker_weighted_sampler is not supported with DistributedSampler" - - if sampler is None: - if getattr(config, "use_language_weighted_sampler", False): - print(" > Using Language weighted sampler") - sampler = get_language_weighted_sampler(dataset.samples) - elif getattr(config, "use_speaker_weighted_sampler", False): - print(" > Using Language weighted sampler") - sampler = get_speaker_weighted_sampler(dataset.samples) + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, shuffle=False, # shuffle is done in the dataset. drop_last=False, # setting this False might cause issues in AMP training. + sampler=sampler, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, From 60e51b53c2ce464c0df07b365b3941ea20f56d12 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 10 Mar 2022 10:51:41 -0300 Subject: [PATCH 11/11] Move the samplers config to tts config --- TTS/config/shared_configs.py | 17 ----------------- TTS/tts/configs/shared_configs.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py index f746e26875..3ea49796fc 100644 --- a/TTS/config/shared_configs.py +++ b/TTS/config/shared_configs.py @@ -251,18 +251,6 @@ class BaseTrainingConfig(TrainerConfig): num_eval_loader_workers (int): Number of workers for evaluation time dataloader. - - use_speaker_weighted_sampler (bool): - Enable / Disable the batch balancer by speaker. Defaults to ```False```. - - speaker_weighted_sampler_alpha (float): - Number that control the influence of the speaker sampler weights. Defaults to ```1.0```. - - use_language_weighted_sampler (bool): - Enable / Disable the batch balancer by language. Defaults to ```False```. - - language_weighted_sampler_alpha (float): - Number that control the influence of the language sampler weights. Defaults to ```1.0```. """ model: str = None @@ -270,8 +258,3 @@ class BaseTrainingConfig(TrainerConfig): num_loader_workers: int = 0 num_eval_loader_workers: int = 0 use_noise_augment: bool = False - # weighted samplers - use_speaker_weighted_sampler: bool = False - speaker_weighted_sampler_alpha: float = 1.0 - use_language_weighted_sampler: bool = False - language_weighted_sampler_alpha: float = 1.0 diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index f43c646473..a9b56ed497 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -220,6 +220,18 @@ class BaseTTSConfig(BaseTrainingConfig): eval_split_size (float): If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + use_speaker_weighted_sampler (bool): + Enable / Disable the batch balancer by speaker. Defaults to ```False```. + + speaker_weighted_sampler_alpha (float): + Number that control the influence of the speaker sampler weights. Defaults to ```1.0```. + + use_language_weighted_sampler (bool): + Enable / Disable the batch balancer by language. Defaults to ```False```. + + language_weighted_sampler_alpha (float): + Number that control the influence of the language sampler weights. Defaults to ```1.0```. """ audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) @@ -262,3 +274,8 @@ class BaseTTSConfig(BaseTrainingConfig): # evaluation eval_split_max_size: int = None eval_split_size: float = 0.01 + # weighted samplers + use_speaker_weighted_sampler: bool = False + speaker_weighted_sampler_alpha: float = 1.0 + use_language_weighted_sampler: bool = False + language_weighted_sampler_alpha: float = 1.0