From 2150199f78dd284a1a55776fe71f167deb6dcea0 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Thu, 16 Nov 2023 17:25:02 +0100 Subject: [PATCH] Fix zh bug --- TTS/tts/layers/xtts/tokenizer.py | 14 ++++++++------ TTS/tts/models/xtts.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 56eb78aed4..1ef655a3cc 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -170,7 +170,7 @@ def split_sentence(text, lang, text_split_length=250): # There are not many common abbreviations in Arabic as in English. ] ], - "zh-cn": [ + "zh": [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts. @@ -335,7 +335,7 @@ def expand_abbreviations_multilingual(text, lang="en"): ("°", " درجة "), ] ], - "zh-cn": [ + "zh": [ # Chinese (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) for x in [ @@ -619,6 +619,7 @@ def katsu(self): return cutlet.Cutlet() def check_input_length(self, txt, lang): + lang = lang.split("-")[0] # remove the region limit = self.char_limits.get(lang, 250) if len(txt) > limit: print( @@ -626,7 +627,7 @@ def check_input_length(self, txt, lang): ) def preprocess_text(self, txt, lang): - if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "ko"}: + if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}: txt = multilingual_cleaners(txt, lang) if lang == "zh": txt = chinese_transliterate(txt) @@ -642,6 +643,7 @@ def encode(self, txt, lang): lang = lang.split("-")[0] # remove the region self.check_input_length(txt, lang) txt = self.preprocess_text(txt, lang) + lang = "zh-cn" if lang == "zh" else lang txt = f"[{lang}]{txt}" txt = txt.replace(" ", "[SPACE]") return self.tokenizer.encode(txt).ids @@ -738,8 +740,8 @@ def test_expand_numbers_multilingual(): ("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"), ("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"), # Chinese (Simplified) - ("在12.5秒内", "在十二点五秒内", "zh-cn"), - ("有50名士兵", "有五十名士兵", "zh-cn"), + ("在12.5秒内", "在十二点五秒内", "zh"), + ("有50名士兵", "有五十名士兵", "zh"), # ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work # ("那将是20€先生", '那将是二十欧元先生', 'zh'), # Turkish @@ -820,7 +822,7 @@ def test_symbols_multilingual(): ("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"), ("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"), ("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"), - ("我的电量为 14%", "我的电量为 14 百分之", "zh-cn"), + ("我的电量为 14%", "我的电量为 14 百分之", "zh"), ("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"), ("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"), ("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"), diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 5ccb26c314..3583591f8b 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -396,7 +396,7 @@ def inference_with_config(self, text, config, ref_audio_path, language, **kwargs inference with config """ assert ( - language in self.config.languages + "zh-cn" if language == "zh" else language in self.config.languages ), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}" # Use generally found best tuning knobs for generation. settings = {