diff --git a/README.md b/README.md index 1341c43..c356395 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ Unconditional audio generation using diffusion models, in PyTorch. The goal of this repository is to explore different architectures and diffusion models to generate audio (speech and music) directly from/to the waveform. -Progress will be documented in the [experiments](#experiments) section. You can use the [`audio-diffusion-pytorch-trainer`](https://github.com/archinetai/audio-diffusion-pytorch-trainer) to run your own experiments – please share your findings in the [discussions](https://github.com/archinetai/audio-diffusion-pytorch/discussions) page! +Progress will be documented in the [experiments](#experiments) section. You can use the [`audio-diffusion-pytorch-trainer`](https://github.com/archinetai/audio-diffusion-pytorch-trainer) to run your own experiments – please share your findings in the [discussions](https://github.com/archinetai/audio-diffusion-pytorch/discussions) page! Pretrained models can be found at [`archisound`](https://github.com/archinetai/archisound). ## Install @@ -241,27 +241,6 @@ composer = SpanBySpanComposer( y_long = composer(y, keep_start=True) # [1, 1, 98304] ``` -## Pretrained Models - -### Diffusion (Magnitude) AutoEncoder ([`dmae1d-ATC64-v1`](https://huggingface.co/archinetai/dmae1d-ATC64-v1/tree/main)) -```py -from audio_diffusion_pytorch import AudioModel - -autoencoder = AudioModel.from_pretrained("dmae1d-ATC64-v1") - -x = torch.randn(1, 2, 2**18) -z = autoencoder.encode(x) # [1, 32, 256] -y = autoencoder.decode(z, num_steps=20) # [1, 2, 262144] -``` - -| Info | | -| ------------- | ------------- | -| Input type | Audio (stereo @ 48kHz) | -| Number of parameters | 234.2M | -| Compression Factor | 64x | -| Downsampling Factor | 1024x | -| Bottleneck Type | Tanh | - ## Experiments diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index 83d56f7..3443f49 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -27,7 +27,7 @@ AudioDiffusionUpphaser, AudioDiffusionUpsampler, AudioDiffusionVocoder, - AudioModel, + DiffusionAR1d, DiffusionAutoencoder1d, DiffusionMAE1d, DiffusionUpphaser1d, diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index 7f71754..9e6eb0c 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -1,13 +1,15 @@ from math import pi +from random import randint from typing import Any, Optional, Sequence, Tuple, Union import torch from audio_encoders_pytorch import Bottleneck, Encoder1d from einops import rearrange from torch import Tensor, nn +from tqdm import tqdm from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion -from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d +from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d, rand_bool from .utils import ( closest_power_2, default, @@ -355,6 +357,105 @@ def forward(self, x: Tensor, **kwargs) -> Tensor: return self.diffusion(x, channels_list=[resampled], features=features, **kwargs) +class DiffusionAR1d(Model1d): + def __init__( + self, + in_channels: int, + chunk_length: int, + upsample: int = 0, + dropout: float = 0.05, + verbose: int = 0, + **kwargs, + ): + self.in_channels = in_channels + self.chunk_length = chunk_length + self.dropout = dropout + self.upsample = upsample + self.verbose = verbose + super().__init__( + in_channels=in_channels, + context_channels=[in_channels * (2 if upsample > 0 else 1)], + **kwargs, + ) + + def reupsample(self, x: Tensor) -> Tensor: + x = x.clone() + x = downsample(x, factor=self.upsample) + x = upsample(x, factor=self.upsample) + return x + + def forward(self, x: Tensor, **kwargs) -> Tensor: + b, _, t, device = *x.shape, x.device + cl, num_chunks = self.chunk_length, t // self.chunk_length + assert num_chunks >= 2, "Input tensor length must be >= chunk_length * 2" + + # Get prev and current target chunks + chunk_index = randint(0, num_chunks - 2) + chunk_pos = cl * (chunk_index + 1) + chunk_prev = x[:, :, cl * chunk_index : chunk_pos] + chunk_curr = x[:, :, chunk_pos : cl * (chunk_index + 2)] + + # Randomly dropout source chunks to allow for zero AR start + if self.dropout > 0: + batch_mask = rand_bool(shape=(b, 1, 1), proba=self.dropout, device=device) + chunk_zeros = torch.zeros_like(chunk_prev) + chunk_prev = torch.where(batch_mask, chunk_zeros, chunk_prev) + + # Condition on previous chunk and reupsampled current if required + if self.upsample > 0: + chunk_reupsampled = self.reupsample(chunk_curr) + channels_list = [torch.cat([chunk_prev, chunk_reupsampled], dim=1)] + else: + channels_list = [chunk_prev] + + # Diffuse current current chunk + return self.diffusion(chunk_curr, channels_list=channels_list, **kwargs) + + def sample(self, x: Tensor, start: Optional[Tensor] = None, **kwargs) -> Tensor: # type: ignore # noqa + noise = x + + if self.upsample > 0: + # In this case we assume that x is the downsampled audio instead of noise + upsampled = upsample(x, factor=self.upsample) + noise = torch.randn_like(upsampled) + + b, c, t, device = *noise.shape, noise.device + cl, num_chunks = self.chunk_length, t // self.chunk_length + assert c == self.in_channels + assert t % cl == 0, "noise must be divisible by chunk_length" + + # Initialize previous chunk + if exists(start): + chunk_prev = start[:, :, -cl:] + else: + chunk_prev = torch.zeros(b, c, cl).to(device) + + # Computed chunks + chunks = [] + + for i in tqdm(range(num_chunks), disable=(self.verbose == 0)): + # Chunk noise + chunk_start, chunk_end = cl * i, cl * (i + 1) + noise_curr = noise[:, :, chunk_start:chunk_end] + + # Condition on previous chunk and artifically upsampled current if required + if self.upsample > 0: + chunk_upsampled = upsampled[:, :, chunk_start:chunk_end] + channels_list = [torch.cat([chunk_prev, chunk_upsampled], dim=1)] + else: + channels_list = [chunk_prev] + default_kwargs = dict(channels_list=channels_list) + + # Sample current chunk + chunk_curr = super().sample(noise_curr, **{**default_kwargs, **kwargs}) + + # Save chunk and use current as prev + chunks += [chunk_curr] + chunk_prev = chunk_curr + + return rearrange(chunks, "l b c t -> b c (l t)") + + """ Audio Diffusion Classes (specific for 1d audio data) """ @@ -363,7 +464,7 @@ def forward(self, x: Tensor, **kwargs) -> Tensor: def get_default_model_kwargs(): return dict( channels=128, - patch_factor=16, + patch_size=16, multipliers=[1, 2, 4, 4, 4, 4, 4], factors=[4, 4, 4, 2, 2, 2], num_blocks=[2, 2, 2, 2, 2, 2], @@ -500,18 +601,3 @@ def __init__(self, in_channels: int, **kwargs): def sample(self, *args, **kwargs): return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs}) - - -""" Pretrained Models Helper """ - -REVISION = {"dmae1d-ATC64-v1": "07885065867977af43b460bb9c1422bdc90c29a0"} - - -class AudioModel: - @staticmethod - def from_pretrained(name: str) -> nn.Module: - from transformers import AutoModel - - return AutoModel.from_pretrained( - f"archinetai/{name}", trust_remote_code=True, revision=REVISION[name] - ) diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index 12e695b..017ef48 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -207,12 +207,12 @@ def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor: return h + self.to_out(x) -class PatchBlock(nn.Module): +class Patcher(nn.Module): def __init__( self, in_channels: int, out_channels: int, - patch_size: int = 2, + patch_size: int, context_mapping_features: Optional[int] = None, ): super().__init__() @@ -223,7 +223,7 @@ def __init__( self.block = ResnetBlock1d( in_channels=in_channels, out_channels=out_channels // patch_size, - num_groups=min(patch_size, in_channels), + num_groups=1, context_mapping_features=context_mapping_features, ) @@ -233,12 +233,12 @@ def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor: return x -class UnpatchBlock(nn.Module): +class Unpatcher(nn.Module): def __init__( self, in_channels: int, out_channels: int, - patch_size: int = 2, + patch_size: int, context_mapping_features: Optional[int] = None, ): super().__init__() @@ -249,7 +249,7 @@ def __init__( self.block = ResnetBlock1d( in_channels=in_channels // patch_size, out_channels=out_channels, - num_groups=min(patch_size, out_channels), + num_groups=1, context_mapping_features=context_mapping_features, ) @@ -259,56 +259,6 @@ def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor: return x -class Patcher(ConditionedSequential): - def __init__( - self, - in_channels: int, - out_channels: int, - blocks: int, - factor: int, - context_mapping_features: Optional[int] = None, - ): - channels_pre = [in_channels * (factor ** i) for i in range(blocks)] - channels_post = [in_channels * (factor ** (i + 1)) for i in range(blocks - 1)] - channels_post += [out_channels] - - super().__init__( - PatchBlock( - in_channels=channels_pre[i], - out_channels=channels_post[i], - patch_size=factor, - context_mapping_features=context_mapping_features, - ) - for i in range(blocks) - ) - - -class Unpatcher(ConditionedSequential): - def __init__( - self, - in_channels: int, - out_channels: int, - blocks: int, - factor: int, - context_mapping_features: Optional[int] = None, - ): - channels_pre = [in_channels] - channels_pre += [ - out_channels * (factor ** (i + 1)) for i in reversed(range(blocks - 1)) - ] - channels_post = [out_channels * (factor ** i) for i in reversed(range(blocks))] - - super().__init__( - UnpatchBlock( - in_channels=channels_pre[i], - out_channels=channels_post[i], - patch_size=factor, - context_mapping_features=context_mapping_features, - ) - for i in range(blocks) - ) - - """ Attention Components """ @@ -927,8 +877,7 @@ def __init__( factors: Sequence[int], num_blocks: Sequence[int], attentions: Sequence[int], - patch_blocks: int = 1, - patch_factor: int = 1, + patch_size: int = 1, resnet_groups: int = 8, use_context_time: bool = True, kernel_multiplier_downsample: int = 2, @@ -1013,11 +962,12 @@ def __init__( assert exists(in_channels) and exists(out_channels) self.stft = STFT(**stft_kwargs) + assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" + self.to_in = Patcher( in_channels=in_channels + context_channels[0], out_channels=channels * multipliers[0], - blocks=patch_blocks, - factor=patch_factor, + patch_size=patch_size, context_mapping_features=context_mapping_features, ) @@ -1076,8 +1026,7 @@ def __init__( self.to_out = Unpatcher( in_channels=channels * multipliers[0], out_channels=out_channels, - blocks=patch_blocks, - factor=patch_factor, + patch_size=patch_size, context_mapping_features=context_mapping_features, ) diff --git a/setup.py b/setup.py index b671ac9..8abafd1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.92", + version="0.0.93", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown", @@ -12,6 +12,7 @@ url="https://github.com/archinetai/audio-diffusion-pytorch", keywords=["artificial intelligence", "deep learning", "audio generation"], install_requires=[ + "tqdm", "torch>=1.6", "data-science-types>=0.2", "einops>=0.4",