diff --git a/TTS/demos/xtts_ft_demo/requirements.txt b/TTS/demos/xtts_ft_demo/requirements.txt new file mode 100644 index 0000000000..cb5b16f66e --- /dev/null +++ b/TTS/demos/xtts_ft_demo/requirements.txt @@ -0,0 +1,2 @@ +faster_whisper==0.9.0 +gradio==4.7.1 \ No newline at end of file diff --git a/TTS/demos/xtts_ft_demo/utils/formatter.py b/TTS/demos/xtts_ft_demo/utils/formatter.py new file mode 100644 index 0000000000..536faa0108 --- /dev/null +++ b/TTS/demos/xtts_ft_demo/utils/formatter.py @@ -0,0 +1,160 @@ +import os +import gc +import torchaudio +import pandas +from faster_whisper import WhisperModel +from glob import glob + +from tqdm import tqdm + +import torch +import torchaudio +# torch.set_num_threads(1) + +from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners + +torch.set_num_threads(16) + + +import os + +audio_types = (".wav", ".mp3", ".flac") + + +def list_audios(basePath, contains=None): + # return the set of files that are valid + return list_files(basePath, validExts=audio_types, contains=contains) + +def list_files(basePath, validExts=None, contains=None): + # loop over the directory structure + for (rootDir, dirNames, filenames) in os.walk(basePath): + # loop over the filenames in the current directory + for filename in filenames: + # if the contains string is not none and the filename does not contain + # the supplied string, then ignore the file + if contains is not None and filename.find(contains) == -1: + continue + + # determine the file extension of the current file + ext = filename[filename.rfind("."):].lower() + + # check to see if the file is an audio and should be processed + if validExts is None or ext.endswith(validExts): + # construct the path to the audio and yield it + audioPath = os.path.join(rootDir, filename) + yield audioPath + +def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): + audio_total_size = 0 + # make sure that ooutput file exists + os.makedirs(out_path, exist_ok=True) + + # Loading Whisper + device = "cuda" if torch.cuda.is_available() else "cpu" + + print("Loading Whisper Model!") + asr_model = WhisperModel("large-v2", device=device, compute_type="float16") + + metadata = {"audio_file": [], "text": [], "speaker_name": []} + + if gradio_progress is not None: + tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...") + else: + tqdm_object = tqdm(audio_files) + + for audio_path in tqdm_object: + wav, sr = torchaudio.load(audio_path) + # stereo to mono if needed + if wav.size(0) != 1: + wav = torch.mean(wav, dim=0, keepdim=True) + + wav = wav.squeeze() + audio_total_size += (wav.size(-1) / sr) + + segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) + segments = list(segments) + i = 0 + sentence = "" + sentence_start = None + first_word = True + # added all segments words in a unique list + words_list = [] + for _, segment in enumerate(segments): + words = list(segment.words) + words_list.extend(words) + + # process each word + for word_idx, word in enumerate(words_list): + if first_word: + sentence_start = word.start + # If it is the first sentence, add buffer or get the begining of the file + if word_idx == 0: + sentence_start = max(sentence_start - buffer, 0) # Add buffer to the sentence start + else: + # get previous sentence end + previous_word_end = words_list[word_idx - 1].end + # add buffer or get the silence midle between the previous sentence and the current one + sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2) + + sentence = word.word + first_word = False + else: + sentence += word.word + + if word.word[-1] in ["!", ".", "?"]: + sentence = sentence[1:] + # Expand number and abbreviations plus normalization + sentence = multilingual_cleaners(sentence, target_language) + audio_file_name, _ = os.path.splitext(os.path.basename(audio_path)) + + audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav" + + # Check for the next word's existence + if word_idx + 1 < len(words_list): + next_word_start = words_list[word_idx + 1].start + else: + # If don't have more words it means that it is the last sentence then use the audio len as next word start + next_word_start = (wav.shape[0] - 1) / sr + + # Average the current word end and next word start + word_end = min((word.end + next_word_start) / 2, word.end + buffer) + + absoulte_path = os.path.join(out_path, audio_file) + os.makedirs(os.path.dirname(absoulte_path), exist_ok=True) + i += 1 + first_word = True + + audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0) + # if the audio is too short ignore it (i.e < 0.33 seconds) + if audio.size(-1) >= sr/3: + torchaudio.save(absoulte_path, + audio, + sr + ) + else: + continue + + metadata["audio_file"].append(audio_file) + metadata["text"].append(sentence) + metadata["speaker_name"].append(speaker_name) + + df = pandas.DataFrame(metadata) + df = df.sample(frac=1) + num_val_samples = int(len(df)*eval_percentage) + + df_eval = df[:num_val_samples] + df_train = df[num_val_samples:] + + df_train = df_train.sort_values('audio_file') + train_metadata_path = os.path.join(out_path, "metadata_train.csv") + df_train.to_csv(train_metadata_path, sep="|", index=False) + + eval_metadata_path = os.path.join(out_path, "metadata_eval.csv") + df_eval = df_eval.sort_values('audio_file') + df_eval.to_csv(eval_metadata_path, sep="|", index=False) + + # deallocate VRAM and RAM + del asr_model, df_train, df_eval, df, metadata + gc.collect() + + return train_metadata_path, eval_metadata_path, audio_total_size \ No newline at end of file diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py new file mode 100644 index 0000000000..a98765c3e7 --- /dev/null +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -0,0 +1,172 @@ +import os +import gc + +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.utils.manage import ModelManager + + +def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995): + # Logging parameters + RUN_NAME = "GPT_XTTS_FT" + PROJECT_NAME = "XTTS_trainer" + DASHBOARD_LOGGER = "tensorboard" + LOGGER_URI = None + + # Set here the path that the checkpoints will be saved. Default: ./run/training/ + OUT_PATH = os.path.join(output_path, "run", "training") + + # Training Parameters + OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False + START_WITH_EVAL = False # if True it will star with evaluation + BATCH_SIZE = batch_size # set here the batch size + GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps + + + # Define here the dataset that you want to use for the fine-tuning on. + config_dataset = BaseDatasetConfig( + formatter="coqui", + dataset_name="ft_dataset", + path=os.path.dirname(train_csv), + meta_file_train=train_csv, + meta_file_val=eval_csv, + language=language, + ) + + # Add here the configs of the datasets + DATASETS_CONFIG_LIST = [config_dataset] + + # Define the path where XTTS v2.0.1 files will be downloaded + CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") + os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) + + + # DVAE files + DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" + MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" + + # Set the path to the downloaded files + DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK)) + MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK)) + + # download DVAE files if needed + if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): + print(" > Downloading DVAE files!") + ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) + + + # Download XTTS v2.0 checkpoint if needed + TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" + XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth" + XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json" + + # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. + TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file + XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file + XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file + + # download XTTS v2.0 files if needed + if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): + print(" > Downloading XTTS v2.0 files!") + ModelManager._download_model_files( + [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True + ) + + # init args and config + model_args = GPTArgs( + max_conditioning_length=132300, # 6 secs + min_conditioning_length=66150, # 3 secs + debug_loading_failures=False, + max_wav_length=max_audio_length, # ~11.6 seconds + max_text_length=200, + mel_norm_file=MEL_NORM_FILE, + dvae_checkpoint=DVAE_CHECKPOINT, + xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune + tokenizer_file=TOKENIZER_FILE, + gpt_num_audio_tokens=1026, + gpt_start_audio_token=1024, + gpt_stop_audio_token=1025, + gpt_use_masking_gt_prompt_approach=True, + gpt_use_perceiver_resampler=True, + ) + # define audio config + audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) + # training parameters config + config = GPTTrainerConfig( + epochs=num_epochs, + output_path=OUT_PATH, + model_args=model_args, + run_name=RUN_NAME, + project_name=PROJECT_NAME, + run_description=""" + GPT XTTS training + """, + dashboard_logger=DASHBOARD_LOGGER, + logger_uri=LOGGER_URI, + audio=audio_config, + batch_size=BATCH_SIZE, + batch_group_size=48, + eval_batch_size=BATCH_SIZE, + num_loader_workers=8, + eval_split_max_size=256, + print_step=50, + plot_step=100, + log_model_step=100, + save_step=1000, + save_n_checkpoints=1, + save_checkpoints=True, + # target_loss="loss", + print_eval=False, + # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters. + optimizer="AdamW", + optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS, + optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2}, + lr=5e-06, # learning rate + lr_scheduler="MultiStepLR", + # it was adjusted accordly for the new step scheme + lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, + test_sentences=[], + ) + + # init the model from config + model = GPTTrainer.init_from_config(config) + + # load training samples + train_samples, eval_samples = load_tts_samples( + DATASETS_CONFIG_LIST, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + # init the trainer and 🚀 + trainer = Trainer( + TrainerArgs( + restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter + skip_train_epoch=False, + start_with_eval=START_WITH_EVAL, + grad_accum_steps=GRAD_ACUMM_STEPS, + ), + config, + output_path=OUT_PATH, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + ) + trainer.fit() + + # get the longest text audio file to use as speaker reference + samples_len = [len(item["text"].split(" ")) for item in train_samples] + longest_text_idx = samples_len.index(max(samples_len)) + speaker_ref = train_samples[longest_text_idx]["audio_file"] + + trainer_out_path = trainer.output_path + + # deallocate VRAM and RAM + del model, trainer, train_samples, eval_samples + gc.collect() + + return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py new file mode 100644 index 0000000000..ebb11f29d1 --- /dev/null +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -0,0 +1,415 @@ +import argparse +import os +import sys +import tempfile + +import gradio as gr +import librosa.display +import numpy as np + +import os +import torch +import torchaudio +import traceback +from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list +from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt + +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.models.xtts import Xtts + + +def clear_gpu_cache(): + # clear the GPU cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +XTTS_MODEL = None +def load_model(xtts_checkpoint, xtts_config, xtts_vocab): + global XTTS_MODEL + clear_gpu_cache() + if not xtts_checkpoint or not xtts_config or not xtts_vocab: + return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" + config = XttsConfig() + config.load_json(xtts_config) + XTTS_MODEL = Xtts.init_from_config(config) + print("Loading XTTS model! ") + XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False) + if torch.cuda.is_available(): + XTTS_MODEL.cuda() + + print("Model Loaded!") + return "Model Loaded!" + +def run_tts(lang, tts_text, speaker_audio_file): + if XTTS_MODEL is None or not speaker_audio_file: + return "You need to run the previous step to load the model !!", None, None + + gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) + out = XTTS_MODEL.inference( + text=tts_text, + language=lang, + gpt_cond_latent=gpt_cond_latent, + speaker_embedding=speaker_embedding, + temperature=XTTS_MODEL.config.temperature, # Add custom parameters here + length_penalty=XTTS_MODEL.config.length_penalty, + repetition_penalty=XTTS_MODEL.config.repetition_penalty, + top_k=XTTS_MODEL.config.top_k, + top_p=XTTS_MODEL.config.top_p, + ) + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: + out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) + out_path = fp.name + torchaudio.save(out_path, out["wav"], 24000) + + return "Speech generated !", out_path, speaker_audio_file + + + + +# define a logger to redirect +class Logger: + def __init__(self, filename="log.out"): + self.log_file = filename + self.terminal = sys.stdout + self.log = open(self.log_file, "w") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + self.terminal.flush() + self.log.flush() + + def isatty(self): + return False + +# redirect stdout and stderr to a file +sys.stdout = Logger() +sys.stderr = sys.stdout + + +# logging.basicConfig(stream=sys.stdout, level=logging.INFO) +import logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.StreamHandler(sys.stdout) + ] +) + +def read_logs(): + sys.stdout.flush() + with open(sys.stdout.log_file, "r") as f: + return f.read() + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="""XTTS fine-tuning demo\n\n""" + """ + Example runs: + python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port + """, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--port", + type=int, + help="Port to run the gradio demo. Default: 5003", + default=5003, + ) + parser.add_argument( + "--out_path", + type=str, + help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/", + default="/tmp/xtts_ft/", + ) + + parser.add_argument( + "--num_epochs", + type=int, + help="Number of epochs to train. Default: 10", + default=10, + ) + parser.add_argument( + "--batch_size", + type=int, + help="Batch size. Default: 4", + default=4, + ) + parser.add_argument( + "--grad_acumm", + type=int, + help="Grad accumulation steps. Default: 1", + default=1, + ) + parser.add_argument( + "--max_audio_length", + type=int, + help="Max permitted audio size in seconds. Default: 11", + default=11, + ) + + args = parser.parse_args() + + with gr.Blocks() as demo: + with gr.Tab("1 - Data processing"): + out_path = gr.Textbox( + label="Output path (where data and checkpoints will be saved):", + value=args.out_path, + ) + # upload_file = gr.Audio( + # sources="upload", + # label="Select here the audio files that you want to use for XTTS trainining !", + # type="filepath", + # ) + upload_file = gr.File( + file_count="multiple", + label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)", + ) + lang = gr.Dropdown( + label="Dataset Language", + value="en", + choices=[ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh", + "hu", + "ko", + "ja" + ], + ) + progress_data = gr.Label( + label="Progress:" + ) + logs = gr.Textbox( + label="Logs:", + interactive=False, + ) + demo.load(read_logs, None, logs, every=1) + + prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") + + def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): + clear_gpu_cache() + out_path = os.path.join(out_path, "dataset") + os.makedirs(out_path, exist_ok=True) + if audio_path is None: + return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" + else: + try: + train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) + except: + traceback.print_exc() + error = traceback.format_exc() + return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" + + clear_gpu_cache() + + # if audio total len is less than 2 minutes raise an error + if audio_total_size < 120: + message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" + print(message) + return message, "", "" + + print("Dataset Processed!") + return "Dataset Processed!", train_meta, eval_meta + + with gr.Tab("2 - Fine-tuning XTTS Encoder"): + train_csv = gr.Textbox( + label="Train CSV:", + ) + eval_csv = gr.Textbox( + label="Eval CSV:", + ) + num_epochs = gr.Slider( + label="Number of epochs:", + minimum=1, + maximum=100, + step=1, + value=args.num_epochs, + ) + batch_size = gr.Slider( + label="Batch size:", + minimum=2, + maximum=512, + step=1, + value=args.batch_size, + ) + grad_acumm = gr.Slider( + label="Grad accumulation steps:", + minimum=2, + maximum=128, + step=1, + value=args.grad_acumm, + ) + max_audio_length = gr.Slider( + label="Max permitted audio size in seconds:", + minimum=2, + maximum=20, + step=1, + value=args.max_audio_length, + ) + progress_train = gr.Label( + label="Progress:" + ) + logs_tts_train = gr.Textbox( + label="Logs:", + interactive=False, + ) + demo.load(read_logs, None, logs_tts_train, every=1) + train_btn = gr.Button(value="Step 2 - Run the training") + + def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): + clear_gpu_cache() + if not train_csv or not eval_csv: + return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" + try: + # convert seconds to waveform frames + max_audio_length = int(max_audio_length * 22050) + config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) + except: + traceback.print_exc() + error = traceback.format_exc() + return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" + + # copy original files to avoid parameters changes issues + os.system(f"cp {config_path} {exp_path}") + os.system(f"cp {vocab_file} {exp_path}") + + ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") + print("Model training done!") + clear_gpu_cache() + return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav + + with gr.Tab("3 - Inference"): + with gr.Row(): + with gr.Column() as col1: + xtts_checkpoint = gr.Textbox( + label="XTTS checkpoint path:", + value="", + ) + xtts_config = gr.Textbox( + label="XTTS config path:", + value="", + ) + + xtts_vocab = gr.Textbox( + label="XTTS vocab path:", + value="", + ) + progress_load = gr.Label( + label="Progress:" + ) + load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") + + with gr.Column() as col2: + speaker_reference_audio = gr.Textbox( + label="Speaker reference audio:", + value="", + ) + tts_language = gr.Dropdown( + label="Language", + value="en", + choices=[ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh", + "hu", + "ko", + "ja", + ] + ) + tts_text = gr.Textbox( + label="Input Text.", + value="This model sounds really good and above all, it's reasonably fast.", + ) + tts_btn = gr.Button(value="Step 4 - Inference") + + with gr.Column() as col3: + progress_gen = gr.Label( + label="Progress:" + ) + tts_output_audio = gr.Audio(label="Generated Audio.") + reference_audio = gr.Audio(label="Reference audio used.") + + prompt_compute_btn.click( + fn=preprocess_dataset, + inputs=[ + upload_file, + lang, + out_path, + ], + outputs=[ + progress_data, + train_csv, + eval_csv, + ], + ) + + + train_btn.click( + fn=train_model, + inputs=[ + lang, + train_csv, + eval_csv, + num_epochs, + batch_size, + grad_acumm, + out_path, + max_audio_length, + ], + outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], + ) + + load_btn.click( + fn=load_model, + inputs=[ + xtts_checkpoint, + xtts_config, + xtts_vocab + ], + outputs=[progress_load], + ) + + tts_btn.click( + fn=run_tts, + inputs=[ + tts_language, + tts_text, + speaker_reference_audio, + ], + outputs=[progress_gen, tts_output_audio, reference_audio], + ) + + demo.launch( + share=True, + debug=False, + server_port=args.port, + server_name="0.0.0.0" + ) diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 4789e1f43f..671be8eb7f 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -225,11 +225,11 @@ def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels @torch.no_grad() def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 + test_audios = {} if self.config.test_sentences: # init gpt for inference mode self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) self.xtts.gpt.eval() - test_audios = {} print(" | > Synthesizing test sentences.") for idx, s_info in enumerate(self.config.test_sentences): wav = self.xtts.synthesize( diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 208ec4d561..6b8cc59101 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -272,6 +272,11 @@ def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = style_embs = [] for i in range(0, audio.shape[1], 22050 * chunk_length): audio_chunk = audio[:, i : i + 22050 * chunk_length] + + # if the chunk is too short ignore it + if audio_chunk.size(-1) < 22050 * 0.33: + continue + mel_chunk = wav_to_mel_cloning( audio_chunk, mel_norms=self.mel_stats.cpu(), diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md index 03e44af170..527dd3d068 100644 --- a/docs/source/models/xtts.md +++ b/docs/source/models/xtts.md @@ -175,6 +175,50 @@ torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000) ### Training +#### Easy training +To make `XTTS_v2` GPT encoder training easier for beginner users we did a gradio demo that implements the whole fine-tuning pipeline. The gradio demo enables the user to easily do the following steps: + +- Preprocessing of the uploaded audio or audio files in 🐸 TTS coqui formatter +- Train the XTTS GPT encoder with the processed data +- Inference support using the fine-tuned model + +The user can run this gradio demo locally or remotely using a Colab Notebook. + +##### Run demo on Colab +To make the `XTTS_v2` fine-tuning more accessible for users that do not have good GPUs available we did a Google Colab Notebook. + +The Colab Notebook is available [here](https://colab.research.google.com/drive/1GiI4_X724M8q2W-zZ-jXo7cWTV7RfaH-?usp=sharing). + +To learn how to use this Colab Notebook please check the [XTTS fine-tuning video](). + +If you are not able to acess the video you need to follow the steps: + +1. Open the Colab notebook and start the demo by runining the first two cells (ignore pip install errors in the first one). +2. Click on the link "Running on public URL:" on the second cell output. +3. On the first Tab (1 - Data processing) you need to select the audio file or files, wait for upload, and then click on the button "Step 1 - Create dataset" and then wait until the dataset processing is done. +4. Soon as the dataset processing is done you need to go to the second Tab (2 - Fine-tuning XTTS Encoder) and press the button "Step 2 - Run the training" and then wait until the training is finished. Note that it can take up to 40 minutes. +5. Soon the training is done you can go to the third Tab (3 - Inference) and then click on the button "Step 3 - Load Fine-tuned XTTS model" and wait until the fine-tuned model is loaded. Then you can do the inference on the model by clicking on the button "Step 4 - Inference". + + +##### Run demo locally + +To run the demo locally you need to do the following steps: +1. Install 🐸 TTS following the instructions available [here](https://tts.readthedocs.io/en/dev/installation.html#installation). +2. Install the Gradio demo requirements with the command `python3 -m pip install -r TTS/demos/xtts_ft_demo/requirements.txt` +3. Run the Gradio demo using the command `python3 TTS/demos/xtts_ft_demo/xtts_demo.py` +4. Follow the steps presented in the [tutorial video](https://www.youtube.com/watch?v=8tpDiiouGxc&feature=youtu.be) to be able to fine-tune and test the fine-tuned model. + + +If you are not able to access the video, here is what you need to do: + +1. On the first Tab (1 - Data processing) select the audio file or files, wait for upload +2. Click on the button "Step 1 - Create dataset" and then wait until the dataset processing is done. +3. Go to the second Tab (2 - Fine-tuning XTTS Encoder) and press the button "Step 2 - Run the training" and then wait until the training is finished. it will take some time. +4. Go to the third Tab (3 - Inference) and then click on the button "Step 3 - Load Fine-tuned XTTS model" and wait until the fine-tuned model is loaded. +5. Now you can run inference with the model by clicking on the button "Step 4 - Inference". + +#### Advanced training + A recipe for `XTTS_v2` GPT encoder training using `LJSpeech` dataset is available at https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech/xtts_v1/train_gpt_xtts.py You need to change the fields of the `BaseDatasetConfig` to match your dataset and then update `GPTArgs` and `GPTTrainerConfig` fields as you need. By default, it will use the same parameters that XTTS v1.1 model was trained with. To speed up the model convergence, as default, it will also download the XTTS v1.1 checkpoint and load it. @@ -222,6 +266,7 @@ torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000) ``` + ## References and Acknowledgements - VallE: https://arxiv.org/abs/2301.02111 - Tortoise Repo: https://github.com/neonbjb/tortoise-tts