diff --git a/audio_diffusion_pytorch/models.py b/audio_diffusion_pytorch/models.py index 6247187..c3f9cc2 100644 --- a/audio_diffusion_pytorch/models.py +++ b/audio_diffusion_pytorch/models.py @@ -8,7 +8,15 @@ from .components import AppendChannelsPlugin, MelSpectrogram from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler -from .utils import closest_power_2, default, downsample, groupby, randn_like, upsample +from .utils import ( + closest_power_2, + default, + downsample, + exists, + groupby, + randn_like, + upsample, +) class DiffusionModel(nn.Module): @@ -46,6 +54,18 @@ def __init__(self): self.downsample_factor = None +class AdapterBase(nn.Module, ABC): + """Abstract class for DiffusionAE encoder""" + + @abstractmethod + def encode(self, x: Tensor) -> Tensor: + pass + + @abstractmethod + def decode(self, x: Tensor) -> Tensor: + pass + + class DiffusionAE(DiffusionModel): """Diffusion Auto Encoder""" @@ -55,6 +75,8 @@ def __init__( channels: Sequence[int], encoder: EncoderBase, inject_depth: int, + latent_factor: Optional[int] = None, + adapter: Optional[AdapterBase] = None, **kwargs, ): context_channels = [0] * len(channels) @@ -68,12 +90,19 @@ def __init__( self.in_channels = in_channels self.encoder = encoder self.inject_depth = inject_depth + # Optional custom latent factor and adapter + self.latent_factor = default(latent_factor, self.encoder.downsample_factor) + self.adapter = adapter.requires_grad_(False) if exists(adapter) else None def forward( # type: ignore self, x: Tensor, with_info: bool = False, **kwargs ) -> Union[Tensor, Tuple[Tensor, Any]]: + # Encode input to latent channels latent, info = self.encode(x, with_info=True) channels = [None] * self.inject_depth + [latent] + # Adapt input to diffusion if adapter provided + x = self.adapter.encode(x) if exists(self.adapter) else x + # Compute diffusion loss loss = super().forward(x, channels=channels, **kwargs) return (loss, info) if with_info else loss @@ -85,10 +114,10 @@ def decode( self, latent: Tensor, generator: Optional[Generator] = None, **kwargs ) -> Tensor: b = latent.shape[0] - length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) + noise_length = closest_power_2(latent.shape[2] * self.latent_factor) # Compute noise by inferring shape from latent length noise = torch.randn( - (b, self.in_channels, length), + (b, self.in_channels, noise_length), device=latent.device, dtype=latent.dtype, generator=generator, @@ -96,7 +125,9 @@ def decode( # Compute context from latent channels = [None] * self.inject_depth + [latent] # type: ignore # Decode by sampling while conditioning on latent channels - return super().sample(noise, channels=channels, **kwargs) + out = super().sample(noise, channels=channels, **kwargs) + # Decode output with adapter if provided + return self.adapter.decode(out) if exists(self.adapter) else out class DiffusionUpsampler(DiffusionModel): diff --git a/setup.py b/setup.py index 3c98e62..a31a95a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.1.1", + version="0.1.2", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",