Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Sep 4, 2023
2 parents 530a893 + 40b5273 commit 33b5e87
Show file tree
Hide file tree
Showing 16 changed files with 157 additions and 34 deletions.
21 changes: 12 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,18 +187,21 @@ More details about the docker images (like GPU support) can be found [here](http

### 🐍 Python API

#### Running a multi-speaker and multi-lingual model

```python
import torch
from TTS.api import TTS

# Running a multi-speaker and multi-lingual model
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"

# List available 🐸TTS models and choose the first one
model_name = TTS.list_models()[0]
model_name = TTS().list_models()[0]
# Init TTS
tts = TTS(model_name)
tts = TTS(model_name).to(device)

# Run TTS

# ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language
# Text to speech with a numpy output
wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
Expand All @@ -210,13 +213,13 @@ tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.langu

```python
# Init TTS with the target model name
tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False).to(device)

# Run TTS
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)

# Example voice cloning with YourTTS in English, French and Portuguese

tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False).to(device)
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr-fr", file_path="output.wav")
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt-br", file_path="output.wav")
Expand All @@ -227,7 +230,7 @@ tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav",
Converting the voice in `source_wav` to the voice of `target_wav`

```python
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False).to("cuda")
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
```

Expand Down Expand Up @@ -256,7 +259,7 @@ These models will follow the naming convention `coqui_studio/en/<studio_speaker_
# XTTS model
models = TTS(cs_api_model="XTTS").list_models()
# Init TTS with the target studio speaker
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_bar=False, gpu=False)
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_bar=False)
# Run TTS
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH)

Expand Down
2 changes: 1 addition & 1 deletion TTS/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.16.5
0.16.6
8 changes: 6 additions & 2 deletions TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def main():
help="Output wav file path.",
)
parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False)
parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu")
parser.add_argument(
"--vocoder_path",
type=str,
Expand Down Expand Up @@ -391,6 +392,10 @@ def main():
if args.encoder_path is not None:
encoder_path = args.encoder_path
encoder_config_path = args.encoder_config_path

device = args.device
if args.use_cuda:
device = "cuda"

# load models
synthesizer = Synthesizer(
Expand All @@ -406,8 +411,7 @@ def main():
vc_config_path,
model_dir,
args.voice_dir,
args.use_cuda,
)
).to(device)

# query speaker ids of a multi-speaker model.
if args.list_speaker_idxs:
Expand Down
Empty file.
34 changes: 34 additions & 0 deletions TTS/tts/utils/text/belarusian/phonemizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

finder = None


def init():
try:
import jpype
import jpype.imports
except ModuleNotFoundError:
raise ModuleNotFoundError("Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`.")

try:
jar_path = os.environ["BEL_FANETYKA_JAR"]
except KeyError:
raise KeyError("You need to define 'BEL_FANETYKA_JAR' environment variable as path to the fanetyka.jar file")

jpype.startJVM(classpath=[jar_path])

# import the Java modules
from org.alex73.korpus.base import GrammarDB2, GrammarFinder

grammar_db = GrammarDB2.initializeFromJar()
global finder
finder = GrammarFinder(grammar_db)


def belarusian_text_to_phonemes(text: str) -> str:
# Initialize only on first run
if finder is None:
init()

from org.alex73.fanetyka.impl import FanetykaText
return str(FanetykaText(finder, text).ipa)
4 changes: 4 additions & 0 deletions TTS/tts/utils/text/phonemizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from TTS.tts.utils.text.phonemizers.bangla_phonemizer import BN_Phonemizer
from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
Expand Down Expand Up @@ -35,6 +36,7 @@
DEF_LANG_TO_PHONEMIZER["zh-cn"] = ZH_CN_Phonemizer.name()
DEF_LANG_TO_PHONEMIZER["ko-kr"] = KO_KR_Phonemizer.name()
DEF_LANG_TO_PHONEMIZER["bn"] = BN_Phonemizer.name()
DEF_LANG_TO_PHONEMIZER["be"] = BEL_Phonemizer.name()


# JA phonemizer has deal breaking dependencies like MeCab for some systems.
Expand Down Expand Up @@ -68,6 +70,8 @@ def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer:
return KO_KR_Phonemizer(**kwargs)
if name == "bn_phonemizer":
return BN_Phonemizer(**kwargs)
if name == "be_phonemizer":
return BEL_Phonemizer(**kwargs)
raise ValueError(f"Phonemizer {name} not found")


Expand Down
55 changes: 55 additions & 0 deletions TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Dict

from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes

_DEF_BE_PUNCS = ",!." # TODO


class BEL_Phonemizer(BasePhonemizer):
"""🐸TTS be phonemizer using functions in `TTS.tts.utils.text.belarusian.phonemizer`
Args:
punctuations (str):
Set of characters to be treated as punctuation. Defaults to `_DEF_BE_PUNCS`.
keep_puncs (bool):
If True, keep the punctuations after phonemization. Defaults to False.
"""

language = "be"

def __init__(self, punctuations=_DEF_BE_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument
super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs)

@staticmethod
def name():
return "be_phonemizer"

@staticmethod
def phonemize_be(text: str, separator: str = "|") -> str: # pylint: disable=unused-argument
return belarusian_text_to_phonemes(text)

def _phonemize(self, text, separator):
return self.phonemize_be(text, separator)

@staticmethod
def supported_languages() -> Dict:
return {"be": "Belarusian"}

def version(self) -> str:
return "0.0.1"

def is_available(self) -> bool:
return True


