Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

voiceassistant: keep punctuations when sending agent transcription #648

Merged
merged 12 commits into from
Sep 2, 2024
5 changes: 5 additions & 0 deletions .changeset/quick-owls-relax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

voiceassistant: keep punctuations when sending agent transcription
30 changes: 18 additions & 12 deletions livekit-agents/livekit/agents/tokenize/_basic_paragraph.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
def split_paragraphs(text: str) -> list[str]:
sep = "\n\n"

paragraphs = text.split(sep)
new_paragraphs = []
for p in paragraphs:
p = p.strip()
if not p:
continue
new_paragraphs.append(p)

return new_paragraphs
import re


def split_paragraphs(text: str) -> list[tuple[str, int, int]]:
"""
Split the text into paragraphs.
Returns a list of paragraphs with their start and end indices of the original text.
"""
matches = re.finditer(r"\n{2,}", text)
paragraphs = []

for match in matches:
paragraph = match.group(0)
start_pos = match.start()
end_pos = match.end()
paragraphs.append((paragraph.strip(), start_pos, end_pos))

return paragraphs
43 changes: 27 additions & 16 deletions livekit-agents/livekit/agents/tokenize/_basic_sent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import re


# rule based segmentation from https://stackoverflow.com/a/31505798, works surprisingly well
def split_sentences(text: str, min_sentence_len: int = 20) -> list[str]:
"""the text can't contains substrings "<prd>" or "<stop>"""
# rule based segmentation based on https://stackoverflow.com/a/31505798, works surprisingly well
def split_sentences(
text: str, min_sentence_len: int = 20
) -> list[tuple[str, int, int]]:
"""
the text may not contain substrings "<prd>" or "<stop>"
"""
alphabets = r"([A-Za-z])"
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
Expand All @@ -14,12 +18,11 @@ def split_sentences(text: str, min_sentence_len: int = 20) -> list[str]:
multiple_dots = r"\.{2,}"

# fmt: off
text = " " + text + " "
text = text.replace("\n"," ")
text = re.sub(prefixes,"\\1<prd>",text)
text = re.sub(websites,"<prd>\\1",text)
text = re.sub(prefixes,"\\1<prd>", text)
text = re.sub(websites,"<prd>\\1", text)
text = re.sub(digits + "[.]" + digits,"\\1<prd>\\2",text)
#text = re.sub(multiple_dots, lambda match: "<prd>" * len(match.group(0)) + "<stop>", text)
# text = re.sub(multiple_dots, lambda match: "<prd>" * len(match.group(0)) + "<stop>", text)
# TODO(theomonnom): need improvement for ""..." dots", check capital + next sentence should not be
# small
text = re.sub(multiple_dots, lambda match: "<prd>" * len(match.group(0)), text)
Expand All @@ -44,21 +47,29 @@ def split_sentences(text: str, min_sentence_len: int = 20) -> list[str]:
text = text.replace("?","?<stop>")
text = text.replace("!","!<stop>")
text = text.replace("<prd>",".")
sentences = text.split("<stop>")
sentences = [s.strip() for s in sentences]
if sentences and not sentences[-1]:
sentences = sentences[:-1]
# fmt: on

new_sentences = []
splitted_sentences = text.split("<stop>")
text = text.replace("<stop>", "")

sentences: list[tuple[str, int, int]] = []

buff = ""
for sentence in sentences:
start_pos = 0
end_pos = 0
for match in splitted_sentences:
sentence = match.strip()
if not sentence:
continue

buff += " " + sentence
end_pos += len(match)
if len(buff) > min_sentence_len:
new_sentences.append(buff[1:])
sentences.append((buff[1:], start_pos, end_pos))
start_pos = end_pos
buff = ""

if buff:
new_sentences.append(buff[1:])
sentences.append((buff[1:], start_pos, len(text) - 1))

return new_sentences
return sentences
41 changes: 27 additions & 14 deletions livekit-agents/livekit/agents/tokenize/_basic_word.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
import re


def split_words(text: str, ignore_punctuation: bool = True) -> list[str]:
def split_words(
text: str, ignore_punctuation: bool = True
) -> list[tuple[str, int, int]]:
"""
Split the text into words.
Returns a list of words with their start and end indices of the original text.
"""
# fmt: off
punctuations = [".", ",", "!", "?", ";", ":", "'", '"', "(", ")", "[", "]", "{", "}", "<", ">",
"—"]
punctuations = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe overkill but for completeness' sake maybe we should take a slice of the unicode punctuations section?

instead of checking for each of the characters we just take its code and check if it's 0x2000 <-> 0x206F
would be shorter too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how I can make this work with str.maketrans, not a big deal, let's keep stuff explicit :)

'?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', '±', '—', '‘', '’', '“', '”', '…']

# fmt: on

if ignore_punctuation:
for p in punctuations:
# TODO(theomonnom): Ignore acronyms
text = text.replace(p, "")
matches = re.finditer(r"\S+", text)
words: list[tuple[str, int, int]] = []

for match in matches:
word = match.group(0)
start_pos = match.start()
end_pos = match.end()

if ignore_punctuation:
# TODO(theomonnom): acronyms passthrough
translation_table = str.maketrans("", "", "".join(punctuations))
word = word.translate(translation_table)

