Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix voice conversion inference #1583

Merged
merged 4 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def main():
help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The d_vectors is computed as their average.",
default=None,
)
parser.add_argument("--gst_style", help="Wav path file for GST stylereference.", default=None)
parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None)
parser.add_argument(
"--list_speaker_idxs",
help="List available speaker ids for the defined multi-speaker model.",
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ def inference_voice_conversion(
self.config.audio.hop_length,
self.config.audio.win_length,
center=False,
).transpose(1, 2)
)
y_lengths = torch.tensor([y.size(-1)]).to(y.device)
speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector
speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector
Expand Down Expand Up @@ -1157,7 +1157,7 @@ def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
else:
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")

z, _, _, y_mask = self.posterior_encoder(y.transpose(1, 2), y_lengths, g=g_src)
z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
Expand Down
2 changes: 1 addition & 1 deletion TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def tts(
# get the speaker embedding or speaker id for the reference wav file
reference_speaker_embedding = None
reference_speaker_id = None
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"):
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"):
if reference_speaker_name and isinstance(reference_speaker_name, str):
if self.tts_config.use_d_vector_file:
# get the speaker embedding from the saved d_vectors.
Expand Down
2 changes: 1 addition & 1 deletion tests/tts_tests/test_vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_voice_conversion(self):
args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
model = Vits(args)

ref_inp = torch.randn(1, spec_len, 513)
ref_inp = torch.randn(1, 513, spec_len)
ref_inp_len = torch.randint(1, spec_effective_len, (1,))
ref_spk_id = torch.randint(1, num_speakers, (1,))
tgt_spk_id = torch.randint(1, num_speakers, (1,))
Expand Down
15 changes: 14 additions & 1 deletion tests/zoo_tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil

from tests import get_tests_output_path, run_cli
from tests import get_tests_data_path, get_tests_output_path, run_cli
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.generic_utils import get_user_data_dir
Expand Down Expand Up @@ -56,3 +56,16 @@ def test_run_all_models():
folders = glob.glob(os.path.join(manager.output_prefix, "*"))
assert len(folders) == len(model_names)
shutil.rmtree(manager.output_prefix)


def test_voice_conversion():
print(" > Run voice conversion inference using YourTTS model.")
model_name = "tts_models/multilingual/multi-dataset/your_tts"
language_id = "en"
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
reference_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0032.wav")
output_path = os.path.join(get_tests_output_path(), "output.wav")
run_cli(
f"tts --model_name {model_name}"
f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} "
)