From a9167cf239f3dc2e1662966f5686f3a1ec5b896f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 15 Dec 2022 00:56:48 +0100 Subject: [PATCH] Fixup overflow (#2218) * Update overflow config * Pulling shuffle and drop_last from config * Print training stats for overflow --- TTS/tts/configs/overflow_config.py | 5 +++-- TTS/tts/configs/shared_configs.py | 9 +++++++++ TTS/tts/models/base_tts.py | 4 ++-- TTS/tts/models/overflow.py | 14 ++++++++++++-- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/TTS/tts/configs/overflow_config.py b/TTS/tts/configs/overflow_config.py index ffa0a8a0f3..dc3e5548b8 100644 --- a/TTS/tts/configs/overflow_config.py +++ b/TTS/tts/configs/overflow_config.py @@ -169,8 +169,9 @@ class OverflowConfig(BaseTTSConfig): # The classname has to be camel case lr_scheduler: str = None # overrides - min_seq_len: int = 3 - max_seq_len: int = 500 + min_text_len: int = 10 + max_text_len: int = 500 + min_audio_len: int = 512 # testing test_sentences: List[str] = field( diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index 676ea4a9eb..16b77c385b 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -230,6 +230,13 @@ class BaseTTSConfig(BaseTrainingConfig): If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. Defaults to False. + shuffle (bool): + If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True. + + drop_last (bool): + If True, the data loader will drop the last batch if it is not complete. It helps to prevent + issues that emerge from the partial batch statistics. Defaults to True. + add_blank (bool): Add blank characters between each other two characters. It improves performance for some models at expense of slower run-time due to the longer input sequence. @@ -309,6 +316,8 @@ class BaseTTSConfig(BaseTrainingConfig): precompute_num_workers: int = 0 use_noise_augment: bool = False start_by_longest: bool = False + shuffle: bool = False + drop_last: bool = False # dataset datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) # optimizer diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index fffb55b0cc..58d740d218 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -345,9 +345,9 @@ def get_data_loader( loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, - shuffle=True, # if there is no other sampler + shuffle=config.shuffle if sampler is not None else False, # if there is no other sampler collate_fn=dataset.collate_fn, - drop_last=False, # setting this False might cause issues in AMP training. + drop_last=config.drop_last, # setting this False might cause issues in AMP training. sampler=sampler, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py index ffb68c0bee..ee5ff411e1 100644 --- a/TTS/tts/models/overflow.py +++ b/TTS/tts/models/overflow.py @@ -159,6 +159,15 @@ def forward(self, text, text_len, mels, mel_len): return outputs + @staticmethod + def _training_stats(batch): + stats = {} + stats["avg_text_length"] = batch["text_lengths"].float().mean() + stats["avg_spec_length"] = batch["mel_lengths"].float().mean() + stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean() + stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean() + return stats + def train_step(self, batch: dict, criterion: nn.Module): text_input = batch["text_input"] text_lengths = batch["text_lengths"] @@ -171,9 +180,10 @@ def train_step(self, batch: dict, criterion: nn.Module): mels=mel_input, mel_len=mel_lengths, ) + loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum())) - loss_dict = criterion(outputs["log_probs"]) - + # for printing useful statistics on terminal + loss_dict.update(self._training_stats(batch)) return outputs, loss_dict def eval_step(self, batch: Dict, criterion: nn.Module):