diff --git a/common/speech/lasr_speech_recognition_msgs/CMakeLists.txt b/common/speech/lasr_speech_recognition_msgs/CMakeLists.txt
index ae6bac4d8..10e0472f1 100644
--- a/common/speech/lasr_speech_recognition_msgs/CMakeLists.txt
+++ b/common/speech/lasr_speech_recognition_msgs/CMakeLists.txt
@@ -7,7 +7,11 @@ project(lasr_speech_recognition_msgs)
## Find catkin macros and libraries
## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
## is used, also find other catkin packages
-find_package(catkin REQUIRED COMPONENTS message_generation)
+find_package(catkin REQUIRED COMPONENTS message_generation genmsg actionlib_msgs actionlib std_msgs)
+add_action_files(
+ DIRECTORY action
+ FILES TranscribeSpeech.action
+)
## System dependencies are found with CMake's conventions
# find_package(Boost REQUIRED COMPONENTS system)
@@ -63,8 +67,9 @@ add_service_files(
## Generate added messages and services with any dependencies listed here
generate_messages(
-# DEPENDENCIES
-# std_msgs # Or other packages containing msgs
+ DEPENDENCIES
+ std_msgs # Or other packages containing msgs
+ actionlib_msgs
)
################################################
diff --git a/common/speech/lasr_speech_recognition_msgs/action/TranscribeSpeech.action b/common/speech/lasr_speech_recognition_msgs/action/TranscribeSpeech.action
new file mode 100644
index 000000000..486b0cb19
--- /dev/null
+++ b/common/speech/lasr_speech_recognition_msgs/action/TranscribeSpeech.action
@@ -0,0 +1,6 @@
+---
+#result definition
+string sequence
+---
+#feedback
+string sequence
\ No newline at end of file
diff --git a/common/speech/lasr_speech_recognition_msgs/package.xml b/common/speech/lasr_speech_recognition_msgs/package.xml
index e319ec5db..6f00b03f4 100644
--- a/common/speech/lasr_speech_recognition_msgs/package.xml
+++ b/common/speech/lasr_speech_recognition_msgs/package.xml
@@ -51,6 +51,8 @@
catkin
message_generation
message_runtime
+ actionlib_msgs
+ actionlib_msgs
diff --git a/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt b/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt
index d2cce16f7..a11465954 100644
--- a/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt
+++ b/common/speech/lasr_speech_recognition_whisper/CMakeLists.txt
@@ -7,7 +7,11 @@ project(lasr_speech_recognition_whisper)
## Find catkin macros and libraries
## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz)
## is used, also find other catkin packages
-find_package(catkin REQUIRED catkin_virtualenv)
+find_package(catkin REQUIRED catkin_virtualenv genmsg actionlib_msgs actionlib std_msgs)
+# add_action_files(
+# DIRECTORY action
+# FILES TranscribeSpeech.action
+# )
## System dependencies are found with CMake's conventions
# find_package(Boost REQUIRED COMPONENTS system)
@@ -70,7 +74,8 @@ catkin_generate_virtualenv(
## Generate added messages and services with any dependencies listed here
# generate_messages(
# DEPENDENCIES
-# std_msgs # Or other packages containing msgs
+# std_msgs
+# actionlib_msgs # Or other packages containing msgs
# )
################################################
@@ -162,8 +167,11 @@ include_directories(
catkin_install_python(PROGRAMS
nodes/simple_transcribe_microphone
nodes/transcribe_microphone
+ nodes/transcribe_microphone_server
scripts/list_microphones.py
scripts/test_microphones.py
+ scripts/repeat_after_me.py
+ scripts/test_speech_server.py
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
)
diff --git a/common/speech/lasr_speech_recognition_whisper/doc/USAGE.md b/common/speech/lasr_speech_recognition_whisper/doc/USAGE.md
index 6c0649478..2bb966c13 100644
--- a/common/speech/lasr_speech_recognition_whisper/doc/USAGE.md
+++ b/common/speech/lasr_speech_recognition_whisper/doc/USAGE.md
@@ -26,3 +26,31 @@ Stop listening whenever:
```bash
rosservice call /whisper/stop_listening "{}"
```
+
+Run an actionlib server to transcribe the microphone:
+
+```bash
+rosrun lasr_speech_recognition_whisper transcribe_microphone_server
+```
+
+The response from the request is a `string` containing the transcribed text.
+
+Several command line configuration options exist, which can be viewed with:
+
+```bash
+rosrun lasr_speech_recognition_whisper transcribe_microphone_server --help
+```
+
+Get tiago to repeat, with TTS the transcribed speech output; he will begin repeating after hearing "tiago, repeat ...." and stop once hearing "tiago, stop..."
+
+```bash
+rosrun lasr_speech_recognition_whisper repeat_after_me.py
+```
+
+To constantly listen and view transcribed speech output in the command line (by constantly sending requests to the actionlib server), run the following script:
+
+```bash
+rosrun lasr_speech_recongition_whisper test_speech_server.py
+```
+
+
diff --git a/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server
new file mode 100644
index 000000000..95f3b1e36
--- /dev/null
+++ b/common/speech/lasr_speech_recognition_whisper/nodes/transcribe_microphone_server
@@ -0,0 +1,351 @@
+#!/usr/bin/env python3
+
+import os
+import argparse
+from typing import Optional
+from dataclasses import dataclass
+from pathlib import Path
+from timeit import default_timer as timer
+
+import rospy
+import numpy as np
+import torch
+import actionlib
+import speech_recognition as sr # type: ignore
+import lasr_speech_recognition_msgs.msg # type: ignore
+from lasr_speech_recognition_whisper import load_model # type: ignore
+
+# Error handler to remove ALSA error messages taken from:
+# https://stackoverflow.com/questions/7088672/pyaudio-working-but-spits-out-error-messages-each-time/17673011#17673011
+
+from ctypes import *
+from contextlib import contextmanager
+
+ERROR_HANDLER_FUNC = CFUNCTYPE(None, c_char_p, c_int, c_char_p, c_int, c_char_p)
+
+
+def py_error_handler(filename, line, function, err, fmt):
+ pass
+
+
+c_error_handler = ERROR_HANDLER_FUNC(py_error_handler)
+
+
+@contextmanager
+def noalsaerr():
+ asound = cdll.LoadLibrary("libasound.so")
+ asound.snd_lib_error_set_handler(c_error_handler)
+ yield
+ asound.snd_lib_error_set_handler(None)
+
+
+@dataclass
+class speech_model_params:
+ """Class for storing speech recognition model parameters.
+
+ Args:
+ model_name (str, optional): Name of the speech recognition model. Defaults to "medium.en".
+ Must be a valid Whisper model name.
+ device (str, optional): Device to run the model on. Defaults to "cuda" if available, otherwise "cpu".
+ start_timeout (float): Max number of seconds of silence when starting listening before stopping. Defaults to 5.0.
+ end_timeout (Optional[float]): Max number of seconds of silence after starting listening before stopping. Defaults to None.
+ sample_rate (int): Sample rate of the microphone. Defaults to 16000Hz.
+ mic_device (Optional[str]): Microphone device index or name. Defaults to None.
+ timer_duration (Optional[int]): Duration of the timer for adjusting the microphone for ambient noise. Defaults to 20 seconds.
+ warmup (bool): Whether to warmup the model by running inference on a test file. Defaults to True.
+ """
+
+ model_name: str = "medium.en"
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
+ start_timeout: float = 5.0
+ end_timeout: Optional[float] = None
+ sample_rate: int = 16000
+ mic_device: Optional[str] = None
+ timer_duration: Optional[int] = 20
+ warmup: bool = True
+
+
+class TranscribeSpeechAction(object):
+ # create messages that are used to publish feedback/result
+ _feedback = lasr_speech_recognition_msgs.msg.TranscribeSpeechFeedback()
+ _result = lasr_speech_recognition_msgs.msg.TranscribeSpeechResult()
+
+ def __init__(
+ self,
+ action_name: str,
+ model_params: speech_model_params,
+ ) -> None:
+ """Starts an action server for transcribing speech.
+
+ Args:
+ action_name (str): Name of the action server.
+ """
+
+ self._action_name = action_name
+ self._model_params = model_params
+
+ with noalsaerr():
+ self._model = load_model(
+ self._model_params.model_name,
+ self._model_params.device,
+ self._model_params.warmup,
+ )
+ # Configure the speech recogniser object and adjust for ambient noise
+ self.recogniser = self._configure_recogniser()
+ # Setup the action server and register execution callback
+ self._action_server = actionlib.SimpleActionServer(
+ self._action_name,
+ lasr_speech_recognition_msgs.msg.TranscribeSpeechAction,
+ execute_cb=self.execute_cb,
+ auto_start=False,
+ )
+ self._action_server.register_preempt_callback(self.prempt_cb)
+ # Setup the timer for adjusting the microphone for ambient noise every x seconds
+ self._timer_duration = self._model_params.timer_duration
+ self._timer = rospy.Timer(
+ rospy.Duration(self._timer_duration), self._timer_cb
+ )
+ self._listening = False
+
+ self._action_server.start()
+
+ def _timer_cb(self, _) -> None:
+ """Adjusts the microphone for ambient noise, unless the action server is listening."""
+ if self._listening:
+ return
+ rospy.loginfo("Adjusting microphone for ambient noise...")
+ with noalsaerr():
+ with self._configure_microphone() as source:
+ self.recogniser.adjust_for_ambient_noise(source)
+
+ def _reset_timer(self) -> None:
+ """Resets the timer for adjusting the microphone for ambient noise."""
+ self._timer.shutdown()
+ self._timer = rospy.Timer(rospy.Duration(self._timer_duration), self._timer_cb)
+
+ def _configure_microphone(self) -> sr.Microphone:
+ """Configures the microphone for listening to speech based on the
+ microphone device index or name.
+
+ Returns: microphone object
+ """
+
+ if self._model_params.mic_device is None:
+ # If no microphone device is specified, use the system default microphone
+ return sr.Microphone(sample_rate=self._model_params.sample_rate)
+ elif self._model_params.mic_device.isdigit():
+ return sr.Microphone(
+ device_index=int(self._model_params.mic_device),
+ sample_rate=self._model_params.sample_rate,
+ )
+ else:
+ microphones = enumerate(sr.Microphone.list_microphone_names())
+ for index, name in microphones:
+ if self._model_params.mic_device in name:
+ return sr.Microphone(
+ device_index=index,
+ sample_rate=self._model_params.sample_rate,
+ )
+ raise ValueError(
+ f"Could not find microphone with name: {self._model_params.mic_device}"
+ )
+
+ def _configure_recogniser(self, ambient_adj: bool = True) -> sr.Recognizer:
+ """Configures the speech recogniser object.
+
+ Args:
+ ambient_adj (bool, optional): Whether to adjust for ambient noise. Defaults to True.
+
+ Returns:
+ sr.Recognizer: speech recogniser object.
+ """
+ self._listening = True
+ recogniser = sr.Recognizer()
+ if ambient_adj:
+ with self._configure_microphone() as source:
+ recogniser.adjust_for_ambient_noise(source)
+ self._listening = False
+ return recogniser
+
+ def prempt_cb(self) -> None:
+ """Callback for preempting the action server.
+
+ Sets server to preempted state.
+ """
+ preempted_str = f"{self._action_name} has been preempted"
+ rospy.loginfo(preempted_str)
+ self._result.sequence = preempted_str
+ self._action_server.set_preempted(result=self._result, text=preempted_str)
+
+ def execute_cb(self, goal) -> None:
+ """Callback for executing the action server.
+
+ Checks for preemption before listening and before and after transcribing, returning
+ if preemption is requested.
+
+ Args:
+ goal: UNUSED - actionlib requires a goal argument in the execute callback, but
+ this action server does not use a goal.
+ """
+ rospy.loginfo("Request Received")
+ if self._action_server.is_preempt_requested():
+ return
+ # Since we are about to listen, reset the timer for adjusting the microphone for ambient noise
+ # as this assumes self_timer_duration seconds of silence before adjusting
+ self._reset_timer()
+ with noalsaerr():
+ with self._configure_microphone() as src:
+ self._listening = True
+ wav_data = self.recogniser.listen(
+ src,
+ timeout=self._model_params.start_timeout,
+ phrase_time_limit=self._model_params.end_timeout,
+ ).get_wav_data()
+ # Magic number 32768.0 is the maximum value of a 16-bit signed integer
+ float_data = (
+ np.frombuffer(wav_data, dtype=np.int16).astype(np.float32, order="C")
+ / 32768.0
+ )
+
+ if self._action_server.is_preempt_requested():
+ self._listening = False
+ return
+
+ rospy.loginfo(f"Transcribing phrase with Whisper...")
+ transcription_start_time = timer()
+ # Cast to fp16 if using GPU
+ phrase = self._model.transcribe(
+ float_data,
+ fp16=self._model_params.device == "cuda",
+ )["text"]
+ transcription_end_time = timer()
+ rospy.loginfo(f"Transcription finished!")
+ rospy.loginfo(
+ f"Time taken: {transcription_end_time - transcription_start_time:.2f}s"
+ )
+
+ if self._action_server.is_preempt_requested():
+ self._listening = False
+ return
+
+ self._result.sequence = phrase
+ rospy.loginfo(f"Transcribed phrase: {phrase}")
+ rospy.loginfo(f"{self._action_name} has succeeded")
+ self._action_server.set_succeeded(self._result)
+
+ # Have this at the very end to not disrupt the action server
+ self._listening = False
+
+
+def parse_args() -> dict:
+ """Parses the command line arguments into a name: value dictinoary.
+
+ Returns:
+ dict: Dictionary of name: value pairs of command line arguments.
+ """
+ parser = argparse.ArgumentParser(
+ description="Starts an action server for transcribing speech."
+ )
+
+ parser.add_argument(
+ "--action-name",
+ type=str,
+ default="transcribe_speech",
+ help="Name of the action server.",
+ )
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="medium.en",
+ help="Name of the speech recognition model.",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda" if torch.cuda.is_available() else "cpu",
+ help="Device to run the model on.",
+ )
+ parser.add_argument(
+ "--start_timeout",
+ type=float,
+ default=5.0,
+ help="Timeout for listening for the start of a phrase.",
+ )
+ parser.add_argument(
+ "--end_timeout",
+ type=float,
+ default=None,
+ help="Timeout for listening for the end of a phrase.",
+ )
+ parser.add_argument(
+ "--sample_rate",
+ type=int,
+ default=16000,
+ help="Sample rate of the microphone.",
+ )
+ parser.add_argument(
+ "--mic_device",
+ type=str,
+ default=None,
+ help="Microphone device index or name",
+ )
+ parser.add_argument(
+ "--timer_duration",
+ type=int,
+ default=20,
+ help="Number of seconds of silence before the ambient noise adjustment is called.",
+ )
+ parser.add_argument(
+ "--no_warmup",
+ action="store_true",
+ help="Disable warming up the model by running inference on a test file.",
+ )
+
+ return vars(parser.parse_args())
+
+
+def configure_model_params(config: dict) -> speech_model_params:
+ """Configures the speech model parameters based on the provided
+ command line parameters.
+
+ Args:
+ config (dict): Command line parameters parsed in dictionary form.
+
+ Returns:
+ speech_model_params: dataclass containing the speech model parameters
+ """
+ model_params = speech_model_params()
+ if config["model_name"]:
+ model_params.model_name = config["model_name"]
+ if config["device"]:
+ model_params.device = config["device"]
+ if config["start_timeout"]:
+ model_params.start_timeout = config["start_timeout"]
+ if config["end_timeout"]:
+ model_params.end_timeout = config["end_timeout"]
+ if config["sample_rate"]:
+ model_params.sample_rate = config["sample_rate"]
+ if config["mic_device"]:
+ model_params.mic_device = config["mic_device"]
+ if config["timer_duration"]:
+ model_params.timer_duration = config["timer_duration"]
+ if config["no_warmup"]:
+ model_params.warmup = False
+
+ return model_params
+
+
+def configure_whisper_cache() -> None:
+ """Configures the whisper cache directory."""
+ whisper_cache = os.path.join(str(Path.home()), ".cache", "whisper")
+ os.makedirs(whisper_cache, exist_ok=True)
+ # Environemntal variable required to run whisper locally
+ os.environ["TIKTOKEN_CACHE_DIR"] = whisper_cache
+
+
+if __name__ == "__main__":
+ configure_whisper_cache()
+ config = parse_args()
+ rospy.init_node(config["action_name"])
+ server = TranscribeSpeechAction(rospy.get_name(), configure_model_params(config))
+ rospy.spin()
diff --git a/common/speech/lasr_speech_recognition_whisper/package.xml b/common/speech/lasr_speech_recognition_whisper/package.xml
index 39935c089..4c6f49965 100644
--- a/common/speech/lasr_speech_recognition_whisper/package.xml
+++ b/common/speech/lasr_speech_recognition_whisper/package.xml
@@ -51,6 +51,10 @@
catkin
catkin_virtualenv
lasr_speech_recognition_msgs
+ actionlib
+ actionlib_msgs
+ actionlib
+ actionlib_msgs
diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py b/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py
new file mode 100644
index 000000000..2e6b20622
--- /dev/null
+++ b/common/speech/lasr_speech_recognition_whisper/scripts/repeat_after_me.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python3
+import rospy
+import actionlib
+from lasr_voice import Voice # type: ignore
+from lasr_speech_recognition_msgs.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore
+from lasr_speech_recognition_msgs.msg import ( # type: ignore
+ TranscribeSpeechAction,
+ TranscribeSpeechGoal,
+)
+
+# import actionlib
+rospy.init_node("repeat")
+
+USE_ACTIONLIB = True
+
+voice = Voice()
+
+
+if USE_ACTIONLIB:
+ client = actionlib.SimpleActionClient("transcribe_speech", TranscribeSpeechAction)
+ client.wait_for_server()
+ repeating = False
+ rospy.loginfo("Done waiting")
+ while not rospy.is_shutdown():
+ goal = TranscribeSpeechGoal()
+ client.send_goal(goal)
+ client.wait_for_result()
+ result = client.get_result()
+ text = result.sequence
+ print(text)
+ if "tiago" in text.lower().strip():
+ if "repeat" in text.lower().strip():
+ repeating = True
+ voice.sync_tts("Okay, I'll start repeating now.")
+ continue
+ elif "stop" in text.lower().strip():
+ repeating = False
+ voice.sync_tts("Okay, I'll stop repeating now.")
+ break
+ if repeating:
+ voice.sync_tts(f"I heard {text}")
+else:
+ transcribe = rospy.ServiceProxy("/whisper/transcribe_audio", TranscribeAudio)
+ repeating = False
+ while not rospy.is_shutdown():
+ text = transcribe().phrase
+ print(text)
+ if "tiago" in text.lower().strip():
+ if "repeat" in text.lower().strip():
+ repeating = True
+ voice.sync_tts("Okay, I'll start repeating now.")
+ continue
+ elif "stop" in text.lower().strip():
+ repeating = False
+ voice.sync_tts("Okay, I'll stop repeating now.")
+ break
+ if repeating:
+ voice.sync_tts(f"I heard {text}")
diff --git a/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py
new file mode 100644
index 000000000..fef16eb0c
--- /dev/null
+++ b/common/speech/lasr_speech_recognition_whisper/scripts/test_speech_server.py
@@ -0,0 +1,21 @@
+#!/usr/bin/env python3
+import rospy
+import actionlib
+from lasr_speech_recognition_msgs.srv import TranscribeAudio, TranscribeAudioResponse # type: ignore
+from lasr_speech_recognition_msgs.msg import ( # type: ignore
+ TranscribeSpeechAction,
+ TranscribeSpeechGoal,
+)
+
+
+rospy.init_node("test_speech_server")
+client = actionlib.SimpleActionClient("transcribe_speech", TranscribeSpeechAction)
+client.wait_for_server()
+rospy.loginfo("Done waiting")
+while not rospy.is_shutdown():
+ goal = TranscribeSpeechGoal()
+ client.send_goal(goal)
+ client.wait_for_result()
+ result = client.get_result()
+ text = result.sequence
+ print(f"Transcribed Speech: {text}")
diff --git a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py
index d0ec731fc..42ec44785 100644
--- a/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py
+++ b/common/speech/lasr_speech_recognition_whisper/src/lasr_speech_recognition_whisper/cache.py
@@ -1,17 +1,43 @@
-import whisper
+import os
+import whisper # type: ignore
+import rospkg # type: ignore
import rospy
# Keep all loaded models in memory
MODEL_CACHE = {}
-def load_model(name: str, device: str = 'cpu'):
- '''
- Load a given Whisper model
- '''
+
+def load_model(
+ name: str, device: str = "cpu", load_test_file: bool = False
+) -> whisper.Whisper:
+ """Loads a whisper model from disk, or from cache if it has already been loaded.
+
+ Args:
+ name (str): Name of the whisper model. Must be the name of an official whisper
+ model, or the path to a model checkpoint.
+ device (str, optional): Pytorch device to put the model on. Defaults to 'cpu'.
+ load_test_file (bool, optional): Whether to run inference on a test audio file
+ after loading the model (if model is not in cache). Defaults to False. Test file
+ is assumed to be called "test.m4a" and be in the root of the package directory.
+
+ Returns:
+ whisper.Whisper: Whisper model instance
+ """
global MODEL_CACHE
if name not in MODEL_CACHE:
- rospy.loginfo(f'Load model {name}')
+ rospy.loginfo(f"Loading model {name}")
MODEL_CACHE[name] = whisper.load_model(name, device=device)
-
+ rospy.loginfo(f"Sucessfully loaded model {name} on {device}")
+ if load_test_file:
+ package_root = rospkg.RosPack().get_path("lasr_speech_recognition_whisper")
+ example_fp = os.path.join(package_root, "test.m4a")
+ rospy.loginfo(
+ "Running transcription on example file to ensure model is loaded..."
+ )
+ test_result: str = MODEL_CACHE[name].transcribe(
+ example_fp, fp16=device == "cuda"
+ )
+ rospy.loginfo(f"Transcription test result: {test_result}")
+
return MODEL_CACHE[name]