Skip to content

Commit

Permalink
en_bert
Browse files Browse the repository at this point in the history
  • Loading branch information
Artrajz committed Oct 28, 2023
1 parent a15b20b commit a09cc56
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 21 deletions.
16 changes: 9 additions & 7 deletions bert_vits2/bert_vits2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,25 @@ def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs):
self.lang = ["zh", "ja", "en"]

self.bert_model_names = {"zh": "CHINESE_ROBERTA_WWM_EXT_LARGE"}
self.ja_bert_embedding_dim = 1024
self.ja_bert_dim = 1024

if self.version in ["1.0", "1.0.0", "1.0.1"]:
self.symbols = symbols_legacy
self.hps_ms.model.n_layers_trans_flow = 3
self.lang = ["zh"]
self.ja_bert_embedding_dim = 768
self.ja_bert_dim = 768

elif self.version in ["1.1.0-transition"]:
self.hps_ms.model.n_layers_trans_flow = 3
self.lang = ["zh", "ja"]
self.bert_model_names["ja"] = "BERT_BASE_JAPANESE_V3"
self.ja_bert_embedding_dim = 768
self.ja_bert_dim = 768

elif self.version in ["1.1", "1.1.0", "1.1.1"]:
self.hps_ms.model.n_layers_trans_flow = 6
self.lang = ["zh", "ja"]
self.bert_model_names["ja"] = "BERT_BASE_JAPANESE_V3"
self.ja_bert_embedding_dim = 768
self.ja_bert_dim = 768

elif self.version in ["2.0", "2.0.0"]:
self.bert_model_names = {"zh": "CHINESE_ROBERTA_WWM_EXT_LARGE",
Expand All @@ -64,6 +64,7 @@ def load_model(self, bert_handler):
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
n_speakers=self.hps_ms.data.n_speakers,
symbols=self.symbols,
ja_bert_dim=self.ja_bert_dim,
**self.hps_ms.model).to(self.device)
_ = self.net_g.eval()
bert_vits2_utils.load_checkpoint(self.model_path, self.net_g, None, skip_optimizer=True, version=self.version)
Expand Down Expand Up @@ -91,7 +92,7 @@ def get_text(self, text, language_str, hps):

if language_str == "zh":
zh_bert = bert
ja_bert = torch.zeros(self.ja_bert_embedding_dim, len(phone))
ja_bert = torch.zeros(self.ja_bert_dim, len(phone))
en_bert = torch.zeros(1024, len(phone))
elif language_str == "ja":
zh_bert = torch.zeros(1024, len(phone))
Expand All @@ -103,7 +104,7 @@ def get_text(self, text, language_str, hps):
en_bert = bert
else:
zh_bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(self.ja_bert_embedding_dim, len(phone))
ja_bert = torch.zeros(self.ja_bert_dim, len(phone))
en_bert = torch.zeros(1024, len(phone))
assert bert.shape[-1] == len(
phone
Expand All @@ -122,9 +123,10 @@ def infer(self, text, id, lang, sdp_ratio, noise, noisew, length, **kwargs):
lang_ids = lang_ids.to(self.device).unsqueeze(0)
zh_bert = zh_bert.to(self.device).unsqueeze(0)
ja_bert = ja_bert.to(self.device).unsqueeze(0)
en_bert = en_bert.to(self.device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(self.device)
speakers = torch.LongTensor([int(id)]).to(self.device)
audio = self.net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, zh_bert, ja_bert, sdp_ratio=sdp_ratio
audio = self.net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, zh_bert, ja_bert,en_bert, sdp_ratio=sdp_ratio
, noise_scale=noise, noise_scale_w=noisew, length_scale=length)[
0][0, 0].data.cpu().float().numpy()

Expand Down
22 changes: 14 additions & 8 deletions bert_vits2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def __init__(self,
kernel_size,
p_dropout,
gin_channels=0,
symbols=None):
symbols=None,
ja_bert_dim=1024):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
Expand All @@ -275,7 +276,8 @@ def __init__(self,
self.language_emb = nn.Embedding(num_languages, hidden_channels)
nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels ** -0.5)
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1)
self.ja_bert_proj = nn.Conv1d(ja_bert_dim, hidden_channels, 1)
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)

self.encoder = attentions.Encoder(
hidden_channels,
Expand All @@ -287,10 +289,12 @@ def __init__(self,
gin_channels=self.gin_channels)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)

