diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index c167f7ca44..b7c6393baa 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -42,6 +42,5 @@ jobs: run: | python3 -m pip install .[all] python3 setup.py egg_info - # - name: Lint check - # run: | - # make lint \ No newline at end of file + - name: Style check + run: make style diff --git a/TTS/api.py b/TTS/api.py index 39f53f577c..5d1fbb5a1c 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -264,7 +264,7 @@ def tts_coqui_studio( language: str = None, emotion: str = None, speed: float = 1.0, - pipe_out = None, + pipe_out=None, file_path: str = None, ) -> Union[np.ndarray, str]: """Convert text to speech using Coqui Studio models. Use `CS_API` class if you are only interested in the API. @@ -359,7 +359,7 @@ def tts_to_file( speaker_wav: str = None, emotion: str = None, speed: float = 1.0, - pipe_out = None, + pipe_out=None, file_path: str = "output.wav", **kwargs, ): @@ -460,7 +460,7 @@ def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None): """ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: # Lazy code... save it to a temp file to resample it while reading it for VC - self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name,speaker_wav=speaker_wav) + self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name, speaker_wav=speaker_wav) if self.voice_converter is None: self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24") wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 78a20c2566..ef41c8e13f 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -427,7 +427,9 @@ def main(): tts_path = model_path tts_config_path = config_path if "default_vocoder" in model_item: - args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name + args.vocoder_name = ( + model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name + ) # voice conversion model if model_item["model_type"] == "voice_conversion_models": diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index c25d42963a..456f8081be 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -1,12 +1,12 @@ +import json import os import re -import json +import pypinyin import torch +from num2words import num2words from tokenizers import Tokenizer -import pypinyin -from num2words import num2words from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words _whitespace_re = re.compile(r"\s+") @@ -87,7 +87,7 @@ "it": [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ - #("sig.ra", "signora"), + # ("sig.ra", "signora"), ("sig", "signore"), ("dr", "dottore"), ("st", "santo"), @@ -121,49 +121,51 @@ "cs": [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ - ("dr", "doktor"), # doctor - ("ing", "inženýr"), # engineer - ("p", "pan"), # Could also map to pani for woman but no easy way to do it + ("dr", "doktor"), # doctor + ("ing", "inženýr"), # engineer + ("p", "pan"), # Could also map to pani for woman but no easy way to do it # Other abbreviations would be specialized and not as common. ] ], "ru": [ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1]) for x in [ - ("г-жа", "госпожа"), # Mrs. - ("г-н", "господин"), # Mr. - ("д-р", "доктор"), # doctor + ("г-жа", "госпожа"), # Mrs. + ("г-н", "господин"), # Mr. + ("д-р", "доктор"), # doctor # Other abbreviations are less common or specialized. ] ], "nl": [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ - ("dhr", "de heer"), # Mr. + ("dhr", "de heer"), # Mr. ("mevr", "mevrouw"), # Mrs. - ("dr", "dokter"), # doctor - ("jhr", "jonkheer"), # young lord or nobleman + ("dr", "dokter"), # doctor + ("jhr", "jonkheer"), # young lord or nobleman # Dutch uses more abbreviations, but these are the most common ones. ] ], "tr": [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ - ("b", "bay"), # Mr. + ("b", "bay"), # Mr. ("byk", "büyük"), # büyük - ("dr", "doktor"), # doctor + ("dr", "doktor"), # doctor # Add other Turkish abbreviations here if needed. ] ], } -def expand_abbreviations_multilingual(text, lang='en'): + +def expand_abbreviations_multilingual(text, lang="en"): for regex, replacement in _abbreviations[lang]: text = re.sub(regex, replacement, text) return text + _symbols_multilingual = { - 'en': [ + "en": [ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) for x in [ ("&", " and "), @@ -172,10 +174,10 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " hash "), ("$", " dollar "), ("£", " pound "), - ("°", " degree ") + ("°", " degree "), ] ], - 'es': [ + "es": [ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) for x in [ ("&", " y "), @@ -184,10 +186,10 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " numeral "), ("$", " dolar "), ("£", " libra "), - ("°", " grados ") + ("°", " grados "), ] ], - 'fr': [ + "fr": [ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) for x in [ ("&", " et "), @@ -196,10 +198,10 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " dièse "), ("$", " dollar "), ("£", " livre "), - ("°", " degrés ") + ("°", " degrés "), ] ], - 'de': [ + "de": [ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) for x in [ ("&", " und "), @@ -208,10 +210,10 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " raute "), ("$", " dollar "), ("£", " pfund "), - ("°", " grad ") + ("°", " grad "), ] ], - 'pt': [ + "pt": [ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) for x in [ ("&", " e "), @@ -220,10 +222,10 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " cardinal "), ("$", " dólar "), ("£", " libra "), - ("°", " graus ") + ("°", " graus "), ] ], - 'it': [ + "it": [ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) for x in [ ("&", " e "), @@ -232,10 +234,10 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " cancelletto "), ("$", " dollaro "), ("£", " sterlina "), - ("°", " gradi ") + ("°", " gradi "), ] ], - 'pl': [ + "pl": [ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) for x in [ ("&", " i "), @@ -244,7 +246,7 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " krzyżyk "), ("$", " dolar "), ("£", " funt "), - ("°", " stopnie ") + ("°", " stopnie "), ] ], "ar": [ @@ -257,7 +259,7 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " رقم "), ("$", " دولار "), ("£", " جنيه "), - ("°", " درجة ") + ("°", " درجة "), ] ], "zh-cn": [ @@ -270,7 +272,7 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " 号 "), ("$", " 美元 "), ("£", " 英镑 "), - ("°", " 度 ") + ("°", " 度 "), ] ], "cs": [ @@ -283,7 +285,7 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " křížek "), ("$", " dolar "), ("£", " libra "), - ("°", " stupně ") + ("°", " stupně "), ] ], "ru": [ @@ -296,7 +298,7 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " номер "), ("$", " доллар "), ("£", " фунт "), - ("°", " градус ") + ("°", " градус "), ] ], "nl": [ @@ -309,7 +311,7 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " hekje "), ("$", " dollar "), ("£", " pond "), - ("°", " graden ") + ("°", " graden "), ] ], "tr": [ @@ -321,15 +323,16 @@ def expand_abbreviations_multilingual(text, lang='en'): ("#", " diyez "), ("$", " dolar "), ("£", " sterlin "), - ("°", " derece ") + ("°", " derece "), ] ], } -def expand_symbols_multilingual(text, lang='en'): + +def expand_symbols_multilingual(text, lang="en"): for regex, replacement in _symbols_multilingual[lang]: text = re.sub(regex, replacement, text) - text = text.replace(' ', ' ') # Ensure there are no double spaces + text = text.replace(" ", " ") # Ensure there are no double spaces return text.strip() @@ -342,41 +345,45 @@ def expand_symbols_multilingual(text, lang='en'): "it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"), "pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"), "ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"), - "cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals. + "cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals. "ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"), "nl": re.compile(r"([0-9]+)(de|ste|e)"), "tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"), } _number_re = re.compile(r"[0-9]+") _currency_re = { - 'USD': re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"), - 'GBP': re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"), - 'EUR': re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))") + "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"), + "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"), + "EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"), } _comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b") _dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b") _decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)") + def _remove_commas(m): text = m.group(0) if "," in text: text = text.replace(",", "") return text + def _remove_dots(m): text = m.group(0) if "." in text: text = text.replace(".", "") return text -def _expand_decimal_point(m, lang='en'): + +def _expand_decimal_point(m, lang="en"): amount = m.group(1).replace(",", ".") return num2words(float(amount), lang=lang if lang != "cs" else "cz") -def _expand_currency(m, lang='en', currency='USD'): - amount = float((re.sub(r'[^\d.]', '', m.group(0).replace(",", ".")))) - full_amount = num2words(amount, to='currency', currency=currency, lang=lang if lang != "cs" else "cz") + +def _expand_currency(m, lang="en", currency="USD"): + amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", ".")))) + full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz") and_equivalents = { "en": ", ", @@ -400,13 +407,16 @@ def _expand_currency(m, lang='en', currency='USD'): return full_amount -def _expand_ordinal(m, lang='en'): + +def _expand_ordinal(m, lang="en"): return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz") -def _expand_number(m, lang='en'): + +def _expand_number(m, lang="en"): return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz") -def expand_numbers_multilingual(text, lang='en'): + +def expand_numbers_multilingual(text, lang="en"): if lang == "zh-cn": text = zh_num2words()(text) else: @@ -415,9 +425,9 @@ def expand_numbers_multilingual(text, lang='en'): else: text = re.sub(_dot_number_re, _remove_dots, text) try: - text = re.sub(_currency_re['GBP'], lambda m: _expand_currency(m, lang, 'GBP'), text) - text = re.sub(_currency_re['USD'], lambda m: _expand_currency(m, lang, 'USD'), text) - text = re.sub(_currency_re['EUR'], lambda m: _expand_currency(m, lang, 'EUR'), text) + text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text) + text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text) + text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text) except: pass if lang != "tr": @@ -426,15 +436,18 @@ def expand_numbers_multilingual(text, lang='en'): text = re.sub(_number_re, lambda m: _expand_number(m, lang), text) return text + def lowercase(text): return text.lower() + def collapse_whitespace(text): return re.sub(_whitespace_re, " ", text) + def multilingual_cleaners(text, lang): - text = text.replace('"', '') - if lang=="tr": + text = text.replace('"', "") + if lang == "tr": text = text.replace("İ", "i") text = text.replace("Ö", "ö") text = text.replace("Ü", "ü") @@ -445,20 +458,26 @@ def multilingual_cleaners(text, lang): text = collapse_whitespace(text) return text + def basic_cleaners(text): """Basic pipeline that lowercases and collapses whitespace without transliteration.""" text = lowercase(text) text = collapse_whitespace(text) return text + def chinese_transliterate(text): - return "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]) + return "".join( + p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True) + ) + def japanese_cleaners(text, katsu): text = katsu.romaji(text) text = lowercase(text) return text + class VoiceBpeTokenizer: def __init__(self, vocab_file=None, preprocess=None): self.tokenizer = None @@ -485,6 +504,7 @@ def preprocess_text(self, txt, lang): elif lang == "ja": if self.katsu is None: import cutlet + self.katsu = cutlet.Cutlet() txt = japanese_cleaners(txt, self.katsu) else: diff --git a/TTS/tts/layers/xtts/zh_num2words.py b/TTS/tts/layers/xtts/zh_num2words.py index d51174746e..2c56e3bbeb 100644 --- a/TTS/tts/layers/xtts/zh_num2words.py +++ b/TTS/tts/layers/xtts/zh_num2words.py @@ -2,9 +2,14 @@ # 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) # 2019.9 - 2022 Jiayu DU -import sys, os, argparse -import string, re +import argparse import csv +import os +import re +import string +import sys + +# fmt: off # ================================================================================ # # basic constant diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 60af2d1e8e..c0532b36b1 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -2,10 +2,10 @@ from contextlib import contextmanager from dataclasses import dataclass +import librosa import torch import torch.nn.functional as F import torchaudio -import librosa from coqpit import Coqpit from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel @@ -386,10 +386,12 @@ def get_diffusion_cond_latents(self, audio, sr): @torch.inference_mode() def get_speaker_embedding(self, audio, sr): audio_16k = torchaudio.functional.resample(audio, sr, 16000) - return self.hifigan_decoder.speaker_encoder.forward( - audio_16k.to(self.device), l2_norm=True - ).unsqueeze(-1).to(self.device) - + return ( + self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True) + .unsqueeze(-1) + .to(self.device) + ) + @torch.inference_mode() def get_conditioning_latents( self, @@ -398,7 +400,7 @@ def get_conditioning_latents( max_ref_length=10, librosa_trim_db=None, sound_norm_refs=False, - ): + ): speaker_embedding = None diffusion_cond_latents = None @@ -647,13 +649,19 @@ def inference( break if decoder == "hifigan": - assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" + assert hasattr( + self, "hifigan_decoder" + ), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) elif decoder == "ne_hifigan": - assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" + assert hasattr( + self, "ne_hifigan_decoder" + ), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding) else: - assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`" + assert hasattr( + self, "diffusion_decoder" + ), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`" mel = do_spectrogram_diffusion( self.diffusion_decoder, diffuser, @@ -742,10 +750,14 @@ def inference_stream( if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): gpt_latents = torch.cat(all_latents, dim=0)[None, :] if decoder == "hifigan": - assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" + assert hasattr( + self, "hifigan_decoder" + ), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) elif decoder == "ne_hifigan": - assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" + assert hasattr( + self, "ne_hifigan_decoder" + ), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`" wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) else: raise NotImplementedError("Diffusion for streaming inference not implemented.") @@ -756,10 +768,14 @@ def inference_stream( yield wav_chunk def forward(self): - raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training") + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + ) def eval_step(self): - raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training") + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + ) @staticmethod def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument @@ -835,12 +851,18 @@ def load_checkpoint( self.load_state_dict(checkpoint, strict=strict) if eval: - if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval() - if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval() - if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval() - if hasattr(self, "vocoder"): self.vocoder.eval() + if hasattr(self, "hifigan_decoder"): + self.hifigan_decoder.eval() + if hasattr(self, "ne_hifigan_decoder"): + self.hifigan_decoder.eval() + if hasattr(self, "diffusion_decoder"): + self.diffusion_decoder.eval() + if hasattr(self, "vocoder"): + self.vocoder.eval() self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) self.gpt.eval() def train_step(self): - raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training") + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + ) diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index e2b71fb2fe..b701e76712 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -428,7 +428,7 @@ def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, return x -def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out = None, **kwargs) -> None: +def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out=None, **kwargs) -> None: """Save float waveform to a file using Scipy. Args: diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py index 248e15b888..4ceb7da4b3 100644 --- a/TTS/utils/audio/processor.py +++ b/TTS/utils/audio/processor.py @@ -694,7 +694,7 @@ def load_wav(self, filename: str, sr: int = None) -> np.ndarray: x = self.rms_volume_norm(x, self.db_level) return x - def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out = None) -> None: + def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None: """Save a waveform to a file using Scipy. Args: diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index a7370cd2c9..8efe608bac 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -235,7 +235,7 @@ def split_into_sentences(self, text) -> List[str]: """ return self.seg.segment(text) - def save_wav(self, wav: List[int], path: str, pipe_out = None) -> None: + def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None: """Save the waveform as a file. Args: diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index 94f3975c2f..9134be0db2 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -7,7 +7,6 @@ from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig from TTS.utils.manage import ModelManager - # Logging parameters RUN_NAME = "GPT_XTTS_LJSpeech_FT" PROJECT_NAME = "XTTS_trainer" @@ -60,13 +59,15 @@ XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth" # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. -TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file -XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file +TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file +XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file # download XTTS v1.1 files if needed if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): print(" > Downloading XTTS v1.1 files!") - ModelManager._download_model_files([TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) + ModelManager._download_model_files( + [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True + ) # Training sentences generations @@ -93,7 +94,7 @@ def main(): gpt_num_audio_tokens=8194, gpt_start_audio_token=8192, gpt_stop_audio_token=8193, - use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint + use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint ) # define audio config audio_config = XttsAudioConfig( diff --git a/tests/api_tests/test_synthesize_api.py b/tests/api_tests/test_synthesize_api.py index 084f81d489..e7b4f12048 100644 --- a/tests/api_tests/test_synthesize_api.py +++ b/tests/api_tests/test_synthesize_api.py @@ -22,7 +22,4 @@ def test_synthesize(): ) # test pipe_out command - run_cli( - 'tts --text "test." --pipe_out ' - f'--out_path "{output_path}" | aplay' - ) + run_cli(f'tts --text "test." --pipe_out --out_path "{output_path}" | aplay')