Skip to content

Commit

Permalink
Add support for multiples speaker references on XTTS inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Nov 6, 2023
1 parent 81416f0 commit 72397ef
Showing 1 changed file with 42 additions and 22 deletions.
64 changes: 42 additions & 22 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,24 +442,49 @@ def get_conditioning_latents(
librosa_trim_db=None,
sound_norm_refs=False,
):
speaker_embedding = None
diffusion_cond_latents = None

audio, sr = torchaudio.load(audio_path)
audio = audio[:, : sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

if self.args.use_hifigan or self.args.use_ne_hifigan:
speaker_embedding = self.get_speaker_embedding(audio, sr)
# deal with multiples references
if not isinstance(audio_path, list):
audio_paths = list(audio_path)
else:
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
audio_paths = audio_path

speaker_embeddings = []
diffusion_cond_latents = []
audios = []
speaker_embedding = None
diffusion_cond_latent = None
for file_path in audio_paths:
audio, sr = torchaudio.load(file_path)
audio = audio[:, : sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

if self.args.use_hifigan or self.args.use_ne_hifigan:
speaker_embedding = self.get_speaker_embedding(audio, sr)
speaker_embeddings.append(speaker_embedding)
else:
diffusion_cond_latent = self.get_diffusion_cond_latents(audio, sr)
diffusion_cond_latents.append(diffusion_cond_latent)

audios.append(audio)

# use a merge of all references for gpt cond latents
full_audio = torch.cat(audios, dim=-1)
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, sr, length=gpt_cond_len) # [1, 1024, T]

if diffusion_cond_latents:
diffusion_cond_latent = torch.stack(diffusion_cond_latents)
diffusion_cond_latent = diffusion_cond_latent.mean(dim=0)

if speaker_embeddings:
speaker_embedding = torch.stack(speaker_embeddings)
speaker_embedding = speaker_embedding.mean(dim=0)

return gpt_cond_latents, diffusion_cond_latent, speaker_embedding

def synthesize(self, text, config, speaker_wav, language, **kwargs):
"""Synthesize speech with the given input text.
Expand All @@ -477,11 +502,6 @@ def synthesize(self, text, config, speaker_wav, language, **kwargs):
as latents used at inference.
"""

# Make the synthesizer happy 🥳
if isinstance(speaker_wav, list):
speaker_wav = speaker_wav[0]

return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)

def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
Expand Down

0 comments on commit 72397ef

Please sign in to comment.