Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Audio] SSL Pretraining framework for flow-matching model for audio processing #10052

Merged
merged 11 commits into from
Aug 24, 2024
4 changes: 4 additions & 0 deletions examples/audio/audio_to_audio_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from nemo.collections.audio.models.enhancement import (
EncMaskDecAudioToAudioModel,
FlowMatchingAudioToAudioModel,
PredictiveAudioToAudioModel,
SchroedingerBridgeAudioToAudioModel,
ScoreBasedGenerativeAudioToAudioModel,
Expand All @@ -50,6 +51,7 @@ class ModelType(str, Enum):
Predictive = 'predictive'
ScoreBased = 'score_based'
SchroedingerBridge = 'schroedinger_bridge'
FlowMatching = 'flow_matching'


def get_model_class(model_type: ModelType):
Expand All @@ -62,6 +64,8 @@ def get_model_class(model_type: ModelType):
return ScoreBasedGenerativeAudioToAudioModel
elif model_type == ModelType.SchroedingerBridge:
return SchroedingerBridgeAudioToAudioModel
elif model_type == ModelType.FlowMatching:
return FlowMatchingAudioToAudioModel
else:
raise ValueError(f'Unknown model type: {model_type}')

Expand Down
164 changes: 164 additions & 0 deletions examples/audio/conf/flow_matching_generative.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
name: flow_matching_generative

model:
type: flow_matching
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1
p_cond: 0.9 # Proability of feeding the conditional input into the model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this classifier-free guidance? Is there a reason to have it enabled in the default "generative model" config?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p_cond is only effective during the training process. In the inference stage one could choose whether to use classifier-free guidance or not.

normalize_input: true # normalize the input signal to 0dBFS
max_utts_evaluation_metrics: 500

train_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
audio_duration: 6.14 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 768
random_offset: true
batch_size: 8 # batch size may be increased based on the available memory
shuffle: true
num_workers: 8
pin_memory: true

validation_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
batch_size: 8
shuffle: false
num_workers: 4
pin_memory: true

log_config:
log_tensorboard: true
log_wandb: false
max_utts: 8

encoder:
_target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
hop_length: 128
magnitude_power: 0.5
scale: 0.33

decoder:
_target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio
fft_length: ${model.encoder.fft_length}
hop_length: ${model.encoder.hop_length}
magnitude_power: ${model.encoder.magnitude_power}
scale: ${model.encoder.scale}

estimator:
_target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet
in_channels: 2 # concatenation of single-channel perturbed and noisy
out_channels: 1 # single-channel score estimate
depth: 24
ff_dropout: 0.1
time_hidden_dim: 1024

flow:
_target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow
sigma_start: 1.0
sigma_end: 1e-4

sampler:
_target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler
num_steps: 20
time_min: 1e-8
time_max: 1.0

loss:
_target_: nemo.collections.audio.losses.MSELoss
ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time)

metrics:
val:
sisdr: # output SI-SDR
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio
estoi: # output ESTOI
_target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility
fs: ${model.sample_rate}
extended: true
pesq: # output PESQ
_target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality
fs: ${model.sample_rate}
mode: wb

optim:
name: adam
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.999]
weight_decay: 0.0

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 5000
warmup_ratio: null
min_lr: 0

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.2
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: false # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}

# use exponential moving average for model parameters
ema:
enable: true
decay: 0.999 # decay rate
cpu_offload: false # offload EMA parameters to CPU to save GPU memory
every_n_steps: 1 # how often to update EMA weights
validate_original_weights: false # use original weights for validation calculation?

# logging
create_tensorboard_logger: true

# checkpointing
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: val_pesq
mode: max
save_top_k: 3
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

# early stopping
create_early_stopping_callback: true
early_stopping_callback_params:
monitor: val_sisdr
mode: max
min_delta: 0.0
patience: 20 # patience in terms of check_val_every_n_epoch
verbose: true
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: test
project: gense
167 changes: 167 additions & 0 deletions examples/audio/conf/flow_matching_generative_finetuning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
name: flow_matching_generative_finetuning

init_from_nemo_model: null
init_strict: false

model:
type: flow_matching
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1
p_cond: 0.9 # Proability of feeding the conditional input into the model.
normalize_input: true # normalize the input signal to 0dBFS
max_utts_evaluation_metrics: 500

train_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
audio_duration: 6.14 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 768
random_offset: true
batch_size: 8 # batch size may be increased based on the available memory
shuffle: true
num_workers: 8
pin_memory: true

validation_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
batch_size: 8
shuffle: false
num_workers: 4
pin_memory: true

log_config:
log_tensorboard: true
log_wandb: false
max_utts: 8

encoder:
_target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
hop_length: 128
magnitude_power: 0.5
scale: 0.33

decoder:
_target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio
fft_length: ${model.encoder.fft_length}
hop_length: ${model.encoder.hop_length}
magnitude_power: ${model.encoder.magnitude_power}
scale: ${model.encoder.scale}

estimator:
_target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet
in_channels: 2 # concatenation of single-channel perturbed and noisy
out_channels: 1 # single-channel score estimate
depth: 24
ff_dropout: 0.1
time_hidden_dim: 1024

flow:
_target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow
sigma_start: 1.0
sigma_end: 1e-4

sampler:
_target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler
num_steps: 20
time_min: 1e-8
time_max: 1.0

loss:
_target_: nemo.collections.audio.losses.MSELoss
ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time)

metrics:
val:
sisdr: # output SI-SDR
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio
estoi: # output ESTOI
_target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility
fs: ${model.sample_rate}
extended: true
pesq: # output PESQ
_target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality
fs: ${model.sample_rate}
mode: wb

optim:
name: adam
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.999]
weight_decay: 0.0

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 5000
warmup_ratio: null
min_lr: 0

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.2
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: false # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}

# use exponential moving average for model parameters
ema:
enable: true
decay: 0.999 # decay rate
cpu_offload: false # offload EMA parameters to CPU to save GPU memory
every_n_steps: 1 # how often to update EMA weights
validate_original_weights: false # use original weights for validation calculation?

# logging
create_tensorboard_logger: true

# checkpointing
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: val_pesq
mode: max
save_top_k: 3
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

# early stopping
create_early_stopping_callback: true
early_stopping_callback_params:
monitor: val_sisdr
mode: max
min_delta: 0.0
patience: 20 # patience in terms of check_val_every_n_epoch
verbose: true
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: test
project: gense
Loading
Loading