Skip to content

Commit

Permalink
Inference API for 🐶Bark (#2685)
Browse files Browse the repository at this point in the history
* Add bark requirements

* Draft Bark implementation

* Download HF models

* Update synthesizer

* Add bark model

* Make style

* Update pylintrc

* Update model URLs

* Update Bark Config

* Fix here and ther

* Make style

* Make lint

* Update requirements

* Update requirements
  • Loading branch information
erogol authored Jun 28, 2023
1 parent 4cf8652 commit c844b65
Show file tree
Hide file tree
Showing 18 changed files with 1,757 additions and 101 deletions.
4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ disable=missing-docstring,
comprehension-escape,
duplicate-code,
not-callable,
import-outside-toplevel
import-outside-toplevel,
logging-fstring-interpolation,
logging-not-lazy

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
159 changes: 86 additions & 73 deletions TTS/.models.json

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def list_models():

def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name)
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["github_rls_url"], list)):
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
# return model directory if there are multiple files
# we assume that the model knows how to load itself
return None, None, None, None, model_path
Expand Down Expand Up @@ -584,6 +584,8 @@ def tts_to_file(
Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None.
file_path (str, optional):
Output file path. Defaults to "output.wav".
kwargs (dict, optional):
Additional arguments for the model.
"""
self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def main():
vc_config_path = config_path

# tts model with multiple files to be loaded from the directory path
if model_item.get("author", None) == "fairseq" or isinstance(model_item["github_rls_url"], list):
if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
model_dir = model_path
tts_path = None
tts_config_path = None
Expand Down
105 changes: 105 additions & 0 deletions TTS/tts/configs/bark_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
from dataclasses import dataclass
from typing import Dict

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.layers.bark.model import GPTConfig
from TTS.tts.layers.bark.model_fine import FineGPTConfig
from TTS.tts.models.bark import BarkAudioConfig
from TTS.utils.generic_utils import get_user_data_dir


@dataclass
class BarkConfig(BaseTTSConfig):
"""Bark TTS configuration
Args:
model (str): model name that registers the model.
audio (BarkAudioConfig): audio configuration. Defaults to BarkAudioConfig().
num_chars (int): number of characters in the alphabet. Defaults to 0.
semantic_config (GPTConfig): semantic configuration. Defaults to GPTConfig().
fine_config (FineGPTConfig): fine configuration. Defaults to FineGPTConfig().
coarse_config (GPTConfig): coarse configuration. Defaults to GPTConfig().
CONTEXT_WINDOW_SIZE (int): GPT context window size. Defaults to 1024.
SEMANTIC_RATE_HZ (float): semantic tokens rate in Hz. Defaults to 49.9.
SEMANTIC_VOCAB_SIZE (int): semantic vocabulary size. Defaults to 10_000.
CODEBOOK_SIZE (int): encodec codebook size. Defaults to 1024.
N_COARSE_CODEBOOKS (int): number of coarse codebooks. Defaults to 2.
N_FINE_CODEBOOKS (int): number of fine codebooks. Defaults to 8.
COARSE_RATE_HZ (int): coarse tokens rate in Hz. Defaults to 75.
SAMPLE_RATE (int): sample rate. Defaults to 24_000.
USE_SMALLER_MODELS (bool): use smaller models. Defaults to False.
TEXT_ENCODING_OFFSET (int): text encoding offset. Defaults to 10_048.
SEMANTIC_PAD_TOKEN (int): semantic pad token. Defaults to 10_000.
TEXT_PAD_TOKEN ([type]): text pad token. Defaults to 10_048.
TEXT_EOS_TOKEN ([type]): text end of sentence token. Defaults to 10_049.
TEXT_SOS_TOKEN ([type]): text start of sentence token. Defaults to 10_050.
SEMANTIC_INFER_TOKEN (int): semantic infer token. Defaults to 10_051.
COARSE_SEMANTIC_PAD_TOKEN (int): coarse semantic pad token. Defaults to 12_048.
COARSE_INFER_TOKEN (int): coarse infer token. Defaults to 12_050.
REMOTE_BASE_URL ([type]): remote base url. Defaults to "https://huggingface.co/erogol/bark/tree".
REMOTE_MODEL_PATHS (Dict): remote model paths. Defaults to None.
LOCAL_MODEL_PATHS (Dict): local model paths. Defaults to None.
SMALL_REMOTE_MODEL_PATHS (Dict): small remote model paths. Defaults to None.
CACHE_DIR (str): local cache directory. Defaults to get_user_data_dir().
DEF_SPEAKER_DIR (str): default speaker directory to stoke speaker values for voice cloning. Defaults to get_user_data_dir().
"""

model: str = "bark"
audio: BarkAudioConfig = BarkAudioConfig()
num_chars: int = 0
semantic_config: GPTConfig = GPTConfig()
fine_config: FineGPTConfig = FineGPTConfig()
coarse_config: GPTConfig = GPTConfig()
CONTEXT_WINDOW_SIZE: int = 1024
SEMANTIC_RATE_HZ: float = 49.9
SEMANTIC_VOCAB_SIZE: int = 10_000
CODEBOOK_SIZE: int = 1024
N_COARSE_CODEBOOKS: int = 2
N_FINE_CODEBOOKS: int = 8
COARSE_RATE_HZ: int = 75
SAMPLE_RATE: int = 24_000
USE_SMALLER_MODELS: bool = False

TEXT_ENCODING_OFFSET: int = 10_048
SEMANTIC_PAD_TOKEN: int = 10_000
TEXT_PAD_TOKEN: int = 129_595
SEMANTIC_INFER_TOKEN: int = 129_599
COARSE_SEMANTIC_PAD_TOKEN: int = 12_048
COARSE_INFER_TOKEN: int = 12_050

REMOTE_BASE_URL = "https://huggingface.co/erogol/bark/tree/main/"
REMOTE_MODEL_PATHS: Dict = None
LOCAL_MODEL_PATHS: Dict = None
SMALL_REMOTE_MODEL_PATHS: Dict = None
CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0"))
DEF_SPEAKER_DIR: str = str(get_user_data_dir("tts/bark_v0/speakers"))

def __post_init__(self):
self.REMOTE_MODEL_PATHS = {
"text": {
"path": os.path.join(self.REMOTE_BASE_URL, "text_2.pt"),
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
},
"coarse": {
"path": os.path.join(self.REMOTE_BASE_URL, "coarse_2.pt"),
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
},
"fine": {
"path": os.path.join(self.REMOTE_BASE_URL, "fine_2.pt"),
"checksum": "59d184ed44e3650774a2f0503a48a97b",
},
}
self.LOCAL_MODEL_PATHS = {
"text": os.path.join(self.CACHE_DIR, "text_2.pt"),
"coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"),
"fine": os.path.join(self.CACHE_DIR, "fine_2.pt"),
"hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"),
"hubert": os.path.join(self.CACHE_DIR, "hubert.pt"),
}
self.SMALL_REMOTE_MODEL_PATHS = {
"text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")},
"coarse": {"path": os.path.join(self.REMOTE_BASE_URL, "coarse.pt")},
"fine": {"path": os.path.join(self.REMOTE_BASE_URL, "fine.pt")},
}
self.sample_rate = self.SAMPLE_RATE # pylint: disable=attribute-defined-outside-init
Empty file added TTS/tts/layers/bark/__init__.py
Empty file.
Empty file.
35 changes: 35 additions & 0 deletions TTS/tts/layers/bark/hubert/hubert_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer

import os.path
import shutil
import urllib.request

import huggingface_hub


class HubertManager:
@staticmethod
def make_sure_hubert_installed(
download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = ""
):
if not os.path.isfile(model_path):
print("Downloading HuBERT base model")
urllib.request.urlretrieve(download_url, model_path)
print("Downloaded HuBERT")
return model_path
return None

@staticmethod
def make_sure_tokenizer_installed(
model: str = "quantifier_hubert_base_ls960_14.pth",
repo: str = "GitMylo/bark-voice-cloning",
model_path: str = "",
):
model_dir = os.path.dirname(model_path)
if not os.path.isfile(model_path):
print("Downloading HuBERT custom tokenizer")
huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False)
shutil.move(os.path.join(model_dir, model), model_path)
print("Downloaded tokenizer")
return model_path
return None
101 changes: 101 additions & 0 deletions TTS/tts/layers/bark/hubert/kmeans_hubert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Modified HuBERT model without kmeans.
Original author: https://github.com/lucidrains/
Modified by: https://www.github.com/gitmylo/
License: MIT
"""

# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py

import logging
from pathlib import Path

import fairseq
import torch
from einops import pack, unpack
from torch import nn
from torchaudio.functional import resample

logging.root.setLevel(logging.ERROR)


def round_down_nearest_multiple(num, divisor):
return num // divisor * divisor


def curtail_to_multiple(t, mult, from_left=False):
data_len = t.shape[-1]
rounded_seq_len = round_down_nearest_multiple(data_len, mult)
seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None)
return t[..., seq_slice]


def exists(val):
return val is not None


def default(val, d):
return val if exists(val) else d


class CustomHubert(nn.Module):
"""
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
or you can train your own
"""

def __init__(self, checkpoint_path, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None):
super().__init__()
self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of
self.output_layer = output_layer

if device is not None:
self.to(device)

model_path = Path(checkpoint_path)

assert model_path.exists(), f"path {checkpoint_path} does not exist"

checkpoint = torch.load(checkpoint_path)
load_model_input = {checkpoint_path: checkpoint}
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)

if device is not None:
model[0].to(device)

self.model = model[0]
self.model.eval()

@property
def groups(self):
return 1

@torch.no_grad()
def forward(self, wav_input, flatten=True, input_sample_hz=None):
device = wav_input.device

if exists(input_sample_hz):
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)

if exists(self.seq_len_multiple_of):
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)

embed = self.model(
wav_input,
features_only=True,
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
output_layer=self.output_layer,
)

embed, packed_shape = pack([embed["x"]], "* d")

# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())

codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()

if flatten:
return codebook_indices

(codebook_indices,) = unpack(codebook_indices, packed_shape, "*")
return codebook_indices
Loading

0 comments on commit c844b65

Please sign in to comment.