Skip to content

Commit

Permalink
fix: correctly strip/restore initial punctuation (#3336)
Browse files Browse the repository at this point in the history
* refactor(punctuation): remove orphan code for handling lone punctuation

The case of lone punctuation is already handled at the top of restore(). The
removed if statement would never be called and would in fact raise an
AttributeError because the _punc_index named tuple doesn't have the attribute
`mark`.

* refactor(punctuation): remove unused argument

* fix(punctuation): correctly handle initial punctuation

Stripping and restoring initial punctuation didn't work correctly because the
string-splitting caused an additional empty string to be inserted in the text
list (because `".A".split(".")` => `["", "A"]`). Now, an initial empty string is
skipped and relevant test cases are added.

Fixes #3333
  • Loading branch information
eginhard authored Nov 30, 2023
1 parent 9328338 commit 39321d0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
23 changes: 11 additions & 12 deletions TTS/tts/utils/text/punctuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class PuncPosition(Enum):
BEGIN = 0
END = 1
MIDDLE = 2
ALONE = 3


class Punctuation:
Expand Down Expand Up @@ -92,7 +91,7 @@ def _strip_to_restore(self, text):
return [text], []
# the text is only punctuations
if len(matches) == 1 and matches[0].group() == text:
return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
return [], [_PUNC_IDX(text, PuncPosition.BEGIN)]
# build a punctuation map to be used later to restore punctuations
puncs = []
for match in matches:
Expand All @@ -107,11 +106,14 @@ def _strip_to_restore(self, text):
for idx, punc in enumerate(puncs):
split = text.split(punc.punc)
prefix, suffix = split[0], punc.punc.join(split[1:])
text = suffix
if prefix == "":
# We don't want to insert an empty string in case of initial punctuation
continue
splitted_text.append(prefix)
# if the text does not end with a punctuation, add it to the last item
if idx == len(puncs) - 1 and len(suffix) > 0:
splitted_text.append(suffix)
text = suffix
return splitted_text, puncs

@classmethod
Expand All @@ -127,10 +129,10 @@ def restore(cls, text, puncs):
['This is', 'example'], ['.', '!'] -> "This is. example!"
"""
return cls._restore(text, puncs, 0)
return cls._restore(text, puncs)

@classmethod
def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements
def _restore(cls, text, puncs): # pylint: disable=too-many-return-statements
"""Auxiliary method for Punctuation.restore()"""
if not puncs:
return text
Expand All @@ -142,21 +144,18 @@ def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statemen
current = puncs[0]

if current.position == PuncPosition.BEGIN:
return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
return cls._restore([current.punc + text[0]] + text[1:], puncs[1:])

if current.position == PuncPosition.END:
return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)

if current.position == PuncPosition.ALONE:
return [current.mark] + cls._restore(text, puncs[1:], num + 1)
return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:])

# POSITION == MIDDLE
if len(text) == 1: # pragma: nocover
# a corner case where the final part of an intermediate
# mark (I) has not been phonemized
return cls._restore([text[0] + current.punc], puncs[1:], num)
return cls._restore([text[0] + current.punc], puncs[1:])

return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:])


# if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions tests/text_tests/test_punctuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ def setUp(self):
("This, is my text ... to be striped !! from text", "This is my text to be striped from text"),
("This, is my text ... to be striped from text?", "This is my text to be striped from text"),
("This, is my text to be striped from text", "This is my text to be striped from text"),
(".", ""),
(" . ", ""),
("!!! Attention !!!", "Attention"),
("!!! Attention !!! This is just a ... test.", "Attention This is just a test"),
("!!! Attention! This is just a ... test.", "Attention This is just a test"),
]

def test_get_set_puncs(self):
Expand Down

0 comments on commit 39321d0

Please sign in to comment.