Skip to content

Commit

Permalink
Encodec (#991)
Browse files Browse the repository at this point in the history
* initial encodec

* works

* nits

* use fast group norm

* fix for rnn layer

* fix mlx version

* use custom LSTM kernel

* audio encodec

* fix example, support batched inference

* nits
  • Loading branch information
awni authored Sep 23, 2024
1 parent 796d5e4 commit 9bb2dd6
Show file tree
Hide file tree
Showing 10 changed files with 1,267 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Some more useful examples are listed below.
### Audio Models

- Speech recognition with [OpenAI's Whisper](whisper).
- Audio compression and generation with [Meta's EnCodec](encodec).

### Multimodal models

Expand Down
83 changes: 83 additions & 0 deletions encodec/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# EnCodec

An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and
generate audio.

### Setup

Install the requirements:

```
pip install -r requirements.txt
```

Optionally install FFmpeg and SciPy for loading and saving audio files,
respectively.

Install [FFmpeg](https://ffmpeg.org/):

```
# on macOS using Homebrew (https://brew.sh/)
brew install ffmpeg
```

Install SciPy:

```
pip install scipy
```

### Example

An example using the model:

```python
import mlx.core as mx
from utils import load, load_audio, save_audio

# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")

# Load an audio file
audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)

# Preprocess the audio (this can also be a list of arrays for batched
# processing).
feats, mask = processor(audio)

# Encode at the given bandwidth. A lower bandwidth results in more
# compression but lower reconstruction quality.
@mx.compile
def encode(feats, mask):
return model.encode(feats, mask, bandwidth=3)

# Decode to reconstruct the audio
@mx.compile
def decode(codes, scales, mask):
return model.decode(codes, scales, mask)


codes, scales = encode(feats, mask)
reconstructed = decode(codes, scales, mask)

# Trim any padding:
reconstructed = reconstructed[0, : len(audio)]

# Save the audio as a wave file
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
```

The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the
[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164)
in several data types.

### Optional

To convert models, use the `convert.py` script. To see the options, run:

```bash
python convert.py -h
```

[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and
[code](https://github.com/facebookresearch/encodec) for more details.
30 changes: 30 additions & 0 deletions encodec/benchmarks/bench_mx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright © 2024 Apple Inc.

import time

import mlx.core as mx
from utils import load

model, processor = load("mlx-community/encodec-48khz-float32")

audio = mx.random.uniform(shape=(288000, 2))
feats, mask = processor(audio)
mx.eval(model, feats, mask)


@mx.compile
def fun():
codes, scales = model.encode(feats, mask, bandwidth=3)
reconstructed = model.decode(codes, scales, mask)
return reconstructed


for _ in range(5):
mx.eval(fun())

tic = time.time()
for _ in range(10):
mx.eval(fun())
toc = time.time()
ms = 1000 * (toc - tic) / 10
print(f"Time per it: {ms:.3f}")
34 changes: 34 additions & 0 deletions encodec/benchmarks/bench_pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright © 2024 Apple Inc.

import time

import numpy as np
import torch
from transformers import AutoProcessor, EncodecModel

processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
audio = np.random.uniform(size=(2, 288000)).astype(np.float32)

pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps")
pt_inputs = processor(
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
).to("mps")


def fun():
pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"])
pt_audio = pt_model.decode(
pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"]
)
torch.mps.synchronize()


for _ in range(5):
fun()

tic = time.time()
for _ in range(10):
fun()
toc = time.time()
ms = 1000 * (toc - tic) / 10
print(f"Time per it: {ms:.3f}")
213 changes: 213 additions & 0 deletions encodec/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Copyright © 2024 Apple Inc.

import argparse
import json
from pathlib import Path
from textwrap import dedent
from types import SimpleNamespace
from typing import Any, Dict, Union

import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten

import encodec


def fetch_from_hub(hf_repo: str) -> Path:
model_path = Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.json", "*.safetensors"],
)
)
return model_path


def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os

from huggingface_hub import HfApi, ModelCard, logging

content = dedent(
f"""
---
language: en
license: other
library: mlx
tags:
- mlx
---
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
converted to MLX format from
[{hf_path}](https://huggingface.co/{hf_path}).
This model is intended to be used with the [EnCodec MLX
example](https://github.com/ml-explore/mlx-examples/tree/main/encodec).
"""
)

card = ModelCard(content)
card.save(os.path.join(path, "README.md"))

logging.set_verbosity_info()

api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")


def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)

total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
mx.save_safetensors(
str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"}
)

for weight_name in weights.keys():
index_data["weight_map"][weight_name] = "model.safetensors"

index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}

with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(index_data, f, indent=4)


def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Clean unused keys
config.pop("_name_or_path", None)

# sort the config for better readability
config = dict(sorted(config.items()))

# write the updated config to the config_path (if provided)
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)


def convert(
upload: bool,
model: str,
dtype: str = None,
):
hf_repo = f"facebook/encodec_{model}"
mlx_repo = f"mlx-community/encodec-{model}-{dtype}"
path = fetch_from_hub(hf_repo)
save_path = Path("mlx_models")

weights = mx.load(str(Path(path) / "model.safetensors"))

with open(path / "config.json", "r") as fid:
config = SimpleNamespace(**json.load(fid))

model = encodec.EncodecModel(config)

new_weights = {}
for k, v in weights.items():
basename, pname = k.rsplit(".", 1)
if pname == "weight_v":
g = weights[basename + ".weight_g"]
v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True))
k = basename + ".weight"
elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]:
continue
elif "lstm" in basename:
w_or_b, ih_or_hh, ln = pname.split("_")
if w_or_b == "weight":
new_pname = "Wx" if ih_or_hh == "ih" else "Wh"
elif w_or_b == "bias" and ih_or_hh == "ih":
continue
else:
v = v + weights[k.replace("_hh_", "_ih_")]
new_pname = "bias"
k = basename + "." + ln[1:] + "." + new_pname
if "conv.weight" in k:
# Possibly a transposed conv which has a different order
if "decoder" in k:
ln = int(k.split(".")[2])
if "conv" in model.decoder.layers[ln] and isinstance(
model.decoder.layers[ln].conv, nn.ConvTranspose1d
):
v = mx.moveaxis(v, 0, 2)
else:
v = mx.moveaxis(v, 1, 2)
else:
v = mx.moveaxis(v, 1, 2)

new_weights[k] = v
weights = new_weights

model.load_weights(list(weights.items()))

if dtype is not None:
t = getattr(mx, dtype)
weights = {k: v.astype(t) for k, v in weights.items()}

if isinstance(save_path, str):
save_path = Path(save_path)

save_weights(save_path, weights)

save_config(vars(config), config_path=save_path / "config.json")

if upload:
upload_to_hub(save_path, mlx_repo, hf_repo)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.")
parser.add_argument(
"--model",
type=str,
default="48khz",
help="",
choices=["24khz", "32khz", "48khz"],
)
parser.add_argument(
"--upload",
action="store_true",
help="Upload the weights to Hugging Face.",
)
parser.add_argument(
"--dtype",
type=str,
help="Data type to convert the model to.",
default="float32",
choices=["float32", "bfloat16", "float16"],
)
args = parser.parse_args()
convert(upload=args.upload, model=args.model, dtype=args.dtype)
Loading

0 comments on commit 9bb2dd6

Please sign in to comment.