Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add studio speakers to open source XTTS! #3405

Merged
merged 18 commits into from
Dec 12, 2023
Merged
5 changes: 3 additions & 2 deletions TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
"multilingual": {
"multi-dataset": {
"xtts_v2": {
"description": "XTTS-v2.0.2 by Coqui with 16 languages.",
"description": "XTTS-v2.0.3 by Coqui with 17 languages.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/speakers_xtts.pth"
],
"model_hash": "10f92b55c512af7a8d39d650547a15a7",
"default_vocoder": null,
Expand Down
18 changes: 18 additions & 0 deletions TTS/tts/layers/xtts/xtts_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

class SpeakerManager():
def __init__(self, speaker_file_path=None):
self.speakers = torch.load(speaker_file_path)

@property
def name_to_id(self):
return self.speakers.keys()


class LanguageManager():
def __init__(self, config):
self.langs = config["languages"]

@property
def name_to_id(self):
return self.langs
27 changes: 17 additions & 10 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec

Expand Down Expand Up @@ -378,7 +379,7 @@ def get_conditioning_latents(

return gpt_cond_latents, speaker_embedding

def synthesize(self, text, config, speaker_wav, language, **kwargs):
def synthesize(self, text, config, speaker_wav, language, speaker_id, **kwargs):
WeberJulian marked this conversation as resolved.
Show resolved Hide resolved
"""Synthesize speech with the given input text.

Args:
Expand All @@ -393,12 +394,6 @@ def synthesize(self, text, config, speaker_wav, language, **kwargs):
`text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents`
as latents used at inference.

"""
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)

def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
"""
inference with config
"""
assert (
"zh-cn" if language == "zh" else language in self.config.languages
Expand All @@ -410,13 +405,18 @@ def inference_with_config(self, text, config, ref_audio_path, language, **kwargs
"repetition_penalty": config.repetition_penalty,
"top_k": config.top_k,
"top_p": config.top_p,
}
settings.update(kwargs) # allow overriding of preset settings with kwargs
if speaker_id is not None:
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
settings.update({
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
}
settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.full_inference(text, ref_audio_path, language, **settings)
})
return self.full_inference(text, speaker_wav, language, **settings)

@torch.inference_mode()
def full_inference(
Expand Down Expand Up @@ -733,6 +733,7 @@ def load_checkpoint(
eval=True,
strict=True,
use_deepspeed=False,
speaker_file_path=None,
):
"""
Loads a checkpoint from disk and initializes the model's state and tokenizer.
Expand All @@ -751,6 +752,12 @@ def load_checkpoint(

model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers_xtts.pth")

self.language_manager = LanguageManager(config)
self.speaker_manager = None
if os.path.exists(speaker_file_path):
self.speaker_manager = SpeakerManager(speaker_file_path)

if os.path.exists(vocab_path):
self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
Expand Down
1 change: 1 addition & 0 deletions TTS/utils/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def _set_model_item(self, model_name):
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/speakers_xtts.pth",
],
}
else:
Expand Down
7 changes: 5 additions & 2 deletions TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def tts(
speaker_embedding = None
speaker_id = None
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
if speaker_name and isinstance(speaker_name, str):
if speaker_name and isinstance(speaker_name, str) and not self.tts_config.model == "xtts":
if self.tts_config.use_d_vector_file:
# get the average speaker embedding from the saved d_vectors.
speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
Expand Down Expand Up @@ -335,7 +335,9 @@ def tts(
# handle multi-lingual
language_id = None
if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
hasattr(self.tts_model, "language_manager")
and self.tts_model.language_manager is not None
and not self.tts_config.model == "xtts"
):
if len(self.tts_model.language_manager.name_to_id) == 1:
language_id = list(self.tts_model.language_manager.name_to_id.values())[0]
Expand Down Expand Up @@ -366,6 +368,7 @@ def tts(
if (
speaker_wav is not None
and self.tts_model.speaker_manager is not None
and hasattr(self.tts_model.speaker_manager, "encoder_ap")
and self.tts_model.speaker_manager.encoder_ap is not None
):
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
Expand Down
Loading