-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
REBASED: Transform Speaker Encoder in a Generic Encoder and Implement…
… Emotion Encoder training support (#1349) * Rename Speaker encoder module to encoder * Add a generic emotion dataset formatter * Transform the Speaker Encoder dataset to a generic dataset and create emotion encoder config * Add class map in emotion config * Add Base encoder config * Add evaluation encoder script * Fix the bug in plot_embeddings * Enable Weight decay for encoder training * Add argumnet to disable storage * Add Perfect Sampler and remove storage * Add evaluation during encoder training * Fix lint checks * Remove useless config parameter * Active evaluation in speaker encoder test and use multispeaker dataset for this test * Unit tests fixs * Remove useless tests for speedup the aux_tests * Use get_optimizer in Encoder * Add BaseEncoder Class * Fix the unitests * Add Perfect Batch Sampler unit test * Add compute encoder accuracy in a function
- Loading branch information
Showing
40 changed files
with
971 additions
and
2,800 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import argparse | ||
import torch | ||
from argparse import RawTextHelpFormatter | ||
|
||
from tqdm import tqdm | ||
|
||
from TTS.config import load_config | ||
from TTS.tts.datasets import load_tts_samples | ||
from TTS.tts.utils.speakers import SpeakerManager | ||
|
||
def compute_encoder_accuracy(dataset_items, encoder_manager): | ||
|
||
class_name_key = encoder_manager.speaker_encoder_config.class_name_key | ||
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, 'map_classid_to_classname', None) | ||
|
||
class_acc_dict = {} | ||
|
||
# compute embeddings for all wav_files | ||
for item in tqdm(dataset_items): | ||
class_name = item[class_name_key] | ||
wav_file = item["audio_file"] | ||
|
||
# extract the embedding | ||
embedd = encoder_manager.compute_d_vector_from_clip(wav_file) | ||
if encoder_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None: | ||
embedding = torch.FloatTensor(embedd).unsqueeze(0) | ||
if encoder_manager.use_cuda: | ||
embedding = embedding.cuda() | ||
|
||
class_id = encoder_manager.speaker_encoder_criterion.softmax.inference(embedding).item() | ||
predicted_label = map_classid_to_classname[str(class_id)] | ||
else: | ||
predicted_label = None | ||
|
||
if class_name is not None and predicted_label is not None: | ||
is_equal = int(class_name == predicted_label) | ||
if class_name not in class_acc_dict: | ||
class_acc_dict[class_name] = [is_equal] | ||
else: | ||
class_acc_dict[class_name].append(is_equal) | ||
else: | ||
raise RuntimeError("Error: class_name or/and predicted_label are None") | ||
|
||
acc_avg = 0 | ||
for key, values in class_acc_dict.items(): | ||
acc = sum(values)/len(values) | ||
print("Class", key, "Accuracy:", acc) | ||
acc_avg += acc | ||
|
||
print("Average Accuracy:", acc_avg/len(class_acc_dict)) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="""Compute the accuracy of the encoder.\n\n""" | ||
""" | ||
Example runs: | ||
python TTS/bin/eval_encoder.py emotion_encoder_model.pth.tar emotion_encoder_config.json dataset_config.json | ||
""", | ||
formatter_class=RawTextHelpFormatter, | ||
) | ||
parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") | ||
parser.add_argument( | ||
"config_path", | ||
type=str, | ||
help="Path to model config file.", | ||
) | ||
|
||
parser.add_argument( | ||
"config_dataset_path", | ||
type=str, | ||
help="Path to dataset config file.", | ||
) | ||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) | ||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True) | ||
|
||
args = parser.parse_args() | ||
|
||
c_dataset = load_config(args.config_dataset_path) | ||
|
||
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval) | ||
items = meta_data_train + meta_data_eval | ||
|
||
enc_manager = SpeakerManager( | ||
encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda | ||
) | ||
|
||
compute_encoder_accuracy(items, enc_manager) |
Oops, something went wrong.