Skip to content

Commit

Permalink
Fix dataset handling with the new embedding file keys (#1991)
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson authored Sep 19, 2022
1 parent 0a112f7 commit 3faccbd
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 20 deletions.
4 changes: 1 addition & 3 deletions TTS/bin/compute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,9 @@
for idx, fields in enumerate(tqdm(samples)):
class_name = fields[class_name_key]
audio_file = fields["audio_file"]
dataset_name = fields["dataset_name"]
embedding_key = fields["audio_unique_name"]
root_path = fields["root_path"]

relfilepath = os.path.splitext(audio_file.replace(root_path, ""))[0]
embedding_key = f"{dataset_name}#{relfilepath}"
if args.old_file is not None and embedding_key in encoder_manager.clip_ids:
# get the embedding from the old file
embedd = encoder_manager.get_embedding_by_clip(embedding_key)
Expand Down
20 changes: 16 additions & 4 deletions TTS/tts/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
return items[:eval_split_size], items[eval_split_size:]


def add_extra_keys(metadata, language, dataset_name):
for item in metadata:
# add language name
item["language"] = language
# add unique audio name
relfilepath = os.path.splitext(item["audio_file"].replace(item["root_path"], ""))[0]
audio_unique_name = f"{dataset_name}#{relfilepath}"
item["audio_unique_name"] = audio_unique_name

return metadata


def load_tts_samples(
datasets: Union[List[Dict], Dict],
eval_split=True,
Expand Down Expand Up @@ -111,15 +123,15 @@ def load_tts_samples(
# load train set
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
assert len(meta_data_train) > 0, f" [!] No training samples found in {root_path}/{meta_file_train}"
meta_data_train = [{**item, **{"language": language, "dataset_name": dataset_name}} for item in meta_data_train]

meta_data_train = add_extra_keys(meta_data_train, language, dataset_name)

print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set
if eval_split:
if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [
{**item, **{"language": language, "dataset_name": dataset_name}} for item in meta_data_eval
]
meta_data_eval = add_extra_keys(meta_data_eval, language, dataset_name)
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
meta_data_eval_all += meta_data_eval
Expand Down
5 changes: 3 additions & 2 deletions TTS/tts/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def load_data(self, idx):
"speaker_name": item["speaker_name"],
"language_name": item["language"],
"wav_file_name": os.path.basename(item["audio_file"]),
"audio_unique_name": item["audio_unique_name"],
}
return sample

Expand Down Expand Up @@ -397,8 +398,8 @@ def collate_fn(self, batch):
language_ids = None
# get pre-computed d-vectors
if self.d_vector_mapping is not None:
wav_files_names = list(batch["wav_file_name"])
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
embedding_keys = list(batch["audio_unique_name"])
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in embedding_keys]
else:
d_vectors = None

Expand Down
5 changes: 4 additions & 1 deletion TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def __getitem__(self, idx):
"wav_file": wav_filename,
"speaker_name": item["speaker_name"],
"language_name": item["language"],
"audio_unique_name": item["audio_unique_name"],
}

@property
Expand All @@ -308,6 +309,7 @@ def collate_fn(self, batch):
- language_names: :math:`[B]`
- audiofile_paths: :math:`[B]`
- raw_texts: :math:`[B]`
- audio_unique_names: :math:`[B]`
"""
# convert list of dicts to dict of lists
B = len(batch)
Expand Down Expand Up @@ -348,6 +350,7 @@ def collate_fn(self, batch):
"language_names": batch["language_name"],
"audio_files": batch["wav_file"],
"raw_text": batch["raw_text"],
"audio_unique_names": batch["audio_unique_name"],
}


Expand Down Expand Up @@ -1470,7 +1473,7 @@ def format_batch(self, batch: Dict) -> Dict:
# get d_vectors from audio file names
if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file:
d_vector_mapping = self.speaker_manager.embeddings
d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]]
d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]]
d_vectors = torch.FloatTensor(d_vectors)

# get language ids from language names
Expand Down
20 changes: 10 additions & 10 deletions tests/data/ljspeech/speakers.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"LJ001-0001.wav": {
"#/wavs/LJ001-0001": {
"name": "ljspeech-0",
"embedding": [
0.05539746582508087,
Expand Down Expand Up @@ -260,7 +260,7 @@
-0.09469571709632874
]
},
"LJ001-0002.wav": {
"#/wavs/LJ001-0002": {
"name": "ljspeech-1",
"embedding": [
0.05539746582508087,
Expand Down Expand Up @@ -521,7 +521,7 @@
-0.09469571709632874
]
},
"LJ001-0003.wav": {
"#/wavs/LJ001-0003": {
"name": "ljspeech-2",
"embedding": [
0.05539746582508087,
Expand Down Expand Up @@ -782,7 +782,7 @@
-0.09469571709632874
]
},
"LJ001-0004.wav": {
"#/wavs/LJ001-0004": {
"name": "ljspeech-3",
"embedding": [
0.05539746582508087,
Expand Down Expand Up @@ -1043,7 +1043,7 @@
-0.09469571709632874
]
},
"LJ001-0005.wav": {
"#/wavs/LJ001-0005": {
"name": "ljspeech-4",
"embedding": [
0.05539746582508087,
Expand Down Expand Up @@ -1304,7 +1304,7 @@
-0.09469571709632874
]
},
"LJ001-0006.wav": {
"#/wavs/LJ001-0006": {
"name": "ljspeech-5",
"embedding": [
0.05539746582508087,
Expand Down Expand Up @@ -1565,7 +1565,7 @@
-0.09469571709632874
]
},
"LJ001-0007.wav": {
"#/wavs/LJ001-0007": {
"name": "ljspeech-6",
"embedding": [
0.05539746582508087,
Expand Down Expand Up @@ -1826,7 +1826,7 @@
-0.09469571709632874
]
},
"LJ001-0008.wav": {
"#/wavs/LJ001-0008": {
"name": "ljspeech-7",
"embedding": [
0.05539746582508087,
Expand Down Expand Up @@ -2087,7 +2087,7 @@
-0.09469571709632874
]
},
"LJ001-0009.wav": {
"#/wavs/LJ001-0009": {
"name": "ljspeech-8",
"embedding": [
0.05539746582508087,
Expand Down Expand Up @@ -2348,7 +2348,7 @@
-0.09469571709632874
]
},
"LJ001-0010.wav": {
"#/wavs/LJ001-0010": {
"name": "ljspeech-9",
"embedding": [
0.05539746582508087,
Expand Down

0 comments on commit 3faccbd

Please sign in to comment.