Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training HiFiGan -- avg loss not decreasing #1003

Closed
skol101 opened this issue Dec 8, 2021 · 24 comments
Closed

Training HiFiGan -- avg loss not decreasing #1003

skol101 opened this issue Dec 8, 2021 · 24 comments

Comments

@skol101
Copy link

skol101 commented Dec 8, 2021

Describe the bug
Running for 240k steps no improvement is avg loss when training HiFiGan.

To Reproduce
Steps to reproduce the behavior:

  1. Run the following command : CUDA_VISIBLE_DEVICES=0,1 python ../../TTS/TTS/bin/distribute.py --script train_hifigan.py
  2. Training:

TRAINING (2021-11-29 10:06:07)

--> STEP: 24/352 -- GLOBAL_STEP: 244300
| > G_l1_spec_loss: 0.36788 (0.35471)
| > G_gen_loss: 16.55468 (15.96176)
| > G_adv_loss: 0.00000 (0.00000)
| > loss_0: 16.55468 (15.96176)
| > grad_norm_0: 22.85464 (28.87626)
| > current_lr_0: 7.0524350586068e-111
| > current_lr_1: 0.00010
| > step_time: 0.29070 (0.29190)
| > loader_time: 0.00150 (0.00135)
3. Evaluation:
--> EVAL PERFORMANCE
| > avg_loader_time: 0.00034 (-0.00003)
| > avg_G_l1_spec_loss: 0.35621 (+0.00000)
| > avg_G_gen_loss: 16.02957 (+0.00000)
| > avg_G_adv_loss: 0.00000 (+0.00000)
| > avg_loss_0: 16.02957 (+0.00000)

Expected behavior
Improvement in loss during training.

Environment (please complete the following information):
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04
PyTorch or TensorFlow version (use command below): pytorch 1.10.0
Python version: 3.8.11
CUDA/cuDNN version: py3.8_cuda11.3_cudnn8.2.0_0
GPU model and memory: 2xRTX 3090

Additional context
Add any other context about the problem here.
Screenshot from 2021-11-29 10-13-26

import os

from TTS.trainer import Trainer, TrainingArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import HifiganConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.models.gan import GAN

output_path = os.path.dirname(os.path.abspath(__file__))

config = HifiganConfig(
    batch_size=64,
    eval_batch_size=16,
    num_loader_workers=4,
    num_eval_loader_workers=4,
    run_eval=True,
    test_delay_epochs=5,
    epochs=1000,
    seq_len=8192,
    pad_short=2000,
    use_noise_augment=True,
    eval_split_size=10,
    print_step=25,
    print_eval=False,
    mixed_precision=False,
    lr_gen=1e-4,
    lr_disc=1e-4,
    data_path=os.path.join(output_path, "../datasets/vctk_all_22"),
    output_path=output_path,
)

# init audio processor
ap = AudioProcessor(**config.audio.to_dict())

# load training samples
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)


# init model
model = GAN(config)

# init the trainer and 🚀
trainer = Trainer(
    TrainingArgs(),
    config,
    output_path,
    model=model,
    train_samples=train_samples,
    eval_samples=eval_samples,
    training_assets={"audio_processor": ap},
)
trainer.fit()

Script also generates config.json in the dir where train_hifigan.py resides as well as in the generated run dir.

