diff --git a/README.md b/README.md index e8f186b..1341c43 100644 --- a/README.md +++ b/README.md @@ -241,7 +241,26 @@ composer = SpanBySpanComposer( y_long = composer(y, keep_start=True) # [1, 1, 98304] ``` +## Pretrained Models +### Diffusion (Magnitude) AutoEncoder ([`dmae1d-ATC64-v1`](https://huggingface.co/archinetai/dmae1d-ATC64-v1/tree/main)) +```py +from audio_diffusion_pytorch import AudioModel + +autoencoder = AudioModel.from_pretrained("dmae1d-ATC64-v1") + +x = torch.randn(1, 2, 2**18) +z = autoencoder.encode(x) # [1, 32, 256] +y = autoencoder.decode(z, num_steps=20) # [1, 2, 262144] +``` + +| Info | | +| ------------- | ------------- | +| Input type | Audio (stereo @ 48kHz) | +| Number of parameters | 234.2M | +| Compression Factor | 64x | +| Downsampling Factor | 1024x | +| Bottleneck Type | Tanh | ## Experiments diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index 0c3b14a..83d56f7 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -27,6 +27,7 @@ AudioDiffusionUpphaser, AudioDiffusionUpsampler, AudioDiffusionVocoder, + AudioModel, DiffusionAutoencoder1d, DiffusionMAE1d, DiffusionUpphaser1d, diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index 53f831b..7f71754 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -500,3 +500,18 @@ def __init__(self, in_channels: int, **kwargs): def sample(self, *args, **kwargs): return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs}) + + +""" Pretrained Models Helper """ + +REVISION = {"dmae1d-ATC64-v1": "07885065867977af43b460bb9c1422bdc90c29a0"} + + +class AudioModel: + @staticmethod + def from_pretrained(name: str) -> nn.Module: + from transformers import AutoModel + + return AutoModel.from_pretrained( + f"archinetai/{name}", trust_remote_code=True, revision=REVISION[name] + ) diff --git a/setup.py b/setup.py index 9643a1c..b671ac9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.91", + version="0.0.92", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",