From 6075fa208c4f508bd9b629d13b99800724899502 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 17 Nov 2023 11:13:46 -0300 Subject: [PATCH] Ensures that only GPT model is in training mode during XTTS GPT training (#3241) * Ensures that only GPT model is in training mode during training * Fix parallel wavegan unit test --- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 7 ++++--- TTS/vocoder/configs/parallel_wavegan_config.py | 1 + requirements.txt | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 005b30bede..4789e1f43f 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -318,9 +318,10 @@ def eval_step(self, batch, criterion): batch["cond_idxs"] = None return self.train_step(batch, criterion) - def on_epoch_start(self, trainer): # pylint: disable=W0613 - # guarante that dvae will be in eval mode after .train() on evaluation end - self.dvae = self.dvae.eval() + def on_train_epoch_start(self, trainer): + trainer.model.eval() # the whole model to eval + # put gpt model in training mode + trainer.model.xtts.gpt.train() def on_init_end(self, trainer): # pylint: disable=W0613 # ignore similarities.pth on clearml save/upload diff --git a/TTS/vocoder/configs/parallel_wavegan_config.py b/TTS/vocoder/configs/parallel_wavegan_config.py index 7845dd6bf8..6059d7f04f 100644 --- a/TTS/vocoder/configs/parallel_wavegan_config.py +++ b/TTS/vocoder/configs/parallel_wavegan_config.py @@ -94,6 +94,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig): use_noise_augment: bool = False use_cache: bool = True steps_to_start_discriminator: int = 200000 + target_loss: str = "loss_1" # LOSS PARAMETERS - overrides use_stft_loss: bool = True diff --git a/requirements.txt b/requirements.txt index ce0e5d9207..1f7a44f6d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,7 +27,7 @@ pandas>=1.4,<2.0 # deps for training matplotlib>=3.7.0 # coqui stack -trainer +trainer>=0.0.32 # config management coqpit>=0.0.16 # chinese g2p deps