-
Notifications
You must be signed in to change notification settings - Fork 921
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial encodec * works * nits * use fast group norm * fix for rnn layer * fix mlx version * use custom LSTM kernel * audio encodec * fix example, support batched inference * nits
- Loading branch information
Showing
10 changed files
with
1,267 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# EnCodec | ||
|
||
An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and | ||
generate audio. | ||
|
||
### Setup | ||
|
||
Install the requirements: | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
Optionally install FFmpeg and SciPy for loading and saving audio files, | ||
respectively. | ||
|
||
Install [FFmpeg](https://ffmpeg.org/): | ||
|
||
``` | ||
# on macOS using Homebrew (https://brew.sh/) | ||
brew install ffmpeg | ||
``` | ||
|
||
Install SciPy: | ||
|
||
``` | ||
pip install scipy | ||
``` | ||
|
||
### Example | ||
|
||
An example using the model: | ||
|
||
```python | ||
import mlx.core as mx | ||
from utils import load, load_audio, save_audio | ||
|
||
# Load the 48 KHz model and preprocessor. | ||
model, processor = load("mlx-community/encodec-48khz-float32") | ||
|
||
# Load an audio file | ||
audio = load_audio("path/to/aduio", model.sampling_rate, model.channels) | ||
|
||
# Preprocess the audio (this can also be a list of arrays for batched | ||
# processing). | ||
feats, mask = processor(audio) | ||
|
||
# Encode at the given bandwidth. A lower bandwidth results in more | ||
# compression but lower reconstruction quality. | ||
@mx.compile | ||
def encode(feats, mask): | ||
return model.encode(feats, mask, bandwidth=3) | ||
|
||
# Decode to reconstruct the audio | ||
@mx.compile | ||
def decode(codes, scales, mask): | ||
return model.decode(codes, scales, mask) | ||
|
||
|
||
codes, scales = encode(feats, mask) | ||
reconstructed = decode(codes, scales, mask) | ||
|
||
# Trim any padding: | ||
reconstructed = reconstructed[0, : len(audio)] | ||
|
||
# Save the audio as a wave file | ||
save_audio("reconstructed.wav", reconstructed, model.sampling_rate) | ||
``` | ||
|
||
The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the | ||
[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164) | ||
in several data types. | ||
|
||
### Optional | ||
|
||
To convert models, use the `convert.py` script. To see the options, run: | ||
|
||
```bash | ||
python convert.py -h | ||
``` | ||
|
||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and | ||
[code](https://github.com/facebookresearch/encodec) for more details. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright © 2024 Apple Inc. | ||
|
||
import time | ||
|
||
import mlx.core as mx | ||
from utils import load | ||
|
||
model, processor = load("mlx-community/encodec-48khz-float32") | ||
|
||
audio = mx.random.uniform(shape=(288000, 2)) | ||
feats, mask = processor(audio) | ||
mx.eval(model, feats, mask) | ||
|
||
|
||
@mx.compile | ||
def fun(): | ||
codes, scales = model.encode(feats, mask, bandwidth=3) | ||
reconstructed = model.decode(codes, scales, mask) | ||
return reconstructed | ||
|
||
|
||
for _ in range(5): | ||
mx.eval(fun()) | ||
|
||
tic = time.time() | ||
for _ in range(10): | ||
mx.eval(fun()) | ||
toc = time.time() | ||
ms = 1000 * (toc - tic) / 10 | ||
print(f"Time per it: {ms:.3f}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright © 2024 Apple Inc. | ||
|
||
import time | ||
|
||
import numpy as np | ||
import torch | ||
from transformers import AutoProcessor, EncodecModel | ||
|
||
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") | ||
audio = np.random.uniform(size=(2, 288000)).astype(np.float32) | ||
|
||
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps") | ||
pt_inputs = processor( | ||
raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt" | ||
).to("mps") | ||
|
||
|
||
def fun(): | ||
pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"]) | ||
pt_audio = pt_model.decode( | ||
pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"] | ||
) | ||
torch.mps.synchronize() | ||
|
||
|
||
for _ in range(5): | ||
fun() | ||
|
||
tic = time.time() | ||
for _ in range(10): | ||
fun() | ||
toc = time.time() | ||
ms = 1000 * (toc - tic) / 10 | ||
print(f"Time per it: {ms:.3f}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
# Copyright © 2024 Apple Inc. | ||
|
||
import argparse | ||
import json | ||
from pathlib import Path | ||
from textwrap import dedent | ||
from types import SimpleNamespace | ||
from typing import Any, Dict, Union | ||
|
||
import mlx.core as mx | ||
import mlx.nn as nn | ||
from huggingface_hub import snapshot_download | ||
from mlx.utils import tree_flatten | ||
|
||
import encodec | ||
|
||
|
||
def fetch_from_hub(hf_repo: str) -> Path: | ||
model_path = Path( | ||
snapshot_download( | ||
repo_id=hf_repo, | ||
allow_patterns=["*.json", "*.safetensors"], | ||
) | ||
) | ||
return model_path | ||
|
||
|
||
def upload_to_hub(path: str, upload_repo: str, hf_path: str): | ||
""" | ||
Uploads the model to Hugging Face hub. | ||
Args: | ||
path (str): Local path to the model. | ||
upload_repo (str): Name of the HF repo to upload to. | ||
hf_path (str): Path to the original Hugging Face model. | ||
""" | ||
import os | ||
|
||
from huggingface_hub import HfApi, ModelCard, logging | ||
|
||
content = dedent( | ||
f""" | ||
--- | ||
language: en | ||
license: other | ||
library: mlx | ||
tags: | ||
- mlx | ||
--- | ||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was | ||
converted to MLX format from | ||
[{hf_path}](https://huggingface.co/{hf_path}). | ||
This model is intended to be used with the [EnCodec MLX | ||
example](https://github.com/ml-explore/mlx-examples/tree/main/encodec). | ||
""" | ||
) | ||
|
||
card = ModelCard(content) | ||
card.save(os.path.join(path, "README.md")) | ||
|
||
logging.set_verbosity_info() | ||
|
||
api = HfApi() | ||
api.create_repo(repo_id=upload_repo, exist_ok=True) | ||
api.upload_folder( | ||
folder_path=path, | ||
repo_id=upload_repo, | ||
repo_type="model", | ||
multi_commits=True, | ||
multi_commits_verbose=True, | ||
) | ||
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.") | ||
|
||
|
||
def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None: | ||
if isinstance(save_path, str): | ||
save_path = Path(save_path) | ||
save_path.mkdir(parents=True, exist_ok=True) | ||
|
||
total_size = sum(v.nbytes for v in weights.values()) | ||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} | ||
mx.save_safetensors( | ||
str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"} | ||
) | ||
|
||
for weight_name in weights.keys(): | ||
index_data["weight_map"][weight_name] = "model.safetensors" | ||
|
||
index_data["weight_map"] = { | ||
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) | ||
} | ||
|
||
with open(save_path / "model.safetensors.index.json", "w") as f: | ||
json.dump(index_data, f, indent=4) | ||
|
||
|
||
def save_config( | ||
config: dict, | ||
config_path: Union[str, Path], | ||
) -> None: | ||
"""Save the model configuration to the ``config_path``. | ||
The final configuration will be sorted before saving for better readability. | ||
Args: | ||
config (dict): The model configuration. | ||
config_path (Union[str, Path]): Model configuration file path. | ||
""" | ||
# Clean unused keys | ||
config.pop("_name_or_path", None) | ||
|
||
# sort the config for better readability | ||
config = dict(sorted(config.items())) | ||
|
||
# write the updated config to the config_path (if provided) | ||
with open(config_path, "w") as fid: | ||
json.dump(config, fid, indent=4) | ||
|
||
|
||
def convert( | ||
upload: bool, | ||
model: str, | ||
dtype: str = None, | ||
): | ||
hf_repo = f"facebook/encodec_{model}" | ||
mlx_repo = f"mlx-community/encodec-{model}-{dtype}" | ||
path = fetch_from_hub(hf_repo) | ||
save_path = Path("mlx_models") | ||
|
||
weights = mx.load(str(Path(path) / "model.safetensors")) | ||
|
||
with open(path / "config.json", "r") as fid: | ||
config = SimpleNamespace(**json.load(fid)) | ||
|
||
model = encodec.EncodecModel(config) | ||
|
||
new_weights = {} | ||
for k, v in weights.items(): | ||
basename, pname = k.rsplit(".", 1) | ||
if pname == "weight_v": | ||
g = weights[basename + ".weight_g"] | ||
v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True)) | ||
k = basename + ".weight" | ||
elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]: | ||
continue | ||
elif "lstm" in basename: | ||
w_or_b, ih_or_hh, ln = pname.split("_") | ||
if w_or_b == "weight": | ||
new_pname = "Wx" if ih_or_hh == "ih" else "Wh" | ||
elif w_or_b == "bias" and ih_or_hh == "ih": | ||
continue | ||
else: | ||
v = v + weights[k.replace("_hh_", "_ih_")] | ||
new_pname = "bias" | ||
k = basename + "." + ln[1:] + "." + new_pname | ||
if "conv.weight" in k: | ||
# Possibly a transposed conv which has a different order | ||
if "decoder" in k: | ||
ln = int(k.split(".")[2]) | ||
if "conv" in model.decoder.layers[ln] and isinstance( | ||
model.decoder.layers[ln].conv, nn.ConvTranspose1d | ||
): | ||
v = mx.moveaxis(v, 0, 2) | ||
else: | ||
v = mx.moveaxis(v, 1, 2) | ||
else: | ||
v = mx.moveaxis(v, 1, 2) | ||
|
||
new_weights[k] = v | ||
weights = new_weights | ||
|
||
model.load_weights(list(weights.items())) | ||
|
||
if dtype is not None: | ||
t = getattr(mx, dtype) | ||
weights = {k: v.astype(t) for k, v in weights.items()} | ||
|
||
if isinstance(save_path, str): | ||
save_path = Path(save_path) | ||
|
||
save_weights(save_path, weights) | ||
|
||
save_config(vars(config), config_path=save_path / "config.json") | ||
|
||
if upload: | ||
upload_to_hub(save_path, mlx_repo, hf_repo) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.") | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
default="48khz", | ||
help="", | ||
choices=["24khz", "32khz", "48khz"], | ||
) | ||
parser.add_argument( | ||
"--upload", | ||
action="store_true", | ||
help="Upload the weights to Hugging Face.", | ||
) | ||
parser.add_argument( | ||
"--dtype", | ||
type=str, | ||
help="Data type to convert the model to.", | ||
default="float32", | ||
choices=["float32", "bfloat16", "float16"], | ||
) | ||
args = parser.parse_args() | ||
convert(upload=args.upload, model=args.model, dtype=args.dtype) |
Oops, something went wrong.