From 726ccca1ddbdd4aa056be07fb8a077383f7e0c12 Mon Sep 17 00:00:00 2001 From: Edresson Date: Fri, 11 Feb 2022 09:31:34 -0300 Subject: [PATCH] Change the test_samplers to unittest format --- tests/data_tests/test_samplers.py | 84 ++++++++++++++++--------------- 1 file changed, 44 insertions(+), 40 deletions(-) diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 1d1c0418d5..c9ce89a43a 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -1,5 +1,7 @@ import functools +import unittest + import torch from TTS.config.shared_configs import BaseDatasetConfig @@ -26,41 +28,11 @@ language="pt-br", ) -# Adding the EN samples twice to create an unbalanced dataset +# Adding the EN samples twice to create a language unbalanced dataset train_samples, eval_samples = load_tts_samples( [dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True ) - -def is_balanced(lang_1, lang_2): - return 0.85 < lang_1 / lang_2 < 1.2 - - -random_sampler = torch.utils.data.RandomSampler(train_samples) -ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) -en, pt = 0, 0 -for index in ids: - if train_samples[index][3] == "en": - en += 1 - else: - pt += 1 - -assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced" - - -weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples)) -ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) -en, pt = 0, 0 -for index in ids: - if train_samples[index][3] == "en": - en += 1 - else: - pt += 1 - -assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced" - -# test speaker weighted sampler - # gerenate a speaker unbalanced dataset for i, sample in enumerate(train_samples): if i < 5: @@ -68,13 +40,45 @@ def is_balanced(lang_1, lang_2): else: sample[2] = "ljspeech-1" -weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples)) -ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) -spk1, spk2 = 0, 0 -for index in ids: - if train_samples[index][2] == "ljspeech-0": - spk1 += 1 - else: - spk2 += 1 -assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced" +def is_balanced(lang_1, lang_2): + return 0.85 < lang_1 / lang_2 < 1.2 + + +class TestSamplers(unittest.TestCase): + def test_language_random_sampler(self): # pylint: disable=no-self-use + random_sampler = torch.utils.data.RandomSampler(train_samples) + ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)]) + en, pt = 0, 0 + for index in ids: + if train_samples[index][3] == "en": + en += 1 + else: + pt += 1 + + assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced" + + def test_language_weighted_random_sampler(self): # pylint: disable=no-self-use + weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_language_balancer_weights(train_samples), len(train_samples)) + ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) + en, pt = 0, 0 + for index in ids: + if train_samples[index][3] == "en": + en += 1 + else: + pt += 1 + + assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced" + + def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use + + weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(get_speaker_balancer_weights(train_samples), len(train_samples)) + ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)]) + spk1, spk2 = 0, 0 + for index in ids: + if train_samples[index][2] == "ljspeech-0": + spk1 += 1 + else: + spk2 += 1 + + assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"