Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Audio] SSL Pretraining framework for flow-matching model for audio processing #10052

Merged
merged 11 commits into from
Aug 24, 2024

Conversation

Kuray107
Copy link
Collaborator

@Kuray107 Kuray107 commented Aug 6, 2024

What does this PR do ?

This PR adds a ssl-pretraining framework and a flow matching AudioToAudio model

Collection: audio

Changelog

  • Add a flow matching generative model in audio/models/enhancement.py
  • Create audio/modules/ssl_pretrain_masking to partially-masked out input in the ssl pre-training stage.
  • Create audio/parts/submodules/transformerunet.py to implement a Transformer U-Net structure.
  • Create audio/parts/submodules/flow.py for training/inferencing with the flow matching generative model
  • Create audio/parts/utils/callbacks.py to add processed audio in the training logger.

Usage

Model can be first be pre-trained using

python examples/audio/audio_to_audio_train.py --config-name flow_matching_generative_ssl_pretraining.yaml
  • Please update the train_ds.manifest_filepath and validation_ds.manifest_filepath with the correct filepaths

After the ssl-pretraining stage is completed, the pretrained model could be finetuned using

python examples/audio/audio_to_audio_train.py --config-name flow_matching_generative_finetuning.yaml
  • Remember to update the init_from_nemo_model with the pretrained checkpoint path in the config file.

Alternatively, one can also just train a flow matching generative from scratch:

python examples/audio/audio_to_audio_train.py --config-name flow_matching_generative.yaml

Model can be evaluated using

python examples/audio/audio_to_audio_eval.py \
    model_path=${PATH_TO_NEMO_FILE} \
    dataset_manifest=${PATH_TO_DATASET_MANIFEST} \
    output_dir=${PATH_TO_OUTPUT_DIR} \
    input_key=noisy_filepath \
    target_key=clean_filepath \
    metrics=[sdr,estoi]

Sampler setting can be easily overridden, e.g., to change the number of sampling steps as

python examples/audio/audio_to_audio_eval.py ... ++sampler.num_steps=${NUM_STEPS}

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • [] Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

@github-actions github-actions bot added the audio label Aug 6, 2024
@Kuray107 Kuray107 marked this pull request as ready for review August 6, 2024 16:16
@anteju anteju force-pushed the pr/ssl_generative_pretraining branch from e5c076c to 6f9a9c6 Compare August 6, 2024 16:30
@anteju anteju added the Run CICD label Aug 6, 2024
@Kuray107 Kuray107 force-pushed the pr/ssl_generative_pretraining branch from ffacac2 to 792a2f8 Compare August 6, 2024 17:32
@anteju anteju added Run CICD and removed Run CICD labels Aug 6, 2024
@Kuray107 Kuray107 force-pushed the pr/ssl_generative_pretraining branch 3 times, most recently from c8bc5ad to 14107c6 Compare August 6, 2024 21:11
@anteju anteju added Run CICD and removed Run CICD labels Aug 6, 2024
@Kuray107 Kuray107 force-pushed the pr/ssl_generative_pretraining branch 3 times, most recently from 429e21e to 47893a9 Compare August 6, 2024 22:42
@anteju anteju added Run CICD and removed Run CICD labels Aug 6, 2024
Comment on lines +157 to +153
# init adaptive normalization to identity

nn.init.zeros_(self.to_gamma.weight)
nn.init.ones_(self.to_gamma.bias)

nn.init.zeros_(self.to_beta.weight)
nn.init.zeros_(self.to_beta.bias)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this identity? I would expect identity to be weight = torch.eye and bias = torch.zeros

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is identity in the sense that output of the whole module is identity

sample_rate: 16000
skip_nan_grad: false
num_outputs: 1
p_cond: 0.9 # Proability of feeding the conditional input into the model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this classifier-free guidance? Is there a reason to have it enabled in the default "generative model" config?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p_cond is only effective during the training process. In the inference stage one could choose whether to use classifier-free guidance or not.

Signed-off-by: Pin-Jui Ku <pku9@gatech.edu>
@anteju anteju added Run CICD and removed Run CICD labels Aug 15, 2024
@anteju anteju force-pushed the pr/ssl_generative_pretraining branch from 3297c02 to f7fc2f3 Compare August 15, 2024 20:26
@anteju anteju added Run CICD and removed Run CICD labels Aug 15, 2024
Pin-Jui Ku and others added 4 commits August 16, 2024 14:52
…tive model.

Signed-off-by: Pin-Jui Ku <pku@nvidia.com>
Signed-off-by: anteju <anteju@users.noreply.github.com>
Signed-off-by: Pin-Jui Ku <pku9@gatech.edu>
Signed-off-by: Kuray107 <Kuray107@users.noreply.github.com>
@anteju anteju force-pushed the pr/ssl_generative_pretraining branch from f7fc2f3 to 1f2e5b8 Compare August 16, 2024 21:52
@anteju anteju added Run CICD and removed Run CICD labels Aug 16, 2024
Kuray107 and others added 3 commits August 19, 2024 05:10
Signed-off-by: Pin-Jui Ku <pku9@gatech.edu>
Signed-off-by: Kuray107 <Kuray107@users.noreply.github.com>
@anteju anteju added Run CICD and removed Run CICD labels Aug 19, 2024
@anteju anteju added Run CICD and removed Run CICD labels Aug 21, 2024
@anteju anteju requested review from racoiaws and anteju August 21, 2024 20:25
Copy link
Collaborator

@anteju anteju left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@anteju anteju merged commit 7cc99e9 into NVIDIA:main Aug 24, 2024
126 of 127 checks passed
WoodieDudy pushed a commit to WoodieDudy/NeMo that referenced this pull request Aug 26, 2024
…rocessing (NVIDIA#10052)

Flow matching generative model with SSL pretraining framework

Signed-off-by: Pin-Jui Ku <pku@nvidia.com>
Co-authored-by: Kuray107 <Kuray107@users.noreply.github.com>
shanmugamr1992 pushed a commit that referenced this pull request Aug 27, 2024
…rocessing (#10052)

Flow matching generative model with SSL pretraining framework

Signed-off-by: Pin-Jui Ku <pku@nvidia.com>
Co-authored-by: Kuray107 <Kuray107@users.noreply.github.com>
hemildesai pushed a commit that referenced this pull request Aug 28, 2024
…rocessing (#10052)

Flow matching generative model with SSL pretraining framework

Signed-off-by: Pin-Jui Ku <pku@nvidia.com>
Co-authored-by: Kuray107 <Kuray107@users.noreply.github.com>
JimmyZhang12 pushed a commit that referenced this pull request Aug 30, 2024
…rocessing (#10052)

Flow matching generative model with SSL pretraining framework

Signed-off-by: Pin-Jui Ku <pku@nvidia.com>
Co-authored-by: Kuray107 <Kuray107@users.noreply.github.com>
adityavavre pushed a commit to adityavavre/NeMo that referenced this pull request Sep 15, 2024
…rocessing (NVIDIA#10052)

Flow matching generative model with SSL pretraining framework

Signed-off-by: Pin-Jui Ku <pku@nvidia.com>
Co-authored-by: Kuray107 <Kuray107@users.noreply.github.com>
Signed-off-by: adityavavre <aditya.vavre@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants