Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes and add support for multiples speaker references on XTTS inference #3149

Merged
merged 13 commits into from
Nov 7, 2023
24 changes: 6 additions & 18 deletions TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,22 @@
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
],
"model_hash": "ae9e4b39e095fd5728fe7f7931eccoqui",
"default_vocoder": null,
"commit": "480a6cdf7",
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": true
},
"xtts_v1": {
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/vocab.json"
],
"default_vocoder": null,
"commit": "e5140314",
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": true
},
"xtts_v1.1": {
"description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/hash.md5"
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/hash.md5"
],
"model_hash": "ae9e4b39e095fd5728fe7f7931ec66ad",
"model_hash": "7c62beaf58d39b729de287330dc254e7b515677416839b649a50e7cf74c3df59",
"default_vocoder": null,
"commit": "82910a63",
"license": "CPML",
Expand Down
30 changes: 1 addition & 29 deletions TTS/tts/layers/xtts/trainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
import random
import sys

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
from TTS.tts.models.xtts import load_audio

torch.set_num_threads(1)

Expand Down Expand Up @@ -50,31 +47,6 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
return rel_clip, rel_clip.shape[-1], cond_idxs


def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio_sox_load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio_soundfile_load(audiopath)

# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)

if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)

# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
return audio


class XTTSDataset(torch.utils.data.Dataset):
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
self.config = config
Expand Down
1 change: 0 additions & 1 deletion TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
s_info["speaker_wav"],
s_info["language"],
gpt_cond_len=3,
decoder="ne_hifigan",
)["wav"]
test_audios["{}-audio".format(idx)] = wav

Expand Down
185 changes: 59 additions & 126 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ def wav_to_mel_cloning(
return mel


def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio.backend.sox_io_backend.load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio.backend.soundfile_backend.load(audiopath)

# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)

if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)

# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
Edresson marked this conversation as resolved.
Show resolved Hide resolved
return audio


def pad_or_truncate(t, length):
"""
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
Expand All @@ -86,78 +111,6 @@ def pad_or_truncate(t, length):
return tp


def load_discrete_vocoder_diffuser(
trained_diffusion_steps=4000,
desired_diffusion_steps=200,
cond_free=True,
cond_free_k=1,
sampler="ddim",
):
"""
Load a GaussianDiffusion instance configured for use as a decoder.

Args:
trained_diffusion_steps (int): The number of diffusion steps used during training.
desired_diffusion_steps (int): The number of diffusion steps to use during inference.
cond_free (bool): Whether to use a conditioning-free model.
cond_free_k (int): The number of samples to use for conditioning-free models.
sampler (str): The name of the sampler to use.

Returns:
A SpacedDiffusion instance configured with the given parameters.
"""
return SpacedDiffusion(
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
model_mean_type="epsilon",
model_var_type="learned_range",
loss_type="mse",
betas=get_named_beta_schedule("linear", trained_diffusion_steps),
conditioning_free=cond_free,
conditioning_free_k=cond_free_k,
sampler=sampler,
)


def do_spectrogram_diffusion(
diffusion_model,
diffuser,
latents,
conditioning_latents,
temperature=1,
):
"""
Generate a mel-spectrogram using a diffusion model and a diffuser.

Args:
diffusion_model (nn.Module): A diffusion model that converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
diffuser (Diffuser): A diffuser that generates a mel-spectrogram from noise.
latents (torch.Tensor): A tensor of shape (batch_size, seq_len, code_size) containing the input spectrogram codes.
conditioning_latents (torch.Tensor): A tensor of shape (batch_size, code_size) containing the conditioning codes.
temperature (float, optional): The temperature of the noise used by the diffuser. Defaults to 1.

