Skip to content

Commit

Permalink
feat: add adapter option to DiffusionAE
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Jan 27, 2023
1 parent a34014f commit 514b8f7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
39 changes: 35 additions & 4 deletions audio_diffusion_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@

from .components import AppendChannelsPlugin, MelSpectrogram
from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler
from .utils import closest_power_2, default, downsample, groupby, randn_like, upsample
from .utils import (
closest_power_2,
default,
downsample,
exists,
groupby,
randn_like,
upsample,
)


class DiffusionModel(nn.Module):
Expand Down Expand Up @@ -46,6 +54,18 @@ def __init__(self):
self.downsample_factor = None


class AdapterBase(nn.Module, ABC):
"""Abstract class for DiffusionAE encoder"""

@abstractmethod
def encode(self, x: Tensor) -> Tensor:
pass

@abstractmethod
def decode(self, x: Tensor) -> Tensor:
pass


class DiffusionAE(DiffusionModel):
"""Diffusion Auto Encoder"""

Expand All @@ -55,6 +75,8 @@ def __init__(
channels: Sequence[int],
encoder: EncoderBase,
inject_depth: int,
latent_factor: Optional[int] = None,
adapter: Optional[AdapterBase] = None,
**kwargs,
):
context_channels = [0] * len(channels)
Expand All @@ -68,12 +90,19 @@ def __init__(
self.in_channels = in_channels
self.encoder = encoder
self.inject_depth = inject_depth
# Optional custom latent factor and adapter
self.latent_factor = default(latent_factor, self.encoder.downsample_factor)
self.adapter = adapter.requires_grad_(False) if exists(adapter) else None

def forward( # type: ignore
self, x: Tensor, with_info: bool = False, **kwargs
) -> Union[Tensor, Tuple[Tensor, Any]]:
# Encode input to latent channels
latent, info = self.encode(x, with_info=True)
channels = [None] * self.inject_depth + [latent]
# Adapt input to diffusion if adapter provided
x = self.adapter.encode(x) if exists(self.adapter) else x
# Compute diffusion loss
loss = super().forward(x, channels=channels, **kwargs)
return (loss, info) if with_info else loss

Expand All @@ -85,18 +114,20 @@ def decode(
self, latent: Tensor, generator: Optional[Generator] = None, **kwargs
) -> Tensor:
b = latent.shape[0]
length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor)
noise_length = closest_power_2(latent.shape[2] * self.latent_factor)
# Compute noise by inferring shape from latent length
noise = torch.randn(
(b, self.in_channels, length),
(b, self.in_channels, noise_length),
device=latent.device,
dtype=latent.dtype,
generator=generator,
)
# Compute context from latent
channels = [None] * self.inject_depth + [latent] # type: ignore
# Decode by sampling while conditioning on latent channels
return super().sample(noise, channels=channels, **kwargs)
out = super().sample(noise, channels=channels, **kwargs)
# Decode output with adapter if provided
return self.adapter.decode(out) if exists(self.adapter) else out


class DiffusionUpsampler(DiffusionModel):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.1.1",
version="0.1.2",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 514b8f7

Please sign in to comment.