diff --git a/common/speech/lasr_speech_recognition_msgs/CMakeLists.txt b/common/speech/lasr_speech_recognition_msgs/CMakeLists.txt index ae6bac4d8..10e0472f1 100644 --- a/common/speech/lasr_speech_recognition_msgs/CMakeLists.txt +++ b/common/speech/lasr_speech_recognition_msgs/CMakeLists.txt @@ -7,7 +7,11 @@ project(lasr_speech_recognition_msgs) ## Find catkin macros and libraries ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) ## is used, also find other catkin packages -find_package(catkin REQUIRED COMPONENTS message_generation) +find_package(catkin REQUIRED COMPONENTS message_generation genmsg actionlib_msgs actionlib std_msgs) +add_action_files( + DIRECTORY action + FILES TranscribeSpeech.action +) ## System dependencies are found with CMake's conventions # find_package(Boost REQUIRED COMPONENTS system) @@ -63,8 +67,9 @@ add_service_files( ## Generate added messages and services with any dependencies listed here generate_messages( -# DEPENDENCIES -# std_msgs # Or other packages containing msgs + DEPENDENCIES + std_msgs # Or other packages containing msgs + actionlib_msgs ) ################################################ diff --git a/common/speech/lasr_speech_recognition_msgs/action/TranscribeSpeech.action b/common/speech/lasr_speech_recognition_msgs/action/TranscribeSpeech.action new file mode 100644 index 000000000..486b0cb19 --- /dev/null +++ b/common/speech/lasr_speech_recognition_msgs/action/TranscribeSpeech.action @@ -0,0 +1,6 @@ +--- +#result definition +string sequence +--- +#feedback +string sequence \ No newline at end of file diff --git a/common/speech/lasr_speech_recognition_msgs/package.xml b/common/speech/lasr_speech_recognition_msgs/package.xml index e319ec5db..6f00b03f4 100644 --- a/common/speech/lasr_speech_recognition_msgs/package.xml +++ b/common/speech/lasr_speech_recognition_msgs/package.xml @@ -51,6 +51,8 @@ catkin message_generation message_runtime + actionlib_msgs + actionlib_msgs diff --git a/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt b/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt index d2cce16f7..a11465954 100644 --- a/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt +++ b/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt @@ -7,7 +7,11 @@ project(lasr_speech_recognition_whisper) ## Find catkin macros and libraries ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) ## is used, also find other catkin packages -find_package(catkin REQUIRED catkin_virtualenv) +find_package(catkin REQUIRED catkin_virtualenv genmsg actionlib_msgs actionlib std_msgs) +# add_action_files( +# DIRECTORY action +# FILES TranscribeSpeech.action +# ) ## System dependencies are found with CMake's conventions # find_package(Boost REQUIRED COMPONENTS system) @@ -70,7 +74,8 @@ catkin_generate_virtualenv( ## Generate added messages and services with any dependencies listed here # generate_messages( # DEPENDENCIES -# std_msgs # Or other packages containing msgs +# std_msgs +# actionlib_msgs # Or other packages containing msgs # ) ################################################ @@ -162,8 +167,11 @@ include_directories( catkin_install_python(PROGRAMS nodes/simple_transcribe_microphone nodes/transcribe_microphone + nodes/transcribe_microphone_server scripts/list_microphones.py scripts/test_microphones.py + scripts/repeat_after_me.py + scripts/test_speech_server.py DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} ) diff --git a/common/speech/lasr_speech_recognition_whisper/doc/USAGE.md b/common/speech/lasr_speech_recognition_whisper/doc/USAGE.md index 6c0649478..2bb966c13 100644 --- a/common/speech/lasr_speech_recognition_whisper/doc/USAGE.md +++ b/common/speech/lasr_speech_recognition_whisper/doc/USAGE.md @@ -26,3 +26,31 @@ Stop listening whenever: ```bash rosservice call /whisper/stop_listening "{}" ``` + +Run an actionlib server to transcribe the microphone: + +```bash +rosrun lasr_speech_recognition_whisper transcribe_microphone_server +``` + +The response from the request is a `string` containing the transcribed text. + +Several command line configuration options exist, which can be viewed with: + +```bash +rosrun lasr_speech_recognition_whisper transcribe_microphone_server --help +``` + +Get tiago to repeat, with TTS the transcribed speech output; he will begin repeating after hearing "tiago, repeat ...." and stop once hearing "tiago, stop..." + +```bash +rosrun lasr_speech_recognition_whisper repeat_after_me.py +``` + +To constantly listen and view transcribed speech output in the command line (by constantly sending requests to the actionlib server), run the following script: + +```bash +rosrun lasr_speech_recongition_whisper test_speech_server.py +``` + + diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server new file mode 100644 index 000000000..95f3b1e36 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server @@ -0,0 +1,351 @@ +#!/usr/bin/env python3 + +import os +import argparse +from typing import Optional +from dataclasses import dataclass +from pathlib import Path +from timeit import default_timer as timer + +import rospy +import numpy as np +import torch +import actionlib +import speech_recognition as sr # type: ignore +import lasr_speech_recognition_msgs.msg # type: ignore +from lasr_speech_recognition_whisper import load_model # type: ignore + +# Error handler to remove ALSA error messages taken from: +# https://stackoverflow.com/questions/7088672/pyaudio-working-but-spits-out-error-messages-each-time/17673011#17673011 + +from ctypes import * +from contextlib import contextmanager + +ERROR_HANDLER_FUNC = CFUNCTYPE(None, c_char_p, c_int, c_char_p, c_int, c_char_p) + + +def py_error_handler(filename, line, function, err, fmt): + pass + + +c_error_handler = ERROR_HANDLER_FUNC(py_error_handler) + + +@contextmanager +def noalsaerr(): + asound = cdll.LoadLibrary("libasound.so") + asound.snd_lib_error_set_handler(c_error_handler) + yield + asound.snd_lib_error_set_handler(None) + + +@dataclass +class speech_model_params: + """Class for storing speech recognition model parameters. + + Args: + model_name (str, optional): Name of the speech recognition model. Defaults to "medium.en". + Must be a valid Whisper model name. + device (str, optional): Device to run the model on. Defaults to "cuda" if available, otherwise "cpu". + start_timeout (float): Max number of seconds of silence when starting listening before stopping. Defaults to 5.0. + end_timeout (Optional[float]): Max number of seconds of silence after starting listening before stopping. Defaults to None. + sample_rate (int): Sample rate of the microphone. Defaults to 16000Hz. + mic_device (Optional[str]): Microphone device index or name. Defaults to None. + timer_duration (Optional[int]): Duration of the timer for adjusting the microphone for ambient noise. Defaults to 20 seconds. + warmup (bool): Whether to warmup the model by running inference on a test file. Defaults to True. + """ + + model_name: str = "medium.en" + device: str = "cuda" if torch.cuda.is_available() else "cpu" + start_timeout: float = 5.0 + end_timeout: Optional[float] = None + sample_rate: int = 16000 + mic_device: Optional[str] = None + timer_duration: Optional[int] = 20 + warmup: bool = True + + +class TranscribeSpeechAction(object): + # create messages that are used to publish feedback/result + _feedback = lasr_speech_recognition_msgs.msg.TranscribeSpeechFeedback() + _result = lasr_speech_recognition_msgs.msg.TranscribeSpeechResult() + + def __init__( + self, + action_name: str, + model_params: speech_model_params, + ) -> None: + """Starts an action server for transcribing speech. + + Args: + action_name (str): Name of the action server. + """ + + self._action_name = action_name + self._model_params = model_params + + with noalsaerr(): + self._model = load_model( + self._model_params.model_name, + self._model_params.device, + self._model_params.warmup, + ) + # Configure the speech recogniser object and adjust for ambient noise + self.recogniser = self._configure_recogniser() + # Setup the action server and register execution callback + self._action_server = actionlib.SimpleActionServer( + self._action_name, + lasr_speech_recognition_msgs.msg.TranscribeSpeechAction, + execute_cb=self.execute_cb, + auto_start=False, + ) + self._action_server.register_preempt_callback(self.prempt_cb) + # Setup the timer for adjusting the microphone for ambient noise every x seconds + self._timer_duration = self._model_params.timer_duration + self._timer = rospy.Timer( + rospy.Duration(self._timer_duration), self._timer_cb + ) + self._listening = False + + self._action_server.start() + + def _timer_cb(self, _) -> None: + """Adjusts the microphone for ambient noise, unless the action server is listening.""" + if self._listening: + return + rospy.loginfo("Adjusting microphone for ambient noise...") + with noalsaerr(): + with self._configure_microphone() as source: + self.recogniser.adjust_for_ambient_noise(source) + + def _reset_timer(self) -> None: + """Resets the timer for adjusting the microphone for ambient noise.""" + self._timer.shutdown() + self._timer = rospy.Timer(rospy.Duration(self._timer_duration), self._timer_cb) + + def _configure_microphone(self) -> sr.Microphone: + """Configures the microphone for listening to speech based on the + microphone device index or name. + + Returns: microphone object + """ + + if self._model_params.mic_device is None: + # If no microphone device is specified, use the system default microphone + return sr.Microphone(sample_rate=self._model_params.sample_rate) + elif self._model_params.mic_device.isdigit(): + return sr.Microphone( + device_index=int(self._model_params.mic_device), + sample_rate=self._model_params.sample_rate, + ) + else: + microphones = enumerate(sr.Microphone.list_microphone_names()) + for index, name in microphones: + if self._model_params.mic_device in name: + return sr.Microphone( + device_index=index, + sample_rate=self._model_params.sample_rate, + ) + raise ValueError( + f"Could not find microphone with name: {self._model_params.mic_device}" + ) + + def _configure_recogniser(self, ambient_adj: bool = True) -> sr.Recognizer: + """Configures the speech recogniser object. + + Args: + ambient_adj (bool, optional): Whether to adjust for ambient noise. Defaults to True. + + Returns: + sr.Recognizer: speech recogniser object. + """ + self._listening = True + recogniser = sr.Recognizer() + if ambient_adj: + with self._configure_microphone() as source: + recogniser.adjust_for_ambient_noise(source) + self._listening = False + return recogniser + + def prempt_cb(self) -> None: + """Callback for preempting the action server. + + Sets server to preempted state. + """ + preempted_str = f"{self._action_name} has been preempted" + rospy.loginfo(preempted_str) + self._result.sequence = preempted_str + self._action_server.set_preempted(result=self._result, text=preempted_str) + + def execute_cb(self, goal) -> None: + """Callback for executing the action server. + + Checks for preemption before listening and before and after transcribing, returning + if preemption is requested. + + Args: + goal: UNUSED - actionlib requires a goal argument in the execute callback, but + this action server does not use a goal. + """ + rospy.loginfo("Request Received") + if self._action_server.is_preempt_requested(): + return + # Since we are about to listen, reset the timer for adjusting the microphone for ambient noise + # as this assumes self_timer_duration seconds of silence before adjusting + self._reset_timer() + with noalsaerr(): + with self._configure_microphone() as src: + self._listening = True + wav_data = self.recogniser.listen( + src, + timeout=self._model_params.start_timeout, + phrase_time_limit=self._model_params.end_timeout, + ).get_wav_data() + # Magic number 32768.0 is the maximum value of a 16-bit signed integer + float_data = ( + np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C") + / 32768.0 + ) + + if self._action_server.is_preempt_requested(): + self._listening = False + return + + rospy.loginfo(f"Transcribing phrase with Whisper...") + transcription_start_time = timer() + # Cast to fp16 if using GPU + phrase = self._model.transcribe( + float_data, + fp16=self._model_params.device == "cuda", + )["text"] + transcription_end_time = timer() + rospy.loginfo(f"Transcription finished!") + rospy.loginfo( + f"Time taken: {transcription_end_time - transcription_start_time:.2f}s" + ) + + if self._action_server.is_preempt_requested(): + self._listening = False + return + + self._result.sequence = phrase + rospy.loginfo(f"Transcribed phrase: {phrase}") + rospy.loginfo(f"{self._action_name} has succeeded") + self._action_server.set_succeeded(self._result) + + # Have this at the very end to not disrupt the action server + self._listening = False + + +def parse_args() -> dict: + """Parses the command line arguments into a name: value dictinoary. + + Returns: + dict: Dictionary of name: value pairs of command line arguments. + """ + parser = argparse.ArgumentParser( + description="Starts an action server for transcribing speech." + ) + + parser.add_argument( + "--action-name", + type=str, + default="transcribe_speech", + help="Name of the action server.", + ) + parser.add_argument( + "--model-name", + type=str, + default="medium.en", + help="Name of the speech recognition model.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run the model on.", + ) + parser.add_argument( + "--start_timeout", + type=float, + default=5.0, + help="Timeout for listening for the start of a phrase.", + ) + parser.add_argument( + "--end_timeout", + type=float, + default=None, + help="Timeout for listening for the end of a phrase.", + ) + parser.add_argument( + "--sample_rate", + type=int, + default=16000, + help="Sample rate of the microphone.", + ) + parser.add_argument( + "--mic_device", + type=str, + default=None, + help="Microphone device index or name", + ) + parser.add_argument( + "--timer_duration", + type=int, + default=20, + help="Number of seconds of silence before the ambient noise adjustment is called.", + ) + parser.add_argument( + "--no_warmup", + action="store_true", + help="Disable warming up the model by running inference on a test file.", + ) + + return vars(parser.parse_args()) + + +def configure_model_params(config: dict) -> speech_model_params: + """Configures the speech model parameters based on the provided + command line parameters. + + Args: + config (dict): Command line parameters parsed in dictionary form. + + Returns: + speech_model_params: dataclass containing the speech model parameters + """ + model_params = speech_model_params() + if config["model_name"]: + model_params.model_name = config["model_name"] + if config["device"]: + model_params.device = config["device"] + if config["start_timeout"]: + model_params.start_timeout = config["start_timeout"] + if config["end_timeout"]: + model_params.end_timeout = config["end_timeout"] + if config["sample_rate"]: + model_params.sample_rate = config["sample_rate"] + if config["mic_device"]: + model_params.mic_device = config["mic_device"] + if config["timer_duration"]: + model_params.timer_duration = config["timer_duration"] + if config["no_warmup"]: + model_params.warmup = False + + return model_params + + +def configure_whisper_cache() -> None: + """Configures the whisper cache directory.""" + whisper_cache = os.path.join(str(Path.home()), ".cache", "whisper") + os.makedirs(whisper_cache, exist_ok=True) + # Environemntal variable required to run whisper locally + os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache + + +if __name__ == "__main__": + configure_whisper_cache() + config = parse_args() + rospy.init_node(config["action_name"]) + server = TranscribeSpeechAction(rospy.get_name(), configure_model_params(config)) + rospy.spin() diff --git a/common/speech/lasr_speech_recognition_whisper/package.xml b/common/speech/lasr_speech_recognition_whisper/package.xml index 39935c089..4c6f49965 100644 --- a/common/speech/lasr_speech_recognition_whisper/package.xml +++ b/common/speech/lasr_speech_recognition_whisper/package.xml @@ -51,6 +51,10 @@ catkin catkin_virtualenv lasr_speech_recognition_msgs + actionlib + actionlib_msgs + actionlib + actionlib_msgs diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py b/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py new file mode 100644 index 000000000..2e6b20622 --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +import rospy +import actionlib +from lasr_voice import Voice # type: ignore +from lasr_speech_recognition_msgs.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore +from lasr_speech_recognition_msgs.msg import ( # type: ignore + TranscribeSpeechAction, + TranscribeSpeechGoal, +) + +# import actionlib +rospy.init_node("repeat") + +USE_ACTIONLIB = True + +voice = Voice() + + +if USE_ACTIONLIB: + client = actionlib.SimpleActionClient("transcribe_speech", TranscribeSpeechAction) + client.wait_for_server() + repeating = False + rospy.loginfo("Done waiting") + while not rospy.is_shutdown(): + goal = TranscribeSpeechGoal() + client.send_goal(goal) + client.wait_for_result() + result = client.get_result() + text = result.sequence + print(text) + if "tiago" in text.lower().strip(): + if "repeat" in text.lower().strip(): + repeating = True + voice.sync_tts("Okay, I'll start repeating now.") + continue + elif "stop" in text.lower().strip(): + repeating = False + voice.sync_tts("Okay, I'll stop repeating now.") + break + if repeating: + voice.sync_tts(f"I heard {text}") +else: + transcribe = rospy.ServiceProxy("/whisper/transcribe_audio", TranscribeAudio) + repeating = False + while not rospy.is_shutdown(): + text = transcribe().phrase + print(text) + if "tiago" in text.lower().strip(): + if "repeat" in text.lower().strip(): + repeating = True + voice.sync_tts("Okay, I'll start repeating now.") + continue + elif "stop" in text.lower().strip(): + repeating = False + voice.sync_tts("Okay, I'll stop repeating now.") + break + if repeating: + voice.sync_tts(f"I heard {text}") diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py new file mode 100644 index 000000000..fef16eb0c --- /dev/null +++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +import rospy +import actionlib +from lasr_speech_recognition_msgs.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore +from lasr_speech_recognition_msgs.msg import ( # type: ignore + TranscribeSpeechAction, + TranscribeSpeechGoal, +) + + +rospy.init_node("test_speech_server") +client = actionlib.SimpleActionClient("transcribe_speech", TranscribeSpeechAction) +client.wait_for_server() +rospy.loginfo("Done waiting") +while not rospy.is_shutdown(): + goal = TranscribeSpeechGoal() + client.send_goal(goal) + client.wait_for_result() + result = client.get_result() + text = result.sequence + print(f"Transcribed Speech: {text}") diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py index d0ec731fc..42ec44785 100644 --- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py +++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py @@ -1,17 +1,43 @@ -import whisper +import os +import whisper # type: ignore +import rospkg # type: ignore import rospy # Keep all loaded models in memory MODEL_CACHE = {} -def load_model(name: str, device: str = 'cpu'): - ''' - Load a given Whisper model - ''' + +def load_model( + name: str, device: str = "cpu", load_test_file: bool = False +) -> whisper.Whisper: + """Loads a whisper model from disk, or from cache if it has already been loaded. + + Args: + name (str): Name of the whisper model. Must be the name of an official whisper + model, or the path to a model checkpoint. + device (str, optional): Pytorch device to put the model on. Defaults to 'cpu'. + load_test_file (bool, optional): Whether to run inference on a test audio file + after loading the model (if model is not in cache). Defaults to False. Test file + is assumed to be called "test.m4a" and be in the root of the package directory. + + Returns: + whisper.Whisper: Whisper model instance + """ global MODEL_CACHE if name not in MODEL_CACHE: - rospy.loginfo(f'Load model {name}') + rospy.loginfo(f"Loading model {name}") MODEL_CACHE[name] = whisper.load_model(name, device=device) - + rospy.loginfo(f"Sucessfully loaded model {name} on {device}") + if load_test_file: + package_root = rospkg.RosPack().get_path("lasr_speech_recognition_whisper") + example_fp = os.path.join(package_root, "test.m4a") + rospy.loginfo( + "Running transcription on example file to ensure model is loaded..." + ) + test_result: str = MODEL_CACHE[name].transcribe( + example_fp, fp16=device == "cuda" + ) + rospy.loginfo(f"Transcription test result: {test_result}") + return MODEL_CACHE[name]