-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add bark requirements * Draft Bark implementation * Download HF models * Update synthesizer * Add bark model * Make style * Update pylintrc * Update model URLs * Update Bark Config * Fix here and ther * Make style * Make lint * Update requirements * Update requirements
- Loading branch information
Showing
18 changed files
with
1,757 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import os | ||
from dataclasses import dataclass | ||
from typing import Dict | ||
|
||
from TTS.tts.configs.shared_configs import BaseTTSConfig | ||
from TTS.tts.layers.bark.model import GPTConfig | ||
from TTS.tts.layers.bark.model_fine import FineGPTConfig | ||
from TTS.tts.models.bark import BarkAudioConfig | ||
from TTS.utils.generic_utils import get_user_data_dir | ||
|
||
|
||
@dataclass | ||
class BarkConfig(BaseTTSConfig): | ||
"""Bark TTS configuration | ||
Args: | ||
model (str): model name that registers the model. | ||
audio (BarkAudioConfig): audio configuration. Defaults to BarkAudioConfig(). | ||
num_chars (int): number of characters in the alphabet. Defaults to 0. | ||
semantic_config (GPTConfig): semantic configuration. Defaults to GPTConfig(). | ||
fine_config (FineGPTConfig): fine configuration. Defaults to FineGPTConfig(). | ||
coarse_config (GPTConfig): coarse configuration. Defaults to GPTConfig(). | ||
CONTEXT_WINDOW_SIZE (int): GPT context window size. Defaults to 1024. | ||
SEMANTIC_RATE_HZ (float): semantic tokens rate in Hz. Defaults to 49.9. | ||
SEMANTIC_VOCAB_SIZE (int): semantic vocabulary size. Defaults to 10_000. | ||
CODEBOOK_SIZE (int): encodec codebook size. Defaults to 1024. | ||
N_COARSE_CODEBOOKS (int): number of coarse codebooks. Defaults to 2. | ||
N_FINE_CODEBOOKS (int): number of fine codebooks. Defaults to 8. | ||
COARSE_RATE_HZ (int): coarse tokens rate in Hz. Defaults to 75. | ||
SAMPLE_RATE (int): sample rate. Defaults to 24_000. | ||
USE_SMALLER_MODELS (bool): use smaller models. Defaults to False. | ||
TEXT_ENCODING_OFFSET (int): text encoding offset. Defaults to 10_048. | ||
SEMANTIC_PAD_TOKEN (int): semantic pad token. Defaults to 10_000. | ||
TEXT_PAD_TOKEN ([type]): text pad token. Defaults to 10_048. | ||
TEXT_EOS_TOKEN ([type]): text end of sentence token. Defaults to 10_049. | ||
TEXT_SOS_TOKEN ([type]): text start of sentence token. Defaults to 10_050. | ||
SEMANTIC_INFER_TOKEN (int): semantic infer token. Defaults to 10_051. | ||
COARSE_SEMANTIC_PAD_TOKEN (int): coarse semantic pad token. Defaults to 12_048. | ||
COARSE_INFER_TOKEN (int): coarse infer token. Defaults to 12_050. | ||
REMOTE_BASE_URL ([type]): remote base url. Defaults to "https://huggingface.co/erogol/bark/tree". | ||
REMOTE_MODEL_PATHS (Dict): remote model paths. Defaults to None. | ||
LOCAL_MODEL_PATHS (Dict): local model paths. Defaults to None. | ||
SMALL_REMOTE_MODEL_PATHS (Dict): small remote model paths. Defaults to None. | ||
CACHE_DIR (str): local cache directory. Defaults to get_user_data_dir(). | ||
DEF_SPEAKER_DIR (str): default speaker directory to stoke speaker values for voice cloning. Defaults to get_user_data_dir(). | ||
""" | ||
|
||
model: str = "bark" | ||
audio: BarkAudioConfig = BarkAudioConfig() | ||
num_chars: int = 0 | ||
semantic_config: GPTConfig = GPTConfig() | ||
fine_config: FineGPTConfig = FineGPTConfig() | ||
coarse_config: GPTConfig = GPTConfig() | ||
CONTEXT_WINDOW_SIZE: int = 1024 | ||
SEMANTIC_RATE_HZ: float = 49.9 | ||
SEMANTIC_VOCAB_SIZE: int = 10_000 | ||
CODEBOOK_SIZE: int = 1024 | ||
N_COARSE_CODEBOOKS: int = 2 | ||
N_FINE_CODEBOOKS: int = 8 | ||
COARSE_RATE_HZ: int = 75 | ||
SAMPLE_RATE: int = 24_000 | ||
USE_SMALLER_MODELS: bool = False | ||
|
||
TEXT_ENCODING_OFFSET: int = 10_048 | ||
SEMANTIC_PAD_TOKEN: int = 10_000 | ||
TEXT_PAD_TOKEN: int = 129_595 | ||
SEMANTIC_INFER_TOKEN: int = 129_599 | ||
COARSE_SEMANTIC_PAD_TOKEN: int = 12_048 | ||
COARSE_INFER_TOKEN: int = 12_050 | ||
|
||
REMOTE_BASE_URL = "https://huggingface.co/erogol/bark/tree/main/" | ||
REMOTE_MODEL_PATHS: Dict = None | ||
LOCAL_MODEL_PATHS: Dict = None | ||
SMALL_REMOTE_MODEL_PATHS: Dict = None | ||
CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0")) | ||
DEF_SPEAKER_DIR: str = str(get_user_data_dir("tts/bark_v0/speakers")) | ||
|
||
def __post_init__(self): | ||
self.REMOTE_MODEL_PATHS = { | ||
"text": { | ||
"path": os.path.join(self.REMOTE_BASE_URL, "text_2.pt"), | ||
"checksum": "54afa89d65e318d4f5f80e8e8799026a", | ||
}, | ||
"coarse": { | ||
"path": os.path.join(self.REMOTE_BASE_URL, "coarse_2.pt"), | ||
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", | ||
}, | ||
"fine": { | ||
"path": os.path.join(self.REMOTE_BASE_URL, "fine_2.pt"), | ||
"checksum": "59d184ed44e3650774a2f0503a48a97b", | ||
}, | ||
} | ||
self.LOCAL_MODEL_PATHS = { | ||
"text": os.path.join(self.CACHE_DIR, "text_2.pt"), | ||
"coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"), | ||
"fine": os.path.join(self.CACHE_DIR, "fine_2.pt"), | ||
"hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"), | ||
"hubert": os.path.join(self.CACHE_DIR, "hubert.pt"), | ||
} | ||
self.SMALL_REMOTE_MODEL_PATHS = { | ||
"text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")}, | ||
"coarse": {"path": os.path.join(self.REMOTE_BASE_URL, "coarse.pt")}, | ||
"fine": {"path": os.path.join(self.REMOTE_BASE_URL, "fine.pt")}, | ||
} | ||
self.sample_rate = self.SAMPLE_RATE # pylint: disable=attribute-defined-outside-init |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer | ||
|
||
import os.path | ||
import shutil | ||
import urllib.request | ||
|
||
import huggingface_hub | ||
|
||
|
||
class HubertManager: | ||
@staticmethod | ||
def make_sure_hubert_installed( | ||
download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = "" | ||
): | ||
if not os.path.isfile(model_path): | ||
print("Downloading HuBERT base model") | ||
urllib.request.urlretrieve(download_url, model_path) | ||
print("Downloaded HuBERT") | ||
return model_path | ||
return None | ||
|
||
@staticmethod | ||
def make_sure_tokenizer_installed( | ||
model: str = "quantifier_hubert_base_ls960_14.pth", | ||
repo: str = "GitMylo/bark-voice-cloning", | ||
model_path: str = "", | ||
): | ||
model_dir = os.path.dirname(model_path) | ||
if not os.path.isfile(model_path): | ||
print("Downloading HuBERT custom tokenizer") | ||
huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False) | ||
shutil.move(os.path.join(model_dir, model), model_path) | ||
print("Downloaded tokenizer") | ||
return model_path | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
""" | ||
Modified HuBERT model without kmeans. | ||
Original author: https://github.com/lucidrains/ | ||
Modified by: https://www.github.com/gitmylo/ | ||
License: MIT | ||
""" | ||
|
||
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py | ||
|
||
import logging | ||
from pathlib import Path | ||
|
||
import fairseq | ||
import torch | ||
from einops import pack, unpack | ||
from torch import nn | ||
from torchaudio.functional import resample | ||
|
||
logging.root.setLevel(logging.ERROR) | ||
|
||
|
||
def round_down_nearest_multiple(num, divisor): | ||
return num // divisor * divisor | ||
|
||
|
||
def curtail_to_multiple(t, mult, from_left=False): | ||
data_len = t.shape[-1] | ||
rounded_seq_len = round_down_nearest_multiple(data_len, mult) | ||
seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None) | ||
return t[..., seq_slice] | ||
|
||
|
||
def exists(val): | ||
return val is not None | ||
|
||
|
||
def default(val, d): | ||
return val if exists(val) else d | ||
|
||
|
||
class CustomHubert(nn.Module): | ||
""" | ||
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert | ||
or you can train your own | ||
""" | ||
|
||
def __init__(self, checkpoint_path, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None): | ||
super().__init__() | ||
self.target_sample_hz = target_sample_hz | ||
self.seq_len_multiple_of = seq_len_multiple_of | ||
self.output_layer = output_layer | ||
|
||
if device is not None: | ||
self.to(device) | ||
|
||
model_path = Path(checkpoint_path) | ||
|
||
assert model_path.exists(), f"path {checkpoint_path} does not exist" | ||
|
||
checkpoint = torch.load(checkpoint_path) | ||
load_model_input = {checkpoint_path: checkpoint} | ||
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) | ||
|
||
if device is not None: | ||
model[0].to(device) | ||
|
||
self.model = model[0] | ||
self.model.eval() | ||
|
||
@property | ||
def groups(self): | ||
return 1 | ||
|
||
@torch.no_grad() | ||
def forward(self, wav_input, flatten=True, input_sample_hz=None): | ||
device = wav_input.device | ||
|
||
if exists(input_sample_hz): | ||
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz) | ||
|
||
if exists(self.seq_len_multiple_of): | ||
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) | ||
|
||
embed = self.model( | ||
wav_input, | ||
features_only=True, | ||
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code | ||
output_layer=self.output_layer, | ||
) | ||
|
||
embed, packed_shape = pack([embed["x"]], "* d") | ||
|
||
# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy()) | ||
|
||
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long() | ||
|
||
if flatten: | ||
return codebook_indices | ||
|
||
(codebook_indices,) = unpack(codebook_indices, packed_shape, "*") | ||
return codebook_indices |
Oops, something went wrong.