Skip to content

Commit

Permalink
Pull Request of the ssl pretraining framework on flow matching genera…
Browse files Browse the repository at this point in the history
…tive model.

Signed-off-by: Kuray107 <Kuray107@users.noreply.github.com>
  • Loading branch information
Pin-Jui Ku committed Aug 6, 2024
1 parent 874a1ea commit 792a2f8
Show file tree
Hide file tree
Showing 11 changed files with 1,882 additions and 7 deletions.
4 changes: 4 additions & 0 deletions examples/audio/audio_to_audio_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from nemo.collections.audio.models.enhancement import (
EncMaskDecAudioToAudioModel,
FlowMatchingAudioToAudioModel,
PredictiveAudioToAudioModel,
SchroedingerBridgeAudioToAudioModel,
ScoreBasedGenerativeAudioToAudioModel,
Expand All @@ -50,6 +51,7 @@ class ModelType(str, Enum):
Predictive = 'predictive'
ScoreBased = 'score_based'
SchroedingerBridge = 'schroedinger_bridge'
FlowMatchingAudioToAudioModel = 'flow_matching'


def get_model_class(model_type: ModelType):
Expand All @@ -62,6 +64,8 @@ def get_model_class(model_type: ModelType):
return ScoreBasedGenerativeAudioToAudioModel
elif model_type == ModelType.SchroedingerBridge:
return SchroedingerBridgeAudioToAudioModel
elif model_type == ModelType.FlowMatchingAudioToAudioModel:
return FlowMatchingAudioToAudioModel
else:
raise ValueError(f'Unknown model type: {model_type}')

Expand Down
164 changes: 164 additions & 0 deletions examples/audio/conf/flow_matching_generative.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
name: flow_matching_generative

model:
type: flow_matching
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1
p_cond: 0.9 # Proability of feeding the conditional input into the model.
normalize_input: true # normalize the input signal to 0dBFS
max_utts_evaluation_metrics: 500

train_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
audio_duration: 6.14 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 768
random_offset: true
batch_size: 8 # batch size may be increased based on the available memory
shuffle: true
num_workers: 8
pin_memory: true

validation_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
batch_size: 8
shuffle: false
num_workers: 4
pin_memory: true

log_config:
log_tensorboard: true
log_wandb: false
max_utts: 8

encoder:
_target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
hop_length: 128
magnitude_power: 0.5
scale: 0.33

decoder:
_target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio
fft_length: ${model.encoder.fft_length}
hop_length: ${model.encoder.hop_length}
magnitude_power: ${model.encoder.magnitude_power}
scale: ${model.encoder.scale}

estimator:
_target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet
in_channels: 2 # concatenation of single-channel perturbed and noisy
out_channels: 1 # single-channel score estimate
depth: 24
ff_dropout: 0.1
time_hidden_dim: 1024

flow:
_target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow
sigma_start: 1.0
sigma_end: 1e-4

sampler:
_target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler
num_steps: 20
time_min: 1e-8
time_max: 1.0

loss:
_target_: nemo.collections.audio.losses.MSELoss
ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time)

metrics:
val:
sisdr: # output SI-SDR
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio
estoi: # output ESTOI
_target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility
fs: ${model.sample_rate}
extended: true
pesq: # output PESQ
_target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality
fs: ${model.sample_rate}
mode: wb

optim:
name: adam
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.999]
weight_decay: 0.0

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 5000
warmup_ratio: null
min_lr: 0

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.2
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: false # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}

# use exponential moving average for model parameters
ema:
enable: true
decay: 0.999 # decay rate
cpu_offload: false # offload EMA parameters to CPU to save GPU memory
every_n_steps: 1 # how often to update EMA weights
validate_original_weights: false # use original weights for validation calculation?

# logging
create_tensorboard_logger: true

# checkpointing
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: val_pesq
mode: max
save_top_k: 3
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

# early stopping
create_early_stopping_callback: true
early_stopping_callback_params:
monitor: val_sisdr
mode: max
min_delta: 0.0
patience: 20 # patience in terms of check_val_every_n_epoch
verbose: true
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: test
project: gense
167 changes: 167 additions & 0 deletions examples/audio/conf/flow_matching_generative_finetuning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
name: flow_matching_generative_finetuning

init_from_nemo_model: null
init_strict: false

model:
type: flow_matching
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1
p_cond: 0.9 # Proability of feeding the conditional input into the model.
normalize_input: true # normalize the input signal to 0dBFS
max_utts_evaluation_metrics: 500

train_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
audio_duration: 6.14 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 768
random_offset: true
batch_size: 8 # batch size may be increased based on the available memory
shuffle: true
num_workers: 8
pin_memory: true

validation_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
batch_size: 8
shuffle: false
num_workers: 4
pin_memory: true

log_config:
log_tensorboard: true
log_wandb: false
max_utts: 8

encoder:
_target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
hop_length: 128
magnitude_power: 0.5
scale: 0.33

decoder:
_target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio
fft_length: ${model.encoder.fft_length}
hop_length: ${model.encoder.hop_length}
magnitude_power: ${model.encoder.magnitude_power}
scale: ${model.encoder.scale}

estimator:
_target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet
in_channels: 2 # concatenation of single-channel perturbed and noisy
out_channels: 1 # single-channel score estimate
depth: 24
ff_dropout: 0.1
time_hidden_dim: 1024

flow:
_target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow
sigma_start: 1.0
sigma_end: 1e-4

sampler:
_target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler
num_steps: 20
time_min: 1e-8
time_max: 1.0

loss:
_target_: nemo.collections.audio.losses.MSELoss
ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time)

metrics:
val:
sisdr: # output SI-SDR
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio
estoi: # output ESTOI
_target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility
fs: ${model.sample_rate}
extended: true
pesq: # output PESQ
_target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality
fs: ${model.sample_rate}
mode: wb

optim:
name: adam
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.999]
weight_decay: 0.0

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 5000
warmup_ratio: null
min_lr: 0

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: 0.2
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: false # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}

# use exponential moving average for model parameters
ema:
enable: true
decay: 0.999 # decay rate
cpu_offload: false # offload EMA parameters to CPU to save GPU memory
every_n_steps: 1 # how often to update EMA weights
validate_original_weights: false # use original weights for validation calculation?

# logging
create_tensorboard_logger: true

# checkpointing
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: val_pesq
mode: max
save_top_k: 3
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

# early stopping
create_early_stopping_callback: true
early_stopping_callback_params:
monitor: val_sisdr
mode: max
min_delta: 0.0
patience: 20 # patience in terms of check_val_every_n_epoch
verbose: true
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: test
project: gense
Loading

0 comments on commit 792a2f8

Please sign in to comment.