From 0b707061a9c2c46d8cfb71051d974fc11f4c7676 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 16 Mar 2022 11:34:32 +0100 Subject: [PATCH] =?UTF-8?q?Update=20base=20model=20wrt=20=F0=9F=91=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TTS/model.py | 142 +++++---------------------------------------------- 1 file changed, 14 insertions(+), 128 deletions(-) diff --git a/TTS/model.py b/TTS/model.py index 39cbeabcbe..a53b916a3f 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -1,46 +1,34 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Tuple +from abc import abstractmethod +from typing import Dict import torch from coqpit import Coqpit -from torch import nn +from trainer import TrainerModel # pylint: skip-file -class BaseTrainerModel(ABC, nn.Module): - """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.""" +class BaseTrainerModel(TrainerModel): + """BaseTrainerModel model expanding TrainerModel with required functions by 🐸TTS. + + Every new 🐸TTS model must inherit it. + """ @staticmethod @abstractmethod def init_from_config(config: Coqpit): - """Init the model from given config. + """Init the model and all its attributes from the given config. Override this depending on your model. """ ... - @abstractmethod - def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: - """Forward ... for the model mainly used in training. - - You can be flexible here and use different number of arguments and argument names since it is intended to be - used by `train_step()` without exposing it out of the model. - - Args: - input (torch.Tensor): Input tensor. - aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs. - - Returns: - Dict: Model outputs. Main model output must be named as "model_outputs". - """ - outputs_dict = {"model_outputs": None} - ... - return outputs_dict - @abstractmethod def inference(self, input: torch.Tensor, aux_input={}) -> Dict: - """Forward ... for inference. + """Forward pass for inference. + + It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs``` + is considered to be the main output and you can add any other auxiliary outputs as you want. We don't use `*kwargs` since it is problematic with the TorchScript API. @@ -55,78 +43,9 @@ def inference(self, input: torch.Tensor, aux_input={}) -> Dict: ... return outputs_dict - def format_batch(self, batch: Dict) -> Dict: - """Format batch returned by the data loader before sending it to the model. - - If not implemented, model uses the batch as is. - Can be used for data augmentation, feature ectraction, etc. - """ - return batch - - def format_batch_on_device(self, batch: Dict) -> Dict: - """Format batch on device before sending it to the model. - - If not implemented, model uses the batch as is. - Can be used for data augmentation, feature ectraction, etc. - """ - return batch - - @abstractmethod - def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: - """Perform a single training step. Run the model forward ... and compute losses. - - Args: - batch (Dict): Input tensors. - criterion (nn.Module): Loss layer designed for the model. - - Returns: - Tuple[Dict, Dict]: Model ouputs and computed losses. - """ - outputs_dict = {} - loss_dict = {} # this returns from the criterion - ... - return outputs_dict, loss_dict - - def train_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None: - """Create visualizations and waveform examples for training. - - For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to - be projected onto Tensorboard. - - Args: - ap (AudioProcessor): audio processor used at training. - batch (Dict): Model inputs used at the previous training step. - outputs (Dict): Model outputs generated at the previoud training step. - - Returns: - Tuple[Dict, np.ndarray]: training plots and output waveform. - """ - ... - - @abstractmethod - def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: - """Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can - call `train_step()` with no changes. - - Args: - batch (Dict): Input tensors. - criterion (nn.Module): Loss layer designed for the model. - - Returns: - Tuple[Dict, Dict]: Model ouputs and computed losses. - """ - outputs_dict = {} - loss_dict = {} # this returns from the criterion - ... - return outputs_dict, loss_dict - - def eval_log(self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int) -> None: - """The same as `train_log()`""" - ... - @abstractmethod def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None: - """Load a checkpoint and get ready for training or inference. + """Load a model checkpoint gile and get ready for training or inference. Args: config (Coqpit): Model configuration. @@ -135,36 +54,3 @@ def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = Fal strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. """ ... - - @staticmethod - @abstractmethod - def init_from_config(config: Coqpit, samples: List[Dict] = None, verbose=False) -> "BaseTrainerModel": - """Init the model from given config. - - Override this depending on your model. - """ - ... - - @abstractmethod - def get_data_loader( - self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int - ): - ... - - # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: - # """Setup an return optimizer or optimizers.""" - # ... - - # def get_lr(self) -> Union[float, List[float]]: - # """Return learning rate(s). - - # Returns: - # Union[float, List[float]]: Model's initial learning rates. - # """ - # ... - - # def get_scheduler(self, optimizer: torch.optim.Optimizer): - # ... - - # def get_criterion(self): - # ...