-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Audio] SSL Pretraining framework for flow-matching model for audio processing #10052
Conversation
e5c076c
to
6f9a9c6
Compare
ffacac2
to
792a2f8
Compare
c8bc5ad
to
14107c6
Compare
429e21e
to
47893a9
Compare
# init adaptive normalization to identity | ||
|
||
nn.init.zeros_(self.to_gamma.weight) | ||
nn.init.ones_(self.to_gamma.bias) | ||
|
||
nn.init.zeros_(self.to_beta.weight) | ||
nn.init.zeros_(self.to_beta.bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this identity? I would expect identity to be weight = torch.eye
and bias = torch.zeros
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, this is identity in the sense that output of the whole module is identity
sample_rate: 16000 | ||
skip_nan_grad: false | ||
num_outputs: 1 | ||
p_cond: 0.9 # Proability of feeding the conditional input into the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this classifier-free guidance? Is there a reason to have it enabled in the default "generative model" config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
p_cond
is only effective during the training process. In the inference stage one could choose whether to use classifier-free guidance or not.
Signed-off-by: Pin-Jui Ku <pku9@gatech.edu>
3297c02
to
f7fc2f3
Compare
…tive model. Signed-off-by: Pin-Jui Ku <pku@nvidia.com>
Signed-off-by: anteju <anteju@users.noreply.github.com>
Signed-off-by: Pin-Jui Ku <pku9@gatech.edu>
Signed-off-by: Kuray107 <Kuray107@users.noreply.github.com>
f7fc2f3
to
1f2e5b8
Compare
Signed-off-by: Pin-Jui Ku <pku9@gatech.edu>
Signed-off-by: Kuray107 <Kuray107@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…rocessing (NVIDIA#10052) Flow matching generative model with SSL pretraining framework Signed-off-by: Pin-Jui Ku <pku@nvidia.com> Co-authored-by: Kuray107 <Kuray107@users.noreply.github.com>
…rocessing (#10052) Flow matching generative model with SSL pretraining framework Signed-off-by: Pin-Jui Ku <pku@nvidia.com> Co-authored-by: Kuray107 <Kuray107@users.noreply.github.com>
…rocessing (#10052) Flow matching generative model with SSL pretraining framework Signed-off-by: Pin-Jui Ku <pku@nvidia.com> Co-authored-by: Kuray107 <Kuray107@users.noreply.github.com>
…rocessing (#10052) Flow matching generative model with SSL pretraining framework Signed-off-by: Pin-Jui Ku <pku@nvidia.com> Co-authored-by: Kuray107 <Kuray107@users.noreply.github.com>
…rocessing (NVIDIA#10052) Flow matching generative model with SSL pretraining framework Signed-off-by: Pin-Jui Ku <pku@nvidia.com> Co-authored-by: Kuray107 <Kuray107@users.noreply.github.com> Signed-off-by: adityavavre <aditya.vavre@gmail.com>
What does this PR do ?
This PR adds a ssl-pretraining framework and a flow matching AudioToAudio model
Collection: audio
Changelog
audio/models/enhancement.py
audio/modules/ssl_pretrain_masking
to partially-masked out input in the ssl pre-training stage.audio/parts/submodules/transformerunet.py
to implement a Transformer U-Net structure.audio/parts/submodules/flow.py
for training/inferencing with the flow matching generative modelaudio/parts/utils/callbacks.py
to add processed audio in the training logger.Usage
Model can be first be pre-trained using
train_ds.manifest_filepath
andvalidation_ds.manifest_filepath
with the correct filepathsAfter the ssl-pretraining stage is completed, the pretrained model could be finetuned using
init_from_nemo_model
with the pretrained checkpoint path in the config file.Alternatively, one can also just train a flow matching generative from scratch:
Model can be evaluated using
Sampler setting can be easily overridden, e.g., to change the number of sampling steps as
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.