From 54592213f4047f5780e289766ea5527f50e4753d Mon Sep 17 00:00:00 2001 From: TITC <35098797+TITC@users.noreply.github.com> Date: Mon, 20 Nov 2023 18:57:05 +0800 Subject: [PATCH 1/2] load multilingual model by path --- TTS/api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/TTS/api.py b/TTS/api.py index c8600dcd38..8d2731745d 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -72,7 +72,8 @@ def __init__( self.csapi = None self.cs_api_model = cs_api_model self.model_name = "" - + if model_path is not None and not model_name: + self.model_name = Path(model_path).name if gpu: warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") From 5fbb6425e40c1ade607d4e503cfbd89f13c4d713 Mon Sep 17 00:00:00 2001 From: TITC <35098797+TITC@users.noreply.github.com> Date: Thu, 23 Nov 2023 15:45:19 +0800 Subject: [PATCH 2/2] use config to assert multi lingual or not --- TTS/api.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 8d2731745d..4797cc27f7 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -10,7 +10,7 @@ from TTS.utils.audio.numpy_transforms import save_wav from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer - +from TTS.config import load_config class TTS(nn.Module): """TODO: Add voice conversion and Capacitron support.""" @@ -66,14 +66,12 @@ def __init__( """ super().__init__() self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False) - + self.config = load_config(config_path) if config_path else None self.synthesizer = None self.voice_converter = None self.csapi = None self.cs_api_model = cs_api_model self.model_name = "" - if model_path is not None and not model_name: - self.model_name = Path(model_path).name if gpu: warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") @@ -107,7 +105,8 @@ def is_coqui_studio(self): @property def is_multi_lingual(self): # Not sure what sets this to None, but applied a fix to prevent crashing. - if isinstance(self.model_name, str) and "xtts" in self.model_name: + if (isinstance(self.model_name, str) and "xtts" in self.model_name or + self.config and ("xtts" in self.config.model or len(self.config.languages) > 1)): return True if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager: return self.synthesizer.tts_model.language_manager.num_languages > 1