Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: Kuray107 <Kuray107@users.noreply.github.com>
  • Loading branch information
Kuray107 authored and anteju committed Aug 6, 2024
1 parent 98a896c commit 6f9a9c6
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 195 deletions.
2 changes: 1 addition & 1 deletion examples/audio/audio_to_audio_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@

from nemo.collections.audio.models.enhancement import (
EncMaskDecAudioToAudioModel,
FlowMatchingAudioToAudioModel,
PredictiveAudioToAudioModel,
SchroedingerBridgeAudioToAudioModel,
ScoreBasedGenerativeAudioToAudioModel,
FlowMatchingAudioToAudioModel,
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/audio/data/audio_to_audio_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __getitem__(self, cuts: CutSet) -> dict[str, torch.Tensor]:
"input_length": src_audio_lens,
}
# keep only the first non-padding cuts
retained_cuts = [cut._first_non_padding_cut if isinstance(cut, MixedCut) else cut for cut in retained_padded_cuts]
retained_cuts = [
cut._first_non_padding_cut if isinstance(cut, MixedCut) else cut for cut in retained_padded_cuts
]
retained_cuts = CutSet.from_cuts(retained_cuts)

if _key_available(retained_cuts, self.TARGET_KEY):
Expand Down
9 changes: 5 additions & 4 deletions nemo/collections/audio/models/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def on_after_backward(self):
if valid_gradients < 1:
logging.warning('detected inf or nan values in gradients! Setting gradients to zero.')
self.zero_grad(set_to_none=False)

def configure_callbacks(self):
"""
Create an callback to add audio/spectrogram into tensorboard & wandb.
Expand All @@ -495,11 +495,12 @@ def configure_callbacks(self):

log_callbacks = []
from nemo.collections.audio.parts.utils.callbacks import SpeechEnhancementLoggingCallback

if isinstance(self._validation_dl, List):
data_loaders = self._validation_dl
else:
data_loaders = [self._validation_dl]

for data_loader_idx, data_loader in enumerate(data_loaders):
log_callbacks.append(
SpeechEnhancementLoggingCallback(
Expand All @@ -509,8 +510,8 @@ def configure_callbacks(self):
log_tensorboard=self.log_config.log_tensorboard,
log_wandb=self.log_config.log_wandb,
sample_rate=self.sample_rate,
max_utts=self.log_config.get("max_utts", None)
max_utts=self.log_config.get("max_utts", None),
)
)

return log_callbacks
return log_callbacks
13 changes: 8 additions & 5 deletions nemo/collections/audio/models/enhancement.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
'ScoreBasedGenerativeAudioToAudioModel',
'PredictiveAudioToAudioModel',
'SchroedingerBridgeAudioToAudioModel',
'FlowMatchingAudioToAudioModel'
'FlowMatchingAudioToAudioModel',
]


Expand Down Expand Up @@ -619,6 +619,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =

return {f'{tag}_loss': loss}


class FlowMatchingAudioToAudioModel(AudioToAudioModel):
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg=cfg, trainer=trainer)
Expand All @@ -637,7 +638,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Sampler
self.sampler = hydra.utils.instantiate(self._cfg.sampler, estimator=self.estimator)

# probability that the conditional input will be feed into the
# probability that the conditional input will be feed into the
# estimator in the training stage
self.p_cond = self._cfg.get('p_cond', 1.0)

Expand Down Expand Up @@ -715,7 +716,7 @@ def forward(self, input_signal, input_length=None):

if self.p_cond == 0:
encoded = torch.zeros_like(encoded)

if self.ssl_pretrain_masking is not None:
encoded = self.ssl_pretrain_masking(input_spec=encoded, length=encoded_length)

Expand Down Expand Up @@ -744,7 +745,9 @@ def forward(self, input_signal, input_length=None):
"input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
"input_length": NeuralType(tuple('B'), LengthsType()),
},
output_types={"loss": NeuralType(None, LossType()),},
output_types={
"loss": NeuralType(None, LossType()),
},
)
def _step(self, target_signal, input_signal, input_length=None):
batch_size = target_signal.size(0)
Expand Down Expand Up @@ -816,7 +819,7 @@ def training_step(self, batch, batch_idx):
return loss

def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):

if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
Expand Down
14 changes: 5 additions & 9 deletions nemo/collections/audio/modules/ssl_pretrain_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,20 @@
import torch

from nemo.core.classes import NeuralModule, typecheck
from nemo.core.neural_types import (
LengthsType,
NeuralType,
SpectrogramType,
)
from nemo.core.neural_types import LengthsType, NeuralType, SpectrogramType

__all__ = ['SSLPretrainWithMaskedPatch']


class SSLPretrainWithMaskedPatch(NeuralModule):
"""
Zeroes out fixed size time patches of the spectrogram.
All samples in batch are guaranteed to have the same amount of masked time steps.
Args:
patch_size (int): up to how many time steps does one patch consist of.
Defaults to 48.
mask_fraction (float): how much fraction in each sample to be masked (number of patches is rounded up).
mask_fraction (float): how much fraction in each sample to be masked (number of patches is rounded up).
Range from 0.0 to 1.0. Defaults to 0.7.
"""

Expand Down Expand Up @@ -63,7 +60,6 @@ def __init__(
else:
self.mask_fraction = mask_fraction


@typecheck()
def forward(self, input_spec, length):
augmented_spec = input_spec
Expand All @@ -89,4 +85,4 @@ def forward(self, input_spec, length):
mask = einops.rearrange(mask, 'T -> 1 1 1 T').float()
augmented_spec = augmented_spec * mask

return augmented_spec
return augmented_spec
66 changes: 41 additions & 25 deletions nemo/collections/audio/parts/submodules/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,29 @@
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Tuple
import torch

import einops
import torch

from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor
from nemo.utils import logging


class ConditionalFlow(ABC):
"""
Abstract class for different conditional flow-matching (CFM) classes
Time horizon is [time_min, time_max (should be 1)]
every path is "conditioned" on endpoints of the path
endpoints are just our paired data samples
subclasses need to either:
1. implement mean, std, d_mean, d_std
2. implement mean, std; override vector_field, flow (also d_flow if needed)
"""
def __init__(self, time_min: float=1e-8, time_max: float = 1.0):

def __init__(self, time_min: float = 1e-8, time_max: float = 1.0):
self.time_min = time_min
self.time_max = time_max

Expand All @@ -42,21 +45,21 @@ def mean(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torc
Return the mean of p_t(x | start_state, end_state) at time t
"""
pass

