diff --git a/examples/audio/audio_to_audio_train.py b/examples/audio/audio_to_audio_train.py index b197d2084144..cef46dcf20b6 100644 --- a/examples/audio/audio_to_audio_train.py +++ b/examples/audio/audio_to_audio_train.py @@ -34,6 +34,7 @@ from nemo.collections.audio.models.enhancement import ( EncMaskDecAudioToAudioModel, + FlowMatchingAudioToAudioModel, PredictiveAudioToAudioModel, SchroedingerBridgeAudioToAudioModel, ScoreBasedGenerativeAudioToAudioModel, @@ -50,6 +51,7 @@ class ModelType(str, Enum): Predictive = 'predictive' ScoreBased = 'score_based' SchroedingerBridge = 'schroedinger_bridge' + FlowMatching = 'flow_matching' def get_model_class(model_type: ModelType): @@ -62,6 +64,8 @@ def get_model_class(model_type: ModelType): return ScoreBasedGenerativeAudioToAudioModel elif model_type == ModelType.SchroedingerBridge: return SchroedingerBridgeAudioToAudioModel + elif model_type == ModelType.FlowMatching: + return FlowMatchingAudioToAudioModel else: raise ValueError(f'Unknown model type: {model_type}') diff --git a/examples/audio/conf/flow_matching_generative.yaml b/examples/audio/conf/flow_matching_generative.yaml new file mode 100644 index 000000000000..5f644f328e6d --- /dev/null +++ b/examples/audio/conf/flow_matching_generative.yaml @@ -0,0 +1,164 @@ +name: flow_matching_generative + +model: + type: flow_matching + sample_rate: 16000 + skip_nan_grad: false + num_outputs: 1 + p_cond: 0.9 # Proability of feeding the conditional input into the model. + normalize_input: true # normalize the input signal to 0dBFS + max_utts_evaluation_metrics: 500 + + train_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + audio_duration: 6.14 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 768 + random_offset: true + batch_size: 8 # batch size may be increased based on the available memory + shuffle: true + num_workers: 8 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + batch_size: 8 + shuffle: false + num_workers: 4 + pin_memory: true + + log_config: + log_tensorboard: true + log_wandb: false + max_utts: 8 + + encoder: + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram + fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 + hop_length: 128 + magnitude_power: 0.5 + scale: 0.33 + + decoder: + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + magnitude_power: ${model.encoder.magnitude_power} + scale: ${model.encoder.scale} + + estimator: + _target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet + in_channels: 2 # concatenation of single-channel perturbed and noisy + out_channels: 1 # single-channel score estimate + depth: 24 + ff_dropout: 0.1 + time_hidden_dim: 1024 + + flow: + _target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow + sigma_start: 1.0 + sigma_end: 1e-4 + + sampler: + _target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler + num_steps: 20 + time_min: 1e-8 + time_max: 1.0 + + loss: + _target_: nemo.collections.audio.losses.MSELoss + ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) + + metrics: + val: + sisdr: # output SI-SDR + _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio + estoi: # output ESTOI + _target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility + fs: ${model.sample_rate} + extended: true + pesq: # output PESQ + _target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality + fs: ${model.sample_rate} + mode: wb + + optim: + name: adam + lr: 1e-4 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.0 + + # scheduler setup + sched: + name: CosineAnnealing + # scheduler config override + warmup_steps: 5000 + warmup_ratio: null + min_lr: 0 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.2 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + + # use exponential moving average for model parameters + ema: + enable: true + decay: 0.999 # decay rate + cpu_offload: false # offload EMA parameters to CPU to save GPU memory + every_n_steps: 1 # how often to update EMA weights + validate_original_weights: false # use original weights for validation calculation? + + # logging + create_tensorboard_logger: true + + # checkpointing + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: val_pesq + mode: max + save_top_k: 3 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + # early stopping + create_early_stopping_callback: true + early_stopping_callback_params: + monitor: val_sisdr + mode: max + min_delta: 0.0 + patience: 20 # patience in terms of check_val_every_n_epoch + verbose: true + strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: test + project: gense diff --git a/examples/audio/conf/flow_matching_generative_finetuning.yaml b/examples/audio/conf/flow_matching_generative_finetuning.yaml new file mode 100644 index 000000000000..c7ba19aee466 --- /dev/null +++ b/examples/audio/conf/flow_matching_generative_finetuning.yaml @@ -0,0 +1,167 @@ +name: flow_matching_generative_finetuning + +init_from_nemo_model: null +init_strict: false + +model: + type: flow_matching + sample_rate: 16000 + skip_nan_grad: false + num_outputs: 1 + p_cond: 0.9 # Proability of feeding the conditional input into the model. + normalize_input: true # normalize the input signal to 0dBFS + max_utts_evaluation_metrics: 500 + + train_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + audio_duration: 6.14 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 768 + random_offset: true + batch_size: 8 # batch size may be increased based on the available memory + shuffle: true + num_workers: 8 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + batch_size: 8 + shuffle: false + num_workers: 4 + pin_memory: true + + log_config: + log_tensorboard: true + log_wandb: false + max_utts: 8 + + encoder: + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram + fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 + hop_length: 128 + magnitude_power: 0.5 + scale: 0.33 + + decoder: + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + magnitude_power: ${model.encoder.magnitude_power} + scale: ${model.encoder.scale} + + estimator: + _target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet + in_channels: 2 # concatenation of single-channel perturbed and noisy + out_channels: 1 # single-channel score estimate + depth: 24 + ff_dropout: 0.1 + time_hidden_dim: 1024 + + flow: + _target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow + sigma_start: 1.0 + sigma_end: 1e-4 + + sampler: + _target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler + num_steps: 20 + time_min: 1e-8 + time_max: 1.0 + + loss: + _target_: nemo.collections.audio.losses.MSELoss + ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) + + metrics: + val: + sisdr: # output SI-SDR + _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio + estoi: # output ESTOI + _target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility + fs: ${model.sample_rate} + extended: true + pesq: # output PESQ + _target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality + fs: ${model.sample_rate} + mode: wb + + optim: + name: adam + lr: 1e-4 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.0 + + # scheduler setup + sched: + name: CosineAnnealing + # scheduler config override + warmup_steps: 5000 + warmup_ratio: null + min_lr: 0 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.2 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + + # use exponential moving average for model parameters + ema: + enable: true + decay: 0.999 # decay rate + cpu_offload: false # offload EMA parameters to CPU to save GPU memory + every_n_steps: 1 # how often to update EMA weights + validate_original_weights: false # use original weights for validation calculation? + + # logging + create_tensorboard_logger: true + + # checkpointing + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: val_pesq + mode: max + save_top_k: 3 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + # early stopping + create_early_stopping_callback: true + early_stopping_callback_params: + monitor: val_sisdr + mode: max + min_delta: 0.0 + patience: 20 # patience in terms of check_val_every_n_epoch + verbose: true + strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: test + project: gense diff --git a/examples/audio/conf/flow_matching_generative_ssl_pretraining.yaml b/examples/audio/conf/flow_matching_generative_ssl_pretraining.yaml new file mode 100644 index 000000000000..7813a9473644 --- /dev/null +++ b/examples/audio/conf/flow_matching_generative_ssl_pretraining.yaml @@ -0,0 +1,171 @@ +name: flow_matching_generative_ssl_pretraining + +model: + type: flow_matching + sample_rate: 16000 + skip_nan_grad: true + num_outputs: 1 + p_cond: 0.9 # Proability of feeding the conditional input into the model. + normalize_input: true # normalize the input signal to 0dBFS + max_utts_evaluation_metrics: 125 + + train_ds: + shar_path: ??? + use_lhotse: true + truncate_duration: 4.09 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 512 + truncate_offset_type: random + batch_size: 8 # batch size may be increased based on the available memory + shuffle: true + num_workers: 8 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: clean_filepath + target_key: clean_filepath + random_offset: false + batch_size: 8 + shuffle: false + num_workers: 4 + pin_memory: true + + log_config: + log_tensorboard: true + log_wandb: false + max_utts: 8 + + encoder: + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram + fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256 + hop_length: 128 + magnitude_power: 0.5 + scale: 0.33 + + decoder: + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + magnitude_power: ${model.encoder.magnitude_power} + scale: ${model.encoder.scale} + + estimator: + _target_: nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet + in_channels: 2 # concatenation of single-channel perturbed and noisy + out_channels: 1 # single-channel score estimate + depth: 24 + ff_dropout: 0.1 + time_hidden_dim: 1024 + + flow: + _target_: nemo.collections.audio.parts.submodules.flow.OptimalTransportFlow + sigma_start: 1.0 + sigma_end: 1e-4 + + sampler: + _target_: nemo.collections.audio.parts.submodules.flow.ConditionalFlowMatchingEulerSampler + num_steps: 20 + time_min: 1e-8 + time_max: 1.0 + + ssl_pretrain_masking: + _target_: nemo.collections.audio.modules.ssl_pretrain_masking.SSLPretrainWithMaskedPatch + patch_size: 10 + mask_fraction: 0.7 + + loss: + _target_: nemo.collections.audio.losses.MSELoss + ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) + + metrics: + val: + sisdr: # output SI-SDR + _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio + estoi: # output ESTOI + _target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility + fs: ${model.sample_rate} + extended: true + pesq: # output PESQ + _target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality + fs: ${model.sample_rate} + mode: wb + + optim: + name: adam + lr: 5e-5 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.0 + + # scheduler setup + sched: + name: CosineAnnealing + # scheduler config override + warmup_steps: 5000 + warmup_ratio: null + min_lr: 1e-5 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 10000 # needs to be set for shar datasets + limit_train_batches: 1000 # number of batches to train on in each pseudo-epoch + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + use_distributed_sampler: false # required for lhotse + accumulate_grad_batches: 1 + gradient_clip_val: 0.2 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + + # use exponential moving average for model parameters + ema: + enable: true + decay: 0.999 # decay rate + cpu_offload: false # offload EMA parameters to CPU to save GPU memory + every_n_steps: 1 # how often to update EMA weights + validate_original_weights: false # use original weights for validation calculation? + + # logging + create_tensorboard_logger: true + + # checkpointing + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: val_pesq + mode: max + save_top_k: 3 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + # early stopping + create_early_stopping_callback: true + early_stopping_callback_params: + monitor: val_sisdr + mode: max + min_delta: 0.0 + patience: 20 # patience in terms of check_val_every_n_epoch + verbose: true + strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/nemo/collections/audio/data/audio_to_audio_lhotse.py b/nemo/collections/audio/data/audio_to_audio_lhotse.py index 27d8a0ed28d7..d8978c19d692 100644 --- a/nemo/collections/audio/data/audio_to_audio_lhotse.py +++ b/nemo/collections/audio/data/audio_to_audio_lhotse.py @@ -44,19 +44,29 @@ class LhotseAudioToTargetDataset(torch.utils.data.Dataset): EMBEDDING_KEY = "embedding_vector" def __getitem__(self, cuts: CutSet) -> dict[str, torch.Tensor]: - src_audio, src_audio_lens = collate_audio(cuts) + # In the rare case, the collate_audio function would raise the FileSeek error when loading .flac (https://github.com/bastibe/python-soundfile/issues/274) + # A workaround is to use fault_tolerant and skip failed data, resulting in a smaller batch size for the few problematic cases. + src_audio, src_audio_lens, retained_padded_cuts = collate_audio(cuts, fault_tolerant=True) ans = { "input_signal": src_audio, "input_length": src_audio_lens, } - if _key_available(cuts, self.TARGET_KEY): - tgt_audio, tgt_audio_lens = collate_audio(cuts, recording_field=self.TARGET_KEY) + # keep only the first non-padding cuts + retained_cuts = [ + cut._first_non_padding_cut if isinstance(cut, MixedCut) else cut for cut in retained_padded_cuts + ] + retained_cuts = CutSet.from_cuts(retained_cuts) + + if _key_available(retained_cuts, self.TARGET_KEY): + # TODO: use fault_tolerant=True for robust loading of target + tgt_audio, tgt_audio_lens = collate_audio(retained_cuts, recording_field=self.TARGET_KEY) ans.update(target_signal=tgt_audio, target_length=tgt_audio_lens) - if _key_available(cuts, self.REFERENCE_KEY): - ref_audio, ref_audio_lens = collate_audio(cuts, recording_field=self.REFERENCE_KEY) + if _key_available(retained_cuts, self.REFERENCE_KEY): + # TODO: use fault_tolerant=True for robust loading of target + ref_audio, ref_audio_lens = collate_audio(retained_cuts, recording_field=self.REFERENCE_KEY) ans.update(reference_signal=ref_audio, reference_length=ref_audio_lens) if _key_available(cuts, self.EMBEDDING_KEY): - emb = collate_custom_field(cuts, field=self.EMBEDDING_KEY) + emb = collate_custom_field(retained_cuts, field=self.EMBEDDING_KEY) ans.update(embedding_signal=emb) return ans diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index ef9ce648f1a2..e1732c1658b7 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -483,4 +483,35 @@ def on_after_backward(self): if valid_gradients < 1: logging.warning('detected inf or nan values in gradients! Setting gradients to zero.') - self.zero_grad() + self.zero_grad(set_to_none=False) + + def configure_callbacks(self): + """ + Create an callback to add audio/spectrogram into tensorboard & wandb. + """ + self.log_config = self.cfg.get("log_config", None) + if not self.log_config: + return [] + + log_callbacks = [] + from nemo.collections.audio.parts.utils.callbacks import SpeechEnhancementLoggingCallback + + if isinstance(self._validation_dl, List): + data_loaders = self._validation_dl + else: + data_loaders = [self._validation_dl] + + for data_loader_idx, data_loader in enumerate(data_loaders): + log_callbacks.append( + SpeechEnhancementLoggingCallback( + data_loader=data_loader, + data_loader_idx=data_loader_idx, + loggers=self.trainer.loggers, + log_tensorboard=self.log_config.log_tensorboard, + log_wandb=self.log_config.log_wandb, + sample_rate=self.sample_rate, + max_utts=self.log_config.get("max_utts", None), + ) + ) + + return log_callbacks diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index e7fbc9023117..cd9f47b98096 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -30,6 +30,7 @@ 'ScoreBasedGenerativeAudioToAudioModel', 'PredictiveAudioToAudioModel', 'SchroedingerBridgeAudioToAudioModel', + 'FlowMatchingAudioToAudioModel', ] @@ -618,6 +619,274 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = return {f'{tag}_loss': loss} +class FlowMatchingAudioToAudioModel(AudioToAudioModel): + """This models uses a flow matching process to generate + an encoded representation of the enhanced signal. + + The model consists of the following blocks: + - encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform) + - estimator: neural model, estimates a score for the diffusion process + - flow: ordinary differential equation (ODE) defining a flow and a vector field. + - sampler: sampler for the inference process, estimates coefficients of the target signal + - decoder: transforms sampler output into the time domain (synthesis transform) + - ssl_pretrain_masking: if it is defined, perform the ssl pretrain masking for self reconstruction in the training process + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + self.sample_rate = self._cfg.sample_rate + + # Setup processing modules + self.encoder = self.from_config_dict(self._cfg.encoder) + self.decoder = self.from_config_dict(self._cfg.decoder) + + # Neural estimator + self.estimator = self.from_config_dict(self._cfg.estimator) + + # Flow + self.flow = self.from_config_dict(self._cfg.flow) + + # Sampler + self.sampler = hydra.utils.instantiate(self._cfg.sampler, estimator=self.estimator) + + # probability that the conditional input will be feed into the + # estimator in the training stage + self.p_cond = self._cfg.get('p_cond', 1.0) + + # Self-Supervised Pretraining + if self._cfg.get('ssl_pretrain_masking') is not None: + logging.debug('SSL-pretrain_masking is found and will be initialized') + self.ssl_pretrain_masking = self.from_config_dict(self._cfg.ssl_pretrain_masking) + else: + self.ssl_pretrain_masking = None + + # Normalization + self.normalize_input = self._cfg.get('normalize_input', False) + + # Metric evaluation + self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics') + + if self.max_utts_evaluation_metrics is not None: + logging.warning( + 'Metrics will be evaluated on first %d examples of the evaluation datasets.', + self.max_utts_evaluation_metrics, + ) + + # Regularization + self.eps = self._cfg.get('eps', 1e-8) + + # Setup optional Optimization flags + self.setup_optimization_flags() + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\tdoing SSL-pretraining: %s', (self.ssl_pretrain_masking is not None)) + logging.debug('\tp_cond: %s', self.p_cond) + logging.debug('\tnormalize_input: %s', self.normalize_input) + logging.debug('\tloss: %s', self.loss) + logging.debug('\teps: %s', self.eps) + + @property + def input_types(self) -> Dict[str, NeuralType]: + return { + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + return { + "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "output_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @typecheck() + @torch.inference_mode() + def forward(self, input_signal, input_length=None): + """Forward pass of the model to generate samples from the target distribution. + + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + + Returns: + Output signal `output` in the time domain and the length of the output signal `output_length`. + """ + batch_length = input_signal.size(-1) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + + # Encoder + encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) + + if self.p_cond == 0: + encoded = torch.zeros_like(encoded) + elif self.ssl_pretrain_masking is not None: + encoded = self.ssl_pretrain_masking(input_spec=encoded, length=encoded_length) + + init_state = torch.randn_like(encoded) * self.flow.sigma_start + + # Sampler + generated, generated_length = self.sampler( + state=init_state, estimator_condition=encoded, state_length=encoded_length + ) + + # Decoder + output, output_length = self.decoder(input=generated, input_length=generated_length) + + if self.normalize_input: + # rescale to the original scale + output = output * norm_scale + + # Trim or pad the estimated signal to match input length + output = self.match_batch_length(input=output, batch_length=batch_length) + + return output, output_length + + @typecheck( + input_types={ + "target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_length": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "loss": NeuralType(None, LossType()), + }, + ) + def _step(self, target_signal, input_signal, input_length=None): + batch_size = target_signal.size(0) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + # scale the target signal + target_signal = target_signal / (norm_scale + self.eps) + + # Apply encoder to both target and the input + input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length) + target_enc, _ = self.encoder(input=target_signal, input_length=input_length) + + # Self-Supervised Pretraining + if self.ssl_pretrain_masking is not None: + input_enc = self.ssl_pretrain_masking(input_spec=input_enc, length=input_enc_len) + + # Drop off conditional inputs (input_enc) with (1 - p_cond) probability. + # The dropped conditions will be set to zeros + keep_conditions = einops.rearrange((torch.rand(batch_size) < self.p_cond).float(), 'B -> B 1 1 1') + input_enc = input_enc * keep_conditions.to(input_enc.device) + + x_start = torch.zeros_like(input_enc) + + time = self.flow.generate_time(batch_size=batch_size).to(device=input_enc.device) + sample = self.flow.sample(time=time, x_start=x_start, x_end=target_enc) + + # we want to get a vector field estimate given current state + # at training time, current state is sampled from the conditional path + # the vector field model is also conditioned on input signal + estimator_input = torch.cat([sample, input_enc], dim=-3) + + # Estimate the vector using the neural estimator + estimate, estimate_len = self.estimator(input=estimator_input, input_length=input_enc_len, condition=time) + + conditional_vector_field = self.flow.vector_field(time=time, x_start=x_start, x_end=target_enc, point=sample) + + return self.loss(estimate=estimate, target=conditional_vector_field, input_length=input_enc_len) + + # PTL-specific methods + def training_step(self, batch, batch_idx): + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch.get('target_signal', input_signal.clone()) + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, "B T -> B 1 T") + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, "B T -> B 1 T") + + # Calculate the loss + loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + + # Logs + self.log('train_loss', loss) + self.log('learning_rate', self._optimizer.param_groups[0]['lr']) + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return loss + + def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch.get('target_signal', input_signal.clone()) + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Calculate loss + loss = self._step( + target_signal=target_signal, + input_signal=input_signal, + input_length=input_length, + ) + + # Update metrics + update_metrics = False + if self.max_utts_evaluation_metrics is None: + # Always update if max is not configured + update_metrics = True + # Number of examples to process + num_examples = input_signal.size(0) # batch size + else: + # Check how many examples have been used for metric calculation + first_metric_name = next(iter(self.metrics[tag][dataloader_idx])) + num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples + # Update metrics if some examples were not processed + update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics + # Number of examples to process + num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0)) + + if update_metrics: + # Generate output signal + output_signal, _ = self.forward( + input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples] + ) + + # Update metrics + if hasattr(self, 'metrics') and tag in self.metrics: + # Update metrics for this (tag, dataloader_idx) + for name, metric in self.metrics[tag][dataloader_idx].items(): + metric.update( + preds=output_signal, + target=target_signal[:num_examples, ...], + input_length=input_length[:num_examples], + ) + + # Log global step + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return {f'{tag}_loss': loss} + + class SchroedingerBridgeAudioToAudioModel(AudioToAudioModel): """This models is using a Schrödinger Bridge process to generate an encoded representation of the enhanced signal. diff --git a/nemo/collections/audio/modules/ssl_pretrain_masking.py b/nemo/collections/audio/modules/ssl_pretrain_masking.py new file mode 100644 index 000000000000..ba0722f180d8 --- /dev/null +++ b/nemo/collections/audio/modules/ssl_pretrain_masking.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import einops +import torch + +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import LengthsType, NeuralType, SpectrogramType + +__all__ = ['SSLPretrainWithMaskedPatch'] + + +class SSLPretrainWithMaskedPatch(NeuralModule): + """ + Zeroes out fixed size time patches of the spectrogram. + All samples in batch are guaranteed to have the same amount of masked time steps. + Note that this may be problematic when we do pretraining on a unbalanced dataset. + + For example, say a batch contains two spectrograms of length 87 and 276. + With mask_fraction=0.7 and patch_size=10, we'll obrain mask_patches=7. + Each of the two data will then have 7 patches of 10-frame mask. + + Args: + patch_size (int): up to how many time steps does one patch consist of. + Defaults to 10. + mask_fraction (float): how much fraction in each sample to be masked (number of patches is rounded up). + Range from 0.0 to 1.0. Defaults to 0.7. + """ + + @property + def input_types(self): + """Returns definitions of module input types""" + return { + "input_spec": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output types""" + return {"augmented_spec": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType())} + + def __init__( + self, + patch_size: int = 10, + mask_fraction: float = 0.7, + ): + super().__init__() + self.patch_size = patch_size + if mask_fraction > 1.0 or mask_fraction < 0.0: + raise ValueError('mask_patches cannot be negative') + else: + self.mask_fraction = mask_fraction + + @typecheck() + def forward(self, input_spec, length): + """ + Apply Patched masking on the input_spec. + + + During the training stage, the mask is generated randomly, with + approximately `self.mask_fraction` of the time frames being masked out. + + In the validation stage, the masking pattern is fixed to ensure + consistent evaluation of checkpoints and to prevent overfitting. Note + that the same masking pattern is applied to all data, regardless of + their lengths. On average, approximately `self.mask_fraction` of the + time frames will be masked out. + + """ + augmented_spec = input_spec + + min_len = torch.min(length) + if self.training: + len_fraction = int(min_len * self.mask_fraction) + mask_patches = len_fraction // self.patch_size + int(len_fraction % self.patch_size != 0) + + if min_len < self.patch_size * mask_patches: + mask_patches = min_len // self.patch_size + + for idx, cur_len in enumerate(length.tolist()): + patches = range(cur_len // self.patch_size) + masked_patches = random.sample(patches, mask_patches) + for mp in masked_patches: + augmented_spec[idx, :, :, mp * self.patch_size : (mp + 1) * self.patch_size] = 0.0 + else: + chunk_length = self.patch_size // self.mask_fraction + mask = torch.arange(augmented_spec.size(-1), device=augmented_spec.device) + mask = (mask % chunk_length) >= self.patch_size + mask = einops.rearrange(mask, 'T -> 1 1 1 T').float() + augmented_spec = augmented_spec * mask + + return augmented_spec diff --git a/nemo/collections/audio/parts/submodules/flow.py b/nemo/collections/audio/parts/submodules/flow.py new file mode 100644 index 000000000000..748d4c6c6d3b --- /dev/null +++ b/nemo/collections/audio/parts/submodules/flow.py @@ -0,0 +1,252 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Tuple + +import einops +import torch + +from nemo.collections.common.parts.utils import mask_sequence_tensor +from nemo.utils import logging + + +class ConditionalFlow(ABC): + """ + Abstract class for different conditional flow-matching (CFM) classes + + Time horizon is [time_min, time_max (should be 1)] + + every path is "conditioned" on endpoints of the path + endpoints are just our paired data samples + subclasses need to implement mean, std, and vector_field + + """ + + def __init__(self, time_min: float = 1e-8, time_max: float = 1.0): + self.time_min = time_min + self.time_max = time_max + + @abstractmethod + def mean(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor: + """ + Return the mean of p_t(x | x_start, x_end) at time t + """ + pass + + @abstractmethod + def std(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor: + """ + Return the standard deviation of p_t(x | x_start, x_end) at time t + """ + pass + + @abstractmethod + def vector_field( + self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor, point: torch.Tensor + ) -> torch.Tensor: + """ + Compute the conditional vector field v_t( point | x_start, x_end) + """ + pass + + @staticmethod + def _broadcast_time(time: torch.Tensor, n_dim: int) -> torch.Tensor: + """ + Broadcast time tensor to the desired number of dimensions + """ + if time.ndim == 1: + target_shape = ' '.join(['B'] + ['1'] * (n_dim - 1)) + time = einops.rearrange(time, f'B -> {target_shape}') + + return time + + def generate_time(self, batch_size: int) -> torch.Tensor: + """ + Randomly sample a batchsize of time_steps from U[0~1] + """ + return torch.clamp(torch.rand((batch_size,)), self.time_min, self.time_max) + + def sample(self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor) -> torch.Tensor: + """ + Generate a sample from p_t(x | x_start, x_end) at time t. + Note that this implementation assumes all path marginals are normally distributed. + """ + time = self._broadcast_time(time, n_dim=x_start.ndim) + + mean = self.mean(time=time, x_start=x_start, x_end=x_end) + std = self.std(time=time, x_start=x_start, x_end=x_end) + return mean + std * torch.randn_like(mean) + + def flow( + self, *, time: torch.Tensor, x_start: torch.Tensor, x_end: torch.Tensor, point: torch.Tensor + ) -> torch.Tensor: + """ + Compute the conditional flow phi_t( point | x_start, x_end). + This is an affine flow. + """ + mean = self.mean(time=time, x_start=x_start, x_end=x_end) + std = self.std(time=time, x_start=x_start, x_end=x_end) + return mean + std * (point - x_start) + + +class OptimalTransportFlow(ConditionalFlow): + """The OT-CFM model from [Lipman et at, 2023] + + Every conditional path the following holds: + p_0 = N(x_start, sigma_start) + p_1 = N(x_end, sigma_end), + + mean(x, t) = (time_max - t) * x_start + t * x_end + (linear interpolation between x_start and x_end) + + std(x, t) = (time_max - t) * sigma_start + t * sigma_end + + Every conditional path is optimal transport map from p_0(x_start, x_end) to p_1(x_start, x_end) + Marginal path is not guaranteed to be an optimal transport map from p_0 to p_1 + + To get the OT-CFM model from [Lipman et at, 2023] just pass zeroes for x_start + To get the I-CFM model, set sigma_min=sigma_max + To get the rectified flow model, set sigma_min=sigma_max=0 + + Args: + time_min: minimum time value used in the process + time_max: maximum time value used in the process + sigma_start: the standard deviation of the initial distribution + sigma_end: the standard deviation of the target distribution + """ + + def __init__( + self, time_min: float = 1e-8, time_max: float = 1.0, sigma_start: float = 1.0, sigma_end: float = 1e-4 + ): + super().__init__(time_min=time_min, time_max=time_max) + self.sigma_start = sigma_start + self.sigma_end = sigma_end + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + logging.debug('\tsgima_start: %s', self.sigma_start) + logging.debug('\tsigma_end: %s', self.sigma_end) + + def mean(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + return (self.time_max - time) * x_start + time * x_end + + def std(self, *, x_start: torch.Tensor, x_end: torch.Tensor, time: torch.Tensor) -> torch.Tensor: + return (self.time_max - time) * self.sigma_start + time * self.sigma_end + + def vector_field( + self, + *, + x_start: torch.Tensor, + x_end: torch.Tensor, + time: torch.Tensor, + point: torch.Tensor, + eps: float = 1e-6, + ) -> torch.Tensor: + time = self._broadcast_time(time, n_dim=x_start.ndim) + + if self.sigma_start == self.sigma_end: + return x_end - x_start + + num = self.sigma_end * (point - x_start) - self.sigma_start * (point - x_end) + denom = (1 - time) * self.sigma_start + time * self.sigma_end + return num / (denom + eps) + + +class ConditionalFlowMatchingSampler(ABC): + """ + Abstract class for different sampler to solve the ODE in CFM + + Args: + estimator: the NN-based conditional vector field estimator + num_steps: How many time steps to iterate in the process + time_min: minimum time value used in the process + time_max: maximum time value used in the process + + """ + + def __init__( + self, + estimator: torch.nn.Module, + num_steps: int = 5, + time_min: float = 1e-8, + time_max: float = 1.0, + ): + self.estimator = estimator + self.num_steps = num_steps + self.time_min = time_min + self.time_max = time_max + + @property + def time_step(self): + return (self.time_max - self.time_min) / self.num_steps + + @abstractmethod + def forward( + self, state: torch.Tensor, estimator_condition: torch.Tensor, state_length: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + +class ConditionalFlowMatchingEulerSampler(ConditionalFlowMatchingSampler): + """ + The Euler Sampler for solving the ODE in CFM on a uniform time grid + """ + + def __init__( + self, + estimator: torch.nn.Module, + num_steps: int = 5, + time_min: float = 1e-8, + time_max: float = 1.0, + ): + super().__init__( + estimator=estimator, + num_steps=num_steps, + time_min=time_min, + time_max=time_max, + ) + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tnum_steps: %s', self.num_steps) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @torch.inference_mode() + def forward( + self, state: torch.Tensor, estimator_condition: torch.Tensor, state_length: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + time_steps = torch.linspace(self.time_min, self.time_max, self.num_steps + 1) + + if state_length is not None: + state = mask_sequence_tensor(state, state_length) + + for t in time_steps: + time = t * torch.ones(state.shape[0], device=state.device) + + if estimator_condition is None: + estimator_input = state + else: + estimator_input = torch.cat([state, estimator_condition], dim=1) + + vector_field, _ = self.estimator(input=estimator_input, input_length=state_length, condition=time) + + state = state + vector_field * self.time_step + + if state_length is not None: + state = mask_sequence_tensor(state, state_length) + + return state, state_length diff --git a/nemo/collections/audio/parts/submodules/transformerunet.py b/nemo/collections/audio/parts/submodules/transformerunet.py new file mode 100644 index 000000000000..b7c14d513bab --- /dev/null +++ b/nemo/collections/audio/parts/submodules/transformerunet.py @@ -0,0 +1,507 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2023 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +from functools import partial +from typing import Dict, Optional + +import einops +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Module + +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import BoolType, FloatType, LengthsType, NeuralType, SpectrogramType +from nemo.utils import logging + +__all__ = ['TransformerUNet'] + + +class LearnedSinusoidalPosEmb(Module): + """The sinusoidal Embedding to encode time conditional information""" + + def __init__(self, dim: int): + super().__init__() + if (dim % 2) != 0: + raise ValueError(f"Input dimension {dim} is not divisible by 2!") + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, t: torch.Tensor) -> torch.Tensor: + """ + Args: + t: input time tensor, shape (B) + + Return: + fouriered: the encoded time conditional embedding, shape (B, D) + """ + t = einops.rearrange(t, 'b -> b 1') + freqs = t * einops.rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + return fouriered + + +class ConvPositionEmbed(Module): + """The Convolutional Embedding to encode time information of each frame""" + + def __init__(self, dim: int, kernel_size: int, groups: Optional[int] = None): + super().__init__() + if (kernel_size % 2) == 0: + raise ValueError(f"Kernel size {kernel_size} is divisible by 2!") + + if groups is None: + groups = dim + + self.dw_conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), nn.GELU() + ) + + def forward(self, x, mask=None): + """ + Args: + x: input tensor, shape (B, T, D) + + Return: + out: output tensor with the same shape (B, T, D) + """ + + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(mask, 0.0) + + x = einops.rearrange(x, 'b n c -> b c n') + x = self.dw_conv1d(x) + out = einops.rearrange(x, 'b c n -> b n c') + + if mask is not None: + out = out.masked_fill(mask, 0.0) + + return out + + +class RMSNorm(Module): + """The Root Mean Square Layer Normalization + + References: + - Zhang et al., Root Mean Square Layer Normalization, 2019 + """ + + def __init__(self, dim): + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +class AdaptiveRMSNorm(Module): + """ + Adaptive Root Mean Square Layer Normalization given a conditional embedding. + This enables the model to consider the conditional input during normalization. + """ + + def __init__(self, dim: int, cond_dim: Optional[int] = None): + super().__init__() + if cond_dim is None: + cond_dim = dim + self.scale = dim**0.5 + + self.to_gamma = nn.Linear(cond_dim, dim) + self.to_beta = nn.Linear(cond_dim, dim) + + # init adaptive normalization to identity + + nn.init.zeros_(self.to_gamma.weight) + nn.init.ones_(self.to_gamma.bias) + + nn.init.zeros_(self.to_beta.weight) + nn.init.zeros_(self.to_beta.bias) + + def forward(self, x: torch.Tensor, cond: torch.Tensor): + normed = F.normalize(x, dim=-1) * self.scale + + gamma, beta = self.to_gamma(cond), self.to_beta(cond) + gamma = einops.rearrange(gamma, 'B D -> B 1 D') + beta = einops.rearrange(beta, 'B D -> B 1 D') + + return normed * gamma + beta + + +class GEGLU(Module): + """The GeGLU activation implementation""" + + def forward(self, x: torch.Tensor): + x, gate = x.chunk(2, dim=-1) + return F.gelu(gate) * x + + +def get_feedforward_layer(dim: int, mult: int = 4, dropout: float = 0.0): + """ + Return a Feed-Forward layer for the Transformer Layer. + GeGLU activation is used in this FF layer + """ + dim_inner = int(dim * mult * 2 / 3) + return nn.Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), nn.Dropout(dropout), nn.Linear(dim_inner, dim)) + + +class TransformerUNet(NeuralModule): + """ + Implementation of the transformer Encoder Model with U-Net structure used in + VoiceBox and AudioBox + + References: + Le et al., Voicebox: Text-Guided Multilingual Universal Speech Generation at Scale, 2023 + Vyas et al., Audiobox: Unified Audio Generation with Natural Language Prompts, 2023 + """ + + def __init__( + self, + dim: int, + depth: int, + heads: int = 8, + ff_mult: int = 4, + attn_dropout: float = 0.0, + ff_dropout: float = 0.0, + max_positions: int = 6000, + adaptive_rmsnorm: bool = False, + adaptive_rmsnorm_cond_dim_in: Optional[int] = None, + use_unet_skip_connection: bool = True, + skip_connect_scale: Optional[int] = None, + ): + """ + Args: + dim: Embedding dimension + depth: Number of Transformer Encoder Layers + heads: Number of heads in MHA + ff_mult: The multiplier for the feedforward dimension (ff_dim = ff_mult * dim) + attn_dropout: dropout rate for the MHA layer + ff_dropout: droupout rate for the feedforward layer + max_positions: The maximum time length of the input during training and inference + adaptive_rmsnorm: Whether to use AdaptiveRMS layer. + Set to True if the model has a conditional embedding in forward() + adaptive_rms_cond_dim_in: Dimension of the conditional embedding + use_unet_skip_connection: Whether to use U-Net or not + skip_connect_scale: The scale of the U-Net connection. + """ + super().__init__() + if (depth % 2) != 0: + raise ValueError(f"Number of layers {depth} is not divisible by 2!") + self.layers = nn.ModuleList([]) + self.init_alibi(max_positions=max_positions, heads=heads) + + if adaptive_rmsnorm: + rmsnorm_class = partial(AdaptiveRMSNorm, cond_dim=adaptive_rmsnorm_cond_dim_in) + else: + rmsnorm_class = RMSNorm + + if skip_connect_scale is None: + self.skip_connect_scale = 2**-0.5 + else: + self.skip_connect_scale = skip_connect_scale + + for ind in range(depth): + layer = ind + 1 + has_skip = use_unet_skip_connection and layer > (depth // 2) + + self.layers.append( + nn.ModuleList( + [ + nn.Linear(dim * 2, dim) if has_skip else None, + rmsnorm_class(dim=dim), + nn.MultiheadAttention( + embed_dim=dim, + num_heads=heads, + dropout=attn_dropout, + batch_first=True, + ), + rmsnorm_class(dim=dim), + get_feedforward_layer(dim=dim, mult=ff_mult, dropout=ff_dropout), + ] + ) + ) + + self.final_norm = RMSNorm(dim) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tembedding dim: %s', dim) + logging.debug('\tNumber of Layer: %s', depth) + logging.debug('\tfeedforward dim: %s', dim * ff_mult) + logging.debug('\tnumber of heads: %s', heads) + logging.debug('\tDropout rate of MHA: %s', attn_dropout) + logging.debug('\tDropout rate of FF: %s', ff_dropout) + logging.debug('\tnumber of heads: %s', heads) + logging.debug('\tmaximun time length: %s', max_positions) + logging.debug('\tuse AdaptiveRMS: %s', adaptive_rmsnorm) + logging.debug('\tConditional dim: %s', adaptive_rmsnorm_cond_dim_in) + logging.debug('\tUse UNet connection: %s', use_unet_skip_connection) + logging.debug('\tskip connect scale: %s', self.skip_connect_scale) + + def init_alibi( + self, + max_positions: int, + heads: int, + ): + """Initialize the Alibi bias parameters + + References: + - Press et al., Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation, 2021 + """ + + def get_slopes(n): + ratio = 2 ** (-8 / n) + return ratio ** torch.arange(1, n + 1) + + if not math.log2(heads).is_integer(): + logging.warning( + "It is recommend to set number of attention heads to be the power of 2 for the Alibi bias!" + ) + logging.warning(f"Current value of heads: {heads}") + + self.slopes = nn.Parameter(einops.rearrange(get_slopes(heads), "B -> B 1 1")) + + pos_matrix = ( + -1 * torch.abs(torch.arange(max_positions).unsqueeze(0) - torch.arange(max_positions).unsqueeze(1)).float() + ) + pos_matrix = einops.rearrange(pos_matrix, "T1 T2 -> 1 T1 T2") + self.register_buffer('pos_matrix', pos_matrix, persistent=False) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "x": NeuralType(('B', 'T', 'D'), FloatType()), + "key_padding_mask": NeuralType(('B', 'T'), BoolType(), optional=True), + "adaptive_rmsnorm_cond": NeuralType(('B', 'D'), FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'T', 'D'), FloatType()), + } + + @typecheck() + def forward(self, x, key_padding_mask: Optional[torch.Tensor] = None, adaptive_rmsnorm_cond=None): + """Forward pass of the model. + + Args: + input: input tensor, shape (B, C, D, T) + key_padding_mask: mask tensor indicating the padding parts, shape (B, T) + adaptive_rmsnorm_cond: conditional input for the model, shape (B, D) + """ + batch_size, seq_len, *_ = x.shape + skip_connects = [] + alibi_bias = self.get_alibi_bias(batch_size=batch_size, seq_len=seq_len) + + rmsnorm_kwargs = dict() + if adaptive_rmsnorm_cond is not None: + rmsnorm_kwargs = dict(cond=adaptive_rmsnorm_cond) + + for skip_combiner, attn_prenorm, attn, ff_prenorm, ff in self.layers: + + if skip_combiner is None: + skip_connects.append(x) + else: + skip_connect = skip_connects.pop() * self.skip_connect_scale + x = torch.cat((x, skip_connect), dim=-1) + x = skip_combiner(x) + + attn_input = attn_prenorm(x, **rmsnorm_kwargs) + if key_padding_mask is not None: + # Since Alibi_bias is a float-type attn_mask, the padding_mask need to be float-type. + float_key_padding_mask = key_padding_mask.float() + float_key_padding_mask = float_key_padding_mask.masked_fill(key_padding_mask, float('-inf')) + else: + float_key_padding_mask = None + + attn_output, _ = attn( + query=attn_input, + key=attn_input, + value=attn_input, + key_padding_mask=float_key_padding_mask, + need_weights=False, + attn_mask=alibi_bias, + ) + x = x + attn_output + + ff_input = ff_prenorm(x, **rmsnorm_kwargs) + x = ff(ff_input) + x + + return self.final_norm(x) + + def get_alibi_bias(self, batch_size: int, seq_len: int): + """ + Return the alibi_bias given batch size and seqence length + """ + pos_matrix = self.pos_matrix[:, :seq_len, :seq_len] + alibi_bias = pos_matrix * self.slopes + alibi_bias = alibi_bias.repeat(batch_size, 1, 1) + + return alibi_bias + + +class SpectrogramTransformerUNet(NeuralModule): + """This model handles complex-valued inputs by stacking real and imaginary components. + Stacked tensor is processed using TransformerUNet and the output is projected to generate real + and imaginary components of the output channels. + + Convolutional Positional Embedding is applied for the input sequence + """ + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + freq_dim: int = 256, + dim: int = 1024, + depth: int = 24, + heads: int = 16, + ff_mult: int = 4, + ff_dropout: float = 0.0, + attn_dropout: float = 0.0, + max_positions: int = 6000, + time_hidden_dim: Optional[int] = None, + conv_pos_embed_kernel_size: int = 31, + conv_pos_embed_groups: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + dim_in = freq_dim * in_channels * 2 + + if time_hidden_dim is None: + time_hidden_dim = dim * 4 + + self.proj_in = nn.Linear(dim_in, dim) + + self.sinu_pos_emb = nn.Sequential(LearnedSinusoidalPosEmb(dim), nn.Linear(dim, time_hidden_dim), nn.SiLU()) + + self.conv_embed = ConvPositionEmbed( + dim=dim, kernel_size=conv_pos_embed_kernel_size, groups=conv_pos_embed_groups + ) + + self.transformerunet = TransformerUNet( + dim=dim, + depth=depth, + heads=heads, + ff_mult=ff_mult, + ff_dropout=ff_dropout, + attn_dropout=attn_dropout, + max_positions=max_positions, + adaptive_rmsnorm=True, + adaptive_rmsnorm_cond_dim_in=time_hidden_dim, + use_unet_skip_connection=True, + ) + + # 2x the frequency dimension as the model operates in the complex-value domain + dim_out = freq_dim * out_channels * 2 + + self.proj_out = nn.Linear(dim, dim_out) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tin_channels: %s', self.in_channels) + logging.debug('\tout_channels: %s', self.out_channels) + logging.debug('\tInput frequency dimension: %s', freq_dim) + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + "condition": NeuralType(('B',), FloatType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "output_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @staticmethod + def _get_key_padding_mask(input_length: torch.Tensor, max_length: int): + """ + Return the self_attention masking according to the input length. + 0 indicates the frame is in the valid range, while 1 indicates the frame is a padding frame. + Args: + input_length: shape (B) + max_length (int): The maximum length of the input sequence + + return: + key_padding_mask: shape (B, T) + """ + key_padding_mask = torch.arange(max_length).expand(len(input_length), max_length).to(input_length.device) + key_padding_mask = key_padding_mask >= input_length.unsqueeze(1) + return key_padding_mask + + @typecheck() + def forward(self, input, input_length=None, condition=None): + """Forward pass of the model. + + Args: + input: input tensor, shape (B, C, D, T) + input_length: length of the valid time steps for each example in the batch, shape (B,) + condition: scalar condition (time) for the model, will be embedded using `self.time_embedding` + """ + # Stack real and imaginary components + B, C_in, D, T = input.shape + if C_in != self.in_channels: + raise RuntimeError(f'Unexpected input channel size {C_in}, expected {self.in_channels}') + + input_real_imag = torch.stack([input.real, input.imag], dim=2) + input = einops.rearrange(input_real_imag, 'B C RI D T -> B T (C RI D)') + + x = self.proj_in(input) + key_padding_mask = self._get_key_padding_mask(input_length, max_length=T) + x = self.conv_embed(x, mask=key_padding_mask) + x + + if condition is None: + raise NotImplementedError + + time_emb = self.sinu_pos_emb(condition) + + x = self.transformerunet(x=x, key_padding_mask=key_padding_mask, adaptive_rmsnorm_cond=time_emb) + + output = self.proj_out(x) + output = einops.rearrange(output, "B T (C RI D) -> B C D T RI", C=self.out_channels, RI=2, D=D) + output = torch.view_as_complex(output.contiguous()) + + return output, input_length diff --git a/nemo/collections/audio/parts/utils/callbacks.py b/nemo/collections/audio/parts/utils/callbacks.py new file mode 100644 index 000000000000..093d5a11f419 --- /dev/null +++ b/nemo/collections/audio/parts/utils/callbacks.py @@ -0,0 +1,177 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Type + +import einops +import torch +from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.loggers.wandb import WandbLogger + +from nemo.utils import logging +from nemo.utils.decorators import experimental + +HAVE_WANDB = True +try: + import wandb +except ModuleNotFoundError: + HAVE_WANDB = False + + +def _get_logger(loggers: List[Logger], logger_type: Type[Logger]): + for logger in loggers: + if isinstance(logger, logger_type): + if hasattr(logger, "experiment"): + return logger.experiment + else: + return logger + raise ValueError(f"Could not find {logger_type} logger in {loggers}.") + + +@experimental +class SpeechEnhancementLoggingCallback(Callback): + """ + Callback which can log artifacts (eg. model predictions, graphs) to local disk, Tensorboard, and/or WandB. + + Args: + data_loader: Data to log artifacts for. + output_dir: Optional local directory. If provided, artifacts will be saved in output_dir. + loggers: Optional list of loggers to use if logging to tensorboard or wandb. + log_tensorboard: Whether to log artifacts to tensorboard. + log_wandb: Whether to log artifacts to WandB. + """ + + def __init__( + self, + data_loader, + data_loader_idx: int, + loggers: Optional[List[Logger]] = None, + log_tensorboard: bool = False, + log_wandb: bool = False, + sample_rate: int = 16000, + max_utts: Optional[int] = None, + ): + self.data_loader = data_loader + self.data_loader_idx = data_loader_idx + self.loggers = loggers if loggers else [] + self.log_tensorboard = log_tensorboard + self.log_wandb = log_wandb + self.sample_rate = sample_rate + self.max_utts = max_utts + + if log_tensorboard: + logging.info('Creating tensorboard logger') + self.tensorboard_logger = _get_logger(self.loggers, TensorBoardLogger) + else: + logging.debug('Not using tensorbord logger') + self.tensorboard_logger = None + + if log_wandb: + if not HAVE_WANDB: + raise ValueError("Wandb not installed.") + logging.info('Creating wandb logger') + self.wandb_logger = _get_logger(self.loggers, WandbLogger) + else: + logging.debug('Not using wandb logger') + self.wandb_logger = None + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tlog_tensorboard: %s', self.log_tensorboard) + logging.debug('\tlog_wandb: %s', self.log_wandb) + + def _log_audio(self, audios: torch.Tensor, lengths: torch.Tensor, step: int, label: str = "input"): + + num_utts = audios.size(0) + for audio_idx in range(num_utts): + length = lengths[audio_idx] + if self.tensorboard_logger: + self.tensorboard_logger.add_audio( + tag=f"{label}_{audio_idx}", + snd_tensor=audios[audio_idx, :length], + global_step=step, + sample_rate=self.sample_rate, + ) + + if self.wandb_logger: + wandb_audio = ( + wandb.Audio(audios[audio_idx], sample_rate=self.sample_rate, caption=f"{label}_{audio_idx}"), + ) + self.wandb_logger.log({f"{label}_{audio_idx}": wandb_audio}) + + def on_validation_epoch_end(self, trainer: Trainer, model: LightningModule): + """Log artifacts at the end of an epoch.""" + epoch = 1 + model.current_epoch + output_signal_list = [] + output_length_list = [] + num_examples_uploaded = 0 + + logging.info(f"Logging processed speech for validation dataset {self.data_loader_idx}...") + for batch in self.data_loader: + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch.get('target_signal', input_signal.clone()) + else: + input_signal, input_length, target_signal, _ = batch + + if self.max_utts is None: + num_examples = input_signal.size(0) # batch size + do_upload = True + else: + do_upload = num_examples_uploaded < self.max_utts + num_examples = min(self.max_utts - num_examples_uploaded, input_signal.size(0)) + num_examples_uploaded += num_examples + + if do_upload: + # Only pick the required numbers of speech to the logger + input_signal = input_signal[:num_examples, ...] + target_signal = target_signal[:num_examples, ...] + input_length = input_length[:num_examples] + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + input_signal = input_signal.to(model.device) + input_length = input_length.to(model.device) + + output_signal, output_length = model(input_signal=input_signal, input_length=input_length) + output_signal_list.append(output_signal.to(target_signal.device)) + output_length_list.append(output_length.to(target_signal.device)) + + if len(output_signal_list) == 0: + logging.debug('List are empty, no artifacts to log at epoch %d.', epoch) + return + + output_signals = torch.concat(output_signal_list, dim=0) + output_lengths = torch.concat(output_length_list, dim=0) + if output_signals.size(1) != 1: + logging.error( + f"Currently only supports single-channel audio! Current output shape: {output_signals.shape}" + ) + raise NotImplementedError + + output_signals = einops.rearrange(output_signals, "B 1 T -> B T") + + self._log_audio( + audios=output_signals, + lengths=output_lengths, + step=model.global_step, + label=f"dataloader_{self.data_loader_idx}_processed", + )