Skip to content

Commit

Permalink
Fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Apr 29, 2022
1 parent 44456b0 commit 212d330
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion TTS/vocoder/models/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[

if optimizer_idx == 1:
# GENERATOR loss
scores_fake, feats_fake, feats_real = None, None, None
if self.train_disc:
if len(signature(self.model_d.forward).parameters) == 2:
D_out_fake = self.model_d(self.y_hat_g, x)
Expand Down Expand Up @@ -182,7 +183,6 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
)
outputs = {"model_outputs": self.y_hat_g}

return outputs, loss_dict

@staticmethod
Expand Down Expand Up @@ -216,6 +216,7 @@ def train_log(
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Call `train_step()` with `no_grad()`"""
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
return self.train_step(batch, criterion, optimizer_idx)

def eval_log(
Expand Down

0 comments on commit 212d330

Please sign in to comment.