Skip to content

Commit

Permalink
feat: add pretrained dmae1d
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Nov 23, 2022
1 parent 4433c96 commit c0020d5
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 1 deletion.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,26 @@ composer = SpanBySpanComposer(
y_long = composer(y, keep_start=True) # [1, 1, 98304]
```

## Pretrained Models

### Diffusion (Magnitude) AutoEncoder ([`dmae1d-ATC64-v1`](https://huggingface.co/archinetai/dmae1d-ATC64-v1/tree/main))
```py
from audio_diffusion_pytorch import AudioModel

autoencoder = AudioModel.from_pretrained("dmae1d-ATC64-v1")

x = torch.randn(1, 2, 2**18)
z = autoencoder.encode(x) # [1, 32, 256]
y = autoencoder.decode(z, num_steps=20) # [1, 2, 262144]
```

| Info | |
| ------------- | ------------- |
| Input type | Audio (stereo @ 48kHz) |
| Number of parameters | 234.2M |
| Compression Factor | 64x |
| Downsampling Factor | 1024x |
| Bottleneck Type | Tanh |


## Experiments
Expand Down
1 change: 1 addition & 0 deletions audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
AudioDiffusionUpphaser,
AudioDiffusionUpsampler,
AudioDiffusionVocoder,
AudioModel,
DiffusionAutoencoder1d,
DiffusionMAE1d,
DiffusionUpphaser1d,
Expand Down
15 changes: 15 additions & 0 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,18 @@ def __init__(self, in_channels: int, **kwargs):

def sample(self, *args, **kwargs):
return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})


""" Pretrained Models Helper """

REVISION = {"dmae1d-ATC64-v1": "07885065867977af43b460bb9c1422bdc90c29a0"}


class AudioModel:
@staticmethod
def from_pretrained(name: str) -> nn.Module:
from transformers import AutoModel

return AutoModel.from_pretrained(
f"archinetai/{name}", trust_remote_code=True, revision=REVISION[name]
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.91",
version="0.0.92",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit c0020d5

Please sign in to comment.