Skip to content

Commit

Permalink
merging
Browse files Browse the repository at this point in the history
  • Loading branch information
tiago committed Feb 23, 2024
2 parents 29060c0 + 205b52e commit 1cb2e3b
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def forward(self, x_copy, x):
def __init__(self, num_classes, in_channels=3, freeze_bn=False, sigmoid=True):
super(UNetWithResnet18Encoder, self).__init__()
self.sigmoid = sigmoid
resnet18 = models.resnet18(pretrained=True)
resnet18 = models.resnet18(pretrained=False)

if in_channels != 3:
resnet18.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
Expand Down Expand Up @@ -258,7 +258,7 @@ def freeze_bn(self):


class MultiLabelResNet(nn.Module):
def __init__(self, num_labels, input_channels=3, sigmoid=True, pretrained=True,):
def __init__(self, num_labels, input_channels=3, sigmoid=True, pretrained=False,):
super(MultiLabelResNet, self).__init__()
self.model = models.resnet18(pretrained=pretrained)
self.sigmoid = sigmoid
Expand Down
9 changes: 5 additions & 4 deletions common/speech/lasr_speech/launch/speech.launch
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
<launch>
<arg name="matcher" default="by-index" />
<arg name="device_param" default="" />
<arg name="matcher" default="--mic_device" />
<arg name="device_param" default="13" />
<arg name="rasa_model" default=""/>
<node pkg="lasr_speech_recognition_whisper" type="simple_transcribe_microphone" name="whisper_service" output="screen" args="$(arg matcher) $(arg device_param)"/>
<include file = "$(find lasr_rasa)/launch/rasa.launch">
<node pkg="lasr_speech_recognition_whisper" type="transcribe_microphone_server" name="transcribe_microphone_server" output="screen" args="$(arg matcher) $(arg device_param)"/>

<include file = "$(find lasr_rasa)/launch/rasa.launch">
<arg name="model" value="$(arg rasa_model)"/>
</include>
<node pkg="lasr_speech" type="service" name="speech_service" output="screen"/>
Expand Down
22 changes: 18 additions & 4 deletions common/speech/lasr_speech/nodes/service
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,23 @@ from time import sleep
from multiprocessing import Process
import rospkg
import os
import actionlib

from lasr_speech_recognition_msgs.msg import ( # type: ignore
TranscribeSpeechAction,
TranscribeSpeechGoal,
)


class TranscribeAndParse:

def __init__(self):
rospy.wait_for_service("/lasr_rasa/parse")
self.rasa = rospy.ServiceProxy("/lasr_rasa/parse", Rasa)
self.transcribe_audio = rospy.ServiceProxy("/whisper/transcribe_audio", TranscribeAudio)
# self.transcribe_audio = rospy.ServiceProxy("/whisper/transcribe_audio", TranscribeAudio)
self.speech_client = actionlib.SimpleActionClient("transcribe_speech", TranscribeSpeechAction)
self.speech_client.wait_for_server()

# self.sound_data = soundfile.read(os.path.join(rospkg.RosPack().get_path("lasr_speech"), "sounds", "beep.wav"))[0]

# def play_sound(self):
Expand All @@ -34,9 +44,13 @@ class TranscribeAndParse:
def __call__(self, req):
# if req.play_sound:
# self.play_sound()
transcription = self.transcribe_audio()
rospy.loginfo(transcription)
rasa_response = self.rasa(transcription.phrase)
goal = TranscribeSpeechGoal()
self.speech_client.send_goal(goal)
self.speech_client.wait_for_result()
result = self.speech_client.get_result()
text = result.sequence
rospy.loginfo(text)
rasa_response = self.rasa(text)
rospy.loginfo(rasa_response)
return SpeechResponse(rasa_response.json_response, rasa_response.success)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ def parse_args() -> dict:
help="Disable warming up the model by running inference on a test file.",
)

return vars(parser.parse_args())
args,unknown = parser.parse_known_args()
return vars(args)


def configure_model_params(config: dict) -> speech_model_params:
Expand Down Expand Up @@ -346,6 +347,6 @@ def configure_whisper_cache() -> None:
if __name__ == "__main__":
configure_whisper_cache()
config = parse_args()
rospy.init_node(config["action_name"])
server = TranscribeSpeechAction(rospy.get_name(), configure_model_params(config))
rospy.init_node("transcribe_speech_server")
server = TranscribeSpeechAction("transcribe_speech", configure_model_params(config))
rospy.spin()
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def load_model_cached(dataset: str) -> None:
model = loaded_models[dataset]
else:
if dataset == 'resnet50':
model = load_model(download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16))
model = load_model('/home/rexy/.keras/tf-bodypix/3fe1b130a0f20e98340612c099b50c18--tfjs-models-savedmodel-bodypix-resnet50-float-model-stride16')
# model = load_model(download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16))
elif dataset == 'mobilenet50':
model = load_model(download_model(BodyPixModelPaths.MOBILENET_FLOAT_50_STRIDE_16))
else:
Expand All @@ -44,6 +45,7 @@ def detect(request: BodyPixDetectionRequest, debug_publisher: rospy.Publisher |
# decode the image
rospy.loginfo('Decoding')
img = cv2_img.msg_to_cv2_img(request.image_raw)
# rospy.logwarn(str(type(img)))
img_height, img_width, _ = img.shape # Get image dimensions

# load model
Expand Down
1 change: 1 addition & 0 deletions tasks/receptionist/launch/setup.launch
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
<arg name="whisper_matcher" default="by-index" />
<!-- <arg name="whisper_device_param" default="9" />-->
<arg name="whisper_device_param" default="13" />

<arg name="rasa_model" default="$(find lasr_rasa)/assistants/receptionist/models"/>

<include file = "$(find lasr_speech)/launch/speech.launch">
Expand Down
18 changes: 10 additions & 8 deletions tasks/receptionist/src/receptionist/states/speakdescriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@ def __init__(self, default):
self.default = default

def execute(self, userdata):
# for person in userdata['people']:
# self.default.voice.speak('I see a person')

for person in userdata['people']:
self.default.voice.speak('I see a person')

# for feature in person['features']:
# if feature.label:
# if len(feature.colours) == 0:
# self.default.voice.speak(f'They have {feature.name}.')
# continue
for feature in person['features']:
if feature.label:
if len(feature.colours) == 0:
self.default.voice.speak(f'They have {feature.name}.')
continue

# self.default.voice.speak(f'They have {feature.name} and it has the colour {feature.colours[0]}')
self.default.voice.speak(f'They have {feature.name} and it has the colour {feature.colours[0]}')


return 'succeeded'

0 comments on commit 1cb2e3b

Please sign in to comment.