From 89a061f1d1ba3a97f5176fdca16eb2d0a2f1f0b6 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Tue, 12 Mar 2024 18:06:50 +0100 Subject: [PATCH] docs(tts.models.vits): clarify use of discriminator/generator [ci skip] --- TTS/tts/models/vits.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index e91d26b9ed..b376f74204 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1233,7 +1233,7 @@ def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> T Args: batch (Dict): Input tensors. criterion (nn.Module): Loss layer designed for the model. - optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. + optimizer_idx (int): Index of optimizer to use. 0 for the discriminator and 1 for the generator networks. Returns: Tuple[Dict, Dict]: Model ouputs and computed losses. @@ -1651,13 +1651,16 @@ def get_data_loader( def get_optimizer(self) -> List: """Initiate and return the GAN optimizers based on the config parameters. - It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + + It returns 2 optimizers in a list. First one is for the discriminator + and the second one is for the generator. + Returns: List: optimizers. """ - # select generator parameters optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) + # select generator parameters gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) optimizer1 = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters