From 04901fb2e4b74953bb5733205adcb7cb13655a06 Mon Sep 17 00:00:00 2001 From: Julian Weber Date: Tue, 14 Nov 2023 16:07:17 +0100 Subject: [PATCH] Add speed control for inference (#3214) * Add speed control for inference * Fix XTTS tests * Add speed control tests --- TTS/tts/models/xtts.py | 17 +++++++++++++++++ tests/zoo_tests/test_models.py | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index b277c3ac72..9198591273 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -530,8 +530,10 @@ def inference( top_p=0.85, do_sample=True, num_beams=1, + speed=1.0, **hf_generate_kwargs, ): + length_scale = 1.0 / max(speed, 0.05) text = text.strip().lower() text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) @@ -584,6 +586,13 @@ def inference( gpt_latents = gpt_latents[:, :k] break + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), + scale_factor=length_scale, + mode="linear" + ).transpose(1, 2) + wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) return { @@ -634,8 +643,10 @@ def inference_stream( top_k=50, top_p=0.85, do_sample=True, + speed=1.0, **hf_generate_kwargs, ): + length_scale = 1.0 / max(speed, 0.05) text = text.strip().lower() text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) @@ -674,6 +685,12 @@ def inference_stream( if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): gpt_latents = torch.cat(all_latents, dim=0)[None, :] + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), + scale_factor=length_scale, + mode="linear" + ).transpose(1, 2) wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index d1c6b67c39..a5aad5c1ea 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -111,7 +111,7 @@ def test_xtts_streaming(): model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) print("Computing speaker latents...") - gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) + gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) print("Inference...") chunks = model.inference_stream( @@ -139,7 +139,7 @@ def test_xtts_v2(): "yes | " f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 " f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True ' - f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"' + f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"' ) else: run_cli( @@ -164,7 +164,7 @@ def test_xtts_v2_streaming(): model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) print("Computing speaker latents...") - gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) + gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) print("Inference...") chunks = model.inference_stream( @@ -179,6 +179,34 @@ def test_xtts_v2_streaming(): assert chunk.shape[-1] > 5000 wav_chuncks.append(chunk) assert len(wav_chuncks) > 1 + normal_len = sum([len(chunk) for chunk in wav_chuncks]) + + chunks = model.inference_stream( + "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", + "en", + gpt_cond_latent, + speaker_embedding, + speed=1.5 + ) + wav_chuncks = [] + for i, chunk in enumerate(chunks): + wav_chuncks.append(chunk) + fast_len = sum([len(chunk) for chunk in wav_chuncks]) + + chunks = model.inference_stream( + "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", + "en", + gpt_cond_latent, + speaker_embedding, + speed=0.66 + ) + wav_chuncks = [] + for i, chunk in enumerate(chunks): + wav_chuncks.append(chunk) + slow_len = sum([len(chunk) for chunk in wav_chuncks]) + + assert slow_len > normal_len + assert normal_len > fast_len def test_tortoise():