def forward(self, x, x_lengths, tone, language, bert, ja_bert, g=None):
bert_emb = self.bert_proj(bert).transpose(1, 2)
def forward(self, x, x_lengths, tone, language, zh_bert, ja_bert, en_bert, g=None):
zh_bert_emb = self.bert_proj(zh_bert).transpose(1, 2)
ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
x = (self.emb(x) + self.tone_emb(tone) + self.language_emb(language) + bert_emb + ja_bert_emb) * math.sqrt(
en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
x = (self.emb(x) + self.tone_emb(tone) + self.language_emb(
language) + zh_bert_emb + ja_bert_emb + en_bert_emb) * math.sqrt(
self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
Expand Down Expand Up @@ -595,6 +599,8 @@ def __init__(self,
n_layers_trans_flow=6,
flow_share_parameter=False,
use_transformer_flow=True,
symbols=None,
ja_bert_dim=1024,
**kwargs):

super().__init__()
Expand Down Expand Up @@ -625,7 +631,6 @@ def __init__(self,
self.current_mas_noise_scale = self.mas_noise_scale_initial
if self.use_spk_conditioned_encoder and gin_channels > 0:
self.enc_gin_channels = gin_channels
symbols = kwargs.get("symbols")
self.enc_p = TextEncoder(n_vocab,
inter_channels,
hidden_channels,
Expand All @@ -636,6 +641,7 @@ def __init__(self,
p_dropout,
gin_channels=self.enc_gin_channels,
symbols=symbols,
ja_bert_dim=ja_bert_dim
)
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
Expand All @@ -656,7 +662,7 @@ def __init__(self,
else:
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)

def infer(self, x, x_lengths, sid, tone, language, bert, ja_bert, noise_scale=.667, length_scale=1,
def infer(self, x, x_lengths, sid, tone, language, zh_bert, ja_bert, en_bert, noise_scale=.667, length_scale=1,
noise_scale_w=0.8,
max_len=None, sdp_ratio=0, y=None):
# x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
Expand All @@ -665,7 +671,7 @@ def infer(self, x, x_lengths, sid, tone, language, bert, ja_bert, noise_scale=.6
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
else:
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert, ja_bert, g=g)
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, zh_bert, ja_bert, en_bert, g=g)
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (sdp_ratio) + self.dp(x, x_mask,
g=g) * (
1 - sdp_ratio)
Expand Down
10 changes: 4 additions & 6 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,14 @@
# VITS
# [ABS_PATH + "/Model/Nene_Nanami_Rong_Tang/1374_epochs.pth", ABS_PATH + "/Model/Nene_Nanami_Rong_Tang/config.json"],
# [ABS_PATH + "/Model/Zero_no_tsukaima/1158_epochs.pth", ABS_PATH + "/Model/Zero_no_tsukaima/config.json"],
[ABS_PATH + "/Model/g/G_953000.pth", ABS_PATH + "/Model/g/config.json"],
[ABS_PATH + "/Model/vits_chinese/vits_bert_model.pth", ABS_PATH + "/Model/vits_chinese/bert_vits.json"],
# [ABS_PATH + "/Model/g/G_953000.pth", ABS_PATH + "/Model/g/config.json"],
# [ABS_PATH + "/Model/vits_chinese/vits_bert_model.pth", ABS_PATH + "/Model/vits_chinese/bert_vits.json"],
# HuBert-VITS (Need to configure HUBERT_SOFT_MODEL)
[ABS_PATH + "/Model/louise/360_epochs.pth", ABS_PATH + "/Model/louise/config.json"],
# [ABS_PATH + "/Model/louise/360_epochs.pth", ABS_PATH + "/Model/louise/config.json"],
# W2V2-VITS (Need to configure DIMENSIONAL_EMOTION_NPY)
[ABS_PATH + "/Model/w2v2-vits/1026_epochs.pth", ABS_PATH + "/Model/w2v2-vits/config.json"],
# [ABS_PATH + "/Model/w2v2-vits/1026_epochs.pth", ABS_PATH + "/Model/w2v2-vits/config.json"],
# Bert-VITS2
# [ABS_PATH + "/Model/bert_vits2/G_9000.pth", ABS_PATH + "/Model/bert_vits2/config.json"],
[r"H:\git\vits-simple-api\Model\tri\G_latest_104_inference.pth",r"H:\git\vits-simple-api\Model\tri\config.json"],
[r"H:\git\vits-simple-api\Model\taffy\G_15800_taffy_inference.pth",r"H:\git\vits-simple-api\Model\taffy\config_taffy.json"],
]

# hubert-vits: hubert soft model
Expand Down

0 comments on commit a09cc56

Please sign in to comment.