Skip to content

Commit

Permalink
Fix GAN optimizer order
Browse files Browse the repository at this point in the history
commit 212d330
Author: Edresson Casanova <edresson1@gmail.com>
Date:   Fri Apr 29 16:29:44 2022 -0300

    Fix unit test

commit 44456b0
Author: Edresson Casanova <edresson1@gmail.com>
Date:   Fri Apr 29 07:28:39 2022 -0300

    Fix style

commit d545bea
Author: Edresson Casanova <edresson1@gmail.com>
Date:   Thu Apr 28 17:08:04 2022 -0300

    Change order of HIFI-GAN optimizers to be equal than the original repository

commit 657c544
Author: Edresson Casanova <edresson1@gmail.com>
Date:   Thu Apr 28 15:40:16 2022 -0300

    Remove audio padding before mel spec extraction

commit 76b274e
Merge: 379ccd7 6233f4f
Author: Edresson Casanova <edresson1@gmail.com>
Date:   Wed Apr 27 07:28:48 2022 -0300

    Merge pull request #1541 from coqui-ai/comp_emb_fix

    Bug fix in compute embedding without eval partition

commit 379ccd7
Author: WeberJulian <julian.weber@hotmail.fr>
Date:   Wed Apr 27 10:42:26 2022 +0200

    returns y_mask in VITS inference (#1540)

    * returns y_mask

    * make style
  • Loading branch information
erogol committed May 7, 2022
1 parent 6003467 commit a0a9279
Showing 1 changed file with 46 additions and 39 deletions.
85 changes: 46 additions & 39 deletions TTS/vocoder/models/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,50 +90,26 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[
raise ValueError(" [!] Unexpected `optimizer_idx`.")

if optimizer_idx == 0:
# GENERATOR
# DISCRIMINATOR optimization

# generator pass
y_hat = self.model_g(x)[:, :, : y.size(2)]
self.y_hat_g = y_hat # save for discriminator
y_hat_sub = None
y_sub = None

# cache for generator loss
# pylint: disable=W0201
self.y_hat_g = y_hat
self.y_hat_sub = None
self.y_sub_g = None

# PQMF formatting
if y_hat.shape[1] > 1:
y_hat_sub = y_hat
self.y_hat_sub = y_hat
y_hat = self.model_g.pqmf_synthesis(y_hat)
self.y_hat_g = y_hat # save for discriminator
y_sub = self.model_g.pqmf_analysis(y)
self.y_hat_g = y_hat # save for generator loss
self.y_sub_g = self.model_g.pqmf_analysis(y)

scores_fake, feats_fake, feats_real = None, None, None
if self.train_disc:

if len(signature(self.model_d.forward).parameters) == 2:
D_out_fake = self.model_d(y_hat, x)
else:
D_out_fake = self.model_d(y_hat)
D_out_real = None

if self.config.use_feat_match_loss:
with torch.no_grad():
D_out_real = self.model_d(y)

# format D outputs
if isinstance(D_out_fake, tuple):
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
feats_real = None
else:
_, feats_real = D_out_real
else:
scores_fake = D_out_fake
feats_fake, feats_real = None, None

# compute losses
loss_dict = criterion[optimizer_idx](y_hat, y, scores_fake, feats_fake, feats_real, y_hat_sub, y_sub)
outputs = {"model_outputs": y_hat}

if optimizer_idx == 1:
# DISCRIMINATOR
if self.train_disc:
# use different samples for G and D trainings
if self.config.diff_samples_for_G_and_D:
Expand Down Expand Up @@ -177,6 +153,36 @@ def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[
loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
outputs = {"model_outputs": y_hat}

if optimizer_idx == 1:
# GENERATOR loss
scores_fake, feats_fake, feats_real = None, None, None
if self.train_disc:
if len(signature(self.model_d.forward).parameters) == 2:
D_out_fake = self.model_d(self.y_hat_g, x)
else:
D_out_fake = self.model_d(self.y_hat_g)
D_out_real = None

if self.config.use_feat_match_loss:
with torch.no_grad():
D_out_real = self.model_d(y)

# format D outputs
if isinstance(D_out_fake, tuple):
scores_fake, feats_fake = D_out_fake
if D_out_real is None:
feats_real = None
else:
_, feats_real = D_out_real
else:
scores_fake = D_out_fake
feats_fake, feats_real = None, None

# compute losses
loss_dict = criterion[optimizer_idx](
self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
)
outputs = {"model_outputs": self.y_hat_g}
return outputs, loss_dict

@staticmethod
Expand Down Expand Up @@ -210,6 +216,7 @@ def train_log(
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Call `train_step()` with `no_grad()`"""
self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss
return self.train_step(batch, criterion, optimizer_idx)

def eval_log(
Expand Down Expand Up @@ -266,15 +273,15 @@ def get_optimizer(self) -> List:
optimizer2 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
)
return [optimizer1, optimizer2]
return [optimizer2, optimizer1]

def get_lr(self) -> List:
"""Set the initial learning rates for each optimizer.
Returns:
List: learning rates for each optimizer.
"""
return [self.config.lr_gen, self.config.lr_disc]
return [self.config.lr_disc, self.config.lr_gen]

def get_scheduler(self, optimizer) -> List:
"""Set the schedulers for each optimizer.
Expand All @@ -287,7 +294,7 @@ def get_scheduler(self, optimizer) -> List:
"""
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler1, scheduler2]
return [scheduler2, scheduler1]

@staticmethod
def format_batch(batch: List) -> Dict:
Expand Down Expand Up @@ -359,7 +366,7 @@ def get_data_loader( # pylint: disable=no-self-use, unused-argument

def get_criterion(self):
"""Return criterions for the optimizers"""
return [GeneratorLoss(self.config), DiscriminatorLoss(self.config)]
return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]

@staticmethod
def init_from_config(config: Coqpit, verbose=True) -> "GAN":
Expand Down

0 comments on commit a0a9279

Please sign in to comment.