Skip to content

Commit

Permalink
Merge pull request #3149 from coqui-ai/fixup_xtts_v2
Browse files Browse the repository at this point in the history
Bug fixes and add support for multiples speaker references on XTTS inference
  • Loading branch information
erogol authored Nov 7, 2023
2 parents f0cb19e + 5f9ab6c commit 5e992d8
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 200 deletions.
24 changes: 6 additions & 18 deletions TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,22 @@
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5"
],
"model_hash": "ae9e4b39e095fd5728fe7f7931eccoqui",
"default_vocoder": null,
"commit": "480a6cdf7",
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": true
},
"xtts_v1": {
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/vocab.json"
],
"default_vocoder": null,
"commit": "e5140314",
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": true
},
"xtts_v1.1": {
"description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/hash.md5"
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/vocab.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/hash.md5"
],
"model_hash": "ae9e4b39e095fd5728fe7f7931ec66ad",
"model_hash": "7c62beaf58d39b729de287330dc254e7b515677416839b649a50e7cf74c3df59",
"default_vocoder": null,
"commit": "82910a63",
"license": "CPML",
Expand Down
30 changes: 1 addition & 29 deletions TTS/tts/layers/xtts/trainer/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
import random
import sys

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
from TTS.tts.models.xtts import load_audio

torch.set_num_threads(1)

Expand Down Expand Up @@ -50,31 +47,6 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
return rel_clip, rel_clip.shape[-1], cond_idxs


def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio_sox_load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio_soundfile_load(audiopath)

# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)

if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)

# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
return audio


class XTTSDataset(torch.utils.data.Dataset):
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
self.config = config
Expand Down
1 change: 0 additions & 1 deletion TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
s_info["speaker_wav"],
s_info["language"],
gpt_cond_len=3,
decoder="ne_hifigan",
)["wav"]
test_audios["{}-audio".format(idx)] = wav

Expand Down
185 changes: 59 additions & 126 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@ def wav_to_mel_cloning(
return mel


def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio.backend.sox_io_backend.load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio.backend.soundfile_backend.load(audiopath)

# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)

if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)

# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
return audio


def pad_or_truncate(t, length):
"""
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
Expand All @@ -86,78 +111,6 @@ def pad_or_truncate(t, length):
return tp


def load_discrete_vocoder_diffuser(
trained_diffusion_steps=4000,
desired_diffusion_steps=200,
cond_free=True,
cond_free_k=1,
sampler="ddim",
):
"""
Load a GaussianDiffusion instance configured for use as a decoder.
Args:
trained_diffusion_steps (int): The number of diffusion steps used during training.
desired_diffusion_steps (int): The number of diffusion steps to use during inference.
cond_free (bool): Whether to use a conditioning-free model.
cond_free_k (int): The number of samples to use for conditioning-free models.
sampler (str): The name of the sampler to use.
Returns:
A SpacedDiffusion instance configured with the given parameters.
"""
return SpacedDiffusion(
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
model_mean_type="epsilon",
model_var_type="learned_range",
loss_type="mse",
betas=get_named_beta_schedule("linear", trained_diffusion_steps),
conditioning_free=cond_free,
conditioning_free_k=cond_free_k,
sampler=sampler,
)


def do_spectrogram_diffusion(
diffusion_model,
diffuser,
latents,
conditioning_latents,
temperature=1,
):
"""
Generate a mel-spectrogram using a diffusion model and a diffuser.
Args:
diffusion_model (nn.Module): A diffusion model that converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
diffuser (Diffuser): A diffuser that generates a mel-spectrogram from noise.
latents (torch.Tensor): A tensor of shape (batch_size, seq_len, code_size) containing the input spectrogram codes.
conditioning_latents (torch.Tensor): A tensor of shape (batch_size, code_size) containing the conditioning codes.
temperature (float, optional): The temperature of the noise used by the diffuser. Defaults to 1.
Returns:
torch.Tensor: A tensor of shape (batch_size, mel_channels, mel_seq_len) containing the generated mel-spectrogram.
"""
with torch.no_grad():
output_seq_len = (
latents.shape[1] * 4 * 24000 // 22050
) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(
latents, conditioning_latents, output_seq_len, False
)

noise = torch.randn(output_shape, device=latents.device) * temperature
mel = diffuser.sample_loop(
diffusion_model,
output_shape,
noise=noise,
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
progress=False,
)
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]


@dataclass
class XttsAudioConfig(Coqpit):
"""
Expand Down Expand Up @@ -336,7 +289,7 @@ def get_gpt_cond_latents(self, audio, sr, length: int = 3):
"""Compute the conditioning latents for the GPT model from the given audio.
Args:
audio_path (str): Path to the audio file.
audio (tensor): audio tensor.
sr (int): Sample rate of the audio.
length (int): Length of the audio in seconds. Defaults to 3.
"""
Expand Down Expand Up @@ -404,20 +357,41 @@ def get_conditioning_latents(
max_ref_length=10,
librosa_trim_db=None,
sound_norm_refs=False,
load_sr=24000,
):
# deal with multiples references
if not isinstance(audio_path, list):
audio_paths = [audio_path]
else:
audio_paths = audio_path

