From ee99a6c1e270a78d6fb755f2478ede73399b4fd6 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 20 May 2022 08:16:01 -0300 Subject: [PATCH] Fix voice conversion inference (#1583) * Add voice conversion zoo test * Fix style * Fix unit test --- TTS/bin/synthesize.py | 2 +- TTS/tts/models/vits.py | 4 ++-- TTS/utils/synthesizer.py | 2 +- tests/tts_tests/test_vits.py | 2 +- tests/zoo_tests/test_models.py | 15 ++++++++++++++- 5 files changed, 19 insertions(+), 6 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 6247b2a467..dc6e30b404 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -171,7 +171,7 @@ def main(): help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The d_vectors is computed as their average.", default=None, ) - parser.add_argument("--gst_style", help="Wav path file for GST stylereference.", default=None) + parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None) parser.add_argument( "--list_speaker_idxs", help="List available speaker ids for the defined multi-speaker model.", diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 2c1c2bc67b..a6b1c74332 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1127,7 +1127,7 @@ def inference_voice_conversion( self.config.audio.hop_length, self.config.audio.win_length, center=False, - ).transpose(1, 2) + ) y_lengths = torch.tensor([y.size(-1)]).to(y.device) speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector @@ -1157,7 +1157,7 @@ def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt): else: raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.") - z, _, _, y_mask = self.posterior_encoder(y.transpose(1, 2), y_lengths, g=g_src) + z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src) z_p = self.flow(z, y_mask, g=g_src) z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 1f33b53e77..2c28861324 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -315,7 +315,7 @@ def tts( # get the speaker embedding or speaker id for the reference wav file reference_speaker_embedding = None reference_speaker_id = None - if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"): + if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"): if reference_speaker_name and isinstance(reference_speaker_name, str): if self.tts_config.use_d_vector_file: # get the speaker embedding from the saved d_vectors. diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 5694fe4dd8..b9cebb5a65 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -122,7 +122,7 @@ def test_voice_conversion(self): args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True) model = Vits(args) - ref_inp = torch.randn(1, spec_len, 513) + ref_inp = torch.randn(1, 513, spec_len) ref_inp_len = torch.randint(1, spec_effective_len, (1,)) ref_spk_id = torch.randint(1, num_speakers, (1,)) tgt_spk_id = torch.randint(1, num_speakers, (1,)) diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index e614ce7491..8c32895f33 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -3,7 +3,7 @@ import os import shutil -from tests import get_tests_output_path, run_cli +from tests import get_tests_data_path, get_tests_output_path, run_cli from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.generic_utils import get_user_data_dir @@ -56,3 +56,16 @@ def test_run_all_models(): folders = glob.glob(os.path.join(manager.output_prefix, "*")) assert len(folders) == len(model_names) shutil.rmtree(manager.output_prefix) + + +def test_voice_conversion(): + print(" > Run voice conversion inference using YourTTS model.") + model_name = "tts_models/multilingual/multi-dataset/your_tts" + language_id = "en" + speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") + reference_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0032.wav") + output_path = os.path.join(get_tests_output_path(), "output.wav") + run_cli( + f"tts --model_name {model_name}" + f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} " + )