Skip to content

Commit

Permalink
Add sentence splitting (#3227)
Browse files Browse the repository at this point in the history
* Add sentence spliting

* update requirements

* update default args v2

* Add spanish

* Fix return gpt_latents

* Update requirements

* Fix requirements
  • Loading branch information
WeberJulian authored Nov 16, 2023
1 parent 3c2d5a9 commit 675f983
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 118 deletions.
78 changes: 67 additions & 11 deletions TTS/tts/layers/xtts/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,72 @@
import json
import os
import re
from functools import cached_property

import pypinyin
import torch
import pypinyin
import textwrap

from functools import cached_property
from hangul_romanize import Transliter
from hangul_romanize.rule import academic
from num2words import num2words
from tokenizers import Tokenizer

from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words

from spacy.lang.en import English
from spacy.lang.zh import Chinese
from spacy.lang.ja import Japanese
from spacy.lang.ar import Arabic
from spacy.lang.es import Spanish


def get_spacy_lang(lang):
if lang == "zh":
return Chinese()
elif lang == "ja":
return Japanese()
elif lang == "ar":
return Arabic()
elif lang == "es":
return Spanish()
else:
# For most languages, Enlish does the job
return English()

def split_sentence(text, lang, text_split_length=250):
"""Preprocess the input text"""
text_splits = []
if text_split_length is not None and len(text) >= text_split_length:
text_splits.append("")
nlp = get_spacy_lang(lang)
nlp.add_pipe("sentencizer")
doc = nlp(text)
for sentence in doc.sents:
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
# if the last sentence + the current sentence is less than the text_split_length
# then add the current sentence to the last sentence
text_splits[-1] += " " + str(sentence)
text_splits[-1] = text_splits[-1].lstrip()
elif len(str(sentence)) > text_split_length:
# if the current sentence is greater than the text_split_length
for line in textwrap.wrap(
str(sentence),
width=text_split_length,
drop_whitespace=True,
break_on_hyphens=False,
tabsize=1,
):
text_splits.append(str(line))
else:
text_splits.append(str(sentence))

if len(text_splits) > 1:
if text_splits[0] == "":
del text_splits[0]
else:
text_splits = [text.lstrip()]

return text_splits

_whitespace_re = re.compile(r"\s+")

# List of (regular expression, replacement) pairs for abbreviations:
Expand Down Expand Up @@ -464,7 +519,7 @@ def _expand_number(m, lang="en"):


def expand_numbers_multilingual(text, lang="en"):
if lang == "zh" or lang == "zh-cn":
if lang == "zh":
text = zh_num2words()(text)
else:
if lang in ["en", "ru"]:
Expand Down Expand Up @@ -525,7 +580,7 @@ def japanese_cleaners(text, katsu):
return text


def korean_cleaners(text):
def korean_transliterate(text):
r = Transliter(academic)
return r.translit(text)

Expand All @@ -546,7 +601,7 @@ def __init__(self, vocab_file=None):
"it": 213,
"pt": 203,
"pl": 224,
"zh-cn": 82,
"zh": 82,
"ar": 166,
"cs": 186,
"ru": 182,
Expand All @@ -571,19 +626,20 @@ def check_input_length(self, txt, lang):
)

def preprocess_text(self, txt, lang):
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "ko"}:
txt = multilingual_cleaners(txt, lang)
if lang in {"zh", "zh-cn"}:
if lang == "zh":
txt = chinese_transliterate(txt)
if lang == "ko":
txt = korean_transliterate(txt)
elif lang == "ja":
txt = japanese_cleaners(txt, self.katsu)
elif lang == "ko":
txt = korean_cleaners(txt)
else:
raise NotImplementedError(f"Language '{lang}' is not supported.")
return txt

def encode(self, txt, lang):
lang = lang.split("-")[0] # remove the region
self.check_input_length(txt, lang)
txt = self.preprocess_text(txt, lang)
txt = f"[{lang}]{txt}"
Expand Down
Loading

0 comments on commit 675f983

Please sign in to comment.