Skip to content

Commit

Permalink
feat: update diffusion ae, remove aes, stft unet1d, new vocoder
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 28, 2022
1 parent 21014f9 commit d0b206a
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 592 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ from audio_diffusion_pytorch import AudioDiffusionAutoencoder

autoencoder = AudioDiffusionAutoencoder(
in_channels=1,
encoder_depth=4,
encoder_channels=32
encoder_depth=4
)

# Train on audio samples
Expand Down Expand Up @@ -157,8 +156,7 @@ unet = UNet1d(
kernel_multiplier_downsample=2,
use_nearest_upsample=False,
use_skip_scale=True,
use_context_time=True,
use_magnitude_channels=False
use_context_time=True
)

x = torch.randn(3, 1, 2 ** 16)
Expand Down
14 changes: 1 addition & 13 deletions audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,4 @@
DiffusionVocoder1d,
Model1d,
)
from .modules import (
AutoEncoder1d,
Decoder1d,
Encoder1d,
MultiEncoder1d,
Noiser,
STFTAutoEncoder1d,
T5Embedder,
Tanh,
UNet1d,
UNetConditional1d,
Variational,
)
from .modules import T5Embedder, UNet1d, UNetConditional1d
152 changes: 71 additions & 81 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Optional, Sequence, Tuple, Union

import torch
from audio_encoders_pytorch import Bottleneck, Encoder1d
from einops import rearrange
from torch import Tensor, nn

Expand All @@ -16,15 +17,18 @@
VKDiffusion,
VSampler,
)
from .modules import (
STFT,
Bottleneck,
MultiEncoder1d,
SinusoidalEmbedding,
UNet1d,
UNetConditional1d,
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d
from .utils import (
closest_power_2,
default,
downsample,
exists,
groupby,
prefix_dict,
prod,
to_list,
upsample,
)
from .utils import default, downsample, exists, groupby_kwargs_prefix, to_list, upsample

"""
Diffusion Classes (generic for 1d data)
Expand All @@ -36,7 +40,7 @@ def __init__(
self, diffusion_type: str, use_classifier_free_guidance: bool = False, **kwargs
):
super().__init__()
diffusion_kwargs, kwargs = groupby_kwargs_prefix("diffusion_", kwargs)
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)

UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
self.unet = UNet(**kwargs)
Expand Down Expand Up @@ -149,31 +153,25 @@ def __init__(
resnet_groups: int,
kernel_multiplier_downsample: int,
encoder_depth: int,
encoder_channels: int,
bottleneck: Optional[Bottleneck] = None,
encoder_num_blocks: Optional[Sequence[int]] = None,
encoder_out_layers: int = 0,
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
bottleneck_channels: Optional[int] = None,
use_stft: bool = False,
**kwargs,
):
self.in_channels = in_channels
encoder_num_blocks = default(encoder_num_blocks, num_blocks)
assert_message = "The number of encoder_num_blocks must match encoder_depth"
assert len(encoder_num_blocks) >= encoder_depth, assert_message
assert patch_blocks == 1, "patch_blocks != 1 not supported"
assert not use_stft, "use_stft not supported"
self.factor = patch_factor * prod(factors[0:encoder_depth])

multiencoder = MultiEncoder1d(
in_channels=in_channels,
channels=channels,
patch_blocks=patch_blocks,
patch_factor=patch_factor,
num_layers=encoder_depth,
num_layers_out=encoder_out_layers,
latent_channels=encoder_channels,
multipliers=multipliers,
factors=factors,
num_blocks=encoder_num_blocks,
kernel_multiplier_downsample=kernel_multiplier_downsample,
resnet_groups=resnet_groups,
)
context_channels = [0] * encoder_depth
if exists(bottleneck_channels):
context_channels += [bottleneck_channels]
else:
context_channels += [channels * multipliers[encoder_depth]]

super().__init__(
in_channels=in_channels,
Expand All @@ -185,89 +183,81 @@ def __init__(
num_blocks=num_blocks,
resnet_groups=resnet_groups,
kernel_multiplier_downsample=kernel_multiplier_downsample,
context_channels=multiencoder.channels_list,
context_channels=context_channels,
**kwargs,
)

self.bottleneck = bottleneck
self.multiencoder = multiencoder
self.bottlenecks = nn.ModuleList(to_list(bottleneck))
self.encoder = Encoder1d(
in_channels=in_channels,
channels=channels,
patch_size=patch_factor,
multipliers=multipliers[0 : encoder_depth + 1],
factors=factors[0:encoder_depth],
num_blocks=encoder_num_blocks[0:encoder_depth],
resnet_groups=resnet_groups,
out_channels=bottleneck_channels,
)

def encode(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
latent, info = self.encoder(x, with_info=True)
for bottleneck in self.bottlenecks:
x, info_bottleneck = bottleneck(x, with_info=True)
info = {**info, **prefix_dict("bottleneck_", info_bottleneck)}
return (latent, info) if with_info else latent

def forward( # type: ignore
self, x: Tensor, with_info: bool = False, **kwargs
) -> Union[Tensor, Tuple[Tensor, Any]]:
if with_info:
latent, info = self.encode(x, with_info=True)
else:
latent = self.encode(x)

channels_list = self.multiencoder.decode(latent)
loss = self.diffusion(x, channels_list=channels_list, **kwargs)
latent, info = self.encode(x, with_info=True)
loss = self.diffusion(x, channels_list=[latent], **kwargs)
return (loss, info) if with_info else loss

def encode(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
latent = self.multiencoder.encode(x)
latent = torch.tanh(latent)
# Apply bottleneck if provided (e.g. quantization module)
if exists(self.bottleneck):
latent, info = self.bottleneck(latent)
return (latent, info) if with_info else latent
return latent

def decode(self, latent: Tensor, **kwargs) -> Tensor:
b, length = latent.shape[0], latent.shape[2] * self.multiencoder.factor
b, length = latent.shape[0], latent.shape[2] * self.factor
# Compute noise by inferring shape from latent length
noise = torch.randn(b, self.in_channels, length).to(latent)
# Compute context form latent
channels_list = self.multiencoder.decode(latent)
default_kwargs = dict(channels_list=channels_list)
default_kwargs = dict(channels_list=[latent])
# Decode by sampling while conditioning on latent channels
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore


class DiffusionVocoder1d(Model1d):
def __init__(
self,
in_channels: int,
vocoder_num_fft: int,
**kwargs,
):
self.frequency_channels = vocoder_num_fft // 2 + 1
spectrogram_channels = in_channels * self.frequency_channels

vocoder_kwargs, kwargs = groupby_kwargs_prefix("vocoder_", kwargs)
def __init__(self, in_channels: int, stft_num_fft: int, **kwargs):
self.stft_num_fft = stft_num_fft
spectrogram_channels = stft_num_fft // 2 + 1
default_kwargs = dict(
in_channels=spectrogram_channels, context_channels=[spectrogram_channels]
in_channels=in_channels,
use_stft=True,
stft_num_fft=stft_num_fft,
context_channels=[in_channels * spectrogram_channels],
)

super().__init__(**{**default_kwargs, **kwargs}) # type: ignore
self.stft = STFT(num_fft=vocoder_num_fft, **vocoder_kwargs)

def forward(self, x: Tensor, **kwargs) -> Tensor:
# Get magnitude and phase of true wave
magnitude, phase = self.stft.encode(x)
# Get magnitude spectrogram from true wave
magnitude, _ = self.unet.stft.encode(x)
magnitude = rearrange(magnitude, "b c f t -> b (c f) t")
phase = rearrange(phase, "b c f t -> b (c f) t")
# Get diffusion phase loss while conditioning on magnitude (/pi [-1,1] range)
return self.diffusion(phase / pi, channels_list=[magnitude], **kwargs)
# Get diffusion loss while conditioning on magnitude
return self.diffusion(x, channels_list=[magnitude], **kwargs)

def sample(self, spectrogram: Tensor, **kwargs): # type: ignore
b, c, f, t, device = *spectrogram.shape, spectrogram.device
b, c, _, t, device = *spectrogram.shape, spectrogram.device
magnitude = rearrange(spectrogram, "b c f t -> b (c f) t")
noise = torch.randn((b, c * f, t), device=device)
timesteps = closest_power_2(self.unet.stft.hop_length * t)
noise = torch.randn((b, c, timesteps), device=device)
default_kwargs = dict(channels_list=[magnitude])
phase = super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa
phase = rearrange(phase, "b (c f) t -> b c f t", c=c)
wave = self.stft.decode(spectrogram, phase * pi)
return wave
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa


class DiffusionUpphaser1d(DiffusionUpsampler1d):
def __init__(self, **kwargs):
vocoder_kwargs, kwargs = groupby_kwargs_prefix("vocoder_", kwargs)
stft_kwargs, kwargs = groupby("stft_", kwargs)
super().__init__(**kwargs)
self.stft = STFT(**vocoder_kwargs)
self.stft = STFT(**stft_kwargs)

def random_rephase(self, x: Tensor) -> Tensor:
magnitude, phase = self.stft.encode(x)
Expand Down Expand Up @@ -305,7 +295,6 @@ def get_default_model_kwargs():
use_nearest_upsample=False,
use_skip_scale=True,
use_context_time=True,
use_magnitude_channels=False,
diffusion_type="v",
diffusion_sigma_distribution=UniformDistribution(),
)
Expand Down Expand Up @@ -380,12 +369,13 @@ class AudioDiffusionVocoder(DiffusionVocoder1d):
def __init__(self, in_channels: int, **kwargs):
default_kwargs = dict(
in_channels=in_channels,
vocoder_num_fft=1023,
channels=32,
stft_num_fft=1023,
stft_hop_length=256,
channels=64,
patch_blocks=1,
patch_factor=1,
multipliers=[64, 32, 16, 8, 4, 2, 1],
factors=[1, 1, 1, 1, 1, 1],
multipliers=[48, 32, 16, 8, 8, 8, 8],
factors=[2, 2, 2, 1, 1, 1],
num_blocks=[1, 1, 1, 1, 1, 1],
attentions=[0, 0, 0, 1, 1, 1],
attention_heads=8,
Expand Down
Loading

0 comments on commit d0b206a

Please sign in to comment.