From 212d330929c22d0cd970be2023770dc1e39449ab Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 29 Apr 2022 16:29:44 -0300 Subject: [PATCH] Fix unit test --- TTS/vocoder/models/gan.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 367efdc24b..ed5b26dd93 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -155,6 +155,7 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[ if optimizer_idx == 1: # GENERATOR loss + scores_fake, feats_fake, feats_real = None, None, None if self.train_disc: if len(signature(self.model_d.forward).parameters) == 2: D_out_fake = self.model_d(self.y_hat_g, x) @@ -182,7 +183,6 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[ self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g ) outputs = {"model_outputs": self.y_hat_g} - return outputs, loss_dict @staticmethod @@ -216,6 +216,7 @@ def train_log( @torch.no_grad() def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: """Call `train_step()` with `no_grad()`""" + self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss return self.train_step(batch, criterion, optimizer_idx) def eval_log(