{
    "model": "hifigan",
    "run_name": "coqui_tts",
    "run_description": "",
    "epochs": 1000,
    "batch_size": 64,
    "eval_batch_size": 16,
    "mixed_precision": false,
    "scheduler_after_epoch": false,
    "run_eval": true,
    "test_delay_epochs": 5,
    "print_eval": false,
    "dashboard_logger": "tensorboard",
    "print_step": 25,
    "plot_step": 100,
    "model_param_stats": false,
    "project_name": null,
    "log_model_step": null,
    "wandb_entity": null,
    "save_step": 10000,
    "checkpoint": true,
    "keep_all_best": false,
    "keep_after": 10000,
    "num_loader_workers": 4,
    "num_eval_loader_workers": 4,
    "use_noise_augment": true,
    "output_path": "/home/sk/work/hifigan",
    "distributed_backend": "nccl",
    "distributed_url": "tcp://localhost:54321",
    "audio": {
        "fft_size": 1024,
        "win_length": 1024,
        "hop_length": 256,
        "frame_shift_ms": null,
        "frame_length_ms": null,
        "stft_pad_mode": "reflect",
        "sample_rate": 22050,
        "resample": false,
        "preemphasis": 0.0,
        "ref_level_db": 20,
        "do_sound_norm": false,
        "log_func": "np.log10",
        "do_trim_silence": true,
        "trim_db": 45,
        "power": 1.5,
        "griffin_lim_iters": 60,
        "num_mels": 80,
        "mel_fmin": 0.0,
        "mel_fmax": null,
        "spec_gain": 20,
        "do_amp_to_db_linear": true,
        "do_amp_to_db_mel": true,
        "signal_norm": true,
        "min_level_db": -100,
        "symmetric_norm": true,
        "max_norm": 4.0,
        "clip_norm": true,
        "stats_path": null
    },
    "eval_split_size": 10,
    "data_path": "/home/sk/work/hifigan/../datasets/vctk_all_wavs",
    "feature_path": null,
    "seq_len": 8192,
    "pad_short": 2000,
    "conv_pad": 0,
    "use_cache": false,
    "wd": 1e-06,
    "optimizer": "AdamW",
    "optimizer_params": {
        "betas": [
            0.8,
            0.99
        ],
        "weight_decay": 0.0
    },
    "use_stft_loss": false,
    "use_subband_stft_loss": false,
    "use_mse_gan_loss": true,
    "use_hinge_gan_loss": false,
    "use_feat_match_loss": true,
    "use_l1_spec_loss": true,
    "stft_loss_weight": 0,
    "subband_stft_loss_weight": 0,
    "mse_G_loss_weight": 1,
    "hinge_G_loss_weight": 0,
    "feat_match_loss_weight": 108,
    "l1_spec_loss_weight": 45,
    "stft_loss_params": {
        "n_ffts": [
            1024,
            2048,
            512
        ],
        "hop_lengths": [
            120,
            240,
            50
        ],
        "win_lengths": [
            600,
            1200,
            240
        ]
    },
    "l1_spec_loss_params": {
        "use_mel": true,
        "sample_rate": 22050,
        "n_fft": 1024,
        "hop_length": 256,
        "win_length": 1024,
        "n_mels": 80,
        "mel_fmin": 0.0,
        "mel_fmax": null
    },
    "target_loss": "loss_0",
    "grad_clip": [
        5,
        5
    ],
    "lr_gen": 0.0001,
    "lr_disc": 0.0001,
    "lr_scheduler_gen": "ExponentialLR",
    "lr_scheduler_gen_params": {
        "gamma": 0.999,
        "last_epoch": -1
    },
    "lr_scheduler_disc": "ExponentialLR",
    "lr_scheduler_disc_params": {
        "gamma": 0.999,
        "last_epoch": -1
    },
    "use_pqmf": false,
    "diff_samples_for_G_and_D": false,
    "discriminator_model": "hifigan_discriminator",
    "generator_model": "hifigan_generator",
    "generator_model_params": {
        "upsample_factors": [
            8,
            8,
            2,
            2
        ],
        "upsample_kernel_sizes": [
            16,
            16,
            4,
            4
        ],
        "upsample_initial_channel": 512,
        "resblock_kernel_sizes": [
            3,
            7,
            11
        ],
        "resblock_dilation_sizes": [
            [
                1,
                3,
                5
            ],
            [
                1,
                3,
                5
            ],
            [
                1,
                3,
                5
            ]
        ],
        "resblock_type": "1"
    },
    "lr": 0.0001
}```

_Originally posted by @skol101 in https://github.com/coqui-ai/TTS/discussions/975_
@skol101
Copy link
Author

skol101 commented Dec 18, 2021

Forgot to add: There's no alignment section on the Tensorboards.

@erogol
Copy link
Member

erogol commented Dec 20, 2021

You don't need alignments with HifiGAN. There is no attention or duration prediction.

@alievilya
Copy link

You don't need alignments with HifiGAN. There is no attention or duration prediction.

Hello! I was running HifiGAN on LJSpeech with default config and having the same issue.
Could you please clarify how to fix it?

@loganhart02
Copy link
Contributor

loganhart02 commented Dec 21, 2021 via email

@andreibezborodov
Copy link

andreibezborodov commented Dec 23, 2021

You don't need alignments with HifiGAN. There is no attention or duration prediction.

Hello! I was running HifiGAN on LJSpeech with default config and having the same issue. Could you please clarify how to fix it?

@erogol Also having the same problem, the model is not fitting. Was running config from 'recipies/ljspeech/hifigan/train_hifigan.py' for LJSpeech dataset. What could cause such problem?

image

@erogol
Copy link
Member

erogol commented Dec 23, 2021

I'm aware of the problem and will check when I have time for it.

@erogol
Copy link
Member

erogol commented Dec 23, 2021

Can someone upload the TB to https://tensorboard.dev/ and share?

@erogol
Copy link
Member

erogol commented Dec 23, 2021

You don't need alignments with HifiGAN. There is no attention or duration prediction.

Hello! I was running HifiGAN on LJSpeech with default config and having the same issue. Could you please clarify how to fix it?

what is the same issue?

@andreibezborodov
Copy link

andreibezborodov commented Dec 23, 2021

@erogol
Copy link
Member

erogol commented Dec 23, 2021

Here is how it looks if I run TTS/recipes/ljspeech/hifigan

And the model optimizes without any problem.

I can't reproduce the problem

Screenshot 2021-12-23 at 15 59 04

@erogol
Copy link
Member

erogol commented Dec 23, 2021

Can someone upload the TB to https://tensorboard.dev/ and share?

https://tensorboard.dev/experiment/9pdI9rvbSYO4eMCWLH1Tqg/#scalars

Your config is the same as my config. only diff is the file paths. I suspect it is about the system setup.

Can you do this ?

wget https://raw.githubusercontent.com/coqui-ai/TTS/main/TTS/bin/collect_env_details.py
python collect_env_details.py

and post here

@erogol
Copy link
Member

erogol commented Dec 23, 2021

@skol101 did you try with a single GPU?

@andreibezborodov
Copy link

Can someone upload the TB to https://tensorboard.dev/ and share?

https://tensorboard.dev/experiment/9pdI9rvbSYO4eMCWLH1Tqg/#scalars

Your config is the same as my config. only diff is the file paths. I suspect it is about the system setup.

Can you do this ?

wget https://raw.githubusercontent.com/coqui-ai/TTS/main/TTS/bin/collect_env_details.py
python collect_env_details.py

and post here

I get 404:not found error after running 'wget ..'.

@erogol
Copy link
Member

erogol commented Dec 23, 2021

try this

wget https://raw.githubusercontent.com/coqui-ai/TTS/main/TTS/bin/collect_env_info.py 
python collect_env_info.py

@loganhart02
Copy link
Contributor

Here is how it looks if I run TTS/recipes/ljspeech/hifigan

And the model optimizes without any problem.

I can't reproduce the problem

Screenshot 2021-12-23 at 15 59 04

Is it normal for the model to stop improving after 10k steps. I've messed with the hyper params and every time the eval audio and the spectrograms just stay the same as expected when the loss doesn't change. I'm just curious if this is normal behavior.

@andreibezborodov
Copy link

try this

wget https://raw.githubusercontent.com/coqui-ai/TTS/main/TTS/bin/collect_env_info.py 
python collect_env_info.py
{
    "CUDA": {
        "GPU": [
            "Tesla V100-PCIE-32GB",
            "Tesla V100-PCIE-32GB"
        ],
        "available": true,
        "version": "10.2"
    },
    "Packages": {
        "PyTorch_debug": false,
        "PyTorch_version": "1.8.1+cu102",
        "TTS": "0.4.2",
        "numpy": "1.19.5"
    },
    "System": {
        "OS": "Windows",
        "architecture": [
            "64bit",
            "WindowsPE"
        ],
        "processor": "Intel64 Family 6 Model 85 Stepping 7, GenuineIntel",
        "python": "3.6.6",
        "version": "6.3.9600"
    }
}

@andreibezborodov
Copy link

andreibezborodov commented Dec 24, 2021

Your config is the same as my config. only diff is the file paths. I suspect it is about the system setup.

Here is the resulting wav that I get after using the whole model for generating speech: Tacotron-2 DDC + HiFiGAN. Tacotron-2 DDC was trained for 120k steps and reproduces good results with Griffin Lim, so the problem seems to be with the vocoder
training process.

result_1.mp4

@michaellin99999
Copy link

hello we also trained hifigan recently with our own dataset and around similar steps added to glow tts and also result with the same sound.

@erogol
Copy link
Member

erogol commented Dec 28, 2021

Your config is the same as my config. only diff is the file paths. I suspect it is about the system setup.

Here is the resulting wav that I get after using the whole model for generating speech: Tacotron-2 DDC + HiFiGAN. Tacotron-2 DDC was trained for 120k steps and reproduces good results with Griffin Lim, so the problem seems to be with the vocoder training process.

result_1.mp4

There are multiple hifigan models,

  • What command did you run?
  • Which hifigan did you use? Not all models are compatible with each other.
  • How did you run it with Griffin-Lim ?

@erogol
Copy link
Member

erogol commented Dec 29, 2021

I realized that the learning rate attenuated too quickly. Maybe this is why the models stop learning. I'll update the recipe in the dev branch.

But you can try setting scheduler_after_epoch = True in your own config to try it out yourself.

@loganhart02
Copy link
Contributor

I realized that the learning rate attenuated too quickly. Maybe this is why the models stop learning. I'll update the recipe in the dev branch.

But you can try setting scheduler_after_epoch = True in your own config to try it out yourself.

0.5 fixed my issue I was able to train hifgan just fine can probably close this

@michaellin99999
Copy link

I realized that the learning rate attenuated too quickly. Maybe this is why the models stop learning. I'll update the recipe in the dev branch.
But you can try setting scheduler_after_epoch = True in your own config to try it out yourself.

0.5 fixed my issue I was able to train hifgan just fine can probably close this

hi what do you mean 0.5? could you clarify?
thanks

@loganhart02
Copy link
Contributor

loganhart02 commented Jan 7, 2022 via email

@skol101
Copy link
Author

skol101 commented Jan 7, 2022

Cheers, guys!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants