Skip to content

Commit

Permalink
feat: add diffusion vocoder, add diffusion upphaser
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 28, 2022
1 parent 22e5d75 commit 21014f9
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 18 deletions.
4 changes: 4 additions & 0 deletions audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@
AudioDiffusionAutoencoder,
AudioDiffusionConditional,
AudioDiffusionModel,
AudioDiffusionUpphaser,
AudioDiffusionUpsampler,
AudioDiffusionVocoder,
DiffusionAutoencoder1d,
DiffusionUpphaser1d,
DiffusionUpsampler1d,
DiffusionVocoder1d,
Model1d,
)
from .modules import (
Expand Down
105 changes: 105 additions & 0 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from math import pi
from typing import Any, Optional, Sequence, Tuple, Union

import torch
from einops import rearrange
from torch import Tensor, nn

from .diffusion import (
Expand All @@ -15,6 +17,7 @@
VSampler,
)
from .modules import (
STFT,
Bottleneck,
MultiEncoder1d,
SinusoidalEmbedding,
Expand Down Expand Up @@ -223,6 +226,62 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor:
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore


class DiffusionVocoder1d(Model1d):
def __init__(
self,
in_channels: int,
vocoder_num_fft: int,
**kwargs,
):
self.frequency_channels = vocoder_num_fft // 2 + 1
spectrogram_channels = in_channels * self.frequency_channels

vocoder_kwargs, kwargs = groupby_kwargs_prefix("vocoder_", kwargs)
default_kwargs = dict(
in_channels=spectrogram_channels, context_channels=[spectrogram_channels]
)

super().__init__(**{**default_kwargs, **kwargs}) # type: ignore
self.stft = STFT(num_fft=vocoder_num_fft, **vocoder_kwargs)

def forward(self, x: Tensor, **kwargs) -> Tensor:
# Get magnitude and phase of true wave
magnitude, phase = self.stft.encode(x)
magnitude = rearrange(magnitude, "b c f t -> b (c f) t")
phase = rearrange(phase, "b c f t -> b (c f) t")
# Get diffusion phase loss while conditioning on magnitude (/pi [-1,1] range)
return self.diffusion(phase / pi, channels_list=[magnitude], **kwargs)

def sample(self, spectrogram: Tensor, **kwargs): # type: ignore
b, c, f, t, device = *spectrogram.shape, spectrogram.device
magnitude = rearrange(spectrogram, "b c f t -> b (c f) t")
noise = torch.randn((b, c * f, t), device=device)
default_kwargs = dict(channels_list=[magnitude])
phase = super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa
phase = rearrange(phase, "b (c f) t -> b c f t", c=c)
wave = self.stft.decode(spectrogram, phase * pi)
return wave


class DiffusionUpphaser1d(DiffusionUpsampler1d):
def __init__(self, **kwargs):
vocoder_kwargs, kwargs = groupby_kwargs_prefix("vocoder_", kwargs)
super().__init__(**kwargs)
self.stft = STFT(**vocoder_kwargs)

def random_rephase(self, x: Tensor) -> Tensor:
magnitude, phase = self.stft.encode(x)
phase_random = (torch.rand_like(phase) - 0.5) * 2 * pi
wave = self.stft.decode(magnitude, phase_random)
return wave

def forward(self, x: Tensor, **kwargs) -> Tensor:
rephased = self.random_rephase(x)
resampled, factors = self.random_reupsample(rephased)
features = self.to_features(factors) if self.use_conditioning else None
return self.diffusion(x, channels_list=[resampled], features=features, **kwargs)


"""
Audio Diffusion Classes (specific for 1d audio data)
"""
Expand Down Expand Up @@ -315,3 +374,49 @@ def sample(self, *args, **kwargs):
embedding_scale=5.0,
)
return super().sample(*args, **{**default_kwargs, **kwargs})


class AudioDiffusionVocoder(DiffusionVocoder1d):
def __init__(self, in_channels: int, **kwargs):
default_kwargs = dict(
in_channels=in_channels,
vocoder_num_fft=1023,
channels=32,
patch_blocks=1,
patch_factor=1,
multipliers=[64, 32, 16, 8, 4, 2, 1],
factors=[1, 1, 1, 1, 1, 1],
num_blocks=[1, 1, 1, 1, 1, 1],
attentions=[0, 0, 0, 1, 1, 1],
attention_heads=8,
attention_features=64,
attention_multiplier=2,
attention_use_rel_pos=False,
resnet_groups=8,
kernel_multiplier_downsample=2,
use_nearest_upsample=False,
use_skip_scale=True,
use_context_time=True,
use_magnitude_channels=False,
diffusion_type="v",
diffusion_sigma_distribution=UniformDistribution(),
)
super().__init__(**{**default_kwargs, **kwargs}) # type: ignore

def sample(self, *args, **kwargs):
default_kwargs = dict(**get_default_sampling_kwargs())
return super().sample(*args, **{**default_kwargs, **kwargs})


class AudioDiffusionUpphaser(DiffusionUpphaser1d):
def __init__(self, in_channels: int, **kwargs):
default_kwargs = dict(
**get_default_model_kwargs(),
in_channels=in_channels,
context_channels=[in_channels],
factor=1,
)
super().__init__(**{**default_kwargs, **kwargs}) # type: ignore

def sample(self, *args, **kwargs):
return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
28 changes: 14 additions & 14 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math
from math import pi
from math import floor, log, pi
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
Expand All @@ -9,7 +8,7 @@
from einops_exts import rearrange_many
from torch import Tensor, einsum

from .utils import default, exists, prod, to_list
from .utils import closest_power_2, default, exists, prod, to_list

"""
Utils
Expand Down Expand Up @@ -338,7 +337,7 @@ def _relative_position_bucket(
max_exact
+ (
torch.log(n.float() / max_exact)
/ math.log(max_distance / max_exact)
/ log(max_distance / max_exact)
* (num_buckets - max_exact)
).long()
)
Expand Down Expand Up @@ -587,7 +586,7 @@ def __init__(self, dim: int):

def forward(self, x: Tensor) -> Tensor:
device, half_dim = x.device, self.dim // 2
emb = torch.tensor(math.log(10000) / (half_dim - 1), device=device)
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
return torch.cat((emb.sin(), emb.cos()), dim=-1)
Expand Down Expand Up @@ -1692,17 +1691,17 @@ def decode(self, latent: Tensor) -> List[Tensor]:
class STFT(nn.Module):
def __init__(
self,
length: int,
num_fft: int = 1024,
hop_length: int = 256,
window_length: int = 1024,
num_fft: int = 1023,
hop_length: Optional[int] = None,
window_length: Optional[int] = None,
length: Optional[int] = None,
):
super().__init__()
self.num_fft = num_fft
self.hop_length = hop_length
self.window_length = window_length
self.hop_length = default(hop_length, floor(num_fft // 4))
self.window_length = default(window_length, num_fft)
self.length = length
self.register_buffer("window", torch.hann_window(window_length))
self.register_buffer("window", torch.hann_window(self.window_length))

def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
b = wave.shape[0]
Expand All @@ -1725,19 +1724,20 @@ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
return mag, phase

def decode(self, magnitude: Tensor, phase: Tensor) -> Tensor:
b = magnitude.shape[0]
b, l = magnitude.shape[0], magnitude.shape[-1] # noqa
assert magnitude.shape == phase.shape, "magnitude and phase must be same shape"
real = rearrange(magnitude * torch.cos(phase), "b c f l -> (b c) f l")
imag = rearrange(magnitude * torch.sin(phase), "b c f l -> (b c) f l")
stft = torch.stack([real, imag], dim=-1)
length = closest_power_2(l * self.hop_length)

wave = torch.istft(
stft,
n_fft=self.num_fft,
hop_length=self.hop_length,
win_length=self.window_length,
window=self.window, # type: ignore
length=self.length,
length=default(self.length, length),
)
wave = rearrange(wave, "(b c) t -> b c t", b=b)
return wave
Expand Down
13 changes: 10 additions & 3 deletions audio_diffusion_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from functools import reduce
from inspect import isfunction
from math import ceil, floor, log2, pi
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union

import torch
Expand Down Expand Up @@ -42,6 +42,13 @@ def prod(vals: Sequence[int]) -> int:
return reduce(lambda x, y: x * y, vals)


def closest_power_2(x: float) -> int:
exponent = log2(x)
distance_fn = lambda z: abs(x - 2 ** z) # noqa
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
return 2 ** int(exponent_closest)


"""
Kwargs Utils
"""
Expand Down Expand Up @@ -79,10 +86,10 @@ def resample(
d = dict(device=waveforms.device, dtype=waveforms.dtype)

base_factor = min(factor_in, factor_out) * rolloff
width = math.ceil(lowpass_filter_width * factor_in / base_factor)
width = ceil(lowpass_filter_width * factor_in / base_factor)
idx = torch.arange(-width, width + factor_in, **d)[None, None] / factor_in # type: ignore # noqa
t = torch.arange(0, -factor_out, step=-1, **d)[:, None, None] / factor_out + idx # type: ignore # noqa
t = (t * base_factor).clamp(-lowpass_filter_width, lowpass_filter_width) * math.pi
t = (t * base_factor).clamp(-lowpass_filter_width, lowpass_filter_width) * pi

window = torch.cos(t / lowpass_filter_width / 2) ** 2
scale = base_factor / factor_in
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.79",
version="0.0.80",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 21014f9

Please sign in to comment.