-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pull Request of the ssl pretraining framework on flow matching genera…
…tive model. Signed-off-by: Kuray107 <Kuray107@users.noreply.github.com>
- Loading branch information
Showing
11 changed files
with
1,882 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
name: flow_matching_generative | ||
|
||
model: | ||
type: flow_matching | ||
sample_rate: 16000 | ||
skip_nan_grad: false | ||
num_outputs: 1 | ||
p_cond: 0.9 # Proability of feeding the conditional input into the model. | ||
normalize_input: true # normalize the input signal to 0dBFS | ||
max_utts_evaluation_metrics: 500 | ||
|
||
train_ds: | ||
manifest_filepath: ??? | ||
input_key: noisy_filepath | ||
target_key: clean_filepath | ||
audio_duration: 6.14 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 768 | ||
random_offset: true | ||
batch_size: 8 # batch size may be increased based on the available memory | ||
shuffle: true | ||
num_workers: 8 | ||
pin_memory: true | ||
|
||
validation_ds: | ||
manifest_filepath: ??? | ||
input_key: noisy_filepath | ||
target_key: clean_filepath | ||
batch_size: 8 | ||
shuffle: false | ||
num_workers: 4 | ||
pin_memory: true | ||
|
||
log_config: | ||
log_tensorboard: true | ||
log_wandb: false | ||
max_utts: 8 | ||
|
||
encoder: | ||
_target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram | ||
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 | ||
hop_length: 128 | ||
magnitude_power: 0.5 | ||
scale: 0.33 | ||
|
||
decoder: | ||
_target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio | ||
fft_length: ${model.encoder.fft_length} | ||
hop_length: ${model.encoder.hop_length} | ||
magnitude_power: ${model.encoder.magnitude_power} | ||
scale: ${model.encoder.scale} | ||
|
||
estimator: | ||
_target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet | ||
in_channels: 2 # concatenation of single-channel perturbed and noisy | ||
out_channels: 1 # single-channel score estimate | ||
depth: 24 | ||
ff_dropout: 0.1 | ||
time_hidden_dim: 1024 | ||
|
||
flow: | ||
_target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow | ||
sigma_start: 1.0 | ||
sigma_end: 1e-4 | ||
|
||
sampler: | ||
_target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler | ||
num_steps: 20 | ||
time_min: 1e-8 | ||
time_max: 1.0 | ||
|
||
loss: | ||
_target_: nemo.collections.audio.losses.MSELoss | ||
ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) | ||
|
||
metrics: | ||
val: | ||
sisdr: # output SI-SDR | ||
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio | ||
estoi: # output ESTOI | ||
_target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility | ||
fs: ${model.sample_rate} | ||
extended: true | ||
pesq: # output PESQ | ||
_target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality | ||
fs: ${model.sample_rate} | ||
mode: wb | ||
|
||
optim: | ||
name: adam | ||
lr: 1e-4 | ||
# optimizer arguments | ||
betas: [0.9, 0.999] | ||
weight_decay: 0.0 | ||
|
||
# scheduler setup | ||
sched: | ||
name: CosineAnnealing | ||
# scheduler config override | ||
warmup_steps: 5000 | ||
warmup_ratio: null | ||
min_lr: 0 | ||
|
||
trainer: | ||
devices: -1 # number of GPUs, -1 would use all available GPUs | ||
num_nodes: 1 | ||
max_epochs: -1 | ||
max_steps: -1 # computed at runtime if not set | ||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
accelerator: auto | ||
strategy: ddp | ||
accumulate_grad_batches: 1 | ||
gradient_clip_val: 0.2 | ||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. | ||
log_every_n_steps: 25 # Interval of logging. | ||
enable_progress_bar: true | ||
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
sync_batchnorm: true | ||
enable_checkpointing: false # Provided by exp_manager | ||
logger: false # Provided by exp_manager | ||
|
||
exp_manager: | ||
exp_dir: null | ||
name: ${name} | ||
|
||
# use exponential moving average for model parameters | ||
ema: | ||
enable: true | ||
decay: 0.999 # decay rate | ||
cpu_offload: false # offload EMA parameters to CPU to save GPU memory | ||
every_n_steps: 1 # how often to update EMA weights | ||
validate_original_weights: false # use original weights for validation calculation? | ||
|
||
# logging | ||
create_tensorboard_logger: true | ||
|
||
# checkpointing | ||
create_checkpoint_callback: true | ||
checkpoint_callback_params: | ||
# in case of multiple validation sets, first one is used | ||
monitor: val_pesq | ||
mode: max | ||
save_top_k: 3 | ||
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints | ||
|
||
# early stopping | ||
create_early_stopping_callback: true | ||
early_stopping_callback_params: | ||
monitor: val_sisdr | ||
mode: max | ||
min_delta: 0.0 | ||
patience: 20 # patience in terms of check_val_every_n_epoch | ||
verbose: true | ||
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. | ||
|
||
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. | ||
# you need to set these two to true to continue the training | ||
resume_if_exists: false | ||
resume_ignore_no_checkpoint: false | ||
|
||
# You may use this section to create a W&B logger | ||
create_wandb_logger: false | ||
wandb_logger_kwargs: | ||
name: test | ||
project: gense |
167 changes: 167 additions & 0 deletions
167
examples/audio/conf/flow_matching_generative_finetuning.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
name: flow_matching_generative_finetuning | ||
|
||
init_from_nemo_model: null | ||
init_strict: false | ||
|
||
model: | ||
type: flow_matching | ||
sample_rate: 16000 | ||
skip_nan_grad: false | ||
num_outputs: 1 | ||
p_cond: 0.9 # Proability of feeding the conditional input into the model. | ||
normalize_input: true # normalize the input signal to 0dBFS | ||
max_utts_evaluation_metrics: 500 | ||
|
||
train_ds: | ||
manifest_filepath: ??? | ||
input_key: noisy_filepath | ||
target_key: clean_filepath | ||
audio_duration: 6.14 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 768 | ||
random_offset: true | ||
batch_size: 8 # batch size may be increased based on the available memory | ||
shuffle: true | ||
num_workers: 8 | ||
pin_memory: true | ||
|
||
validation_ds: | ||
manifest_filepath: ??? | ||
input_key: noisy_filepath | ||
target_key: clean_filepath | ||
batch_size: 8 | ||
shuffle: false | ||
num_workers: 4 | ||
pin_memory: true | ||
|
||
log_config: | ||
log_tensorboard: true | ||
log_wandb: false | ||
max_utts: 8 | ||
|
||
encoder: | ||
_target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram | ||
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 | ||
hop_length: 128 | ||
magnitude_power: 0.5 | ||
scale: 0.33 | ||
|
||
decoder: | ||
_target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio | ||
fft_length: ${model.encoder.fft_length} | ||
hop_length: ${model.encoder.hop_length} | ||
magnitude_power: ${model.encoder.magnitude_power} | ||
scale: ${model.encoder.scale} | ||
|
||
estimator: | ||
_target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet | ||
in_channels: 2 # concatenation of single-channel perturbed and noisy | ||
out_channels: 1 # single-channel score estimate | ||
depth: 24 | ||
ff_dropout: 0.1 | ||
time_hidden_dim: 1024 | ||
|
||
flow: | ||
_target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow | ||
sigma_start: 1.0 | ||
sigma_end: 1e-4 | ||
|
||
sampler: | ||
_target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler | ||
num_steps: 20 | ||
time_min: 1e-8 | ||
time_max: 1.0 | ||
|
||
loss: | ||
_target_: nemo.collections.audio.losses.MSELoss | ||
ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) | ||
|
||
metrics: | ||
val: | ||
sisdr: # output SI-SDR | ||
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio | ||
estoi: # output ESTOI | ||
_target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility | ||
fs: ${model.sample_rate} | ||
extended: true | ||
pesq: # output PESQ | ||
_target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality | ||
fs: ${model.sample_rate} | ||
mode: wb | ||
|
||
optim: | ||
name: adam | ||
lr: 1e-4 | ||
# optimizer arguments | ||
betas: [0.9, 0.999] | ||
weight_decay: 0.0 | ||
|
||
# scheduler setup | ||
sched: | ||
name: CosineAnnealing | ||
# scheduler config override | ||
warmup_steps: 5000 | ||
warmup_ratio: null | ||
min_lr: 0 | ||
|
||
trainer: | ||
devices: -1 # number of GPUs, -1 would use all available GPUs | ||
num_nodes: 1 | ||
max_epochs: -1 | ||
max_steps: -1 # computed at runtime if not set | ||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
accelerator: auto | ||
strategy: ddp | ||
accumulate_grad_batches: 1 | ||
gradient_clip_val: 0.2 | ||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. | ||
log_every_n_steps: 25 # Interval of logging. | ||
enable_progress_bar: true | ||
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
sync_batchnorm: true | ||
enable_checkpointing: false # Provided by exp_manager | ||
logger: false # Provided by exp_manager | ||
|
||
exp_manager: | ||
exp_dir: null | ||
name: ${name} | ||
|
||
# use exponential moving average for model parameters | ||
ema: | ||
enable: true | ||
decay: 0.999 # decay rate | ||
cpu_offload: false # offload EMA parameters to CPU to save GPU memory | ||
every_n_steps: 1 # how often to update EMA weights | ||
validate_original_weights: false # use original weights for validation calculation? | ||
|
||
# logging | ||
create_tensorboard_logger: true | ||
|
||
# checkpointing | ||
create_checkpoint_callback: true | ||
checkpoint_callback_params: | ||
# in case of multiple validation sets, first one is used | ||
monitor: val_pesq | ||
mode: max | ||
save_top_k: 3 | ||
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints | ||
|
||
# early stopping | ||
create_early_stopping_callback: true | ||
early_stopping_callback_params: | ||
monitor: val_sisdr | ||
mode: max | ||
min_delta: 0.0 | ||
patience: 20 # patience in terms of check_val_every_n_epoch | ||
verbose: true | ||
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. | ||
|
||
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. | ||
# you need to set these two to true to continue the training | ||
resume_if_exists: false | ||
resume_ignore_no_checkpoint: false | ||
|
||
# You may use this section to create a W&B logger | ||
create_wandb_logger: false | ||
wandb_logger_kwargs: | ||
name: test | ||
project: gense |
Oops, something went wrong.