Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REBASED: Transform Speaker Encoder in a Generic Encoder and Implement Emotion Encoder training support #1349

Merged
merged 21 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions TTS/bin/compute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,35 @@
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
wav_files = meta_data_train + meta_data_eval

speaker_manager = SpeakerManager(
encoder_manager = SpeakerManager(
encoder_model_path=args.model_path,
encoder_config_path=args.config_path,
d_vectors_file_path=args.old_file,
use_cuda=args.use_cuda,
)

class_name_key = encoder_manager.speaker_encoder_config.class_name_key

# compute speaker embeddings
speaker_mapping = {}
for idx, wav_file in enumerate(tqdm(wav_files)):
if isinstance(wav_file, list):
speaker_name = wav_file[2]
wav_file = wav_file[1]
if isinstance(wav_file, dict):
class_name = wav_file[class_name_key]
wav_file = wav_file["audio_file"]
else:
speaker_name = None
class_name = None

wav_file_name = os.path.basename(wav_file)
if args.old_file is not None and wav_file_name in speaker_manager.clip_ids:
if args.old_file is not None and wav_file_name in encoder_manager.clip_ids:
# get the embedding from the old file
embedd = speaker_manager.get_d_vector_by_clip(wav_file_name)
embedd = encoder_manager.get_d_vector_by_clip(wav_file_name)
else:
# extract the embedding
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)

# create speaker_mapping if target dataset is defined
speaker_mapping[wav_file_name] = {}
speaker_mapping[wav_file_name]["name"] = speaker_name
speaker_mapping[wav_file_name]["name"] = class_name
speaker_mapping[wav_file_name]["embedding"] = embedd

if speaker_mapping:
Expand All @@ -81,5 +83,5 @@
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)

# pylint: disable=W0212
speaker_manager._save_json(mapping_file_path, speaker_mapping)
encoder_manager._save_json(mapping_file_path, speaker_mapping)
print("Speaker embeddings saved at:", mapping_file_path)
88 changes: 88 additions & 0 deletions TTS/bin/eval_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import argparse
import torch
from argparse import RawTextHelpFormatter

from tqdm import tqdm

from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.speakers import SpeakerManager

def compute_encoder_accuracy(dataset_items, encoder_manager):

class_name_key = encoder_manager.speaker_encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, 'map_classid_to_classname', None)

class_acc_dict = {}

# compute embeddings for all wav_files
for item in tqdm(dataset_items):
class_name = item[class_name_key]
wav_file = item["audio_file"]

# extract the embedding
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
if encoder_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None:
embedding = torch.FloatTensor(embedd).unsqueeze(0)
if encoder_manager.use_cuda:
embedding = embedding.cuda()

class_id = encoder_manager.speaker_encoder_criterion.softmax.inference(embedding).item()
predicted_label = map_classid_to_classname[str(class_id)]
else:
predicted_label = None

if class_name is not None and predicted_label is not None:
is_equal = int(class_name == predicted_label)
if class_name not in class_acc_dict:
class_acc_dict[class_name] = [is_equal]
else:
class_acc_dict[class_name].append(is_equal)
else:
raise RuntimeError("Error: class_name or/and predicted_label are None")

acc_avg = 0
for key, values in class_acc_dict.items():
acc = sum(values)/len(values)
print("Class", key, "Accuracy:", acc)
acc_avg += acc

print("Average Accuracy:", acc_avg/len(class_acc_dict))


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""Compute the accuracy of the encoder.\n\n"""
"""
Example runs:
python TTS/bin/eval_encoder.py emotion_encoder_model.pth.tar emotion_encoder_config.json dataset_config.json
""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument("model_path", type=str, help="Path to model checkpoint file.")
parser.add_argument(
"config_path",
type=str,
help="Path to model config file.",
)

parser.add_argument(
"config_dataset_path",
type=str,
help="Path to dataset config file.",
)
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)

args = parser.parse_args()

c_dataset = load_config(args.config_dataset_path)

meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
items = meta_data_train + meta_data_eval

enc_manager = SpeakerManager(
encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda
)

compute_encoder_accuracy(items, enc_manager)
Loading