speaker_embeddings = []
audios = []
speaker_embedding = None
for file_path in audio_paths:
# load the audio in 24khz to avoid issued with multiple sr references
audio = load_audio(file_path, load_sr)
audio = audio[:, : load_sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

speaker_embedding = self.get_speaker_embedding(audio, load_sr)
speaker_embeddings.append(speaker_embedding)

audios.append(audio)

# use a merge of all references for gpt cond latents
full_audio = torch.cat(audios, dim=-1)
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T]

if speaker_embeddings:
speaker_embedding = torch.stack(speaker_embeddings)
speaker_embedding = speaker_embedding.mean(dim=0)

audio, sr = torchaudio.load(audio_path)
audio = audio[:, : sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

speaker_embedding = self.get_speaker_embedding(audio, sr)
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
return gpt_cond_latents, speaker_embedding

def synthesize(self, text, config, speaker_wav, language, **kwargs):
Expand All @@ -426,7 +400,7 @@ def synthesize(self, text, config, speaker_wav, language, **kwargs):
Args:
text (str): Input text.
config (XttsConfig): Config with inference parameters.
speaker_wav (str): Path to the speaker audio file for cloning.
speaker_wav (list): List of paths to the speaker audio files to be used for cloning.
language (str): Language ID of the speaker.
**kwargs: Inference settings. See `inference()`.
Expand All @@ -436,11 +410,6 @@ def synthesize(self, text, config, speaker_wav, language, **kwargs):
as latents used at inference.
"""

# Make the synthesizer happy 🥳
if isinstance(speaker_wav, list):
speaker_wav = speaker_wav[0]

return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)

def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
Expand Down Expand Up @@ -522,27 +491,6 @@ def full_inference(
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
Generally a value above 250 is not noticeably better, however. Defaults to 100.
cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion
performs two forward passes for each diffusion step: one with the outputs of the autoregressive model
and one with no conditioning priors. The output of the two is blended according to the cond_free_k
value below. Conditioning-free diffusion is the real deal, and dramatically improves realism.
Defaults to True.
cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the
conditioning-present signal. [0,inf]. As cond_free_k increases, the output becomes dominated by the
conditioning-free signal. Defaults to 2.0.
diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1].
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to 1.0.
decoder: (str) Selects the decoder to use between ("hifigan", "diffusion")
Defaults to hifigan
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
here: https://huggingface.co/docs/transformers/internal/generation_utils
Expand All @@ -569,12 +517,6 @@ def full_inference(
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
decoder_iterations=decoder_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
diffusion_temperature=diffusion_temperature,
decoder_sampler=decoder_sampler,
decoder=decoder,
**hf_generate_kwargs,
)

Expand All @@ -592,13 +534,6 @@ def inference(
top_k=50,
top_p=0.85,
do_sample=True,
# Decoder inference
decoder_iterations=100,
cond_free=True,
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
decoder="hifigan",
num_beams=1,
**hf_generate_kwargs,
):
Expand Down Expand Up @@ -693,8 +628,6 @@ def inference_stream(
top_k=50,
top_p=0.85,
do_sample=True,
# Decoder inference
decoder="hifigan",
**hf_generate_kwargs,
):
text = text.strip().lower()
Expand Down
Loading

0 comments on commit 5e992d8

Please sign in to comment.