Skip to content

Commit

Permalink
Add Perfect Batch Sampler unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Mar 10, 2022
1 parent 9c8b820 commit b0bad56
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
4 changes: 2 additions & 2 deletions TTS/encoder/utils/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ class PerfectBatchSampler(Sampler):
drop_last (bool): if True, drops last incomplete batch.
"""

def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False):
def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"):
super().__init__(dataset_items)
assert batch_size % (num_classes_in_batch * num_gpus) == 0, (
'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).')

label_indices = {}
for idx, item in enumerate(dataset_items):
label = item['class_name']
label = item[label_key]
if label not in label_indices.keys():
label_indices[label] = [idx]
else:
Expand Down
49 changes: 49 additions & 0 deletions tests/data_tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights
from TTS.encoder.utils.samplers import PerfectBatchSampler

# Fixing random state to avoid random fails
torch.manual_seed(0)
Expand Down Expand Up @@ -82,3 +83,51 @@ def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use
spk2 += 1

assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"

def test_perfect_sampler(self): # pylint: disable=no-self-use
classes = set()
for item in train_samples:
classes.add(item["speaker_name"])

sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=False,
drop_last=True)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0
# for in each batch
for index in batch:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"

def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use
classes = set()
for item in train_samples:
classes.add(item["speaker_name"])

sampler = PerfectBatchSampler(
train_samples,
classes,
batch_size=2 * 3, # total batch size
num_classes_in_batch=2,
label_key="speaker_name",
shuffle=True,
drop_last=False)
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
for batch in batchs:
spk1, spk2 = 0, 0
# for in each batch
for index in batch:
if train_samples[index]["speaker_name"] == "ljspeech-0":
spk1 += 1
else:
spk2 += 1
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"

0 comments on commit b0bad56

Please sign in to comment.