Skip to content

Commit

Permalink
feat: decouple ae from diffae/diffmae, update ncca
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Nov 26, 2022
1 parent d730653 commit 7fa0ad2
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 225 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ upsampled = upsampler.sample(

### Autoencoding
```py
from audio_diffusion_pytorch import AudioDiffusionAutoencoder
from audio_diffusion_pytorch import AudioDiffusionAE

autoencoder = AudioDiffusionAutoencoder(in_channels=1)
autoencoder = AudioDiffusionAE(in_channels=1)

# Train on audio samples
x = torch.randn(2, 1, 2 ** 18)
Expand Down
7 changes: 4 additions & 3 deletions audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from audio_encoders_pytorch import Encoder1d, ME1d

from .diffusion import (
ADPM2Sampler,
AEulerSampler,
Expand All @@ -21,15 +23,14 @@
XDiffusion,
)
from .model import (
AudioDiffusionAutoencoder,
AudioDiffusionAE,
AudioDiffusionConditional,
AudioDiffusionModel,
AudioDiffusionUpphaser,
AudioDiffusionUpsampler,
AudioDiffusionVocoder,
DiffusionAE1d,
DiffusionAR1d,
DiffusionAutoencoder1d,
DiffusionMAE1d,
DiffusionUpphaser1d,
DiffusionUpsampler1d,
DiffusionVocoder1d,
Expand Down
220 changes: 27 additions & 193 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Optional, Sequence, Tuple, Union

import torch
from audio_encoders_pytorch import Bottleneck, Encoder1d
from audio_encoders_pytorch import Encoder1d
from einops import rearrange
from torch import Tensor, nn
from tqdm import tqdm
Expand All @@ -16,8 +16,6 @@
downsample,
exists,
groupby,
prefix_dict,
prod,
to_list,
upsample,
)
Expand Down Expand Up @@ -104,194 +102,40 @@ def sample( # type: ignore
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore


class DiffusionAutoencoder1d(nn.Module):
def __init__(
self,
in_channels: int,
encoder_inject_depth: int,
encoder_channels: int,
encoder_factors: Sequence[int],
encoder_multipliers: Sequence[int],
encoder_patch_size: int = 1,
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
bottleneck_channels: Optional[int] = None,
unet_type: str = "base",
**kwargs,
):
super().__init__()
self.in_channels = in_channels

encoder_kwargs, kwargs = groupby("encoder_", kwargs)
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)

# Compute context channels
context_channels = [0] * encoder_inject_depth
if exists(bottleneck_channels):
context_channels += [bottleneck_channels]
else:
context_channels += [encoder_channels * encoder_multipliers[-1]]

self.unet = XUNet1d(
type=unet_type,
in_channels=in_channels,
context_channels=context_channels,
**kwargs,
)

self.diffusion = XDiffusion(net=self.unet, **diffusion_kwargs)

self.encoder = Encoder1d(
in_channels=in_channels,
channels=encoder_channels,
patch_size=encoder_patch_size,
factors=encoder_factors,
multipliers=encoder_multipliers,
out_channels=bottleneck_channels,
**encoder_kwargs,
)

self.encoder_downsample_factor = encoder_patch_size * prod(encoder_factors)
self.bottleneck_channels = bottleneck_channels
self.bottlenecks = nn.ModuleList(to_list(bottleneck))
class DiffusionAE1d(Model1d):
"""Diffusion Auto Encoder"""

def encode(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
latent, info = self.encoder(x, with_info=True)
# Apply bottlenecks if present
for bottleneck in self.bottlenecks:
latent, info_bottleneck = bottleneck(latent, with_info=True)
info = {**info, **prefix_dict("bottleneck_", info_bottleneck)}
return (latent, info) if with_info else latent

def forward( # type: ignore
self, x: Tensor, with_info: bool = False, **kwargs
) -> Union[Tensor, Tuple[Tensor, Any]]:
latent, info = self.encode(x, with_info=True)
loss = self.diffusion(x, channels_list=[latent], **kwargs)
return (loss, info) if with_info else loss

def decode(self, latent: Tensor, **kwargs) -> Tensor:
b = latent.shape[0]
length = latent.shape[2] * self.encoder_downsample_factor
# Compute noise by inferring shape from latent length
noise = torch.randn(b, self.in_channels, length, device=latent.device)
# Compute context form latent
default_kwargs = dict(channels_list=[latent])
# Decode by sampling while conditioning on latent channels
return self.sample(noise, **{**default_kwargs, **kwargs})

def sample(self, *args, **kwargs) -> Tensor:
return self.diffusion.sample(*args, **kwargs)


class DiffusionMAE1d(nn.Module):
def __init__(
self,
in_channels: int,
encoder_inject_depth: int,
encoder_channels: int,
encoder_factors: Sequence[int],
encoder_multipliers: Sequence[int],
diffusion_type: str,
stft_num_fft: int,
stft_hop_length: int,
stft_use_complex: bool,
stft_window_length: Optional[int] = None,
encoder_patch_size: int = 1,
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
bottleneck_channels: Optional[int] = None,
unet_type: str = "base",
**kwargs,
self, in_channels: int, encoder: Encoder1d, encoder_inject_depth: int, **kwargs
):
super().__init__()
self.in_channels = in_channels

encoder_kwargs, kwargs = groupby("encoder_", kwargs)
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
stft_kwargs, kwargs = groupby("stft_", kwargs)

# Compute context channels
context_channels = [0] * encoder_inject_depth
if exists(bottleneck_channels):
context_channels += [bottleneck_channels]
else:
context_channels += [encoder_channels * encoder_multipliers[-1]]

self.spectrogram_channels = stft_num_fft // 2 + 1
self.stft_hop_length = stft_hop_length

self.encoder_stft = STFT(
num_fft=stft_num_fft,
hop_length=stft_hop_length,
window_length=stft_window_length,
use_complex=False, # Magnitude encoding
)

self.unet = XUNet1d(
type=unet_type,
super().__init__(
in_channels=in_channels,
context_channels=context_channels,
use_stft=True,
stft_use_complex=stft_use_complex,
stft_num_fft=stft_num_fft,
stft_hop_length=stft_hop_length,
stft_window_length=stft_window_length,
context_channels=[0] * encoder_inject_depth + [encoder.out_channels],
**kwargs,
)

self.diffusion = XDiffusion(
type=diffusion_type, net=self.unet, **diffusion_kwargs
)

self.encoder = Encoder1d(
in_channels=in_channels * self.spectrogram_channels,
channels=encoder_channels,
patch_size=encoder_patch_size,
factors=encoder_factors,
multipliers=encoder_multipliers,
out_channels=bottleneck_channels,
**encoder_kwargs,
)

self.encoder_downsample_factor = encoder_patch_size * prod(encoder_factors)
self.bottleneck_channels = bottleneck_channels
self.bottlenecks = nn.ModuleList(to_list(bottleneck))

def encode(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
# Extract magnitude and encode
magnitude, _ = self.encoder_stft.encode(x)
magnitude_flat = rearrange(magnitude, "b c f t -> b (c f) t")
latent, info = self.encoder(magnitude_flat, with_info=True)
# Apply bottlenecks if present
for bottleneck in self.bottlenecks:
latent, info_bottleneck = bottleneck(latent, with_info=True)
info = {**info, **prefix_dict("bottleneck_", info_bottleneck)}
return (latent, info) if with_info else latent
self.in_channels = in_channels
self.encoder = encoder

def forward( # type: ignore
self, x: Tensor, with_info: bool = False, **kwargs
) -> Union[Tensor, Tuple[Tensor, Any]]:
latent, info = self.encode(x, with_info=True)
loss = self.diffusion(x, channels_list=[latent], **kwargs)
print(latent.shape)
loss = super().forward(x, channels_list=[latent], **kwargs)
return (loss, info) if with_info else loss

def encode(self, *args, **kwargs):
return self.encoder(*args, **kwargs)

def decode(self, latent: Tensor, **kwargs) -> Tensor:
b = latent.shape[0]
length = closest_power_2(
self.stft_hop_length * latent.shape[2] * self.encoder_downsample_factor
)
length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor)
# Compute noise by inferring shape from latent length
noise = torch.randn(b, self.in_channels, length, device=latent.device)
# Compute context form latent
default_kwargs = dict(channels_list=[latent])
# Decode by sampling while conditioning on latent channels
return self.sample(noise, **{**default_kwargs, **kwargs}) # type: ignore

def sample(self, *args, **kwargs) -> Tensor:
return self.diffusion.sample(*args, **kwargs)
return super().sample(noise, **{**default_kwargs, **kwargs})


class DiffusionVocoder1d(Model1d):
Expand Down Expand Up @@ -499,31 +343,21 @@ def sample(self, *args, **kwargs):
return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})


class AudioDiffusionAutoencoder(DiffusionAutoencoder1d):
def __init__(self, *args, **kwargs):
class AudioDiffusionAE(DiffusionAE1d):
def __init__(self, in_channels: int, *args, **kwargs):
default_kwargs = dict(
**get_default_model_kwargs(),
in_channels=in_channels,
encoder=Encoder1d(
in_channels=in_channels,
patch_size=16,
channels=16,
multipliers=[1, 2, 4, 4, 4, 4, 4],
factors=[4, 4, 4, 2, 2, 2],
num_blocks=[2, 2, 2, 2, 2, 2],
out_channels=64,
),
encoder_inject_depth=6,
encoder_channels=16,
encoder_patch_size=16,
encoder_multipliers=[1, 2, 4, 4, 4, 4, 4],
encoder_factors=[4, 4, 4, 2, 2, 2],
encoder_num_blocks=[2, 2, 2, 2, 2, 2],
bottleneck_channels=64,
)
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore

def decode(self, *args, **kwargs):
return super().decode(*args, **{**get_default_sampling_kwargs(), **kwargs})


class AudioDiffusionMAE(DiffusionMAE1d):
def __init__(self, *args, **kwargs):
default_kwargs = dict(
diffusion_type="v",
diffusion_sigma_distribution=UniformDistribution(),
stft_num_fft=1023,
stft_hop_length=256,
)
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore

Expand Down
52 changes: 26 additions & 26 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from einops_exts import rearrange_many
from torch import Tensor, einsum

from .utils import closest_power_2, default, exists, groupby, is_sequence
from .utils import closest_power_2, default, exists, groupby

"""
Utils
Expand Down Expand Up @@ -1197,44 +1197,44 @@ def __init__(self, context_features: int, **kwargs):
super().__init__(context_features=context_features, **kwargs)
self.embedder = NumberEmbedder(features=context_features)

def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
x = x if torch.is_tensor(x) else torch.tensor(x)
return x.expand(shape)

def forward( # type: ignore
self,
x: Tensor,
time: Tensor,
*,
channels_list: Sequence[Tensor],
channels_augmentation: bool = False,
channels_scale: Union[int, Sequence[int]] = 0,
channels_augmentation: Union[
bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
] = False,
channels_scale: Union[
float, Sequence[float], Sequence[Sequence[float]], Tensor
] = 0,
**kwargs,
) -> Tensor:
b, num_items = x.shape[0], len(channels_list)

if channels_augmentation:
# Random noise augmentation for each item
channels_scale = torch.rand(num_items, b).to(x) # type: ignore
for i in range(num_items):
item = channels_list[i]
scale = rearrange(channels_scale[i], "b -> b 1 1") # type: ignore
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
else:
# Expand same scale to each batch element
if is_sequence(channels_scale):
assert_message = "len(channels_scale) must match len(channels_list)"
assert len(channels_scale) == num_items, assert_message
else:
channels_scale = num_items * [channels_scale] # type: ignore
channels_scale = torch.tensor(channels_scale).to(x) # type: ignore
channels_scale = repeat(channels_scale, "n -> n b", b=b)

# Compute scale feature embedding
scale_embedding = self.embedder(channels_scale)
scale_embedding = reduce(scale_embedding, "n b d -> b d", "sum")
b, n = x.shape[0], len(channels_list)
channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)

# Augmentation (for each channel list item)
for i in range(n):
scale = channels_scale[:, i] * channels_augmentation[:, i]
scale = rearrange(scale, "b -> b 1 1")
item = channels_list[i]
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa

# Scale embedding (sum reduction if more than one channel list item)
channels_scale_emb = self.embedder(channels_scale)
channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")

return super().forward(
x=x,
time=time,
channels_list=channels_list,
features=scale_embedding,
features=channels_scale_emb,
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.95",
version="0.0.96",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 7fa0ad2

Please sign in to comment.