Skip to content

Commit

Permalink
Update: Support for limiting language range in Bert-VITS2
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Jul 24, 2024
1 parent eba46df commit d4de56c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 26 deletions.
15 changes: 11 additions & 4 deletions bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def load_model(self, model_handler):
def get_speakers(self):
return self.speakers

def get_text(self, text, language_str, hps, style_text=None, style_weight=0.7):
def get_text(self, text, language_str: str, hps, style_text=None, style_weight=0.7):
clean_text_lang_str = language_str + self.text_extra_str_map.get(language_str, "")
bert_feature_lang_str = language_str + self.bert_extra_str_map.get(language_str, "")

Expand Down Expand Up @@ -337,8 +337,9 @@ def _infer(self, id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_rat
torch.cuda.empty_cache()
return audio

def infer(self, text, id, lang, sdp_ratio, noise, noisew, length, reference_audio=None, emotion=None,
def infer(self, text, id, lang: list, sdp_ratio, noise, noisew, length, reference_audio=None, emotion=None,
text_prompt=None, style_text=None, style_weigth=0.7, **kwargs):
lang = lang[0]
zh_bert, ja_bert, en_bert, phones, tones, lang_ids = self.get_text(text, lang, self.hps_ms, style_text,
style_weigth)

Expand All @@ -351,9 +352,15 @@ def infer(self, text, id, lang, sdp_ratio, noise, noisew, length, reference_audi
return self._infer(id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_ratio, noise, noisew, length,
emo)

def infer_multilang(self, text, id, lang, sdp_ratio, noise, noisew, length, reference_audio=None, emotion=None,
def infer_multilang(self, text, id, lang: list, sdp_ratio, noise, noisew, length, reference_audio=None,
emotion=None,
text_prompt=None, style_text=None, style_weigth=0.7, **kwargs):
sentences_list = split_languages(text, self.lang, expand_abbreviations=True, expand_hyphens=True)
target_languages = lang
if len(lang) == 1 and lang[0] == "auto":
target_languages = self.lang

sentences_list = split_languages(text, target_languages=target_languages, expand_abbreviations=True,
expand_hyphens=True)

emo = None
if self.hps_ms.model.emotion_embedding == 1:
Expand Down
10 changes: 5 additions & 5 deletions manager/TTSManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,14 @@ def bert_vits2_infer(self, state, encode=True):

if model.zh_bert_extra:
infer_func = model.infer
state["lang"] = "zh"
state["lang"] = ["zh"]
elif model.ja_bert_extra:
infer_func = model.infer
state["lang"] = "ja"
elif state["lang"].lower() == "auto":
infer_func = model.infer_multilang
else:
state["lang"] = ["ja"]
elif len(state["lang"]) == 1 and "auto" not in state["lang"]:
infer_func = model.infer
else:
infer_func = model.infer_multilang

audios = []
for sentences in sentences_list:
Expand Down
50 changes: 33 additions & 17 deletions tts_app/voice_api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,32 @@ def update_default_params(state):
return state


def get_lang_list(lang, speaker_lang):
lang_list = re.split(r'[,,\s]+', lang)
new_lang_list = []

for idx in range(len(lang_list)):
if lang_list[idx].strip() == "":
continue
lang_list[idx] = lang_list[idx].lower()
lang = lang_list[idx]
if lang not in ["auto", "mix"] and len(speaker_lang) > 1 and lang not in speaker_lang:
logger.info(f"[{ModelType.BERT_VITS2.value}] lang \"{lang}\" is not in {speaker_lang}")
status = "error"
msg = f"lang '{lang}' is not in {speaker_lang}"
return new_lang_list, status, msg
new_lang_list.append(lang)

if "auto" in lang_list and len(lang_list) > 1:
status = "error"
msg = "Do not pass 'auto' along with other languages."
return new_lang_list, status, msg

status = ""
msg = ""
return new_lang_list, status, msg


@voice_api.route('/default_parameter', methods=["GET", "POST"])
def default_parameter():
gpt_sovits_config = copy.deepcopy(config.gpt_sovits_config.asdict())
Expand Down Expand Up @@ -496,10 +522,9 @@ def voice_bert_vits2_api():

# 校验模型是否支持输入的语言
speaker_lang = model_manager.voice_speakers[ModelType.BERT_VITS2.value][id].get('lang')
if lang not in ["auto", "mix"] and len(speaker_lang) > 1 and lang not in speaker_lang:
logger.info(f"[{ModelType.BERT_VITS2.value}] lang \"{lang}\" is not in {speaker_lang}")
return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}),
400)
lang_list, status, msg = get_lang_list(lang, speaker_lang)
if status == "error":
return make_response(jsonify({"status": status, "message": msg}), 400)

# 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
if (lang_detect := config.language_identification.language_automatic_detect) and isinstance(lang_detect, list):
Expand All @@ -522,7 +547,7 @@ def voice_bert_vits2_api():
"noisew": noisew,
"sdp_ratio": sdp_ratio,
"segment_size": segment_size,
"lang": lang,
"lang": lang_list,
"speaker_lang": speaker_lang,
"emotion": emotion,
"reference_audio": reference_audio,
Expand Down Expand Up @@ -606,20 +631,11 @@ def voice_gpt_sovits_api():
logger.info(f"[{ModelType.GPT_SOVITS.value}] speaker id {id} does not exist")
return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)

lang_list = re.split(r'[,,\s]+', lang)
# 校验模型是否支持输入的语言
speaker_lang = model_manager.voice_speakers[ModelType.GPT_SOVITS.value][id].get('lang')
for idx in range(len(lang_list)):
lang_list[idx] = lang_list[idx].lower()
lang = lang_list[idx]
if lang not in ["auto", "mix"] and len(speaker_lang) > 1 and lang not in speaker_lang:
logger.info(f"[{ModelType.GPT_SOVITS.value}] lang \"{lang}\" is not in {speaker_lang}")
return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}),
400)

if "auto" in lang_list and len(lang_list) > 1:
return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}),
400)
lang_list, status, msg = get_lang_list(lang, speaker_lang)
if status == "error":
return make_response(jsonify({"status": status, "message": msg}), 400)

# 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
if (lang_detect := config.language_identification.language_automatic_detect) and isinstance(lang_detect, list):
Expand Down

0 comments on commit d4de56c

Please sign in to comment.