Skip to content

Commit

Permalink
Merge pull request #3103 from coqui-ai/fix_xttsv1.1_again
Browse files Browse the repository at this point in the history
Second round of issue fixing for XTTS v1.1
  • Loading branch information
erogol authored Oct 28, 2023
2 parents edd3a28 + 1c98821 commit 788959d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 61 deletions.
4 changes: 1 addition & 3 deletions TTS/tts/layers/xtts/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,20 +483,18 @@ def preprocess_text(self, txt, lang):
if lang == "zh-cn":
txt = chinese_transliterate(txt)
elif lang == "ja":
assert txt[:4] == "[ja]", "Japanese speech should start with the [ja] token."
txt = txt[4:]
if self.katsu is None:
import cutlet
self.katsu = cutlet.Cutlet()
txt = japanese_cleaners(txt, self.katsu)
txt = "[ja]" + txt
else:
raise NotImplementedError()
return txt

def encode(self, txt, lang):
if self.preprocess:
txt = self.preprocess_text(txt, lang)
txt = f"[{lang}]{txt}"
txt = txt.replace(" ", "[SPACE]")
return self.tokenizer.encode(txt).ids

Expand Down
97 changes: 39 additions & 58 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn.functional as F
import torchaudio
import librosa
from coqpit import Coqpit

from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
Expand All @@ -21,34 +22,6 @@
init_stream_support()


def load_audio(audiopath, sr=22050):
"""
Load an audio file from disk and resample it to the specified sampling rate.
Args:
audiopath (str): Path to the audio file.
sr (int): Target sampling rate.
Returns:
Tensor: Audio waveform tensor with shape (1, T), where T is the number of samples.
"""
audio, sampling_rate = torchaudio.load(audiopath)

if len(audio.shape) > 1:
if audio.shape[0] < 5:
audio = audio[0]
else:
assert audio.shape[1] < 5
audio = audio[:, 0]

if sampling_rate != sr:
resampler = torchaudio.transforms.Resample(sampling_rate, sr)
audio = resampler(audio)

audio = audio.clamp_(-1, 1)
return audio.unsqueeze(0)


def wav_to_mel_cloning(
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
):
Expand Down Expand Up @@ -376,32 +349,29 @@ def device(self):
return next(self.parameters()).device

@torch.inference_mode()
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
def get_gpt_cond_latents(self, audio, sr, length: int = 3):
"""Compute the conditioning latents for the GPT model from the given audio.
Args:
audio_path (str): Path to the audio file.
length (int): Length of the audio in seconds. Defaults to 3.
"""

audio = load_audio(audio_path)
audio = audio[:, : 22050 * length]
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
audio_22k = audio_22k[:, : 22050 * length]
mel = wav_to_mel_cloning(audio_22k, mel_norms=self.mel_stats.cpu())
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)

@torch.inference_mode()
def get_diffusion_cond_latents(
self,
audio_path,
):
def get_diffusion_cond_latents(self, audio, sr):
from math import ceil

diffusion_conds = []
CHUNK_SIZE = 102400
audio = load_audio(audio_path, 24000)
for chunk in range(ceil(audio.shape[1] / CHUNK_SIZE)):
current_sample = audio[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
audio_24k = torchaudio.functional.resample(audio, sr, 24000)
for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)):
current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
cond_mel = wav_to_univnet_mel(
current_sample.to(self.device),
Expand All @@ -414,27 +384,38 @@ def get_diffusion_cond_latents(
return diffusion_latent

@torch.inference_mode()
def get_speaker_embedding(self, audio_path):
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
speaker_embedding = (
self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True)
.unsqueeze(-1)
.to(self.device)
)
return speaker_embedding

def get_speaker_embedding(self, audio, sr):
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
return self.hifigan_decoder.speaker_encoder.forward(
audio_16k.to(self.device), l2_norm=True
).unsqueeze(-1).to(self.device)

@torch.inference_mode()
def get_conditioning_latents(
self,
audio_path,
gpt_cond_len=3,
):
gpt_cond_len=6,
max_ref_length=10,
librosa_trim_db=None,
sound_norm_refs=False,
):
speaker_embedding = None
diffusion_cond_latents = None
if self.args.use_hifigan:
speaker_embedding = self.get_speaker_embedding(audio_path)

audio, sr = torchaudio.load(audio_path)
audio = audio[:, : sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

if self.args.use_hifigan or self.args.use_ne_hifigan:
speaker_embedding = self.get_speaker_embedding(audio, sr)
else:
diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding

def synthesize(self, text, config, speaker_wav, language, **kwargs):
Expand Down Expand Up @@ -494,7 +475,7 @@ def full_inference(
repetition_penalty=2.0,
top_k=50,
top_p=0.85,
gpt_cond_len=4,
gpt_cond_len=6,
do_sample=True,
# Decoder inference
decoder_iterations=100,
Expand Down Expand Up @@ -531,7 +512,7 @@ def full_inference(
(aka boring) outputs. Defaults to 0.8.
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
else the first `gpt_cond_len` secs is used. Defaults to 3 seconds.
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
Expand Down Expand Up @@ -610,7 +591,7 @@ def inference(
decoder="hifigan",
**hf_generate_kwargs,
):
text = f"[{language}]{text.strip().lower()}"
text = text.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)

assert (
Expand Down Expand Up @@ -722,7 +703,7 @@ def inference_stream(
assert hasattr(
self, "hifigan_decoder"
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
text = f"[{language}]{text.strip().lower()}"
text = text.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)

fake_inputs = self.gpt.compute_embeddings(
Expand Down

0 comments on commit 788959d

Please sign in to comment.