Skip to content

Commit

Permalink
Merge pull request #183 from idiap/openvoice
Browse files Browse the repository at this point in the history
Add OpenVoice VC models
  • Loading branch information
eginhard authored Dec 3, 2024
2 parents 98a372b + 3539e65 commit 9ae0b27
Show file tree
Hide file tree
Showing 28 changed files with 671 additions and 65 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ repos:
- repo: "https://github.com/pre-commit/pre-commit-hooks"
rev: v5.0.0
hooks:
- id: check-json
files: "TTS/.models.json"
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
Expand Down
20 changes: 14 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@

## 🐸Coqui TTS News
- 📣 Fork of the [original, unmaintained repository](https://github.com/coqui-ai/TTS). New PyPI package: [coqui-tts](https://pypi.org/project/coqui-tts)
- 📣 [OpenVoice](https://github.com/myshell-ai/OpenVoice) models now available for voice conversion.
- 📣 Prebuilt wheels are now also published for Mac and Windows (in addition to Linux as before) for easier installation across platforms.
- 📣 ⓍTTSv2 is here with 16 languages and better performance across the board.
- 📣 ⓍTTSv2 is here with 17 languages and better performance across the board. ⓍTTS can stream with <200ms latency.
- 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/idiap/coqui-ai-TTS/tree/dev/recipes/ljspeech).
- 📣 ⓍTTS can now stream with <200ms latency.
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/latest/models/xtts.html)
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/bark.html)
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
- 📣 You can use [Fairseq models in ~1100 languages](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.

## <img src="https://raw.githubusercontent.com/idiap/coqui-ai-TTS/main/images/coqui-log-green-TTS.png" height="56"/>

Expand Down Expand Up @@ -121,6 +120,7 @@ repository are also still a useful source of information.

### Voice Conversion
- FreeVC: [paper](https://arxiv.org/abs/2210.15418)
- OpenVoice: [technical report](https://arxiv.org/abs/2312.01479)

You can also help us implement more models.

Expand Down Expand Up @@ -244,8 +244,14 @@ tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progr
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
```

#### Example voice cloning together with the voice conversion model.
This way, you can clone voices by using any model in 🐸TTS.
Other available voice conversion models:
- `voice_conversion_models/multilingual/multi-dataset/openvoice_v1`
- `voice_conversion_models/multilingual/multi-dataset/openvoice_v2`

#### Example voice cloning together with the default voice conversion model.

This way, you can clone voices by using any model in 🐸TTS. The FreeVC model is
used for voice conversion after synthesizing speech.

```python

Expand Down Expand Up @@ -412,4 +418,6 @@ $ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<mode
|- (same)
|- vocoder/ (Vocoder models.)
|- (same)
|- vc/ (Voice conversion models.)
|- (same)
```
22 changes: 22 additions & 0 deletions TTS/.models.json
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,28 @@
"license": "MIT",
"commit": null
}
},
"multi-dataset": {
"openvoice_v1": {
"hf_url": [
"https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter/config.json",
"https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter/checkpoint.pth"
],
"description": "OpenVoice VC model from https://huggingface.co/myshell-ai/OpenVoiceV2",
"author": "MyShell.ai",
"license": "MIT",
"commit": null
},
"openvoice_v2": {
"hf_url": [
"https://huggingface.co/myshell-ai/OpenVoiceV2/resolve/main/converter/config.json",
"https://huggingface.co/myshell-ai/OpenVoiceV2/resolve/main/converter/checkpoint.pth"
],
"description": "OpenVoice VC model from https://huggingface.co/myshell-ai/OpenVoiceV2",
"author": "MyShell.ai",
"license": "MIT",
"commit": null
}
}
}
}
Expand Down
14 changes: 9 additions & 5 deletions TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.model_name = model_name
model_path, config_path, _, _, _ = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu)
model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer(
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
)

def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of 🐸TTS models by name.
Expand Down Expand Up @@ -355,15 +357,17 @@ def voice_conversion(
target_wav (str):`
Path to the target wav file.
"""
wav = self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav)
return wav
if self.voice_converter is None:
msg = "The selected model does not support voice conversion."
raise RuntimeError(msg)
return self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav)

def voice_conversion_to_file(
self,
source_wav: str,
target_wav: str,
file_path: str = "output.wav",
):
) -> str:
"""Voice conversion with FreeVC. Convert source wav to target speaker.
Args:
Expand Down
24 changes: 12 additions & 12 deletions TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,18 +407,18 @@ def main():

# load models
synthesizer = Synthesizer(
tts_path,
tts_config_path,
speakers_file_path,
language_ids_file_path,
vocoder_path,
vocoder_config_path,
encoder_path,
encoder_config_path,
vc_path,
vc_config_path,
model_dir,
args.voice_dir,
tts_checkpoint=tts_path,
tts_config_path=tts_config_path,
tts_speakers_file=speakers_file_path,
tts_languages_file=language_ids_file_path,
vocoder_checkpoint=vocoder_path,
vocoder_config=vocoder_config_path,
encoder_checkpoint=encoder_path,
encoder_config=encoder_config_path,
vc_checkpoint=vc_path,
vc_config=vc_config_path,
model_dir=model_dir,
voice_dir=args.voice_dir,
).to(device)

# query speaker ids of a multi-speaker model.
Expand Down
4 changes: 2 additions & 2 deletions TTS/tts/layers/vits/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def __init__(
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)

def forward(self, x, x_lengths, g=None):
def forward(self, x, x_lengths, g=None, tau=1.0):
"""
Shapes:
- x: :math:`[B, C, T]`
Expand All @@ -268,5 +268,5 @@ def forward(self, x, x_lengths, g=None):
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
z = (mean + torch.randn_like(mean) * tau * torch.exp(log_scale)) * x_mask
return z, mean, log_scale, x_mask
2 changes: 1 addition & 1 deletion TTS/utils/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def _find_files(output_path: str) -> Tuple[str, str]:
model_file = None
config_file = None
for file_name in os.listdir(output_path):
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]:
model_file = os.path.join(output_path, file_name)
elif file_name == "config.json":
config_file = os.path.join(output_path, file_name)
Expand Down
32 changes: 26 additions & 6 deletions TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import time
from pathlib import Path
from typing import List

import numpy as np
Expand All @@ -15,7 +16,9 @@
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import save_wav
from TTS.vc.configs.openvoice_config import OpenVoiceConfig
from TTS.vc.models import setup_model as setup_vc_model
from TTS.vc.models.openvoice import OpenVoice
from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input

Expand All @@ -25,6 +28,7 @@
class Synthesizer(nn.Module):
def __init__(
self,
*,
tts_checkpoint: str = "",
tts_config_path: str = "",
tts_speakers_file: str = "",
Expand Down Expand Up @@ -91,23 +95,20 @@ def __init__(

if tts_checkpoint:
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]

if vocoder_checkpoint:
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]

if vc_checkpoint:
if vc_checkpoint and model_dir is None:
self._load_vc(vc_checkpoint, vc_config, use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]

if model_dir:
if "fairseq" in model_dir:
self._load_fairseq_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
elif "openvoice" in model_dir:
self._load_openvoice_from_dir(Path(model_dir), use_cuda)
else:
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]

@staticmethod
def _get_segmenter(lang: str):
Expand Down Expand Up @@ -136,6 +137,7 @@ def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> N
"""
# pylint: disable=global-statement
self.vc_config = load_config(vc_config_path)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
self.vc_model = setup_vc_model(config=self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
if use_cuda:
Expand All @@ -150,16 +152,32 @@ def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None:
self.tts_model = Vits.init_from_config(self.tts_config)
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
self.tts_config = self.tts_model.config
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if use_cuda:
self.tts_model.cuda()

def _load_openvoice_from_dir(self, checkpoint: Path, use_cuda: bool) -> None:
"""Load the OpenVoice model from a directory.
We assume the model knows how to load itself from the directory and
there is a config.json file in the directory.
"""
self.vc_config = OpenVoiceConfig()
self.vc_model = OpenVoice.init_from_config(self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True)
self.vc_config = self.vc_model.config
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if use_cuda:
self.vc_model.cuda()

def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the TTS model from a directory.
We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
"""
config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
self.tts_model = setup_tts_model(config)
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
if use_cuda:
Expand All @@ -181,6 +199,7 @@ def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -
"""
# pylint: disable=global-statement
self.tts_config = load_config(tts_config_path)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None:
raise ValueError("Phonemizer is not defined in the TTS config.")

Expand Down Expand Up @@ -218,6 +237,7 @@ def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> N
use_cuda (bool): enable/disable CUDA use.
"""
self.vocoder_config = load_config(model_config)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio)
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
Expand Down
2 changes: 1 addition & 1 deletion TTS/vc/configs/freevc_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class FreeVCConfig(BaseVCConfig):
If true, language embedding is used. Defaults to `False`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
Check :class:`TTS.tts.configs.shared_configs.BaseVCConfig` for the inherited parameters.
Example:
Expand Down
Loading

0 comments on commit 9ae0b27

Please sign in to comment.