diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index dd16783381..12c811f0da 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -121,4 +121,4 @@ def get_language_balancer_weights(items: list): dataset_samples_weight = np.array([weight_language[l] for l in language_ids]) # normalize dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) - return torch.from_numpy(dataset_samples_weight).double() + return torch.from_numpy(dataset_samples_weight).float() diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 102593387a..b5f114203f 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -421,4 +421,4 @@ def get_speaker_balancer_weights(items: list): dataset_samples_weight = np.array([weight_speaker[l] for l in speaker_ids]) # normalize dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) - return torch.from_numpy(dataset_samples_weight).double() + return torch.from_numpy(dataset_samples_weight).float()