Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sentence splitting #3227

Merged
merged 7 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 67 additions & 11 deletions TTS/tts/layers/xtts/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,72 @@
import json
import os
import re
from functools import cached_property

import pypinyin
import torch
import pypinyin
import textwrap

from functools import cached_property
from hangul_romanize import Transliter
from hangul_romanize.rule import academic
from num2words import num2words
from tokenizers import Tokenizer

from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words

from spacy.lang.en import English
from spacy.lang.zh import Chinese
from spacy.lang.ja import Japanese
from spacy.lang.ar import Arabic
from spacy.lang.es import Spanish


def get_spacy_lang(lang):
if lang == "zh":
return Chinese()
elif lang == "ja":
return Japanese()
elif lang == "ar":
return Arabic()
elif lang == "es":
return Spanish()
else:
# For most languages, Enlish does the job
return English()

def split_sentence(text, lang, text_split_length=250):
"""Preprocess the input text"""
text_splits = []
if text_split_length is not None and len(text) >= text_split_length:
text_splits.append("")
nlp = get_spacy_lang(lang)
nlp.add_pipe("sentencizer")
doc = nlp(text)
for sentence in doc.sents:
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
# if the last sentence + the current sentence is less than the text_split_length
# then add the current sentence to the last sentence
text_splits[-1] += " " + str(sentence)
text_splits[-1] = text_splits[-1].lstrip()
elif len(str(sentence)) > text_split_length:
# if the current sentence is greater than the text_split_length
for line in textwrap.wrap(
str(sentence),
width=text_split_length,
drop_whitespace=True,
break_on_hyphens=False,
tabsize=1,
):
text_splits.append(str(line))
else:
text_splits.append(str(sentence))

if len(text_splits) > 1:
if text_splits[0] == "":
del text_splits[0]
else:
text_splits = [text.lstrip()]

return text_splits

_whitespace_re = re.compile(r"\s+")

# List of (regular expression, replacement) pairs for abbreviations:
Expand Down Expand Up @@ -464,7 +519,7 @@ def _expand_number(m, lang="en"):


def expand_numbers_multilingual(text, lang="en"):
if lang == "zh" or lang == "zh-cn":
if lang == "zh":
text = zh_num2words()(text)
else:
if lang in ["en", "ru"]:
Expand Down Expand Up @@ -525,7 +580,7 @@ def japanese_cleaners(text, katsu):
return text


def korean_cleaners(text):
def korean_transliterate(text):
r = Transliter(academic)
return r.translit(text)

Expand All @@ -546,7 +601,7 @@ def __init__(self, vocab_file=None):
"it": 213,
"pt": 203,
"pl": 224,
"zh-cn": 82,
"zh": 82,
"ar": 166,
"cs": 186,
"ru": 182,
Expand All @@ -571,19 +626,20 @@ def check_input_length(self, txt, lang):
)

def preprocess_text(self, txt, lang):
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}:
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "ko"}:
txt = multilingual_cleaners(txt, lang)
if lang in {"zh", "zh-cn"}:
if lang == "zh":
txt = chinese_transliterate(txt)
if lang == "ko":
txt = korean_transliterate(txt)
elif lang == "ja":
txt = japanese_cleaners(txt, self.katsu)
elif lang == "ko":
txt = korean_cleaners(txt)
else:
raise NotImplementedError(f"Language '{lang}' is not supported.")
return txt

def encode(self, txt, lang):
lang = lang.split("-")[0] # remove the region
self.check_input_length(txt, lang)
txt = self.preprocess_text(txt, lang)
txt = f"[{lang}]{txt}"
Expand Down
Loading
Loading