Skip to content

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Apr 29, 2022
1 parent d545bea commit 44456b0
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions TTS/vocoder/models/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[
if optimizer_idx not in [0, 1]:
raise ValueError(" [!] Unexpected `optimizer_idx`.")


if optimizer_idx == 0:
# DISCRIMINATOR optimization

# generator pass
y_hat = self.model_g(x)[:, :, : y.size(2)]

# cache for generator loss
# pylint: disable=W0201
self.y_hat_g = y_hat
self.y_hat_sub = None
self.y_sub_g = None
Expand Down Expand Up @@ -178,7 +178,9 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[
feats_fake, feats_real = None, None

# compute losses
loss_dict = criterion[optimizer_idx](self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g)
loss_dict = criterion[optimizer_idx](
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
)
outputs = {"model_outputs": self.y_hat_g}

return outputs, loss_dict
Expand Down

0 comments on commit 44456b0

Please sign in to comment.