Skip to content

Commit

Permalink
Fix XTTS GPT padding and inference issues (#3216)
Browse files Browse the repository at this point in the history
* Fix end artifact for fine tuning models

* Bug fix on zh-cn inference

* Remove ununsed code
  • Loading branch information
Edresson authored Nov 15, 2023
1 parent 15f0ac5 commit 73a5bd0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 47 deletions.
11 changes: 1 addition & 10 deletions TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,15 +426,6 @@ def forward(
if max_mel_len > audio_codes.shape[-1]:
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))

silence = True
for idx, l in enumerate(code_lengths):
length = l.item()
while silence:
if audio_codes[idx, length - 1] != 83:
break
length -= 1
code_lengths[idx] = length

# 💖 Lovely assertions
assert (
max_mel_len <= audio_codes.shape[-1]
Expand All @@ -450,7 +441,7 @@ def forward(
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)

# Pad mel codes with stop_audio_token
audio_codes = self.set_mel_padding(audio_codes, code_lengths)
audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet

# Build input and target tensors
# Prepend start token to inputs and append stop token to targets
Expand Down
12 changes: 6 additions & 6 deletions TTS/tts/layers/xtts/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
# There are not many common abbreviations in Arabic as in English.
]
],
"zh": [
"zh-cn": [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
Expand Down Expand Up @@ -280,7 +280,7 @@ def expand_abbreviations_multilingual(text, lang="en"):
("°", " درجة "),
]
],
"zh": [
"zh-cn": [
# Chinese
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
for x in [
Expand Down Expand Up @@ -571,7 +571,7 @@ def check_input_length(self, txt, lang):
)

def preprocess_text(self, txt, lang):
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}:
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
txt = multilingual_cleaners(txt, lang)
if lang in {"zh", "zh-cn"}:
txt = chinese_transliterate(txt)
Expand Down Expand Up @@ -682,8 +682,8 @@ def test_expand_numbers_multilingual():
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
# Chinese (Simplified)
("在12.5秒内", "在十二点五秒内", "zh"),
("有50名士兵", "有五十名士兵", "zh"),
("在12.5秒内", "在十二点五秒内", "zh-cn"),
("有50名士兵", "有五十名士兵", "zh-cn"),
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
# Turkish
Expand Down Expand Up @@ -764,7 +764,7 @@ def test_symbols_multilingual():
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
("我的电量为 14%", "我的电量为 14 百分之", "zh-cn"),
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
Expand Down
31 changes: 0 additions & 31 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torchaudio
from coqpit import Coqpit

from TTS.tts.layers.tortoise.audio_utils import wav_to_univnet_mel
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
Expand Down Expand Up @@ -308,26 +307,6 @@ def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int =
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)

@torch.inference_mode()
def get_diffusion_cond_latents(self, audio, sr):
from math import ceil

diffusion_conds = []
CHUNK_SIZE = 102400
audio_24k = torchaudio.functional.resample(audio, sr, 24000)
for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)):
current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
cond_mel = wav_to_univnet_mel(
current_sample.to(self.device),
do_normalization=False,
device=self.device,
)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
return diffusion_latent

@torch.inference_mode()
def get_speaker_embedding(self, audio, sr):
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
Expand Down Expand Up @@ -575,16 +554,6 @@ def inference(
return_attentions=False,
return_latent=True,
)
silence_token = 83
ctokens = 0
for k in range(gpt_codes.shape[-1]):
if gpt_codes[0, k] == silence_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8:
gpt_latents = gpt_latents[:, :k]
break

if length_scale != 1.0:
gpt_latents = F.interpolate(
Expand Down

0 comments on commit 73a5bd0

Please sign in to comment.