Skip to content

Commit

Permalink
feat: add true v-diffusion, separate vk-diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 17, 2022
1 parent 531dcee commit 6c9b5d4
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 50 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ y = unet(x, t) # [3, 1, 32768], compute 3 samples of ~1.5 seconds at 22050Hz wit

#### Training
```python
from audio_diffusion_pytorch import KDiffusion, VDiffusion, LogNormalDistribution, VDistribution
from audio_diffusion_pytorch import KDiffusion, LogNormalDistribution
from audio_diffusion_pytorch import VDiffusion, UniformDistribution

# Either use KDiffusion
diffusion = KDiffusion(
Expand All @@ -184,7 +185,7 @@ diffusion = KDiffusion(
# Or use VDiffusion
diffusion = VDiffusion(
net=unet,
sigma_distribution=VDistribution()
sigma_distribution=UniformDistribution()
)

x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples
Expand Down
5 changes: 4 additions & 1 deletion audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
Sampler,
Schedule,
SpanBySpanComposer,
UniformDistribution,
VDiffusion,
VDistribution,
VKDiffusion,
VKDistribution,
VSampler,
)
from .model import (
AudioDiffusionAutoencoder,
Expand Down
175 changes: 142 additions & 33 deletions audio_diffusion_pytorch/diffusion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from math import atan, pi, sqrt
from typing import Any, Callable, Optional, Tuple
from math import atan, cos, pi, sin, sqrt
from typing import Any, Callable, List, Optional, Tuple, Type

import torch
import torch.nn as nn
Expand Down Expand Up @@ -33,7 +33,12 @@ def __call__(
return normal.exp()


class VDistribution(Distribution):
class UniformDistribution(Distribution):
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
return torch.rand(num_samples, device=device)


class VKDistribution(Distribution):
def __init__(
self,
min_value: float = 0.0,
Expand Down Expand Up @@ -94,6 +99,8 @@ def to_batch(

class Diffusion(nn.Module):

alias: str = ""

"""Base diffusion class"""

def denoise_fn(
Expand All @@ -110,24 +117,19 @@ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:


class VDiffusion(Diffusion):

alias = "v"

def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
super().__init__()
self.net = net
self.sigma_distribution = sigma_distribution

def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
sigma_data = 1.0
sigmas = rearrange(sigmas, "b -> b 1 1")
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
return c_skip, c_out, c_in

def sigma_to_t(self, sigmas: Tensor) -> Tensor:
return sigmas.atan() / pi * 2

def t_to_sigma(self, t: Tensor) -> Tensor:
return (t * pi / 2).tan()
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
angle = sigmas * pi / 2
alpha = torch.cos(angle)
beta = torch.sin(angle)
return alpha, beta

def denoise_fn(
self,
Expand All @@ -138,12 +140,7 @@ def denoise_fn(
) -> Tensor:
batch_size, device = x_noisy.shape[0], x_noisy.device
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)

# Predict network output and add skip connection
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
x_denoised = c_skip * x_noisy + c_out * x_pred
return x_denoised
return self.net(x_noisy, sigmas, **kwargs)

def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
batch_size, device = x.shape[0], x.device
Expand All @@ -152,25 +149,24 @@ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
sigmas_padded = rearrange(sigmas, "b -> b 1 1")

# Add noise to input
# Get noise
noise = default(noise, lambda: torch.randn_like(x))
x_noisy = x + sigmas_padded * noise

# Compute model output
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
# Combine input and noise weighted by half-circle
alpha, beta = self.get_alpha_beta(sigmas_padded)
x_noisy = x * alpha + noise * beta
x_target = noise * alpha - x * beta

# Compute v-objective target
v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)

# Compute loss
loss = F.mse_loss(x_pred, v_target)
return loss
# Denoise and return loss
x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
return F.mse_loss(x_denoised, x_target)


class KDiffusion(Diffusion):
"""Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""

alias = "k"

def __init__(
self,
net: nn.Module,
Expand Down Expand Up @@ -235,7 +231,68 @@ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
losses = reduce(losses, "b ... -> b", "mean")
losses = losses * self.loss_weight(sigmas)
loss = losses.mean()
return loss


class VKDiffusion(Diffusion):

alias = "vk"

def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
super().__init__()
self.net = net
self.sigma_distribution = sigma_distribution

def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
sigma_data = 1.0
sigmas = rearrange(sigmas, "b -> b 1 1")
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
return c_skip, c_out, c_in

def sigma_to_t(self, sigmas: Tensor) -> Tensor:
return sigmas.atan() / pi * 2

def t_to_sigma(self, t: Tensor) -> Tensor:
return (t * pi / 2).tan()

def denoise_fn(
self,
x_noisy: Tensor,
sigmas: Optional[Tensor] = None,
sigma: Optional[float] = None,
**kwargs,
) -> Tensor:
batch_size, device = x_noisy.shape[0], x_noisy.device
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)

# Predict network output and add skip connection
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
x_denoised = c_skip * x_noisy + c_out * x_pred
return x_denoised

def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
batch_size, device = x.shape[0], x.device

# Sample amount of noise to add for each batch element
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
sigmas_padded = rearrange(sigmas, "b -> b 1 1")

# Add noise to input
noise = default(noise, lambda: torch.randn_like(x))
x_noisy = x + sigmas_padded * noise

# Compute model output
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)

# Compute v-objective target
v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)

# Compute loss
loss = F.mse_loss(x_pred, v_target)
return loss


Expand All @@ -253,6 +310,12 @@ def forward(self, num_steps: int, device: torch.device) -> Tensor:
raise NotImplementedError()


class LinearSchedule(Schedule):
def forward(self, num_steps: int, device: Any) -> Tensor:
sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
return sigmas


class KarrasSchedule(Schedule):
"""https://arxiv.org/abs/2206.00364 equation 5"""

Expand All @@ -278,6 +341,9 @@ def forward(self, num_steps: int, device: Any) -> Tensor:


class Sampler(nn.Module):

diffusion_types: List[Type[Diffusion]] = []

def forward(
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
) -> Tensor:
Expand All @@ -295,9 +361,41 @@ def inpaint(
raise NotImplementedError("Inpainting not available with current sampler")


class VSampler(Sampler):

diffusion_types = [VDiffusion]

def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
angle = sigma * pi / 2
alpha = cos(angle)
beta = sin(angle)
return alpha, beta

def forward(
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
) -> Tensor:
x = sigmas[0] * noise
alpha, beta = self.get_alpha_beta(sigmas[0].item())

for i in range(num_steps - 1):
is_last = i == num_steps - 1

x_denoised = fn(x, sigma=sigmas[i])
x_pred = x * alpha - x_denoised * beta
x_eps = x * beta + x_denoised * alpha

if not is_last:
alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
x = x_pred * alpha + x_eps * beta

return x


class KarrasSampler(Sampler):
"""https://arxiv.org/abs/2206.00364 algorithm 1"""

diffusion_types = [KDiffusion, VKDiffusion]

def __init__(
self,
s_tmin: float = 0,
Expand Down Expand Up @@ -351,6 +449,9 @@ def forward(


class AEulerSampler(Sampler):

diffusion_types = [KDiffusion, VKDiffusion]

def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
Expand Down Expand Up @@ -380,6 +481,8 @@ def forward(
class ADPM2Sampler(Sampler):
"""https://www.desmos.com/calculator/jbxjlqd9mb"""

diffusion_types = [KDiffusion, VKDiffusion]

def __init__(self, rho: float = 1.0):
super().__init__()
self.rho = rho
Expand Down Expand Up @@ -459,6 +562,12 @@ def __init__(
self.sigma_schedule = sigma_schedule
self.num_steps = num_steps

# Check sampler is compatible with diffusion type
sampler_class = sampler.__class__.__name__
diffusion_class = diffusion.__class__.__name__
message = f"{sampler_class} incompatible with {diffusion_class}"
assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message

@torch.no_grad()
def forward(
self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
Expand Down
29 changes: 16 additions & 13 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from torch import Tensor, nn

from .diffusion import (
AEulerSampler,
Diffusion,
DiffusionSampler,
KarrasSchedule,
KDiffusion,
LinearSchedule,
Sampler,
Schedule,
UniformDistribution,
VDiffusion,
VDistribution,
VKDiffusion,
VSampler,
)
from .modules import (
Bottleneck,
Expand All @@ -38,12 +38,15 @@ def __init__(
UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
self.unet = UNet(**kwargs)

if diffusion_type == "v":
self.diffusion: Diffusion = VDiffusion(net=self.unet, **diffusion_kwargs)
elif diffusion_type == "k":
self.diffusion = KDiffusion(net=self.unet, **diffusion_kwargs)
else:
raise ValueError(f"diffusion_type must be v or k, found {diffusion_type}")
# Check valid diffusion type
diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
aliases = [t.alias for t in diffusion_classes] # type: ignore
message = f"diffusion_type='{diffusion_type}' must be one of {*aliases,}"
assert diffusion_type in aliases, message

for XDiffusion in diffusion_classes:
if XDiffusion.alias == diffusion_type: # type: ignore
self.diffusion = XDiffusion(net=self.unet, **diffusion_kwargs)

def forward(self, x: Tensor, **kwargs) -> Tensor:
return self.diffusion(x, **kwargs)
Expand Down Expand Up @@ -242,14 +245,14 @@ def get_default_model_kwargs():
use_context_time=True,
use_magnitude_channels=False,
diffusion_type="v",
diffusion_sigma_distribution=VDistribution(),
diffusion_sigma_distribution=UniformDistribution(),
)


def get_default_sampling_kwargs():
return dict(
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
sampler=AEulerSampler(),
sigma_schedule=LinearSchedule(),
sampler=VSampler(),
)


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.66",
version="0.0.67",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 6c9b5d4

Please sign in to comment.