Skip to content

Commit

Permalink
Merge pull request #3318 from coqui-ai/calling_hf_models
Browse files Browse the repository at this point in the history
Run XTTS models by direct name with versions
  • Loading branch information
erogol authored Nov 30, 2023
2 parents 11ec9f7 + bfbaffc commit 9328338
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 17 deletions.
70 changes: 63 additions & 7 deletions TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from TTS.utils.synthesizer import Synthesizer
from TTS.config import load_config


class TTS(nn.Module):
"""TODO: Add voice conversion and Capacitron support."""

Expand Down Expand Up @@ -75,11 +76,13 @@ def __init__(
if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")

if model_name is not None:
if model_name is not None and len(model_name) > 0:
if "tts_models" in model_name or "coqui_studio" in model_name:
self.load_tts_model_by_name(model_name, gpu)
elif "voice_conversion_models" in model_name:
self.load_vc_model_by_name(model_name, gpu)
else:
self.load_model_by_name(model_name, gpu)

if model_path:
self.load_tts_model_by_path(
Expand All @@ -105,8 +108,12 @@ def is_coqui_studio(self):
@property
def is_multi_lingual(self):
# Not sure what sets this to None, but applied a fix to prevent crashing.
if (isinstance(self.model_name, str) and "xtts" in self.model_name or
self.config and ("xtts" in self.config.model or len(self.config.languages) > 1)):
if (
isinstance(self.model_name, str)
and "xtts" in self.model_name
or self.config
and ("xtts" in self.config.model or len(self.config.languages) > 1)
):
return True
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
return self.synthesizer.tts_model.language_manager.num_languages > 1
Expand Down Expand Up @@ -149,6 +156,15 @@ def download_model_by_name(self, model_name: str):
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
return model_path, config_path, vocoder_path, vocoder_config_path, None

def load_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of the 🐸TTS models by name.
Args:
model_name (str): Model name to load. You can list models by ```tts.models```.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.load_tts_model_by_name(model_name, gpu)

def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of the voice conversion models by name.
Expand Down Expand Up @@ -310,6 +326,7 @@ def tts(
speaker_wav: str = None,
emotion: str = None,
speed: float = None,
split_sentences: bool = True,
**kwargs,
):
"""Convert text to speech.
Expand All @@ -330,6 +347,12 @@ def tts(
speed (float, optional):
Speed factor to use for 🐸Coqui Studio models, between 0 and 2.0. If None, Studio models use 1.0.
Defaults to None.
split_sentences (bool, optional):
Split text into sentences, synthesize them separately and concatenate the file audio.
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
applicable to the 🐸TTS models. Defaults to True.
kwargs (dict, optional):
Additional arguments for the model.
"""
self._check_arguments(
speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed, **kwargs
Expand All @@ -347,6 +370,7 @@ def tts(
style_wav=None,
style_text=None,
reference_speaker_name=None,
split_sentences=split_sentences,
**kwargs,
)
return wav
Expand All @@ -361,6 +385,7 @@ def tts_to_file(
speed: float = 1.0,
pipe_out=None,
file_path: str = "output.wav",
split_sentences: bool = True,
**kwargs,
):
"""Convert text to speech.
Expand All @@ -385,6 +410,10 @@ def tts_to_file(
Flag to stdout the generated TTS wav file for shell pipe.
file_path (str, optional):
Output file path. Defaults to "output.wav".
split_sentences (bool, optional):
Split text into sentences, synthesize them separately and concatenate the file audio.
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
applicable to the 🐸TTS models. Defaults to True.
kwargs (dict, optional):
Additional arguments for the model.
"""
Expand All @@ -400,7 +429,14 @@ def tts_to_file(
file_path=file_path,
pipe_out=pipe_out,
)
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
wav = self.tts(
text=text,
speaker=speaker,
language=language,
speaker_wav=speaker_wav,
split_sentences=split_sentences,
**kwargs,
)
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
return file_path

Expand Down Expand Up @@ -440,7 +476,14 @@ def voice_conversion_to_file(
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
return file_path

def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None, speaker: str = None):
def tts_with_vc(
self,
text: str,
language: str = None,
speaker_wav: str = None,
speaker: str = None,
split_sentences: bool = True,
):
"""Convert text to speech with voice conversion.
It combines tts with voice conversion to fake voice cloning.
Expand All @@ -460,10 +503,16 @@ def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None,
speaker (str, optional):
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
split_sentences (bool, optional):
Split text into sentences, synthesize them separately and concatenate the file audio.
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
applicable to the 🐸TTS models. Defaults to True.
"""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
# Lazy code... save it to a temp file to resample it while reading it for VC
self.tts_to_file(text=text, speaker=speaker, language=language, file_path=fp.name)
self.tts_to_file(
text=text, speaker=speaker, language=language, file_path=fp.name, split_sentences=split_sentences
)
if self.voice_converter is None:
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
Expand All @@ -476,6 +525,7 @@ def tts_with_vc_to_file(
speaker_wav: str = None,
file_path: str = "output.wav",
speaker: str = None,
split_sentences: bool = True,
):
"""Convert text to speech with voice conversion and save to file.
Expand All @@ -495,6 +545,12 @@ def tts_with_vc_to_file(
speaker (str, optional):
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
split_sentences (bool, optional):
Split text into sentences, synthesize them separately and concatenate the file audio.
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
applicable to the 🐸TTS models. Defaults to True.
"""
wav = self.tts_with_vc(text=text, language=language, speaker_wav=speaker_wav, speaker=speaker)
wav = self.tts_with_vc(
text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences
)
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
2 changes: 1 addition & 1 deletion TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def eval_step(self, batch, criterion):
return self.train_step(batch, criterion)

def on_train_epoch_start(self, trainer):
trainer.model.eval() # the whole model to eval
trainer.model.eval() # the whole model to eval
# put gpt model in training mode
trainer.model.xtts.gpt.train()

Expand Down
35 changes: 32 additions & 3 deletions TTS/utils/manage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import re
import tarfile
import zipfile
from pathlib import Path
Expand All @@ -26,7 +27,6 @@
}



class ModelManager(object):
tqdm_progress = None
"""Manage TTS models defined in .models.json.
Expand Down Expand Up @@ -276,13 +276,15 @@ def set_model_url(model_item: Dict):
model_item["model_url"] = model_item["hf_url"]
elif "fairseq" in model_item["model_name"]:
model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/"
elif "xtts" in model_item["model_name"]:
model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/"
return model_item

def _set_model_item(self, model_name):
# fetch model info from the dict
model_type, lang, dataset, model = model_name.split("/")
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
if "fairseq" in model_name:
model_type = "tts_models"
lang = model_name.split("/")[1]
model_item = {
"model_type": "tts_models",
"license": "CC BY-NC 4.0",
Expand All @@ -291,10 +293,37 @@ def _set_model_item(self, model_name):
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
}
model_item["model_name"] = model_name
elif "xtts" in model_name and len(model_name.split("/")) != 4:
# loading xtts models with only model name (e.g. xtts_v2.0.2)
# check model name has the version number with regex
version_regex = r"v\d+\.\d+\.\d+"
if re.search(version_regex, model_name):
model_version = model_name.split("_")[-1]
else:
model_version = "main"
model_type = "tts_models"
lang = "multilingual"
dataset = "multi-dataset"
model = model_name
model_item = {
"default_vocoder": None,
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": True,
"hf_url": [
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5",
],
}
else:
# get model from models.json
model_type, lang, dataset, model = model_name.split("/")
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type

model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
model_item = self.set_model_url(model_item)
return model_item, model_full_name, model, md5hash
Expand Down
9 changes: 7 additions & 2 deletions TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def tts(
style_text=None,
reference_wav=None,
reference_speaker_name=None,
split_sentences: bool = True,
**kwargs,
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
Expand All @@ -277,6 +278,8 @@ def tts(
style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None.
reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None.
reference_speaker_name ([type], optional): speaker id of reference waveform. Defaults to None.
split_sentences (bool, optional): split the input text into sentences. Defaults to True.
**kwargs: additional arguments to pass to the TTS model.
Returns:
List[int]: [description]
"""
Expand All @@ -289,8 +292,10 @@ def tts(
)

if text:
sens = self.split_into_sentences(text)
print(" > Text splitted to sentences.")
sens = [text]
if split_sentences:
print(" > Text splitted to sentences.")
sens = self.split_into_sentences(text)
print(sens)

# handle multi-speaker
Expand Down
66 changes: 62 additions & 4 deletions docs/source/models/xtts.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ You can also mail us at info@coqui.ai.
#### 🐸TTS API

##### Single reference

Splits the text into sentences and generates audio for each sentence. The audio files are then concatenated to produce the final audio.
You can optionally disable sentence splitting for better coherence but more VRAM and possibly hitting models context length limit.

```python
from TTS.api import TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
Expand All @@ -47,21 +51,72 @@ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
file_path="output.wav",
speaker_wav=["/path/to/target/speaker.wav"],
language="en")
language="en",
split_sentences=True
)
```

##### Multiple references

You can pass multiple audio files to the `speaker_wav` argument for better voice cloning.

```python
from TTS.api import TTS

# using the default version set in 🐸TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)

# using a specific version
# 👀 see the branch names for versions on https://huggingface.co/coqui/XTTS-v2/tree/main
# ❗some versions might be incompatible with the API
tts = TTS("xtts_v2.0.2", gpu=True)

# getting the latest XTTS_v2
tts = TTS("xtts", gpu=True)

# generate speech by cloning a voice using default settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
file_path="output.wav",
speaker_wav=["/path/to/target/speaker.wav", "/path/to/target/speaker_2.wav", "/path/to/target/speaker_3.wav"],
language="en")
```

##### Streaming inference

XTTS supports streaming inference. This is useful for real-time applications.

```python
import os
import time
import torch
import torchaudio

print("Loading model...")
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
model = tts.synthesizer.tts_model

print("Computing speaker latents...")
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])

print("Inference...")
t0 = time.time()
stream_generator = model.inference_stream(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding
)

wav_chuncks = []
for i, chunk in enumerate(stream_generator):
if i == 0:
print(f"Time to first chunck: {time.time() - t0}")
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
wav_chuncks.append(chunk)
wav = torch.cat(wav_chuncks, dim=0)
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
```

#### 🐸TTS Command line

##### Single reference
Expand Down Expand Up @@ -91,10 +146,13 @@ or for all wav files in a directory you can use:
--use_cuda true
```

#### 🐸TTS Model API

To use the model API, you need to download the model files and pass config and model file paths manually.

#### model directly
##### Calling manually

If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
If you want to be able to run with `use_deepspeed=True` and **enjoy the speedup**, you need to install deepspeed first.

```console
pip install deepspeed==0.10.3
Expand Down Expand Up @@ -129,7 +187,7 @@ torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
```


#### streaming inference
##### Streaming manually

Here the goal is to stream the audio as it is being generated. This is useful for real-time applications.
Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster.
Expand Down

0 comments on commit 9328338

Please sign in to comment.