Skip to content

Commit

Permalink
Update Bert-VITS2
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Nov 6, 2023
1 parent 576a8c3 commit 940bfe6
Show file tree
Hide file tree
Showing 8 changed files with 833 additions and 376 deletions.
18 changes: 13 additions & 5 deletions bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs):

self.bert_model_names = {"zh": "CHINESE_ROBERTA_WWM_EXT_LARGE"}
self.ja_bert_dim = 1024
self.ja_extra_str = ""

if self.version in ["1.0", "1.0.0", "1.0.1"]:
self.symbols = symbols_legacy
Expand All @@ -38,14 +39,17 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs):
self.lang = ["zh", "ja"]
self.bert_model_names["ja"] = "BERT_BASE_JAPANESE_V3"
self.ja_bert_dim = 768
self.ja_extra_str = "_v111"

elif self.version in ["1.1", "1.1.0", "1.1.1"]:
self.hps_ms.model.n_layers_trans_flow = 6
self.lang = ["zh", "ja"]
self.bert_model_names["ja"] = "BERT_BASE_JAPANESE_V3"
self.ja_bert_dim = 768
self.ja_extra_str = "_v111"

elif self.version in ["2.0", "2.0.0"]:
self.hps_ms.model.n_layers_trans_flow = 4
self.bert_model_names = {"zh": "CHINESE_ROBERTA_WWM_EXT_LARGE",
"ja": "DEBERTA_V2_LARGE_JAPANESE",
"en": "DEBERTA_V3_LARGE"}
Expand Down Expand Up @@ -74,7 +78,10 @@ def get_speakers(self):

def get_text(self, text, language_str, hps):
tokenizer, _ = self.bert_handler.get_bert_model(self.bert_model_names[language_str])
norm_text, phone, tone, word2ph = clean_text(text, language_str, tokenizer)
clean_bert_lang_str = language_str
if language_str == 'ja':
clean_bert_lang_str += self.ja_extra_str
norm_text, phone, tone, word2ph = clean_text(text, clean_bert_lang_str, tokenizer)

phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, self._symbol_to_id)

Expand All @@ -86,7 +93,8 @@ def get_text(self, text, language_str, hps):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1

bert = self.bert_handler.get_bert_feature(norm_text, word2ph, language_str, self.bert_model_names[language_str])
bert = self.bert_handler.get_bert_feature(norm_text, word2ph, clean_bert_lang_str,
self.bert_model_names[language_str])
del word2ph
assert bert.shape[-1] == len(phone), phone

Expand All @@ -100,7 +108,7 @@ def get_text(self, text, language_str, hps):
en_bert = torch.zeros(1024, len(phone))
elif language_str == "en":
zh_bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(self.ja_bert_dim, len(phone))
en_bert = bert
else:
zh_bert = torch.zeros(1024, len(phone))
Expand All @@ -113,7 +121,6 @@ def get_text(self, text, language_str, hps):
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return zh_bert, ja_bert, en_bert, phone, tone, language


def infer(self, text, id, lang, sdp_ratio, noise, noisew, length, **kwargs):
zh_bert, ja_bert, en_bert, phones, tones, lang_ids = self.get_text(text, lang, self.hps_ms)
Expand All @@ -126,7 +133,8 @@ def infer(self, text, id, lang, sdp_ratio, noise, noisew, length, **kwargs):
en_bert = en_bert.to(self.device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(self.device)
speakers = torch.LongTensor([int(id)]).to(self.device)
audio = self.net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, zh_bert, ja_bert,en_bert, sdp_ratio=sdp_ratio
audio = self.net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, zh_bert, ja_bert, en_bert,
sdp_ratio=sdp_ratio
, noise_scale=noise, noise_scale_w=noisew, length_scale=length)[
0][0, 0].data.cpu().float().numpy()

Expand Down
3 changes: 2 additions & 1 deletion bert_vits2/text/bert_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .chinese_bert import get_bert_feature as zh_bert
from .english_bert_mock import get_bert_feature as en_bert
from .japanese_bert import get_bert_feature as ja_bert
from .japanese_bert_v111 import get_bert_feature as ja_bert_v111


class BertHandler:
Expand All @@ -22,7 +23,7 @@ def __init__(self, device):
"DEBERTA_V2_LARGE_JAPANESE": os.path.join(config.ABS_PATH, "bert_vits2/bert/deberta-v2-large-japanese"),
"DEBERTA_V3_LARGE": os.path.join(config.ABS_PATH, "bert_vits2/bert/deberta-v3-large")
}
self.lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert}
self.lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert, "ja_v111": ja_bert_v111}

self.bert_models = {} # Value: (tokenizer, model, reference_count)
self.device = device
Expand Down
5 changes: 3 additions & 2 deletions bert_vits2/text/cleaner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from bert_vits2.text import chinese, japanese, english, cleaned_text_to_sequence
from bert_vits2.text import chinese, japanese, english, cleaned_text_to_sequence, japanese_v111

language_module_map = {
'zh': chinese,
'ja': japanese,
'en': english
'en': english,
'ja_v111': japanese_v111
}


Expand Down
Loading

0 comments on commit 940bfe6

Please sign in to comment.