Skip to content

Commit

Permalink
Fix colliding dataset cache file names (#1994)
Browse files Browse the repository at this point in the history
* Fix colliding dataset cache file names

* Remove unused code
  • Loading branch information
Edresson authored Sep 21, 2022
1 parent 3faccbd commit d6ad9a0
Show file tree
Hide file tree
Showing 34 changed files with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions TTS/tts/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import collections
import os
import random
Expand Down Expand Up @@ -34,6 +35,12 @@ def noise_augment_audio(wav):
return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)


def string2filename(string):
# generate a safe and reversible filename based on a string
filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
return filename


class TTSDataset(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -201,7 +208,7 @@ def get_phonemes(self, idx, text):
def get_f0(self, idx):
out_dict = self.f0_dataset[idx]
item = self.samples[idx]
assert item["audio_file"] == out_dict["audio_file"]
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
return out_dict

@staticmethod
Expand Down Expand Up @@ -561,19 +568,18 @@ def __init__(

def __getitem__(self, index):
item = self.samples[index]
ids = self.compute_or_load(item["audio_file"], item["text"])
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"])
ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}

def __len__(self):
return len(self.samples)

def compute_or_load(self, wav_file, text):
def compute_or_load(self, file_name, text):
"""Compute phonemes for the given text.
If the phonemes are already cached, load them from cache.
"""
file_name = os.path.splitext(os.path.basename(wav_file))[0]
file_ext = "_phoneme.npy"
cache_path = os.path.join(self.cache_path, file_name + file_ext)
try:
Expand Down Expand Up @@ -670,11 +676,11 @@ def __init__(

def __getitem__(self, idx):
item = self.samples[idx]
f0 = self.compute_or_load(item["audio_file"])
f0 = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
if self.normalize_f0:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
f0 = self.normalize(f0)
return {"audio_file": item["audio_file"], "f0": f0}
return {"audio_unique_name": item["audio_unique_name"], "f0": f0}

def __len__(self):
return len(self.samples)
Expand Down Expand Up @@ -706,8 +712,7 @@ def get_pad_id(self):
return self.pad_id

@staticmethod
def create_pitch_file_path(wav_file, cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
def create_pitch_file_path(file_name, cache_path):
pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
return pitch_file

Expand Down Expand Up @@ -745,26 +750,26 @@ def denormalize(self, pitch):
pitch[zero_idxs] = 0.0
return pitch

def compute_or_load(self, wav_file):
def compute_or_load(self, wav_file, audio_unique_name):
"""
compute pitch and return a numpy array of pitch values
"""
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path)
if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
else:
pitch = np.load(pitch_file)
return pitch.astype(np.float32)

def collate_fn(self, batch):
audio_file = [item["audio_file"] for item in batch]
audio_unique_name = [item["audio_unique_name"] for item in batch]
f0s = [item["f0"] for item in batch]
f0_lens = [len(item["f0"]) for item in batch]
f0_lens_max = max(f0_lens)
f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
for i, f0_len in enumerate(f0_lens):
f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens}
return {"audio_unique_name": audio_unique_name, "f0": f0s_torch, "f0_lens": f0_lens}

def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
Expand Down
Binary file modified tests/data/ljspeech/f0_cache/pitch_stats.npy
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit d6ad9a0

Please sign in to comment.