diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 5081715432..68571fb440 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -42,33 +42,35 @@ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval) wav_files = meta_data_train + meta_data_eval -speaker_manager = SpeakerManager( +encoder_manager = SpeakerManager( encoder_model_path=args.model_path, encoder_config_path=args.config_path, d_vectors_file_path=args.old_file, use_cuda=args.use_cuda, ) +class_name_key = encoder_manager.speaker_encoder_config.class_name_key + # compute speaker embeddings speaker_mapping = {} for idx, wav_file in enumerate(tqdm(wav_files)): - if isinstance(wav_file, list): - speaker_name = wav_file[2] - wav_file = wav_file[1] + if isinstance(wav_file, dict): + class_name = wav_file[class_name_key] + wav_file = wav_file["audio_file"] else: - speaker_name = None + class_name = None wav_file_name = os.path.basename(wav_file) - if args.old_file is not None and wav_file_name in speaker_manager.clip_ids: + if args.old_file is not None and wav_file_name in encoder_manager.clip_ids: # get the embedding from the old file - embedd = speaker_manager.get_d_vector_by_clip(wav_file_name) + embedd = encoder_manager.get_d_vector_by_clip(wav_file_name) else: # extract the embedding - embedd = speaker_manager.compute_d_vector_from_clip(wav_file) + embedd = encoder_manager.compute_d_vector_from_clip(wav_file) # create speaker_mapping if target dataset is defined speaker_mapping[wav_file_name] = {} - speaker_mapping[wav_file_name]["name"] = speaker_name + speaker_mapping[wav_file_name]["name"] = class_name speaker_mapping[wav_file_name]["embedding"] = embedd if speaker_mapping: @@ -81,5 +83,5 @@ os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True) # pylint: disable=W0212 - speaker_manager._save_json(mapping_file_path, speaker_mapping) + encoder_manager._save_json(mapping_file_path, speaker_mapping) print("Speaker embeddings saved at:", mapping_file_path) diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py new file mode 100644 index 0000000000..a03bfd824f --- /dev/null +++ b/TTS/bin/eval_encoder.py @@ -0,0 +1,88 @@ +import argparse +import torch +from argparse import RawTextHelpFormatter + +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.speakers import SpeakerManager + +def compute_encoder_accuracy(dataset_items, encoder_manager): + + class_name_key = encoder_manager.speaker_encoder_config.class_name_key + map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, 'map_classid_to_classname', None) + + class_acc_dict = {} + + # compute embeddings for all wav_files + for item in tqdm(dataset_items): + class_name = item[class_name_key] + wav_file = item["audio_file"] + + # extract the embedding + embedd = encoder_manager.compute_d_vector_from_clip(wav_file) + if encoder_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None: + embedding = torch.FloatTensor(embedd).unsqueeze(0) + if encoder_manager.use_cuda: + embedding = embedding.cuda() + + class_id = encoder_manager.speaker_encoder_criterion.softmax.inference(embedding).item() + predicted_label = map_classid_to_classname[str(class_id)] + else: + predicted_label = None + + if class_name is not None and predicted_label is not None: + is_equal = int(class_name == predicted_label) + if class_name not in class_acc_dict: + class_acc_dict[class_name] = [is_equal] + else: + class_acc_dict[class_name].append(is_equal) + else: + raise RuntimeError("Error: class_name or/and predicted_label are None") + + acc_avg = 0 + for key, values in class_acc_dict.items(): + acc = sum(values)/len(values) + print("Class", key, "Accuracy:", acc) + acc_avg += acc + + print("Average Accuracy:", acc_avg/len(class_acc_dict)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""Compute the accuracy of the encoder.\n\n""" + """ + Example runs: + python TTS/bin/eval_encoder.py emotion_encoder_model.pth.tar emotion_encoder_config.json dataset_config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") + parser.add_argument( + "config_path", + type=str, + help="Path to model config file.", + ) + + parser.add_argument( + "config_dataset_path", + type=str, + help="Path to dataset config file.", + ) + parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + + args = parser.parse_args() + + c_dataset = load_config(args.config_dataset_path) + + meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval) + items = meta_data_train + meta_data_eval + + enc_manager = SpeakerManager( + encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda + ) + + compute_encoder_accuracy(items, enc_manager) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index b7424698f6..af3e6ec4a1 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -10,16 +10,16 @@ from torch.utils.data import DataLoader from trainer.torch import NoamLR -from TTS.speaker_encoder.dataset import SpeakerEncoderDataset -from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss -from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model -from TTS.speaker_encoder.utils.training import init_training -from TTS.speaker_encoder.utils.visual import plot_embeddings +from TTS.encoder.dataset import EncoderDataset +from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model +from TTS.encoder.utils.samplers import PerfectBatchSampler +from TTS.encoder.utils.training import init_training +from TTS.encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor -from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict -from TTS.utils.io import load_fsspec -from TTS.utils.radam import RAdam +from TTS.utils.generic_utils import count_parameters, remove_experiment_folder +from TTS.utils.io import copy_model_files +from trainer.trainer_utils import get_optimizer from TTS.utils.training import check_update torch.backends.cudnn.enabled = True @@ -32,164 +32,238 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): + num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class + num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch + + dataset = EncoderDataset( + c, + ap, + meta_data_eval if is_val else meta_data_train, + voice_len=c.voice_len, + num_utter_per_class=num_utter_per_class, + num_classes_in_batch=num_classes_in_batch, + verbose=verbose, + augmentation_config=c.audio_augmentation if not is_val else None, + use_torch_spec=c.model_params.get("use_torch_spec", False), + ) + # get classes list + classes = dataset.get_class_list() + + sampler = PerfectBatchSampler( + dataset.items, + classes, + batch_size=num_classes_in_batch*num_utter_per_class, # total batch size + num_classes_in_batch=num_classes_in_batch, + num_gpus=1, + shuffle=not is_val, + drop_last=True) + + if len(classes) < num_classes_in_batch: + if is_val: + raise RuntimeError(f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !") + raise RuntimeError(f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !") + + # set the classes to avoid get wrong class_id when the number of training and eval classes are not equal if is_val: - loader = None - else: - dataset = SpeakerEncoderDataset( - ap, - meta_data_eval if is_val else meta_data_train, - voice_len=c.voice_len, - num_utter_per_speaker=c.num_utters_per_speaker, - num_speakers_in_batch=c.num_speakers_in_batch, - skip_speakers=c.skip_speakers, - storage_size=c.storage["storage_size"], - sample_from_storage_p=c.storage["sample_from_storage_p"], - verbose=verbose, - augmentation_config=c.audio_augmentation, - use_torch_spec=c.model_params.get("use_torch_spec", False), - ) + dataset.set_classes(train_classes) - # sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader( - dataset, - batch_size=c.num_speakers_in_batch, - shuffle=False, - num_workers=c.num_loader_workers, - collate_fn=dataset.collate_fn, - ) - return loader, dataset.get_num_speakers() + loader = DataLoader( + dataset, + num_workers=c.num_loader_workers, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + ) + return loader, classes, dataset.get_map_classid_to_classname() -def train(model, optimizer, scheduler, criterion, data_loader, global_step): +def evaluation(model, criterion, data_loader, global_step): + eval_loss = 0 + for _, data in enumerate(data_loader): + with torch.no_grad(): + # setup input data + inputs, labels = data + + # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] + labels = torch.transpose(labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1).reshape(labels.shape) + inputs = torch.transpose(inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) + + # dispatch data to GPU + if use_cuda: + inputs = inputs.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + # forward pass model + outputs = model(inputs) + + # loss computation + loss = criterion(outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels) + + eval_loss += loss.item() + + eval_avg_loss = eval_loss/len(data_loader) + # save stats + dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss}) + # plot the last batch in the evaluation + figures = { + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), + } + dashboard_logger.eval_figures(global_step, figures) + return eval_avg_loss + +def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step): model.train() - epoch_time = 0 best_loss = float("inf") - avg_loss = 0 - avg_loss_all = 0 avg_loader_time = 0 end_time = time.time() - - for _, data in enumerate(data_loader): - start_time = time.time() - - # setup input data - inputs, labels = data - loader_time = time.time() - end_time - global_step += 1 - - # setup lr - if c.lr_decay: - scheduler.step() - optimizer.zero_grad() - - # dispatch data to GPU - if use_cuda: - inputs = inputs.cuda(non_blocking=True) - labels = labels.cuda(non_blocking=True) - - # forward pass model - outputs = model(inputs) - - # loss computation - loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1), labels) - loss.backward() - grad_norm, _ = check_update(model, c.grad_clip) - optimizer.step() - - step_time = time.time() - start_time - epoch_time += step_time - - # Averaged Loss and Averaged Loader Time - avg_loss = 0.01 * loss.item() + 0.99 * avg_loss if avg_loss != 0 else loss.item() - num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1 - avg_loader_time = ( - 1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time - if avg_loader_time != 0 - else loader_time + for epoch in range(c.epochs): + tot_loss = 0 + epoch_time = 0 + for _, data in enumerate(data_loader): + start_time = time.time() + + # setup input data + inputs, labels = data + # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] + labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) + inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) + # ToDo: move it to a unit test + # labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) + # inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) + # idx = 0 + # for j in range(0, c.num_classes_in_batch, 1): + # for i in range(j, len(labels), c.num_classes_in_batch): + # if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])): + # print("Invalid") + # print(labels) + # exit() + # idx += 1 + # labels = labels_converted + # inputs = inputs_converted + + loader_time = time.time() - end_time + global_step += 1 + + # setup lr + if c.lr_decay: + scheduler.step() + optimizer.zero_grad() + + # dispatch data to GPU + if use_cuda: + inputs = inputs.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + # forward pass model + outputs = model(inputs) + + # loss computation + loss = criterion(outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels) + loss.backward() + grad_norm, _ = check_update(model, c.grad_clip) + optimizer.step() + + step_time = time.time() - start_time + epoch_time += step_time + + # acumulate the total epoch loss + tot_loss += loss.item() + + # Averaged Loader Time + num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1 + avg_loader_time = ( + 1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time + if avg_loader_time != 0 + else loader_time + ) + current_lr = optimizer.param_groups[0]["lr"] + + if global_step % c.steps_plot_stats == 0: + # Plot Training Epoch Stats + train_stats = { + "loss": loss.item(), + "lr": current_lr, + "grad_norm": grad_norm, + "step_time": step_time, + "avg_loader_time": avg_loader_time, + } + dashboard_logger.train_epoch_stats(global_step, train_stats) + figures = { + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), + } + dashboard_logger.train_figures(global_step, figures) + + if global_step % c.print_step == 0: + print( + " | > Step:{} Loss:{:.5f} GradNorm:{:.5f} " + "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format( + global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr + ), + flush=True, + ) + + if global_step % c.save_step == 0: + # save model + save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch) + + end_time = time.time() + + print("") + print( + ">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} " + "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format( + epoch, tot_loss/len(data_loader), grad_norm, epoch_time, avg_loader_time + ), + flush=True, ) - current_lr = optimizer.param_groups[0]["lr"] - - if global_step % c.steps_plot_stats == 0: - # Plot Training Epoch Stats - train_stats = { - "loss": avg_loss, - "lr": current_lr, - "grad_norm": grad_norm, - "step_time": step_time, - "avg_loader_time": avg_loader_time, - } - dashboard_logger.train_epoch_stats(global_step, train_stats) - figures = { - # FIXME: not constant - "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10), - } - dashboard_logger.train_figures(global_step, figures) - - if global_step % c.print_step == 0: + # evaluation + if c.run_eval: + model.eval() + eval_loss = evaluation(model, criterion, eval_data_loader, global_step) + print("\n\n") + print("--> EVAL PERFORMANCE") print( - " | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} " - "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format( - global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, avg_loader_time, current_lr - ), - flush=True, + " | > Epoch:{} AvgLoss: {:.5f} ".format( + epoch, eval_loss + ), + flush=True, ) - avg_loss_all += avg_loss + # save the best checkpoint + best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch) + model.train() - if global_step >= c.max_train_step or global_step % c.save_step == 0: - # save best model only - best_loss = save_best_model(model, optimizer, criterion, avg_loss, best_loss, OUT_PATH, global_step) - avg_loss_all = 0 - if global_step >= c.max_train_step: - break - - end_time = time.time() - - return avg_loss, global_step + return best_loss, global_step def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global meta_data_train global meta_data_eval + global train_classes ap = AudioProcessor(**c.audio) model = setup_speaker_encoder_model(c) - optimizer = RAdam(model.parameters(), lr=c.lr) + optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model) # pylint: disable=redefined-outer-name - meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False) - - data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True) + meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) - if c.loss == "ge2e": - criterion = GE2ELoss(loss_method="softmax") - elif c.loss == "angleproto": - criterion = AngleProtoLoss() - elif c.loss == "softmaxproto": - criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_speakers) + train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True) + if c.run_eval: + eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True) else: - raise Exception("The %s not is a loss supported" % c.loss) + eval_data_loader = None + + num_classes = len(train_classes) + criterion = model.get_criterion(c, num_classes) + + if c.loss == "softmaxproto" and c.model != "speaker_encoder": + c.map_classid_to_classname = map_classid_to_classname + copy_model_files(c, OUT_PATH) if args.restore_path: - checkpoint = load_fsspec(args.restore_path) - try: - model.load_state_dict(checkpoint["model"]) - - if "criterion" in checkpoint: - criterion.load_state_dict(checkpoint["criterion"]) - - except (KeyError, RuntimeError): - print(" > Partial model initialization.") - model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], c) - model.load_state_dict(model_dict) - del model_dict - for group in optimizer.param_groups: - group["lr"] = c.lr - - print(" > Model restored from step %d" % checkpoint["step"], flush=True) - args.restore_step = checkpoint["step"] + criterion, args.restore_step = model.load_checkpoint(c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion) + print(" > Model restored from step %d" % args.restore_step, flush=True) else: args.restore_step = 0 @@ -206,7 +280,7 @@ def main(args): # pylint: disable=redefined-outer-name criterion.cuda() global_step = args.restore_step - _, global_step = train(model, optimizer, scheduler, criterion, data_loader, global_step) + _, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step) if __name__ == "__main__": diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index 5c905295ec..6b0778c5a7 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -37,7 +37,7 @@ def register_config(model_name: str) -> Coqpit: """ config_class = None config_name = model_name + "_config" - paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"] + paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs"] for path in paths: try: config_class = find_module(path, config_name) diff --git a/TTS/speaker_encoder/README.md b/TTS/encoder/README.md similarity index 100% rename from TTS/speaker_encoder/README.md rename to TTS/encoder/README.md diff --git a/TTS/speaker_encoder/__init__.py b/TTS/encoder/__init__.py similarity index 100% rename from TTS/speaker_encoder/__init__.py rename to TTS/encoder/__init__.py diff --git a/TTS/speaker_encoder/speaker_encoder_config.py b/TTS/encoder/configs/base_encoder_config.py similarity index 66% rename from TTS/speaker_encoder/speaker_encoder_config.py rename to TTS/encoder/configs/base_encoder_config.py index 8212acc73b..02b88d6630 100644 --- a/TTS/speaker_encoder/speaker_encoder_config.py +++ b/TTS/encoder/configs/base_encoder_config.py @@ -7,10 +7,10 @@ @dataclass -class SpeakerEncoderConfig(BaseTrainingConfig): - """Defines parameters for Speaker Encoder model.""" +class BaseEncoderConfig(BaseTrainingConfig): + """Defines parameters for a Generic Encoder model.""" - model: str = "speaker_encoder" + model: str = None audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) # model params @@ -27,34 +27,33 @@ class SpeakerEncoderConfig(BaseTrainingConfig): audio_augmentation: Dict = field(default_factory=lambda: {}) - storage: Dict = field( - default_factory=lambda: { - "sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage - "storage_size": 15, # the size of the in-memory storage with respect to a single batch - } - ) - # training params - max_train_step: int = 1000000 # end training when number of training steps reaches this value. + epochs: int = 10000 loss: str = "angleproto" grad_clip: float = 3.0 lr: float = 0.0001 + optimizer: str = "radam" + optimizer_params: Dict = field(default_factory=lambda: { + "betas": [0.9, 0.999], + "weight_decay": 0 + }) lr_decay: bool = False warmup_steps: int = 4000 - wd: float = 1e-6 # logging params tb_model_param_stats: bool = False steps_plot_stats: int = 10 - checkpoint: bool = True save_step: int = 1000 print_step: int = 20 + run_eval: bool = False # data loader - num_speakers_in_batch: int = MISSING - num_utters_per_speaker: int = MISSING + num_classes_in_batch: int = MISSING + num_utter_per_class: int = MISSING + eval_num_classes_in_batch: int = None + eval_num_utter_per_class: int = None + num_loader_workers: int = MISSING - skip_speakers: bool = False voice_len: float = 1.6 def check_values(self): diff --git a/TTS/encoder/configs/emotion_encoder_config.py b/TTS/encoder/configs/emotion_encoder_config.py new file mode 100644 index 0000000000..5eda2671be --- /dev/null +++ b/TTS/encoder/configs/emotion_encoder_config.py @@ -0,0 +1,12 @@ +from dataclasses import asdict, dataclass + +from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig + + +@dataclass +class EmotionEncoderConfig(BaseEncoderConfig): + """Defines parameters for Emotion Encoder model.""" + + model: str = "emotion_encoder" + map_classid_to_classname: dict = None + class_name_key: str = "emotion_name" diff --git a/TTS/encoder/configs/speaker_encoder_config.py b/TTS/encoder/configs/speaker_encoder_config.py new file mode 100644 index 0000000000..6dceb00277 --- /dev/null +++ b/TTS/encoder/configs/speaker_encoder_config.py @@ -0,0 +1,11 @@ +from dataclasses import asdict, dataclass + +from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig + + +@dataclass +class SpeakerEncoderConfig(BaseEncoderConfig): + """Defines parameters for Speaker Encoder model.""" + + model: str = "speaker_encoder" + class_name_key: str = "speaker_name" diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py new file mode 100644 index 0000000000..a4db021bd0 --- /dev/null +++ b/TTS/encoder/dataset.py @@ -0,0 +1,149 @@ +import random + +import torch +from torch.utils.data import Dataset + +from TTS.encoder.utils.generic_utils import AugmentWAV + +class EncoderDataset(Dataset): + def __init__( + self, + config, + ap, + meta_data, + voice_len=1.6, + num_classes_in_batch=64, + num_utter_per_class=10, + verbose=False, + augmentation_config=None, + use_torch_spec=None, + ): + """ + Args: + ap (TTS.tts.utils.AudioProcessor): audio processor object. + meta_data (list): list of dataset instances. + seq_len (int): voice segment length in seconds. + verbose (bool): print diagnostic information. + """ + super().__init__() + self.config = config + self.items = meta_data + self.sample_rate = ap.sample_rate + self.seq_len = int(voice_len * self.sample_rate) + self.num_utter_per_class = num_utter_per_class + self.ap = ap + self.verbose = verbose + self.use_torch_spec = use_torch_spec + self.classes, self.items = self.__parse_items() + + self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} + + # Data Augmentation + self.augmentator = None + self.gaussian_augmentation_config = None + if augmentation_config: + self.data_augmentation_p = augmentation_config["p"] + if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config): + self.augmentator = AugmentWAV(ap, augmentation_config) + + if "gaussian" in augmentation_config.keys(): + self.gaussian_augmentation_config = augmentation_config["gaussian"] + + if self.verbose: + print("\n > DataLoader initialization") + print(f" | > Classes per Batch: {num_classes_in_batch}") + print(f" | > Number of instances : {len(self.items)}") + print(f" | > Sequence length: {self.seq_len}") + print(f" | > Num Classes: {len(self.classes)}") + print(f" | > Classes: {self.classes}") + + + def load_wav(self, filename): + audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) + return audio + + def __parse_items(self): + class_to_utters = {} + for item in self.items: + path_ = item["audio_file"] + class_name = item[self.config.class_name_key] + if class_name in class_to_utters.keys(): + class_to_utters[class_name].append(path_) + else: + class_to_utters[class_name] = [ + path_, + ] + + # skip classes with number of samples >= self.num_utter_per_class + class_to_utters = { + k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class + } + + classes = list(class_to_utters.keys()) + classes.sort() + + new_items = [] + for item in self.items: + path_ = item["audio_file"] + class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"] + # ignore filtered classes + if class_name not in classes: + continue + # ignore small audios + if self.load_wav(path_).shape[0] - self.seq_len <= 0: + continue + + new_items.append({"wav_file_path": path_, "class_name": class_name}) + + return classes, new_items + + def __len__(self): + return len(self.items) + + def get_num_classes(self): + return len(self.classes) + + def get_class_list(self): + return self.classes + def set_classes(self, classes): + self.classes = classes + self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} + + + def get_map_classid_to_classname(self): + return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items()) + + def __getitem__(self, idx): + return self.items[idx] + + def collate_fn(self, batch): + # get the batch class_ids + labels = [] + feats = [] + for item in batch: + utter_path = item["wav_file_path"] + class_name = item["class_name"] + + # get classid + class_id = self.classname_to_classid[class_name] + # load wav file + wav = self.load_wav(utter_path) + offset = random.randint(0, wav.shape[0] - self.seq_len) + wav = wav[offset : offset + self.seq_len] + + if self.augmentator is not None and self.data_augmentation_p: + if random.random() < self.data_augmentation_p: + wav = self.augmentator.apply_one(wav) + + if not self.use_torch_spec: + mel = self.ap.melspectrogram(wav) + feats.append(torch.FloatTensor(mel)) + else: + feats.append(torch.FloatTensor(wav)) + + labels.append(class_id) + + feats = torch.stack(feats) + labels = torch.LongTensor(labels) + + return feats, labels diff --git a/TTS/speaker_encoder/losses.py b/TTS/encoder/losses.py similarity index 97% rename from TTS/speaker_encoder/losses.py rename to TTS/encoder/losses.py index 8ba917b7e9..de65d8d66b 100644 --- a/TTS/speaker_encoder/losses.py +++ b/TTS/encoder/losses.py @@ -189,6 +189,11 @@ def forward(self, x, label=None): return L + def inference(self, embedding): + x = self.fc(embedding) + activations = torch.nn.functional.softmax(x, dim=1).squeeze(0) + class_id = torch.argmax(activations) + return class_id class SoftmaxAngleProtoLoss(nn.Module): """ diff --git a/TTS/encoder/models/base_encoder.py b/TTS/encoder/models/base_encoder.py new file mode 100644 index 0000000000..c35c636d74 --- /dev/null +++ b/TTS/encoder/models/base_encoder.py @@ -0,0 +1,145 @@ +import torch +import torchaudio +import numpy as np +from torch import nn + +from TTS.utils.io import load_fsspec +from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss +from TTS.utils.generic_utils import set_init_dict +from coqpit import Coqpit + +class PreEmphasis(nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + +class BaseEncoder(nn.Module): + """Base `encoder` class. Every new `encoder` model must inherit this. + + It defines common `encoder` specific functions. + """ + + # pylint: disable=W0102 + def __init__(self): + super(BaseEncoder, self).__init__() + + def get_torch_mel_spectrogram_class(self, audio_config): + return torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + # TorchSTFT( + # n_fft=audio_config["fft_size"], + # hop_length=audio_config["hop_length"], + # win_length=audio_config["win_length"], + # sample_rate=audio_config["sample_rate"], + # window="hamming_window", + # mel_fmin=0.0, + # mel_fmax=None, + # use_htk=True, + # do_amp_to_db=False, + # n_mels=audio_config["num_mels"], + # power=2.0, + # use_mel=True, + # mel_norm=None, + # ) + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ) + ) + + @torch.no_grad() + def inference(self, x, l2_norm=True): + return self.forward(x, l2_norm) + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + # map to the waveform size + if self.use_torch_spec: + num_frames = num_frames * self.audio_config["hop_length"] + + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.inference(frames_batch, l2_norm=l2_norm) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + return embeddings + + def get_criterion(self, c: Coqpit, num_classes=None): + if c.loss == "ge2e": + criterion = GE2ELoss(loss_method="softmax") + elif c.loss == "angleproto": + criterion = AngleProtoLoss() + elif c.loss == "softmaxproto": + criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) + else: + raise Exception("The %s not is a loss supported" % c.loss) + return criterion + + def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + try: + self.load_state_dict(state["model"]) + except (KeyError, RuntimeError) as error: + # If eval raise the error + if eval: + raise error + + print(" > Partial model initialization.") + model_dict = self.state_dict() + model_dict = set_init_dict(model_dict, state["model"], c) + self.load_state_dict(model_dict) + del model_dict + + # load the criterion for restore_path + if criterion is not None and "criterion" in state: + try: + criterion.load_state_dict(state["criterion"]) + except (KeyError, RuntimeError) as error: + print(" > Criterion load ignored because of:", error) + + # instance and load the criterion for the encoder classifier in inference time + if eval and criterion is None and "criterion" in state and getattr(config, 'map_classid_to_classname', None) is not None: + criterion = self.get_criterion(config, len(config.map_classid_to_classname)) + criterion.load_state_dict(state["criterion"]) + + if use_cuda: + self.cuda() + if criterion is not None: + criterion = criterion.cuda() + + if eval: + self.eval() + assert not self.training + + if not eval: + return criterion, state["step"] + return criterion diff --git a/TTS/encoder/models/lstm.py b/TTS/encoder/models/lstm.py new file mode 100644 index 0000000000..51852b5b82 --- /dev/null +++ b/TTS/encoder/models/lstm.py @@ -0,0 +1,99 @@ +import torch +from torch import nn + +from TTS.encoder.models.base_encoder import BaseEncoder + + +class LSTMWithProjection(nn.Module): + def __init__(self, input_size, hidden_size, proj_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.proj_size = proj_size + self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) + self.linear = nn.Linear(hidden_size, proj_size, bias=False) + + def forward(self, x): + self.lstm.flatten_parameters() + o, (_, _) = self.lstm(x) + return self.linear(o) + + +class LSTMWithoutProjection(nn.Module): + def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): + super().__init__() + self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) + self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) + self.relu = nn.ReLU() + + def forward(self, x): + _, (hidden, _) = self.lstm(x) + return self.relu(self.linear(hidden[-1])) + + +class LSTMSpeakerEncoder(BaseEncoder): + def __init__( + self, + input_dim, + proj_dim=256, + lstm_dim=768, + num_lstm_layers=3, + use_lstm_with_projection=True, + use_torch_spec=False, + audio_config=None, + ): + super().__init__() + self.use_lstm_with_projection = use_lstm_with_projection + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + layers = [] + # choise LSTM layer + if use_lstm_with_projection: + layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) + for _ in range(num_lstm_layers - 1): + layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) + self.layers = nn.Sequential(*layers) + else: + self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) + else: + self.torch_spec = None + + self._init_layers() + + def _init_layers(self): + for name, param in self.layers.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0.0) + elif "weight" in name: + nn.init.xavier_normal_(param) + + def forward(self, x, l2_norm=True): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if self.use_torch_spec: + x.squeeze_(1) + x = self.torch_spec(x) + x = self.instancenorm(x).transpose(1, 2) + d = self.layers(x) + if self.use_lstm_with_projection: + d = d[:, -1] + if l2_norm: + d = torch.nn.functional.normalize(d, p=2, dim=1) + return d diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/encoder/models/resnet.py similarity index 67% rename from TTS/speaker_encoder/models/resnet.py rename to TTS/encoder/models/resnet.py index a799fc5276..c4ba953774 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/encoder/models/resnet.py @@ -1,24 +1,8 @@ -import numpy as np import torch -import torchaudio from torch import nn # from TTS.utils.audio import TorchSTFT -from TTS.utils.io import load_fsspec - - -class PreEmphasis(nn.Module): - def __init__(self, coefficient=0.97): - super().__init__() - self.coefficient = coefficient - self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) - - def forward(self, x): - assert len(x.size()) == 2 - - x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") - return torch.nn.functional.conv1d(x, self.filter).squeeze(1) - +from TTS.encoder.models.base_encoder import BaseEncoder class SELayer(nn.Module): def __init__(self, channel, reduction=8): @@ -71,7 +55,7 @@ def forward(self, x): return out -class ResNetSpeakerEncoder(nn.Module): +class ResNetSpeakerEncoder(BaseEncoder): """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 Adapted from: https://github.com/clovaai/voxceleb_trainer """ @@ -110,32 +94,7 @@ def __init__( self.instancenorm = nn.InstanceNorm1d(input_dim) if self.use_torch_spec: - self.torch_spec = torch.nn.Sequential( - PreEmphasis(audio_config["preemphasis"]), - # TorchSTFT( - # n_fft=audio_config["fft_size"], - # hop_length=audio_config["hop_length"], - # win_length=audio_config["win_length"], - # sample_rate=audio_config["sample_rate"], - # window="hamming_window", - # mel_fmin=0.0, - # mel_fmax=None, - # use_htk=True, - # do_amp_to_db=False, - # n_mels=audio_config["num_mels"], - # power=2.0, - # use_mel=True, - # mel_norm=None, - # ) - torchaudio.transforms.MelSpectrogram( - sample_rate=audio_config["sample_rate"], - n_fft=audio_config["fft_size"], - win_length=audio_config["win_length"], - hop_length=audio_config["hop_length"], - window_fn=torch.hamming_window, - n_mels=audio_config["num_mels"], - ), - ) + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) else: self.torch_spec = None @@ -238,47 +197,3 @@ def forward(self, x, l2_norm=False): if l2_norm: x = torch.nn.functional.normalize(x, p=2, dim=1) return x - - @torch.no_grad() - def inference(self, x, l2_norm=False): - return self.forward(x, l2_norm) - - @torch.no_grad() - def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): - """ - Generate embeddings for a batch of utterances - x: 1xTxD - """ - # map to the waveform size - if self.use_torch_spec: - num_frames = num_frames * self.audio_config["hop_length"] - - max_len = x.shape[1] - - if max_len < num_frames: - num_frames = max_len - - offsets = np.linspace(0, max_len - num_frames, num=num_eval) - - frames_batch = [] - for offset in offsets: - offset = int(offset) - end_offset = int(offset + num_frames) - frames = x[:, offset:end_offset] - frames_batch.append(frames) - - frames_batch = torch.cat(frames_batch, dim=0) - embeddings = self.inference(frames_batch, l2_norm=l2_norm) - - if return_mean: - embeddings = torch.mean(embeddings, dim=0, keepdim=True) - return embeddings - - def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) - self.load_state_dict(state["model"]) - if use_cuda: - self.cuda() - if eval: - self.eval() - assert not self.training diff --git a/TTS/speaker_encoder/requirements.txt b/TTS/encoder/requirements.txt similarity index 100% rename from TTS/speaker_encoder/requirements.txt rename to TTS/encoder/requirements.txt diff --git a/TTS/speaker_encoder/utils/__init__.py b/TTS/encoder/utils/__init__.py similarity index 100% rename from TTS/speaker_encoder/utils/__init__.py rename to TTS/encoder/utils/__init__.py diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py similarity index 80% rename from TTS/speaker_encoder/utils/generic_utils.py rename to TTS/encoder/utils/generic_utils.py index 4ab4e92322..17f1c3d9ec 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/encoder/utils/generic_utils.py @@ -3,60 +3,15 @@ import os import random import re -from multiprocessing import Manager import numpy as np from scipy import signal -from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder -from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder +from TTS.encoder.models.lstm import LSTMSpeakerEncoder +from TTS.encoder.models.resnet import ResNetSpeakerEncoder from TTS.utils.io import save_fsspec -class Storage(object): - def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8): - # use multiprocessing for threading safe - self.storage = Manager().list() - self.maxsize = maxsize - self.num_speakers_in_batch = num_speakers_in_batch - self.num_threads = num_threads - self.ignore_last_batch = False - - if storage_batchs >= 3: - self.ignore_last_batch = True - - # used for fast random sample - self.safe_storage_size = self.maxsize - self.num_threads - if self.ignore_last_batch: - self.safe_storage_size -= self.num_speakers_in_batch - - def __len__(self): - return len(self.storage) - - def full(self): - return len(self.storage) >= self.maxsize - - def append(self, item): - # if storage is full, remove an item - if self.full(): - self.storage.pop(0) - - self.storage.append(item) - - def get_random_sample(self): - # safe storage size considering all threads remove one item from storage in same time - storage_size = len(self.storage) - self.num_threads - - if self.ignore_last_batch: - storage_size -= self.num_speakers_in_batch - - return self.storage[random.randint(0, storage_size)] - - def get_random_sample_fast(self): - """Call this method only when storage is full""" - return self.storage[random.randint(0, self.safe_storage_size)] - - class AugmentWAV(object): def __init__(self, ap, augmentation_config): @@ -209,7 +164,7 @@ def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_s save_fsspec(state, checkpoint_path) -def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step): +def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch): if model_loss < best_loss: new_state_dict = model.state_dict() state = { @@ -217,6 +172,7 @@ def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path "optimizer": optimizer.state_dict(), "criterion": criterion.state_dict(), "step": current_step, + "epoch": epoch, "loss": model_loss, "date": datetime.date.today().strftime("%B %d, %Y"), } diff --git a/TTS/speaker_encoder/utils/io.py b/TTS/encoder/utils/io.py similarity index 100% rename from TTS/speaker_encoder/utils/io.py rename to TTS/encoder/utils/io.py diff --git a/TTS/speaker_encoder/utils/prepare_voxceleb.py b/TTS/encoder/utils/prepare_voxceleb.py similarity index 100% rename from TTS/speaker_encoder/utils/prepare_voxceleb.py rename to TTS/encoder/utils/prepare_voxceleb.py diff --git a/TTS/encoder/utils/samplers.py b/TTS/encoder/utils/samplers.py new file mode 100644 index 0000000000..947f5da029 --- /dev/null +++ b/TTS/encoder/utils/samplers.py @@ -0,0 +1,102 @@ +import random +from torch.utils.data.sampler import Sampler, SubsetRandomSampler + + +class SubsetSampler(Sampler): + """ + Samples elements sequentially from a given list of indices. + + Args: + indices (list): a sequence of indices + """ + + def __init__(self, indices): + super().__init__(indices) + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in range(len(self.indices))) + + def __len__(self): + return len(self.indices) + + +class PerfectBatchSampler(Sampler): + """ + Samples a mini-batch of indices for a balanced class batching + + Args: + dataset_items(list): dataset items to sample from. + classes (list): list of classes of dataset_items to sample from. + batch_size (int): total number of samples to be sampled in a mini-batch. + num_gpus (int): number of GPU in the data parallel mode. + shuffle (bool): if True, samples randomly, otherwise samples sequentially. + drop_last (bool): if True, drops last incomplete batch. + """ + + def __init__(self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name"): + super().__init__(dataset_items) + assert batch_size % (num_classes_in_batch * num_gpus) == 0, ( + 'Batch size must be divisible by number of classes times the number of data parallel devices (if enabled).') + + label_indices = {} + for idx, item in enumerate(dataset_items): + label = item[label_key] + if label not in label_indices.keys(): + label_indices[label] = [idx] + else: + label_indices[label].append(idx) + + if shuffle: + self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] + else: + self._samplers = [SubsetSampler(label_indices[key]) for key in classes] + + self._batch_size = batch_size + self._drop_last = drop_last + self._dp_devices = num_gpus + self._num_classes_in_batch = num_classes_in_batch + + def __iter__(self): + + batch = [] + if self._num_classes_in_batch != len(self._samplers): + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + else: + valid_samplers_idx = None + + iters = [iter(s) for s in self._samplers] + done = False + + while True: + b = [] + for i, it in enumerate(iters): + if valid_samplers_idx is not None and i not in valid_samplers_idx: + continue + idx = next(it, None) + if idx is None: + done = True + break + b.append(idx) + if done: + break + batch += b + if len(batch) == self._batch_size: + yield batch + batch = [] + if valid_samplers_idx is not None: + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + + if not self._drop_last: + if len(batch) > 0: + groups = len(batch) // self._num_classes_in_batch + if groups % self._dp_devices == 0: + yield batch + else: + batch = batch[:(groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch] + if len(batch) > 0: + yield batch + + def __len__(self): + class_batch_size = self._batch_size // self._num_classes_in_batch + return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/encoder/utils/training.py similarity index 100% rename from TTS/speaker_encoder/utils/training.py rename to TTS/encoder/utils/training.py diff --git a/TTS/speaker_encoder/utils/visual.py b/TTS/encoder/utils/visual.py similarity index 69% rename from TTS/speaker_encoder/utils/visual.py rename to TTS/encoder/utils/visual.py index 4f40f68c9d..f2db2f3fa3 100644 --- a/TTS/speaker_encoder/utils/visual.py +++ b/TTS/encoder/utils/visual.py @@ -29,14 +29,18 @@ ) -def plot_embeddings(embeddings, num_utter_per_speaker): - embeddings = embeddings[: 10 * num_utter_per_speaker] +def plot_embeddings(embeddings, num_classes_in_batch): + num_utter_per_class = embeddings.shape[0] // num_classes_in_batch + + # if necessary get just the first 10 classes + if num_classes_in_batch > 10: + num_classes_in_batch = 10 + embeddings = embeddings[: num_classes_in_batch * num_utter_per_class] + model = umap.UMAP() projection = model.fit_transform(embeddings) - num_speakers = embeddings.shape[0] // num_utter_per_speaker - ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker) + ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class) colors = [colormap[i] for i in ground_truth] - fig, ax = plt.subplots(figsize=(16, 10)) _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) plt.gca().set_aspect("equal", "datalim") diff --git a/TTS/speaker_encoder/configs/config.json b/TTS/speaker_encoder/configs/config.json deleted file mode 100644 index 30d83e5198..0000000000 --- a/TTS/speaker_encoder/configs/config.json +++ /dev/null @@ -1,118 +0,0 @@ - -{ - "model_name": "lstm", - "run_name": "mueller91", - "run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ", - "audio":{ - // Audio processing parameters - "num_mels": 40, // size of the mel spec frame. - "fft_size": 400, // number of stft frequency levels. Size of the linear spectogram frame. - "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. - "win_length": 400, // stft window length in ms. - "hop_length": 160, // stft window hop-lengh in ms. - "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. - "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. - "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. - "min_level_db": -100, // normalization range - "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. - "power": 1.5, // value to sharpen wav signals after GL algorithm. - "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. - // Normalization parameters - "signal_norm": true, // normalize the spec values in range [0, 1] - "symmetric_norm": true, // move normalization to range [-1, 1] - "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] - "clip_norm": true, // clip normalized values into the range. - "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! - "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! - "do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) - "trim_db": 60, // threshold for timming silence. Set this according to your dataset. - "stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored - }, - "reinit_layers": [], - "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA) - "grad_clip": 3.0, // upper limit for gradients for clipping. - "epochs": 1000, // total number of epochs to train. - "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. - "lr_decay": false, // if true, Noam learning rate decaying is applied through training. - "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" - "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. - "steps_plot_stats": 10, // number of steps to plot embeddings. - "num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. - "num_utters_per_speaker": 10, // - "skip_speakers": false, // skip speakers with samples less than "num_utters_per_speaker" - - "voice_len": 1.6, // number of seconds for each training instance - "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. - "wd": 0.000001, // Weight decay weight. - "checkpoint": true, // If true, it saves checkpoints per "save_step" - "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. - "print_step": 20, // Number of steps to log traning on console. - "output_path": "../../MozillaTTSOutput/checkpoints/voxceleb_librispeech/speaker_encoder/", // DATASET-RELATED: output path for all training outputs. - "model": { - "input_dim": 40, - "proj_dim": 256, - "lstm_dim": 768, - "num_lstm_layers": 3, - "use_lstm_with_projection": true - }, - - "audio_augmentation": { - "p": 0, - //add a gaussian noise to the data in order to increase robustness - "gaussian":{ // as the insertion of Gaussian noise is quick to be calculated, we added it after loading the wav file, this way, even audios that were reused with the cache can receive this noise - "p": 1, // propability of apply this method, 0 is disable - "min_amplitude": 0.0, - "max_amplitude": 1e-5 - } - }, - "storage": { - "sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage - "storage_size": 15, // the size of the in-memory storage with respect to a single batch - "additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness - }, - "datasets": - [ - { - "name": "vctk_slim", - "path": "../../../audio-datasets/en/VCTK-Corpus/", - "meta_file_train": null, - "meta_file_val": null - }, - { - "name": "libri_tts", - "path": "../../../audio-datasets/en/LibriTTS/train-clean-100", - "meta_file_train": null, - "meta_file_val": null - }, - { - "name": "libri_tts", - "path": "../../../audio-datasets/en/LibriTTS/train-clean-360", - "meta_file_train": null, - "meta_file_val": null - }, - { - "name": "libri_tts", - "path": "../../../audio-datasets/en/LibriTTS/train-other-500", - "meta_file_train": null, - "meta_file_val": null - }, - { - "name": "voxceleb1", - "path": "../../../audio-datasets/en/voxceleb1/", - "meta_file_train": null, - "meta_file_val": null - }, - { - "name": "voxceleb2", - "path": "../../../audio-datasets/en/voxceleb2/", - "meta_file_train": null, - "meta_file_val": null - }, - { - "name": "common_voice", - "path": "../../../audio-datasets/en/MozillaCommonVoice", - "meta_file_train": "train.tsv", - "meta_file_val": "test.tsv" - } - ] -} \ No newline at end of file diff --git a/TTS/speaker_encoder/configs/config_resnet_angleproto.json b/TTS/speaker_encoder/configs/config_resnet_angleproto.json deleted file mode 100644 index c26d29cebc..0000000000 --- a/TTS/speaker_encoder/configs/config_resnet_angleproto.json +++ /dev/null @@ -1,956 +0,0 @@ -{ - "model": "speaker_encoder", - "run_name": "speaker_encoder", - "run_description": "resnet speaker encoder trained with commonvoice all languages dev and train, Voxceleb 1 dev and Voxceleb 2 dev", - // AUDIO PARAMETERS - "audio":{ - // Audio processing parameters - "num_mels": 80, // size of the mel spec frame. - "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. - "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. - "win_length": 1024, // stft window length in ms. - "hop_length": 256, // stft window hop-lengh in ms. - "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. - "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. - "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. - "min_level_db": -100, // normalization range - "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. - "power": 1.5, // value to sharpen wav signals after GL algorithm. - "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. - "stft_pad_mode": "reflect", - // Normalization parameters - "signal_norm": true, // normalize the spec values in range [0, 1] - "symmetric_norm": true, // move normalization to range [-1, 1] - "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] - "clip_norm": true, // clip normalized values into the range. - "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! - "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! - "spec_gain": 20.0, - "do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) - "trim_db": 60, // threshold for timming silence. Set this according to your dataset. - "stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored - }, - "reinit_layers": [], - - "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss - "grad_clip": 3.0, // upper limit for gradients for clipping. - "max_train_step": 1000000, // total number of steps to train. - "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. - "lr_decay": false, // if true, Noam learning rate decaying is applied through training. - "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" - "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. - "steps_plot_stats": 100, // number of steps to plot embeddings. - - // Speakers config - "num_speakers_in_batch": 200, // Batch size for training. - "num_utters_per_speaker": 2, // - "skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker" - "voice_len": 2, // number of seconds for each training instance - - "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. - "wd": 0.000001, // Weight decay weight. - "checkpoint": true, // If true, it saves checkpoints per "save_step" - "save_step": 1000, // Number of training steps expected to save the best checkpoints in training. - "print_step": 50, // Number of steps to log traning on console. - "output_path": "../checkpoints/speaker_encoder/angleproto/resnet_voxceleb1_and_voxceleb2-and-common-voice-all-using-angleproto/", // DATASET-RELATED: output path for all training outputs. - - "audio_augmentation": { - "p": 0.5, // propability of apply this method, 0 is disable rir and additive noise augmentation - "rir":{ - "rir_path": "/workspace/store/ecasanova/ComParE/RIRS_NOISES/simulated_rirs/", - "conv_mode": "full" - }, - "additive":{ - "sounds_path": "/workspace/store/ecasanova/ComParE/musan/", - // list of each of the directories in your data augmentation, if a directory is in "sounds_path" but is not listed here it will be ignored - "speech":{ - "min_snr_in_db": 13, - "max_snr_in_db": 20, - "min_num_noises": 2, - "max_num_noises": 3 - }, - "noise":{ - "min_snr_in_db": 0, - "max_snr_in_db": 15, - "min_num_noises": 1, - "max_num_noises": 1 - }, - "music":{ - "min_snr_in_db": 5, - "max_snr_in_db": 15, - "min_num_noises": 1, - "max_num_noises": 1 - } - }, - //add a gaussian noise to the data in order to increase robustness - "gaussian":{ // as the insertion of Gaussian noise is quick to be calculated, we added it after loading the wav file, this way, even audios that were reused with the cache can receive this noise - "p": 0.5, // propability of apply this method, 0 is disable - "min_amplitude": 0.0, - "max_amplitude": 1e-5 - } - }, - "model_params": { - "model_name": "resnet", - "input_dim": 80, - "proj_dim": 512 - }, - "storage": { - "sample_from_storage_p": 0.5, // the probability with which we'll sample from the DataSet in-memory storage - "storage_size": 35 // the size of the in-memory storage with respect to a single batch - }, - "datasets": - [ - { - "name": "voxceleb2", - "path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox2_dev_aac/", - "meta_file_train": null, - "meta_file_val": null - }, - { - "name": "voxceleb1", - "path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox1_dev_wav/", - "meta_file_train": null, - "meta_file_val": null - }, - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv", - "meta_file_train": "dev.tsv", - "meta_file_val": null - } - - ] -} \ No newline at end of file diff --git a/TTS/speaker_encoder/configs/config_resnet_softmax_angleproto.json b/TTS/speaker_encoder/configs/config_resnet_softmax_angleproto.json deleted file mode 100644 index ccbd751ad9..0000000000 --- a/TTS/speaker_encoder/configs/config_resnet_softmax_angleproto.json +++ /dev/null @@ -1,957 +0,0 @@ - -{ - "model": "speaker_encoder", - "run_name": "speaker_encoder", - "run_description": "resnet speaker encoder trained with commonvoice all languages dev and train, Voxceleb 1 dev and Voxceleb 2 dev", - // AUDIO PARAMETERS - "audio":{ - // Audio processing parameters - "num_mels": 80, // size of the mel spec frame. - "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. - "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. - "win_length": 1024, // stft window length in ms. - "hop_length": 256, // stft window hop-lengh in ms. - "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. - "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. - "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. - "min_level_db": -100, // normalization range - "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. - "power": 1.5, // value to sharpen wav signals after GL algorithm. - "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. - "stft_pad_mode": "reflect", - // Normalization parameters - "signal_norm": true, // normalize the spec values in range [0, 1] - "symmetric_norm": true, // move normalization to range [-1, 1] - "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] - "clip_norm": true, // clip normalized values into the range. - "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! - "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! - "spec_gain": 20.0, - "do_trim_silence": false, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) - "trim_db": 60, // threshold for timming silence. Set this according to your dataset. - "stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored - }, - "reinit_layers": [], - - "loss": "softmaxproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss - "grad_clip": 3.0, // upper limit for gradients for clipping. - "max_train_step": 1000000, // total number of steps to train. - "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. - "lr_decay": false, // if true, Noam learning rate decaying is applied through training. - "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" - "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. - "steps_plot_stats": 100, // number of steps to plot embeddings. - - // Speakers config - "num_speakers_in_batch": 200, // Batch size for training. - "num_utters_per_speaker": 2, // - "skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker" - "voice_len": 2, // number of seconds for each training instance - - "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. - "wd": 0.000001, // Weight decay weight. - "checkpoint": true, // If true, it saves checkpoints per "save_step" - "save_step": 1000, // Number of training steps expected to save the best checkpoints in training. - "print_step": 50, // Number of steps to log traning on console. - "output_path": "../../../checkpoints/speaker_encoder/resnet_voxceleb1_and_voxceleb2-and-common-voice-all/", // DATASET-RELATED: output path for all training outputs. - - "audio_augmentation": { - "p": 0.5, // propability of apply this method, 0 is disable rir and additive noise augmentation - "rir":{ - "rir_path": "/workspace/store/ecasanova/ComParE/RIRS_NOISES/simulated_rirs/", - "conv_mode": "full" - }, - "additive":{ - "sounds_path": "/workspace/store/ecasanova/ComParE/musan/", - // list of each of the directories in your data augmentation, if a directory is in "sounds_path" but is not listed here it will be ignored - "speech":{ - "min_snr_in_db": 13, - "max_snr_in_db": 20, - "min_num_noises": 2, - "max_num_noises": 3 - }, - "noise":{ - "min_snr_in_db": 0, - "max_snr_in_db": 15, - "min_num_noises": 1, - "max_num_noises": 1 - }, - "music":{ - "min_snr_in_db": 5, - "max_snr_in_db": 15, - "min_num_noises": 1, - "max_num_noises": 1 - } - }, - //add a gaussian noise to the data in order to increase robustness - "gaussian":{ // as the insertion of Gaussian noise is quick to be calculated, we added it after loading the wav file, this way, even audios that were reused with the cache can receive this noise - "p": 0.5, // propability of apply this method, 0 is disable - "min_amplitude": 0.0, - "max_amplitude": 1e-5 - } - }, - "model_params": { - "model_name": "resnet", - "input_dim": 80, - "proj_dim": 512 - }, - "storage": { - "sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage - "storage_size": 35 // the size of the in-memory storage with respect to a single batch - }, - "datasets": - [ - { - "name": "voxceleb2", - "path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox2_dev_aac/", - "meta_file_train": null, - "meta_file_val": null - }, - { - "name": "voxceleb1", - "path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox1_dev_wav/", - "meta_file_train": null, - "meta_file_val": null - }, - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fi", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-CN", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-sursilv", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lt", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ka", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sv-SE", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pl", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ru", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mn", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/nl", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sl", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/es", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pt", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hi", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ja", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ia", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/br", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/id", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/dv", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ta", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/or", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-HK", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/de", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/uk", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/en", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fa", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vi", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ab", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/sah", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/vot", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fr", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tr", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lg", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/mt", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rw", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hu", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/rm-vallader", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/el", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/tt", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/zh-TW", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/et", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/fy-NL", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cs", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/as", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ro", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eo", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/pa-IN", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/th", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/it", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ga-IE", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cnh", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ky", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ar", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/eu", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/ca", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/kab", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cy", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/cv", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/hsb", - "meta_file_train": "dev.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv", - "meta_file_train": "train.tsv", - "meta_file_val": null - }, - - { - "name": "common_voice", - "path": "/workspace/scratch/ecasanova/datasets/common-voice/cv-corpus-6.1-2020-12-11_16khz/lv", - "meta_file_train": "dev.tsv", - "meta_file_val": null - } - - ] -} \ No newline at end of file diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py deleted file mode 100644 index 07fa924660..0000000000 --- a/TTS/speaker_encoder/dataset.py +++ /dev/null @@ -1,243 +0,0 @@ -import random - -import numpy as np -import torch -from torch.utils.data import Dataset - -from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage - - -class SpeakerEncoderDataset(Dataset): - def __init__( - self, - ap, - meta_data, - voice_len=1.6, - num_speakers_in_batch=64, - storage_size=1, - sample_from_storage_p=0.5, - num_utter_per_speaker=10, - skip_speakers=False, - verbose=False, - augmentation_config=None, - use_torch_spec=None, - ): - """ - Args: - ap (TTS.tts.utils.AudioProcessor): audio processor object. - meta_data (list): list of dataset instances. - seq_len (int): voice segment length in seconds. - verbose (bool): print diagnostic information. - """ - super().__init__() - self.items = meta_data - self.sample_rate = ap.sample_rate - self.seq_len = int(voice_len * self.sample_rate) - self.num_speakers_in_batch = num_speakers_in_batch - self.num_utter_per_speaker = num_utter_per_speaker - self.skip_speakers = skip_speakers - self.ap = ap - self.verbose = verbose - self.use_torch_spec = use_torch_spec - self.__parse_items() - storage_max_size = storage_size * num_speakers_in_batch - self.storage = Storage( - maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch - ) - self.sample_from_storage_p = float(sample_from_storage_p) - - speakers_aux = list(self.speakers) - speakers_aux.sort() - self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)} - - # Augmentation - self.augmentator = None - self.gaussian_augmentation_config = None - if augmentation_config: - self.data_augmentation_p = augmentation_config["p"] - if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config): - self.augmentator = AugmentWAV(ap, augmentation_config) - - if "gaussian" in augmentation_config.keys(): - self.gaussian_augmentation_config = augmentation_config["gaussian"] - - if self.verbose: - print("\n > DataLoader initialization") - print(f" | > Speakers per Batch: {num_speakers_in_batch}") - print(f" | > Storage Size: {storage_max_size} instances, each with {num_utter_per_speaker} utters") - print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}") - print(f" | > Number of instances : {len(self.items)}") - print(f" | > Sequence length: {self.seq_len}") - print(f" | > Num speakers: {len(self.speakers)}") - - def load_wav(self, filename): - audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) - return audio - - def __parse_items(self): - self.speaker_to_utters = {} - for i in self.items: - path_ = i["audio_file"] - speaker_ = i["speaker_name"] - if speaker_ in self.speaker_to_utters.keys(): - self.speaker_to_utters[speaker_].append(path_) - else: - self.speaker_to_utters[speaker_] = [ - path_, - ] - - if self.skip_speakers: - self.speaker_to_utters = { - k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker - } - - self.speakers = [k for (k, v) in self.speaker_to_utters.items()] - - def __len__(self): - return int(1e10) - - def get_num_speakers(self): - return len(self.speakers) - - def __sample_speaker(self, ignore_speakers=None): - speaker = random.sample(self.speakers, 1)[0] - # if list of speakers_id is provide make sure that it's will be ignored - if ignore_speakers and self.speakerid_to_classid[speaker] in ignore_speakers: - while True: - speaker = random.sample(self.speakers, 1)[0] - if self.speakerid_to_classid[speaker] not in ignore_speakers: - break - - if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): - utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker) - else: - utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker) - return speaker, utters - - def __sample_speaker_utterances(self, speaker): - """ - Sample all M utterances for the given speaker. - """ - wavs = [] - labels = [] - for _ in range(self.num_utter_per_speaker): - # TODO:dummy but works - while True: - # remove speakers that have num_utter less than 2 - if len(self.speaker_to_utters[speaker]) > 1: - utter = random.sample(self.speaker_to_utters[speaker], 1)[0] - else: - if speaker in self.speakers: - self.speakers.remove(speaker) - - speaker, _ = self.__sample_speaker() - continue - - wav = self.load_wav(utter) - if wav.shape[0] - self.seq_len > 0: - break - - if utter in self.speaker_to_utters[speaker]: - self.speaker_to_utters[speaker].remove(utter) - - if self.augmentator is not None and self.data_augmentation_p: - if random.random() < self.data_augmentation_p: - wav = self.augmentator.apply_one(wav) - - wavs.append(wav) - labels.append(self.speakerid_to_classid[speaker]) - return wavs, labels - - def __getitem__(self, idx): - speaker, _ = self.__sample_speaker() - speaker_id = self.speakerid_to_classid[speaker] - return speaker, speaker_id - - def __load_from_disk_and_storage(self, speaker): - # don't sample from storage, but from HDD - wavs_, labels_ = self.__sample_speaker_utterances(speaker) - # put the newly loaded item into storage - self.storage.append((wavs_, labels_)) - return wavs_, labels_ - - def collate_fn(self, batch): - # get the batch speaker_ids - batch = np.array(batch) - speakers_id_in_batch = set(batch[:, 1].astype(np.int32)) - - labels = [] - feats = [] - speakers = set() - - for speaker, speaker_id in batch: - speaker_id = int(speaker_id) - - # ensure that an speaker appears only once in the batch - if speaker_id in speakers: - - # remove current speaker - if speaker_id in speakers_id_in_batch: - speakers_id_in_batch.remove(speaker_id) - - speaker, _ = self.__sample_speaker(ignore_speakers=speakers_id_in_batch) - speaker_id = self.speakerid_to_classid[speaker] - speakers_id_in_batch.add(speaker_id) - - if random.random() < self.sample_from_storage_p and self.storage.full(): - # sample from storage (if full) - wavs_, labels_ = self.storage.get_random_sample_fast() - - # force choose the current speaker or other not in batch - # It's necessary for ideal training with AngleProto and GE2E losses - if labels_[0] in speakers_id_in_batch and labels_[0] != speaker_id: - attempts = 0 - while True: - wavs_, labels_ = self.storage.get_random_sample_fast() - if labels_[0] == speaker_id or labels_[0] not in speakers_id_in_batch: - break - - attempts += 1 - # Try 5 times after that load from disk - if attempts >= 5: - wavs_, labels_ = self.__load_from_disk_and_storage(speaker) - break - else: - # don't sample from storage, but from HDD - wavs_, labels_ = self.__load_from_disk_and_storage(speaker) - - # append speaker for control - speakers.add(labels_[0]) - - # remove current speaker and append other - if speaker_id in speakers_id_in_batch: - speakers_id_in_batch.remove(speaker_id) - - speakers_id_in_batch.add(labels_[0]) - - # get a random subset of each of the wavs and extract mel spectrograms. - feats_ = [] - for wav in wavs_: - offset = random.randint(0, wav.shape[0] - self.seq_len) - wav = wav[offset : offset + self.seq_len] - # add random gaussian noise - if self.gaussian_augmentation_config and self.gaussian_augmentation_config["p"]: - if random.random() < self.gaussian_augmentation_config["p"]: - wav += np.random.normal( - self.gaussian_augmentation_config["min_amplitude"], - self.gaussian_augmentation_config["max_amplitude"], - size=len(wav), - ) - - if not self.use_torch_spec: - mel = self.ap.melspectrogram(wav) - feats_.append(torch.FloatTensor(mel)) - else: - feats_.append(torch.FloatTensor(wav)) - - labels.append(torch.LongTensor(labels_)) - feats.extend(feats_) - - feats = torch.stack(feats) - labels = torch.stack(labels) - - return feats, labels diff --git a/TTS/speaker_encoder/models/lstm.py b/TTS/speaker_encoder/models/lstm.py deleted file mode 100644 index ec394cdbf8..0000000000 --- a/TTS/speaker_encoder/models/lstm.py +++ /dev/null @@ -1,189 +0,0 @@ -import numpy as np -import torch -import torchaudio -from torch import nn - -from TTS.speaker_encoder.models.resnet import PreEmphasis -from TTS.utils.io import load_fsspec - - -class LSTMWithProjection(nn.Module): - def __init__(self, input_size, hidden_size, proj_size): - super().__init__() - self.input_size = input_size - self.hidden_size = hidden_size - self.proj_size = proj_size - self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) - self.linear = nn.Linear(hidden_size, proj_size, bias=False) - - def forward(self, x): - self.lstm.flatten_parameters() - o, (_, _) = self.lstm(x) - return self.linear(o) - - -class LSTMWithoutProjection(nn.Module): - def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): - super().__init__() - self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) - self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) - self.relu = nn.ReLU() - - def forward(self, x): - _, (hidden, _) = self.lstm(x) - return self.relu(self.linear(hidden[-1])) - - -class LSTMSpeakerEncoder(nn.Module): - def __init__( - self, - input_dim, - proj_dim=256, - lstm_dim=768, - num_lstm_layers=3, - use_lstm_with_projection=True, - use_torch_spec=False, - audio_config=None, - ): - super().__init__() - self.use_lstm_with_projection = use_lstm_with_projection - self.use_torch_spec = use_torch_spec - self.audio_config = audio_config - self.proj_dim = proj_dim - - layers = [] - # choise LSTM layer - if use_lstm_with_projection: - layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) - for _ in range(num_lstm_layers - 1): - layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) - self.layers = nn.Sequential(*layers) - else: - self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) - - self.instancenorm = nn.InstanceNorm1d(input_dim) - - if self.use_torch_spec: - self.torch_spec = torch.nn.Sequential( - PreEmphasis(audio_config["preemphasis"]), - # TorchSTFT( - # n_fft=audio_config["fft_size"], - # hop_length=audio_config["hop_length"], - # win_length=audio_config["win_length"], - # sample_rate=audio_config["sample_rate"], - # window="hamming_window", - # mel_fmin=0.0, - # mel_fmax=None, - # use_htk=True, - # do_amp_to_db=False, - # n_mels=audio_config["num_mels"], - # power=2.0, - # use_mel=True, - # mel_norm=None, - # ) - torchaudio.transforms.MelSpectrogram( - sample_rate=audio_config["sample_rate"], - n_fft=audio_config["fft_size"], - win_length=audio_config["win_length"], - hop_length=audio_config["hop_length"], - window_fn=torch.hamming_window, - n_mels=audio_config["num_mels"], - ), - ) - else: - self.torch_spec = None - - self._init_layers() - - def _init_layers(self): - for name, param in self.layers.named_parameters(): - if "bias" in name: - nn.init.constant_(param, 0.0) - elif "weight" in name: - nn.init.xavier_normal_(param) - - def forward(self, x, l2_norm=True): - """Forward pass of the model. - - Args: - x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` - to compute the spectrogram on-the-fly. - l2_norm (bool): Whether to L2-normalize the outputs. - - Shapes: - - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` - """ - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - if self.use_torch_spec: - x.squeeze_(1) - x = self.torch_spec(x) - x = self.instancenorm(x).transpose(1, 2) - d = self.layers(x) - if self.use_lstm_with_projection: - d = d[:, -1] - if l2_norm: - d = torch.nn.functional.normalize(d, p=2, dim=1) - return d - - @torch.no_grad() - def inference(self, x, l2_norm=True): - d = self.forward(x, l2_norm=l2_norm) - return d - - def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): - """ - Generate embeddings for a batch of utterances - x: 1xTxD - """ - max_len = x.shape[1] - - if max_len < num_frames: - num_frames = max_len - - offsets = np.linspace(0, max_len - num_frames, num=num_eval) - - frames_batch = [] - for offset in offsets: - offset = int(offset) - end_offset = int(offset + num_frames) - frames = x[:, offset:end_offset] - frames_batch.append(frames) - - frames_batch = torch.cat(frames_batch, dim=0) - embeddings = self.inference(frames_batch) - - if return_mean: - embeddings = torch.mean(embeddings, dim=0, keepdim=True) - - return embeddings - - def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5): - """ - Generate embeddings for a batch of utterances - x: BxTxD - """ - num_overlap = num_frames * overlap - max_len = x.shape[1] - embed = None - num_iters = seq_lens / (num_frames - num_overlap) - cur_iter = 0 - for offset in range(0, max_len, num_frames - num_overlap): - cur_iter += 1 - end_offset = min(x.shape[1], offset + num_frames) - frames = x[:, offset:end_offset] - if embed is None: - embed = self.inference(frames) - else: - embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :]) - return embed / num_iters - - # pylint: disable=unused-argument, redefined-builtin - def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) - self.load_state_dict(state["model"]) - if use_cuda: - self.cuda() - if eval: - self.eval() - assert not self.training diff --git a/TTS/speaker_encoder/umap.png b/TTS/speaker_encoder/umap.png deleted file mode 100644 index ca8aefeac8..0000000000 Binary files a/TTS/speaker_encoder/umap.png and /dev/null differ diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index a9b56ed497..dcc862e85a 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -264,7 +264,7 @@ class BaseTTSConfig(BaseTrainingConfig): # dataset datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) # optimizer - optimizer: str = None + optimizer: str = "radam" optimizer_params: dict = None # scheduler lr_scheduler: str = "" diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 573a5debba..3e963d0c45 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -441,6 +441,26 @@ def _voxcel_x(root_path, meta_file, voxcel_idx): return [x.strip().split("|") for x in f.readlines()] +def emotion(root_path, meta_file, ignored_speakers=None): + """Generic emotion dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("file_path"): + continue + cols = line.split(",") + wav_file = os.path.join(root_path, cols[0]) + speaker_id = cols[1] + emotion_id = cols[2].replace("\n", "") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id}) + return items + + def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument """Normalizes the Baker meta data file to TTS format diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index c15a3abf4f..1a5da94a24 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -9,7 +9,7 @@ from coqpit import Coqpit from TTS.config import get_from_config_or_model_args_with_default, load_config -from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model +from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.utils.audio import AudioProcessor @@ -269,7 +269,7 @@ def init_speaker_encoder(self, model_path: str, config_path: str) -> None: """ self.speaker_encoder_config = load_config(config_path) self.speaker_encoder = setup_speaker_encoder_model(self.speaker_encoder_config) - self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda) + self.speaker_encoder_criterion = self.speaker_encoder.load_checkpoint(self.speaker_encoder_config, model_path, eval=True, use_cuda=self.use_cuda) self.speaker_encoder_ap = AudioProcessor(**self.speaker_encoder_config.audio) def compute_d_vector_from_clip(self, wav_file: Union[str, List[str]]) -> list: diff --git a/tests/aux_tests/test_speaker_encoder.py b/tests/aux_tests/test_speaker_encoder.py index 97b3b92f90..f2875cc188 100644 --- a/tests/aux_tests/test_speaker_encoder.py +++ b/tests/aux_tests/test_speaker_encoder.py @@ -3,9 +3,9 @@ import torch as T from tests import get_tests_input_path -from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss -from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder -from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder +from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss +from TTS.encoder.models.lstm import LSTMSpeakerEncoder +from TTS.encoder.models.resnet import ResNetSpeakerEncoder file_path = get_tests_input_path() diff --git a/tests/aux_tests/test_speaker_encoder_train.py b/tests/aux_tests/test_speaker_encoder_train.py index 7901fe5a64..d9d6d71e77 100644 --- a/tests/aux_tests/test_speaker_encoder_train.py +++ b/tests/aux_tests/test_speaker_encoder_train.py @@ -4,14 +4,14 @@ from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig -from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig +from TTS.encoder.configs.speaker_encoder_config import SpeakerEncoderConfig def run_test_train(): command = ( f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} " f"--coqpit.output_path {output_path} " - "--coqpit.datasets.0.name ljspeech " + "--coqpit.datasets.0.name ljspeech_test " "--coqpit.datasets.0.meta_file_train metadata.csv " "--coqpit.datasets.0.meta_file_val metadata.csv " "--coqpit.datasets.0.path tests/data/ljspeech " @@ -24,17 +24,21 @@ def run_test_train(): config = SpeakerEncoderConfig( batch_size=4, - num_speakers_in_batch=1, - num_utters_per_speaker=10, - num_loader_workers=0, - max_train_step=2, + num_classes_in_batch=4, + num_utter_per_class=2, + eval_num_classes_in_batch=4, + eval_num_utter_per_class=2, + num_loader_workers=1, + epochs=1, print_step=1, - save_step=1, + save_step=2, print_eval=True, + run_eval=True, audio=BaseAudioConfig(num_mels=80), ) config.audio.do_trim_silence = True config.audio.trim_db = 60 +config.loss = "ge2e" config.save_json(config_path) print(config) @@ -69,14 +73,14 @@ def run_test_train(): shutil.rmtree(continue_path) # test model with ge2e loss function -config.loss = "ge2e" -config.save_json(config_path) -run_test_train() +# config.loss = "ge2e" +# config.save_json(config_path) +# run_test_train() # test model with angleproto loss function -config.loss = "angleproto" -config.save_json(config_path) -run_test_train() +# config.loss = "angleproto" +# config.save_json(config_path) +# run_test_train() # test model with softmaxproto loss function config.loss = "softmaxproto" diff --git a/tests/aux_tests/test_speaker_manager.py b/tests/aux_tests/test_speaker_manager.py index fff49b131f..5fafb56a83 100644 --- a/tests/aux_tests/test_speaker_manager.py +++ b/tests/aux_tests/test_speaker_manager.py @@ -6,8 +6,8 @@ from tests import get_tests_input_path from TTS.config import load_config -from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model -from TTS.speaker_encoder.utils.io import save_checkpoint +from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model +from TTS.encoder.utils.io import save_checkpoint from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index 12152fb812..c888c629cd 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -8,6 +8,7 @@ from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.languages import get_language_balancer_weights from TTS.tts.utils.speakers import get_speaker_balancer_weights +from TTS.encoder.utils.samplers import PerfectBatchSampler # Fixing random state to avoid random fails torch.manual_seed(0) @@ -82,3 +83,51 @@ def test_speaker_weighted_random_sampler(self): # pylint: disable=no-self-use spk2 += 1 assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced" + + def test_perfect_sampler(self): # pylint: disable=no-self-use + classes = set() + for item in train_samples: + classes.add(item["speaker_name"]) + + sampler = PerfectBatchSampler( + train_samples, + classes, + batch_size=2 * 3, # total batch size + num_classes_in_batch=2, + label_key="speaker_name", + shuffle=False, + drop_last=True) + batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)]) + for batch in batchs: + spk1, spk2 = 0, 0 + # for in each batch + for index in batch: + if train_samples[index]["speaker_name"] == "ljspeech-0": + spk1 += 1 + else: + spk2 += 1 + assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced" + + def test_perfect_sampler_shuffle(self): # pylint: disable=no-self-use + classes = set() + for item in train_samples: + classes.add(item["speaker_name"]) + + sampler = PerfectBatchSampler( + train_samples, + classes, + batch_size=2 * 3, # total batch size + num_classes_in_batch=2, + label_key="speaker_name", + shuffle=True, + drop_last=False) + batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)]) + for batch in batchs: + spk1, spk2 = 0, 0 + # for in each batch + for index in batch: + if train_samples[index]["speaker_name"] == "ljspeech-0": + spk1 += 1 + else: + spk2 += 1 + assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced" diff --git a/tests/inputs/test_glow_tts.json b/tests/inputs/test_glow_tts.json index 6dd86057ee..64b0982879 100644 --- a/tests/inputs/test_glow_tts.json +++ b/tests/inputs/test_glow_tts.json @@ -66,8 +66,8 @@ "use_mas": false, // use Monotonic Alignment Search if true. Otherwise use pre-computed attention alignments. // TRAINING - "batch_size": 2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. - "eval_batch_size":1, + "batch_size": 8, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "eval_batch_size": 8, "r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. "loss_masking": true, // enable / disable loss masking against the sequence padding. "data_dep_init_iter": 1, diff --git a/tests/inputs/test_speaker_encoder_config.json b/tests/inputs/test_speaker_encoder_config.json index 09a2f6a4aa..bfcc17ab0e 100644 --- a/tests/inputs/test_speaker_encoder_config.json +++ b/tests/inputs/test_speaker_encoder_config.json @@ -36,8 +36,8 @@ "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. "steps_plot_stats": 10, // number of steps to plot embeddings. - "num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. - "num_utters_per_speaker": 10, // + "num_classes_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "num_utter_per_class": 10, // "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. "wd": 0.000001, // Weight decay weight. "checkpoint": true, // If true, it saves checkpoints per "save_step" diff --git a/tests/inputs/test_tacotron2_config.json b/tests/inputs/test_tacotron2_config.json index 6c82891d73..69b235609c 100644 --- a/tests/inputs/test_tacotron2_config.json +++ b/tests/inputs/test_tacotron2_config.json @@ -61,8 +61,8 @@ "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. // TRAINING - "batch_size": 1, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. - "eval_batch_size":1, + "batch_size": 8, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "eval_batch_size": 8, "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. "gradual_training": [[0, 7, 4], [1, 5, 2]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed. "loss_masking": true, // enable / disable loss masking against the sequence padding. diff --git a/tests/inputs/test_tacotron_config.json b/tests/inputs/test_tacotron_config.json index b60ed35e3c..90e07fc7c9 100644 --- a/tests/inputs/test_tacotron_config.json +++ b/tests/inputs/test_tacotron_config.json @@ -61,8 +61,8 @@ "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. // TRAINING - "batch_size": 1, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. - "eval_batch_size":1, + "batch_size": 8, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "eval_batch_size": 8, "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. "gradual_training": [[0, 7, 4], [1, 5, 2]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed. "loss_masking": true, // enable / disable loss masking against the sequence padding. diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 384234e510..81d2ebbd62 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -7,7 +7,7 @@ from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path from TTS.config import load_config -from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model +from TTS.encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.tts.configs.vits_config import VitsConfig from TTS.tts.models.vits import Vits, VitsArgs, amp_to_db, db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec from TTS.tts.utils.speakers import SpeakerManager