From 5bf183781e6b36c2ee072300e76a9eb97cf79f88 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Fri, 25 Nov 2022 19:39:31 +0100 Subject: [PATCH] feat: add noise conditioning, move cfg, add unet1dall --- audio_diffusion_pytorch/__init__.py | 2 +- audio_diffusion_pytorch/model.py | 4 +- audio_diffusion_pytorch/modules.py | 83 ++++++++++++++++++++++++----- setup.py | 2 +- 4 files changed, 75 insertions(+), 16 deletions(-) diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index 3443f49..61f3b85 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -35,4 +35,4 @@ DiffusionVocoder1d, Model1d, ) -from .modules import NumberEmbedder, T5Embedder, UNet1d, UNetConditional1d +from .modules import NumberEmbedder, T5Embedder, UNet1d diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index 9e6eb0c..fbf9ac3 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -9,7 +9,7 @@ from tqdm import tqdm from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion -from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d, rand_bool +from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetCFG1d, rand_bool from .utils import ( closest_power_2, default, @@ -34,7 +34,7 @@ def __init__( super().__init__() diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) - UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d + UNet = UNetCFG1d if use_classifier_free_guidance else UNet1d self.unet = UNet(**kwargs) self.diffusion = XDiffusion( diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index 017ef48..e044507 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -8,7 +8,7 @@ from einops_exts import rearrange_many from torch import Tensor, einsum -from .utils import closest_power_2, default, exists, groupby +from .utils import closest_power_2, default, exists, groupby, is_sequence """ Utils @@ -909,9 +909,11 @@ def __init__( self.use_stft = use_stft self.use_stft_context = use_stft_context + self.context_features = context_features context_channels_pad_length = num_layers + 1 - len(context_channels) context_channels = context_channels + [0] * context_channels_pad_length self.context_channels = context_channels + self.context_embedding_features = context_embedding_features if use_context_channels: has_context = [c > 0 for c in context_channels] @@ -1140,22 +1142,21 @@ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) -class UNetConditional1d(UNet1d): - """ - UNet1d with classifier-free guidance on the token embeddings - """ +class UNetCFG1d(UNet1d): + + """UNet1d with Classifier-Free Guidance""" def __init__( self, - context_embedding_features: int, context_embedding_max_length: int, + context_embedding_features: int, **kwargs, ): super().__init__( context_embedding_features=context_embedding_features, **kwargs ) self.fixed_embedding = FixedEmbedding( - context_embedding_max_length, context_embedding_features + max_length=context_embedding_max_length, features=context_embedding_features ) def forward( # type: ignore @@ -1178,14 +1179,72 @@ def forward( # type: ignore ) embedding = torch.where(batch_mask, fixed_embedding, embedding) - out = super().forward(x, time, embedding=embedding, **kwargs) - if embedding_scale != 1.0: - # Scale conditional output using classifier-free guidance + # Compute both normal and fixed embedding outputs + out = super().forward(x, time, embedding=embedding, **kwargs) out_masked = super().forward(x, time, embedding=fixed_embedding, **kwargs) - out = out_masked + (out - out_masked) * embedding_scale + # Scale conditional output using classifier-free guidance + return out_masked + (out - out_masked) * embedding_scale + else: + return super().forward(x, time, embedding=embedding, **kwargs) + + +class UNetNCCA1d(UNet1d): + + """UNet1d with Noise Channel Conditioning Augmentation""" + + def __init__(self, context_features: int, **kwargs): + super().__init__(context_features=context_features, **kwargs) + self.embedder = NumberEmbedder(features=context_features) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + channels_list: Sequence[Tensor], + channels_augmentation: bool = False, + channels_scale: Union[int, Sequence[int]] = 0, + **kwargs, + ) -> Tensor: + b, num_items = x.shape[0], len(channels_list) + + if channels_augmentation: + # Random noise augmentation for each item + channels_scale = torch.rand(num_items, b).to(x) # type: ignore + for i in range(num_items): + item = channels_list[i] + scale = rearrange(channels_scale[i], "b -> b 1 1") # type: ignore + channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa + else: + # Expand same scale to each batch element + if is_sequence(channels_scale): + assert_message = "len(channels_scale) must match len(channels_list)" + assert len(channels_scale) == num_items, assert_message + else: + channels_scale = num_items * [channels_scale] # type: ignore + channels_scale = torch.tensor(channels_scale).to(x) # type: ignore + channels_scale = repeat(channels_scale, "n -> n b", b=b) + + # Compute scale feature embedding + scale_embedding = self.embedder(channels_scale) + scale_embedding = reduce(scale_embedding, "n b d -> b d", "sum") + + return super().forward( + x=x, + time=time, + channels_list=channels_list, + features=scale_embedding, + **kwargs, + ) + + +class UNetAll1d(UNetCFG1d, UNetNCCA1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - return out + def forward(self, *args, **kwargs): # type: ignore + return UNetCFG1d.forward(self, *args, **kwargs) class T5Embedder(nn.Module): diff --git a/setup.py b/setup.py index 8abafd1..b010eba 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.93", + version="0.0.94", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",