Returns:
torch.Tensor: A tensor of shape (batch_size, mel_channels, mel_seq_len) containing the generated mel-spectrogram.
"""
with torch.no_grad():
output_seq_len = (
latents.shape[1] * 4 * 24000 // 22050
) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(
latents, conditioning_latents, output_seq_len, False
)

noise = torch.randn(output_shape, device=latents.device) * temperature
mel = diffuser.sample_loop(
diffusion_model,
output_shape,
noise=noise,
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
progress=False,
)
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]


@dataclass
class XttsAudioConfig(Coqpit):
"""
Expand Down Expand Up @@ -336,7 +289,7 @@ def get_gpt_cond_latents(self, audio, sr, length: int = 3):
"""Compute the conditioning latents for the GPT model from the given audio.

Args:
audio_path (str): Path to the audio file.
audio (tensor): audio tensor.
sr (int): Sample rate of the audio.
length (int): Length of the audio in seconds. Defaults to 3.
"""
Expand Down Expand Up @@ -404,20 +357,41 @@ def get_conditioning_latents(
max_ref_length=10,
librosa_trim_db=None,
sound_norm_refs=False,
load_sr=24000,
):
# deal with multiples references
if not isinstance(audio_path, list):
audio_paths = [audio_path]
else:
audio_paths = audio_path

speaker_embeddings = []
audios = []
speaker_embedding = None
for file_path in audio_paths:
# load the audio in 24khz to avoid issued with multiple sr references
audio = load_audio(file_path, load_sr)
audio = audio[:, : load_sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

speaker_embedding = self.get_speaker_embedding(audio, load_sr)
speaker_embeddings.append(speaker_embedding)

audios.append(audio)

# use a merge of all references for gpt cond latents
full_audio = torch.cat(audios, dim=-1)
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T]

if speaker_embeddings:
speaker_embedding = torch.stack(speaker_embeddings)
speaker_embedding = speaker_embedding.mean(dim=0)

audio, sr = torchaudio.load(audio_path)
audio = audio[:, : sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

speaker_embedding = self.get_speaker_embedding(audio, sr)
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
return gpt_cond_latents, speaker_embedding

def synthesize(self, text, config, speaker_wav, language, **kwargs):
Expand All @@ -426,7 +400,7 @@ def synthesize(self, text, config, speaker_wav, language, **kwargs):
Args:
text (str): Input text.
config (XttsConfig): Config with inference parameters.
speaker_wav (str): Path to the speaker audio file for cloning.
speaker_wav (list): List of paths to the speaker audio files to be used for cloning.
language (str): Language ID of the speaker.
**kwargs: Inference settings. See `inference()`.

Expand All @@ -436,11 +410,6 @@ def synthesize(self, text, config, speaker_wav, language, **kwargs):
as latents used at inference.

"""

# Make the synthesizer happy 🥳
if isinstance(speaker_wav, list):
speaker_wav = speaker_wav[0]
Edresson marked this conversation as resolved.
Show resolved Hide resolved

return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)

def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
Expand Down Expand Up @@ -522,27 +491,6 @@ def full_inference(
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.

decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
Generally a value above 250 is not noticeably better, however. Defaults to 100.

cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion
performs two forward passes for each diffusion step: one with the outputs of the autoregressive model
and one with no conditioning priors. The output of the two is blended according to the cond_free_k
value below. Conditioning-free diffusion is the real deal, and dramatically improves realism.
Defaults to True.

cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the
conditioning-present signal. [0,inf]. As cond_free_k increases, the output becomes dominated by the
conditioning-free signal. Defaults to 2.0.

diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1].
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to 1.0.

decoder: (str) Selects the decoder to use between ("hifigan", "diffusion")
Defaults to hifigan

hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
here: https://huggingface.co/docs/transformers/internal/generation_utils
Expand All @@ -569,12 +517,6 @@ def full_inference(
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
decoder_iterations=decoder_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
diffusion_temperature=diffusion_temperature,
decoder_sampler=decoder_sampler,
decoder=decoder,
**hf_generate_kwargs,
)

Expand All @@ -592,13 +534,6 @@ def inference(
top_k=50,
top_p=0.85,
do_sample=True,
# Decoder inference
decoder_iterations=100,
cond_free=True,
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
decoder="hifigan",
num_beams=1,
**hf_generate_kwargs,
):
Expand Down Expand Up @@ -693,8 +628,6 @@ def inference_stream(
top_k=50,
top_p=0.85,
do_sample=True,
# Decoder inference
decoder="hifigan",
**hf_generate_kwargs,
):
text = text.strip().lower()
Expand Down
Loading
Loading