diff --git a/parallel_wavegan/bin/train.py b/parallel_wavegan/bin/train.py index 5f417880..529338b9 100755 --- a/parallel_wavegan/bin/train.py +++ b/parallel_wavegan/bin/train.py @@ -219,6 +219,9 @@ def _train_step(self, batch): # Discriminator # ####################### if self.steps > self.config["discriminator_train_start_steps"]: + # re-compute y_ which leads better quality + with torch.no_grad(): + y_ = self.model["generator"](*x) # calculate discriminator loss p = self.model["discriminator"](y.unsqueeze(1)) p_ = self.model["discriminator"](y_.unsqueeze(1).detach())