@abstractmethod
def std(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor:
"""
Return the standard deviation of p_t(x | start_state, end_state) at time t
"""
pass

@abstractmethod
def d_mean(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor:
"""
Return the time derivatives of mean of p_t(x | start_state, end_state) at time t
"""
pass

@abstractmethod
def d_std(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -86,12 +89,14 @@ def sample(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: to
Generate a sample from p_t(x | start_state, end_state) at time t
"""
time = self._broadcast_time(time, n_dim=start_state.ndim)

mean = self.mean(time=time, start_state=start_state, end_state=end_state)
std = self.std(time=time, start_state=start_state, end_state=end_state)
return mean + std * torch.randn_like(mean)

def vector_field(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
def vector_field(
self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor
) -> torch.Tensor:
"""
Compute the conditional vector field v_t( point | start_state, end_state)
"""
Expand All @@ -106,15 +111,19 @@ def vector_field(self, *, time: torch.Tensor, start_state: torch.Tensor, end_sta
d_std = self.d_std(time=time, start_state=start_state, end_state=end_state)
return d_std * (point - mean) / std + d_mean

def flow(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
def flow(
self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor
) -> torch.Tensor:
"""
Compute the conditional flow phi_t( point | start_state, end_state)
"""
mean = self.mean(time=time, start_state=start_state, end_state=end_state)
std = self.std(time=time, start_state=start_state, end_state=end_state)
return mean + std * (point - start_state)

def d_flow(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
def d_flow(
self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: torch.Tensor, point: torch.Tensor
) -> torch.Tensor:
"""
Compute the time derivatives of conditional flow
"""
Expand All @@ -124,16 +133,16 @@ def d_flow(self, *, time: torch.Tensor, start_state: torch.Tensor, end_state: to


class OptimalTransportFlow(ConditionalFlow):
""" The OT-CFM model from [Lipman et at, 2023]
"""The OT-CFM model from [Lipman et at, 2023]
Every conditional path the following holds:
p_0 = N(start_state, sigma_start)
p_1 = N(end_state, sigma_end),
mean(x, t) = (time_max - t) * start_state + t * end_state
(linear interpolation between start_state and end_state)
std(x, t) = (time_max - t) * sigma_start + t * sigma_end
std(x, t) = (time_max - t) * sigma_start + t * sigma_end
Every conditional path is optimal transport map from p_0(start_state, end_state) to p_1(start_state, end_state)
Marginal path is not guaranteed to be an optimal transport map from p_0 to p_1
Expand All @@ -149,16 +158,13 @@ class OptimalTransportFlow(ConditionalFlow):
sigma_end: the standard deviation of the target distribution
"""

def __init__(self,
time_min: float = 1e-8,
time_max: float = 1.0,
sigma_start: float = 1.0,
sigma_end: float = 1e-4
):
def __init__(
self, time_min: float = 1e-8, time_max: float = 1.0, sigma_start: float = 1.0, sigma_end: float = 1e-4
):
super().__init__(time_min=time_min, time_max=time_max)
self.sigma_start = sigma_start
self.sigma_end = sigma_end

logging.debug('Initialized %s with', self.__class__.__name__)
logging.debug('\ttime_min: %s', self.time_min)
logging.debug('\ttime_max: %s', self.time_max)
Expand All @@ -177,7 +183,15 @@ def d_mean(self, *, start_state: torch.Tensor, end_state: torch.Tensor, time: to
def d_std(self, *, start_state: torch.Tensor, end_state: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
return self.sigma_end - self.sigma_start

def vector_field(self, *, start_state: torch.Tensor, end_state: torch.Tensor, time: torch.Tensor, point: torch.Tensor, eps: float=1e-6) -> torch.Tensor:
def vector_field(
self,
*,
start_state: torch.Tensor,
end_state: torch.Tensor,
time: torch.Tensor,
point: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
time = self._broadcast_time(time, n_dim=start_state.ndim)

if self.sigma_start == self.sigma_end:
Expand All @@ -199,6 +213,7 @@ class ConditionalFlowMatchingSampler(ABC):
time_max: maximum time value used in the process
"""

def __init__(
self,
estimator: torch.nn.Module,
Expand Down Expand Up @@ -226,6 +241,7 @@ class ConditionalFlowMatchingEulerSampler(ConditionalFlowMatchingSampler):
"""
The Euler Sampler for solving the ODE in CFM on a uniform time grid
"""

def __init__(
self,
estimator: torch.nn.Module,
Expand All @@ -234,10 +250,10 @@ def __init__(
time_max: float = 1.0,
):
super().__init__(
estimator = estimator,
num_steps = num_steps,
time_min = time_min,
time_max = time_max,
estimator=estimator,
num_steps=num_steps,
time_min=time_min,
time_max=time_max,
)
logging.debug('Initialized %s with', self.__class__.__name__)
logging.debug('\tnum_steps: %s', self.num_steps)
Expand Down
Loading

0 comments on commit 6f9a9c6

Please sign in to comment.