if not word:
continue

words = re.split("[ \n]+", text)
new_words = []
for word in words:
if not word:
continue # ignore empty
new_words.append(word)
words.append((word, start_pos, end_pos))

return new_words
return words
20 changes: 13 additions & 7 deletions livekit-agents/livekit/agents/tokenize/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ def __init__(
)

def tokenize(self, text: str, *, language: str | None = None) -> list[str]:
return _basic_sent.split_sentences(
text, min_sentence_len=self._config.min_sentence_len
)
return [
tok[0]
for tok in _basic_sent.split_sentences(
text, min_sentence_len=self._config.min_sentence_len
)
]

def stream(self, *, language: str | None = None) -> tokenizer.SentenceStream:
return token_stream.BufferedSentenceStream(
Expand All @@ -65,9 +68,12 @@ def __init__(self, *, ignore_punctuation: bool = True) -> None:
self._ignore_punctuation = ignore_punctuation

def tokenize(self, text: str, *, language: str | None = None) -> list[str]:
return _basic_word.split_words(
text, ignore_punctuation=self._ignore_punctuation
)
return [
tok[0]
for tok in _basic_word.split_words(
text, ignore_punctuation=self._ignore_punctuation
)
]

def stream(self, *, language: str | None = None) -> tokenizer.WordStream:
return token_stream.BufferedWordStream(
Expand All @@ -84,4 +90,4 @@ def hyphenate_word(word: str) -> list[str]:


def tokenize_paragraphs(text: str) -> list[str]:
return _basic_paragraph.split_paragraphs(text)
return [tok[0] for tok in _basic_paragraph.split_paragraphs(text)]
35 changes: 26 additions & 9 deletions livekit-agents/livekit/agents/tokenize/token_stream.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from __future__ import annotations

from typing import Callable
import typing
from typing import Callable, Union

from ..utils import aio, shortuuid
from .tokenizer import SentenceStream, TokenData, WordStream

# Tokenizers can either provide us with a list of tokens or a list of tokens along with their start and end indices.
# If the start and end indices are not available, we attempt to locate the token within the text using str.find.
TokenizeCallable = Callable[[str], Union[list[str], list[tuple[str, int, int]]]]


class BufferedTokenStream:
def __init__(
self,
*,
tokenize_fnc: Callable[[str], list[str]],
tokenize_fnc: TokenizeCallable,
min_token_len: int,
min_ctx_len: int,
) -> None:
Expand All @@ -23,6 +28,7 @@ def __init__(
self._buf_tokens: list[str] = [] # <= min_token_len
self._buf = ""

@typing.no_type_check
def push_text(self, text: str) -> None:
self._check_not_closed()
self._buf += text
Expand All @@ -39,26 +45,37 @@ def push_text(self, text: str) -> None:
buf += " "

tok = tokens.pop(0)
buf += tok
tok_text = tok
if isinstance(tok, tuple):
tok_text = tok[0]

buf += tok_text
buf_toks.append(tok)
if len(buf) >= self._min_token_len:
self._event_ch.send_nowait(
TokenData(token=buf, segment_id=self._current_segment_id)
)

for i, tok in enumerate(buf_toks):
tok_i = self._buf.find(tok)
self._buf = self._buf[tok_i + len(tok) :].lstrip()
if isinstance(tok, tuple):
self._buf = self._buf[tok[2] :]
else:
for i, tok in enumerate(buf_toks):
tok_i = max(self._buf.find(tok), 0)
self._buf = self._buf[tok_i + len(tok) :].lstrip()

buf_toks = []
buf = ""

@typing.no_type_check
def flush(self) -> None:
self._check_not_closed()
if self._buf:
tokens = self._tokenize_fnc(self._buf)
if tokens:
buf = " ".join(tokens)
if isinstance(tokens[0], tuple):
buf = " ".join([tok[0] for tok in tokens])
else:
buf = " ".join(tokens)
else:
buf = self._buf

Expand Down Expand Up @@ -92,7 +109,7 @@ class BufferedSentenceStream(BufferedTokenStream, SentenceStream):
def __init__(
self,
*,
tokenizer: Callable[[str], list[str]],
tokenizer: TokenizeCallable,
min_token_len: int,
min_ctx_len: int,
) -> None:
Expand All @@ -107,7 +124,7 @@ class BufferedWordStream(BufferedTokenStream, WordStream):
def __init__(
self,
*,
tokenizer: Callable[[str], list[str]],
tokenizer: TokenizeCallable,
min_token_len: int,
min_ctx_len: int,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ class AssistantTranscriptionOptions:
sentence_tokenizer: tokenize.SentenceTokenizer = tokenize.basic.SentenceTokenizer()
"""The tokenizer used to split the speech into sentences.
This is used to decide when to mark a transcript as final for the agent transcription."""
word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer()
word_tokenizer: tokenize.WordTokenizer = tokenize.basic.WordTokenizer(
ignore_punctuation=False
)
"""The tokenizer used to split the speech into words.
This is used to simulate the "interim results" of the agent transcription."""
hyphenate_word: Callable[[str], list[str]] = tokenize.basic.hyphenate_word
Expand Down
Loading