if __name__ == "__main__":
txt = "тэст"
e = BEL_Phonemizer()
print(e.supported_languages())
print(e.version())
print(e.language)
print(e.name())
print(e.is_available())
print("`" + e.phonemize(txt) + "`")
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
furo
myst-parser == 2.0.0
sphinx == 7.0.1
sphinx == 7.2.5
sphinx_inline_tabs
sphinx_copybutton
linkify-it-py
2 changes: 1 addition & 1 deletion docs/source/docker_images.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Start the container and get a shell inside it.
```bash
docker run --rm -it -p 5002:5002 --entrypoint /bin/bash ghcr.io/coqui-ai/tts-cpu
python3 TTS/server/server.py --list_models #To get the list of available models
python3 TTS/server/server.py --model_name tts_models/en/vctk/vits
python3 TTS/server/server.py --model_name tts_models/en/vctk/vits
```

### GPU version
Expand Down
4 changes: 1 addition & 3 deletions docs/source/implementing_a_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
There is also the `callback` interface by which you can manipulate both the model and the `Trainer` states. Callbacks give you
an infinite flexibility to add custom behaviours for your model and training routines.

For more details, see {ref}`BaseTTS <Base TTS Model>` and :obj:`TTS.utils.callbacks`.
For more details, see {ref}`BaseTTS <Base tts Model>` and :obj:`TTS.utils.callbacks`.

6. Optionally, define `MyModelArgs`.

Expand Down Expand Up @@ -204,5 +204,3 @@ class MyModel(BaseTTS):
pass

```


12 changes: 6 additions & 6 deletions docs/source/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ You can run a multi-speaker and multi-lingual model in Python as
from TTS.api import TTS

# List available 🐸TTS models and choose the first one
model_name = TTS.list_models()[0]
model_name = TTS().list_models()[0]
# Init TTS
tts = TTS(model_name)
# Run TTS
Expand All @@ -132,15 +132,15 @@ tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.langu

```python
# Init TTS with the target model name
tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False)
# Run TTS
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
```

#### Example voice cloning with YourTTS in English, French and Portuguese:

```python
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False).to("cuda")
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="output.wav")
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
Expand All @@ -149,7 +149,7 @@ tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav",
#### Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav`

```python
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False).to("cuda")
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
```

Expand Down Expand Up @@ -177,7 +177,7 @@ You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API toke
# The name format is coqui_studio/en/<studio_speaker_name>/coqui_studio
models = TTS().list_models()
# Init TTS with the target studio speaker
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_bar=False, gpu=False)
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_bar=False)
# Run TTS
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH)
# Run TTS with emotion and speed control
Expand Down Expand Up @@ -222,7 +222,7 @@ You can find the list of language ISO codes [here](https://dl.fbaipublicfiles.co

```python
from TTS.api import TTS
api = TTS(model_name="tts_models/eng/fairseq/vits", gpu=True)
api = TTS(model_name="tts_models/eng/fairseq/vits").to("cuda")
api.tts_to_file("This is a test.", file_path="output.wav")

# TTS with on the fly voice conversion
Expand Down
6 changes: 3 additions & 3 deletions docs/source/main_classes/model_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ Model API provides you a set of functions that easily make your model compatible
## Base TTS Model

```{eval-rst}
.. autoclass:: TTS.model.BaseModel
.. autoclass:: TTS.model.BaseTrainerModel
:members:
```

## Base `tts` Model
## Base tts Model

```{eval-rst}
.. autoclass:: TTS.tts.models.base_tts.BaseTTS
:members:
```

## Base `vocoder` Model
## Base vocoder Model

```{eval-rst}
.. autoclass:: TTS.vocoder.models.base_vocoder.BaseVocoder
Expand Down
6 changes: 0 additions & 6 deletions docs/source/models/bark.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,6 @@ tts --model_name tts_models/multilingual/multi-dataset/bark \
:members:
```

## BarkArgs
```{eval-rst}
.. autoclass:: TTS.tts.models.bark.BarkArgs
:members:
```

## Bark Model
```{eval-rst}
.. autoclass:: TTS.tts.models.bark.Bark
Expand Down
4 changes: 3 additions & 1 deletion recipes/bel-alex73/train_glowtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
output_path=output_path,
add_blank=True,
datasets=[dataset_config],
characters=characters,
# characters=characters,
enable_eos_bos_chars=True,
mixed_precision=False,
save_step=10000,
Expand All @@ -69,6 +69,8 @@
text_cleaner="no_cleaners",
audio=audio_config,
test_sentences=[],
use_phonemes=True,
phoneme_language="be",
)

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
numpy==1.22.0;python_version<="3.10"
numpy==1.24.3;python_version>"3.10"
cython==0.29.30
scipy>=1.4.0
scipy>=1.11.2
torch>=1.7
torchaudio
soundfile
Expand Down
29 changes: 29 additions & 0 deletions tests/text_tests/test_belarusian_phonemizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import warnings
import unittest

from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes

_TEST_CASES = """
Фанетычны канвертар/fanʲɛˈtɨt͡ʂnɨ kanˈvʲɛrtar
Гэтак мы працавалі/ˈɣɛtak ˈmɨ prat͡saˈvalʲi
"""


class TestText(unittest.TestCase):
def test_belarusian_text_to_phonemes(self):
try:
os.environ["BEL_FANETYKA_JAR"]
except KeyError:
warnings.warn(
"You need to define 'BEL_FANETYKA_JAR' environment variable as path to the fanetyka.jar file to test Belarusian phonemizer",
Warning)
return

for line in _TEST_CASES.strip().split("\n"):
text, phonemes = line.split("/")
self.assertEqual(belarusian_text_to_phonemes(text), phonemes)


if __name__ == "__main__":
unittest.main()

0 comments on commit 33b5e87

Please sign in to comment.