Skip to content

Commit

Permalink
feat: add noise conditioning, move cfg, add unet1dall
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Nov 25, 2022
1 parent c937e49 commit 5bf1837
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 16 deletions.
2 changes: 1 addition & 1 deletion audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@
DiffusionVocoder1d,
Model1d,
)
from .modules import NumberEmbedder, T5Embedder, UNet1d, UNetConditional1d
from .modules import NumberEmbedder, T5Embedder, UNet1d
4 changes: 2 additions & 2 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tqdm import tqdm

from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d, rand_bool
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetCFG1d, rand_bool
from .utils import (
closest_power_2,
default,
Expand All @@ -34,7 +34,7 @@ def __init__(
super().__init__()
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)

UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
UNet = UNetCFG1d if use_classifier_free_guidance else UNet1d
self.unet = UNet(**kwargs)

self.diffusion = XDiffusion(
Expand Down
83 changes: 71 additions & 12 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from einops_exts import rearrange_many
from torch import Tensor, einsum

from .utils import closest_power_2, default, exists, groupby
from .utils import closest_power_2, default, exists, groupby, is_sequence

"""
Utils
Expand Down Expand Up @@ -909,9 +909,11 @@ def __init__(
self.use_stft = use_stft
self.use_stft_context = use_stft_context

self.context_features = context_features
context_channels_pad_length = num_layers + 1 - len(context_channels)
context_channels = context_channels + [0] * context_channels_pad_length
self.context_channels = context_channels
self.context_embedding_features = context_embedding_features

if use_context_channels:
has_context = [c > 0 for c in context_channels]
Expand Down Expand Up @@ -1140,22 +1142,21 @@ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)


class UNetConditional1d(UNet1d):
"""
UNet1d with classifier-free guidance on the token embeddings
"""
class UNetCFG1d(UNet1d):

"""UNet1d with Classifier-Free Guidance"""

def __init__(
self,
context_embedding_features: int,
context_embedding_max_length: int,
context_embedding_features: int,
**kwargs,
):
super().__init__(
context_embedding_features=context_embedding_features, **kwargs
)
self.fixed_embedding = FixedEmbedding(
context_embedding_max_length, context_embedding_features
max_length=context_embedding_max_length, features=context_embedding_features
)

def forward( # type: ignore
Expand All @@ -1178,14 +1179,72 @@ def forward( # type: ignore
)
embedding = torch.where(batch_mask, fixed_embedding, embedding)

out = super().forward(x, time, embedding=embedding, **kwargs)

if embedding_scale != 1.0:
# Scale conditional output using classifier-free guidance
# Compute both normal and fixed embedding outputs
out = super().forward(x, time, embedding=embedding, **kwargs)
out_masked = super().forward(x, time, embedding=fixed_embedding, **kwargs)
out = out_masked + (out - out_masked) * embedding_scale
# Scale conditional output using classifier-free guidance
return out_masked + (out - out_masked) * embedding_scale
else:
return super().forward(x, time, embedding=embedding, **kwargs)


class UNetNCCA1d(UNet1d):

"""UNet1d with Noise Channel Conditioning Augmentation"""

def __init__(self, context_features: int, **kwargs):
super().__init__(context_features=context_features, **kwargs)
self.embedder = NumberEmbedder(features=context_features)

def forward( # type: ignore
self,
x: Tensor,
time: Tensor,
*,
channels_list: Sequence[Tensor],
channels_augmentation: bool = False,
channels_scale: Union[int, Sequence[int]] = 0,
**kwargs,
) -> Tensor:
b, num_items = x.shape[0], len(channels_list)

if channels_augmentation:
# Random noise augmentation for each item
channels_scale = torch.rand(num_items, b).to(x) # type: ignore
for i in range(num_items):
item = channels_list[i]
scale = rearrange(channels_scale[i], "b -> b 1 1") # type: ignore
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
else:
# Expand same scale to each batch element
if is_sequence(channels_scale):
assert_message = "len(channels_scale) must match len(channels_list)"
assert len(channels_scale) == num_items, assert_message
else:
channels_scale = num_items * [channels_scale] # type: ignore
channels_scale = torch.tensor(channels_scale).to(x) # type: ignore
channels_scale = repeat(channels_scale, "n -> n b", b=b)

# Compute scale feature embedding
scale_embedding = self.embedder(channels_scale)
scale_embedding = reduce(scale_embedding, "n b d -> b d", "sum")

return super().forward(
x=x,
time=time,
channels_list=channels_list,
features=scale_embedding,
**kwargs,
)


class UNetAll1d(UNetCFG1d, UNetNCCA1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

return out
def forward(self, *args, **kwargs): # type: ignore
return UNetCFG1d.forward(self, *args, **kwargs)


class T5Embedder(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.93",
version="0.0.94",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 5bf1837

Please sign in to comment.