Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speech2text #39

Merged
merged 3 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions besser/bot/nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,14 @@
default value: ``0.4``
"""

NLP_STT_HF_MODEL = Property(SECTION_NLP, 'nlp.speech2text.hf.model', str, None)
"""
The name of the Hugging Face model for the HFSpeech2Text bot component. If none is provided, the component will not be
activated.

name: ``nlp.speech2text.hf.model``

type: ``str``

default value: ``None``
"""
3 changes: 3 additions & 0 deletions besser/bot/nlp/intent_classifier/simple_intent_classifier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import TYPE_CHECKING

import keras
Expand All @@ -18,6 +19,8 @@
if TYPE_CHECKING:
from besser.bot.nlp.nlp_engine import NLPEngine

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


class SimpleIntentClassifier(IntentClassifier):
"""A Simple Intent Classifier.
Expand Down
26 changes: 25 additions & 1 deletion besser/bot/nlp/nlp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from besser.bot.nlp.ner.ner import NER
from besser.bot.nlp.ner.simple_ner import SimpleNER
from besser.bot.nlp.preprocessing.pipelines import lang_map
from besser.bot.nlp.speech2text.hf_speech2text import HFSpeech2Text
from besser.bot.nlp.speech2text.speech2text import Speech2Text

if TYPE_CHECKING:
from besser.bot.core.bot import Bot
from besser.bot.core.state import State
Expand All @@ -29,12 +32,14 @@ class NLPEngine:
_intent_classifiers (dict[State, IntentClassifier]): The collection of Intent Classifiers of the NLPEngine.
There is one for each bot state (only states with transitions triggered by intent matching)
_ner (NER or None): The NER (Named Entity Recognition) system of the NLPEngine
_speech2text (Speech2Text or None): The Speech-to-Text System of the NLPEngine
"""

def __init__(self, bot: 'Bot'):
self._bot: 'Bot' = bot
self._intent_classifiers: dict['State', IntentClassifier] = {}
self._ner: NER or None = None
self._speech2text: Speech2Text or None = None

@property
def ner(self):
Expand All @@ -44,12 +49,18 @@ def ner(self):
def initialize(self) -> None:
"""Initialize the NLPEngine."""
if self.get_property(nlp.NLP_LANGUAGE) in lang_map.values():
self._bot.set_property(nlp.NLP_LANGUAGE, list(lang_map.keys())[list(lang_map.values()).index(self.get_property(nlp.NLP_LANGUAGE))])
# Set the language to ISO 639-1 format (e.g., 'english' => 'en')
self._bot.set_property(
nlp.NLP_LANGUAGE,
list(lang_map.keys())[list(lang_map.values()).index(self.get_property(nlp.NLP_LANGUAGE))]
)
for state in self._bot.states:
if state not in self._intent_classifiers and state.intents:
self._intent_classifiers[state] = SimpleIntentClassifier(self, state)
# TODO: Only instantiate the NER if asked (maybe a bot does not need NER), via bot properties
self._ner = SimpleNER(self, self._bot)
if self.get_property(nlp.NLP_STT_HF_MODEL):
self._speech2text = HFSpeech2Text(self)

def get_property(self, prop: Property) -> Any:
"""Get a NLP property's value from the NLPEngine's bot.
Expand Down Expand Up @@ -120,3 +131,16 @@ def get_best_intent_prediction(
if best_intent_prediction.score < intent_threshold:
return None
return best_intent_prediction

def speech2text(self, speech: bytes):
"""Transcribe a voice audio into its corresponding text representation.

Args:
speech (bytes): the recorded voice that wants to be transcribed

Returns:
str: the speech transcription
"""
text = self._speech2text.speech2text(speech)
logging.info(f"[Speech2Text] Transcribed audio message: '{text}'")
return self._speech2text.speech2text(speech)
Empty file.
54 changes: 54 additions & 0 deletions besser/bot/nlp/speech2text/hf_speech2text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import io
from typing import TYPE_CHECKING

import librosa
from transformers import AutoProcessor, TFAutoModelForSpeechSeq2Seq, logging

from besser.bot import nlp
from besser.bot.nlp.speech2text.speech2text import Speech2Text

if TYPE_CHECKING:
from besser.bot.nlp.nlp_engine import NLPEngine

logging.set_verbosity_error()


class HFSpeech2Text(Speech2Text):
"""A Hugging Face Speech2Text.

It loads a Speech2Text Hugging Face model to perform the Speech2Text task.

.. warning::

Only tested with ``openai/whisper-*`` models

Args:
nlp_engine (NLPEngine): the NLPEngine that handles the NLP processes of the bot

Attributes:
_model_name (str): the Hugging Face model name
_processor (): the model text processor
_model (): the Speech2Text model
_sampling_rate (int): the sampling rate of audio data, it must coincide with the sampling rate used to train the
model
_forced_decoder_ids (list): the decoder ids
"""

def __init__(self, nlp_engine: 'NLPEngine'):
super().__init__(nlp_engine)
self._model_name: str = self._nlp_engine.get_property(nlp.NLP_STT_HF_MODEL)
self._processor = AutoProcessor.from_pretrained(self._model_name)
self._model = TFAutoModelForSpeechSeq2Seq.from_pretrained(self._model_name)
self._sampling_rate: int = 16000
# self.model.config.forced_decoder_ids = None
self._forced_decoder_ids = self._processor.get_decoder_prompt_ids(
language=self._nlp_engine.get_property(nlp.NLP_LANGUAGE), task="transcribe"
)

def speech2text(self, speech: bytes):
wav_stream = io.BytesIO(speech)
audio, sampling_rate = librosa.load(wav_stream, sr=self._sampling_rate)
input_features = self._processor(audio, sampling_rate=self._sampling_rate, return_tensors="tf").input_features
predicted_ids = self._model.generate(input_features, forced_decoder_ids=self._forced_decoder_ids)
transcriptions = self._processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcriptions[0]
37 changes: 37 additions & 0 deletions besser/bot/nlp/speech2text/speech2text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from besser.bot.nlp.nlp_engine import NLPEngine


class Speech2Text(ABC):
"""The Speech2Text abstract class.

The Speech2Text component, also known as STT, Automatic Speech Recognition or ASR, is in charge of converting spoken
language or audio speech signals into written text. This task is called transcribing.

We can use it in a chatbot to allow the users to send voice messages and transcribe them to written text so the bot
can process them like regular text messages.

Args:
nlp_engine (NLPEngine): the NLPEngine that handles the NLP processes of the bot

Attributes:
_nlp_engine (): The NLPEngine that handles the NLP processes of the bot
"""

def __init__(self, nlp_engine: 'NLPEngine'):
self._nlp_engine: 'NLPEngine' = nlp_engine

@abstractmethod
def speech2text(self, speech: bytes) -> str:
"""Transcribe a voice audio into its corresponding text representation.

Args:
speech (bytes): the recorded voice that wants to be transcribed

Returns:
str: the speech transcription
"""
pass
3 changes: 3 additions & 0 deletions besser/bot/platforms/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class PayloadAction(Enum):
USER_MESSAGE = 'user_message'
"""PayloadAction: Indicates that the payload's purpose is to send a user message."""

USER_VOICE = 'user_voice'
"""PayloadAction: Indicates that the payload's purpose is to send a user audio."""

RESET = 'reset'
"""PayloadAction: Use the :class:`~besser.bot.platforms.websocket.websocket_platform.WebSocketPlatform` on this
bot.
Expand Down
4 changes: 2 additions & 2 deletions besser/bot/platforms/telegram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
The Telegram Bot token. Used to connect to the Telegram Bot

type: string
type: ``str``

default value: None
default value: ``None``
"""
12 changes: 12 additions & 0 deletions besser/bot/platforms/telegram/telegram_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ async def reset(update: Update, context: ContextTypes.DEFAULT_TYPE):
reset_handler = CommandHandler('reset', reset)
self._handlers.append(reset_handler)

# Handler for voice messages
async def voice(update: Update, context: ContextTypes.DEFAULT_TYPE):
session_id = str(update.effective_chat.id)
session = self._bot.get_or_create_session(session_id, self)
voice_file = await context.bot.get_file(update.message.voice.file_id)
voice_data = await voice_file.download_as_bytearray()
text = self._bot.nlp_engine.speech2text(bytes(voice_data))
self._bot.receive_message(session.id, text)

voice_handler = MessageHandler(filters.VOICE, voice)
self._handlers.append(voice_handler)

@property
def telegram_app(self):
"""telegram.ext._application.Application: The Telegram app."""
Expand Down
11 changes: 11 additions & 0 deletions besser/bot/platforms/websocket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@
default value: ``8765``
"""

WEBSOCKET_MAX_SIZE = Property(SECTION_WEBSOCKET, 'websocket.max_size', int, None)
"""
WebSocket's maximum size of incoming messages, in bytes. :obj:`None` disables the limit.

name: ``websocket.max_size``

type: ``int``

default value: ``None``
"""

STREAMLIT_HOST = Property(SECTION_WEBSOCKET, 'streamlit.host', str, 'localhost')
"""
The Streamlit UI host address. If you are using our default UI, you must define its address where you can access and
Expand Down
23 changes: 20 additions & 3 deletions besser/bot/platforms/websocket/streamlit_ui.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import json
import queue
import sys
Expand All @@ -7,6 +8,7 @@
import pandas as pd
import streamlit as st
import websocket
from audio_recorder_streamlit import audio_recorder
from streamlit.runtime import Runtime
from streamlit.runtime.app_session import AppSession
from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx
Expand Down Expand Up @@ -114,16 +116,31 @@ def on_pong(ws, data):
ws = st.session_state['websocket']

with st.sidebar:
reset_button = st.button(label="Reset bot")
if reset_button:

if reset_button := st.button(label="Reset bot"):
st.session_state['history'] = []
st.session_state['queue'] = queue.Queue()
payload = Payload(action=PayloadAction.RESET)
ws.send(json.dumps(payload, cls=PayloadEncoder))

if voice_bytes := audio_recorder(text=None, pause_threshold=2):
if 'last_voice_message' not in st.session_state or st.session_state['last_voice_message'] != voice_bytes:
st.session_state['last_voice_message'] = voice_bytes
# Encode the audio bytes to a base64 string
st.session_state.history.append((voice_bytes, 1))
voice_base64 = base64.b64encode(voice_bytes).decode('utf-8')
payload = Payload(action=PayloadAction.USER_VOICE, message=voice_base64)
try:
ws.send(json.dumps(payload, cls=PayloadEncoder))
except Exception as e:
st.error('Your message could not be sent. The connection is already closed')

for message in st.session_state['history']:
with st.chat_message(user_type[message[1]]):
st.write(message[0])
if isinstance(message[0], bytes):
st.audio(message[0], format="audio/wav")
else:
st.write(message[0])

first_message = True
while not st.session_state['queue'].empty():
Expand Down
13 changes: 12 additions & 1 deletion besser/bot/platforms/websocket/websocket_platform.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import inspect
import json
import logging
Expand Down Expand Up @@ -71,6 +72,11 @@ def message_handler(conn: ServerConnection) -> None:
payload: Payload = Payload.decode(payload_str)
if payload.action == PayloadAction.USER_MESSAGE.value:
self._bot.receive_message(session.id, payload.message)
elif payload.action == PayloadAction.USER_VOICE.value:
# Decode the base64 string to get audio bytes
audio_bytes = base64.b64decode(payload.message.encode('utf-8'))
message = self._bot.nlp_engine.speech2text(audio_bytes)
self._bot.receive_message(session.id, message)
elif payload.action == PayloadAction.RESET.value:
self._bot.reset(session.id)
except ConnectionClosedError:
Expand All @@ -85,7 +91,12 @@ def message_handler(conn: ServerConnection) -> None:
def initialize(self) -> None:
self._host = self._bot.get_property(websocket.WEBSOCKET_HOST)
self._port = self._bot.get_property(websocket.WEBSOCKET_PORT)
self._websocket_server = serve(self._message_handler, self._host, self._port)
self._websocket_server = serve(
handler=self._message_handler,
host=self._host,
port=self._port,
max_size=self._bot.get_property(websocket.WEBSOCKET_MAX_SIZE)
)

def start(self) -> None:
if self._use_ui:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api/nlp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ nlp
nlp/number
nlp/pipelines
nlp/text_preprocessing
nlp/hf_speech2text
nlp/speech2text
8 changes: 8 additions & 0 deletions docs/source/api/nlp/hf_speech2text.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
hf_speech2text
==============

.. automodule:: besser.bot.nlp.speech2text.hf_speech2text
:members:
:private-members:
:undoc-members:
:show-inheritance:
8 changes: 8 additions & 0 deletions docs/source/api/nlp/speech2text.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
speech2text
===========

.. automodule:: besser.bot.nlp.speech2text.speech2text
:members:
:private-members:
:undoc-members:
:show-inheritance:
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
audio-recorder-streamlit==0.0.8
dateparser==1.1.8
keras==2.14.0
librosa==0.10.1
nltk==3.8.1
numpy==1.26.1
pandas==2.1.1
Expand All @@ -8,5 +10,6 @@ snowballstemmer==2.2.0
streamlit==1.27.2
tensorflow==2.14.0
text2num==2.5.0
transformers==4.34.1
websocket-client==1.6.4
websockets==11.0.3