Skip to content

Commit

Permalink
Replace the old model with the new one.
Browse files Browse the repository at this point in the history
  • Loading branch information
Benteng Ma committed Feb 23, 2024
1 parent 86e7a3e commit 3e384d1
Show file tree
Hide file tree
Showing 9 changed files with 494 additions and 207 deletions.
2 changes: 2 additions & 0 deletions common/vision/lasr_vision_msgs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ add_message_files(
BodyPixMaskRequest.msg
ColourPrediction.msg
FeatureWithColour.msg
# Description.msg
)

## Generate services in the 'srv' folder
Expand All @@ -59,6 +60,7 @@ add_service_files(
YoloDetection.srv
BodyPixDetection.srv
TorchFaceFeatureDetection.srv
TorchFaceFeatureDetectionDescription.srv
)

## Generate actions in the 'action' folder
Expand Down
1 change: 1 addition & 0 deletions common/vision/lasr_vision_msgs/msg/Description.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
string discription
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Image to run inference on
sensor_msgs/Image image_raw

uint8[] head_mask_data # For serialized array data
uint32[] head_mask_shape # To store the shape of the array
string head_mask_dtype # Data type of the array elements

uint8[] torso_mask_data
uint32[] torso_mask_shape
string torso_mask_dtype
---

# Detection result
string description
30 changes: 16 additions & 14 deletions common/vision/lasr_vision_torch/nodes/service
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from lasr_vision_msgs.srv import TorchFaceFeatureDetection, TorchFaceFeatureDetectionRequest, TorchFaceFeatureDetectionResponse
from lasr_vision_msgs.srv import TorchFaceFeatureDetection, TorchFaceFeatureDetectionRequest, TorchFaceFeatureDetectionResponse, TorchFaceFeatureDetectionDescriptionRequest, TorchFaceFeatureDetectionDescriptionResponse
from lasr_vision_msgs.msg import FeatureWithColour, ColourPrediction
from cv2_img import msg_to_cv2_img
from torch_module.helpers import binary_erosion_dilation, median_color_float
Expand All @@ -13,7 +13,7 @@ import lasr_vision_torch
from os import path


def detect(request: TorchFaceFeatureDetectionRequest) -> TorchFaceFeatureDetectionResponse:
def detect(request: TorchFaceFeatureDetectionDescriptionRequest) -> TorchFaceFeatureDetectionDescriptionRequest:
# decode the image
rospy.loginfo('Decoding')
full_frame = msg_to_cv2_img(request.image_raw)
Expand All @@ -24,20 +24,22 @@ def detect(request: TorchFaceFeatureDetectionRequest) -> TorchFaceFeatureDetecti
head_frame = lasr_vision_torch.extract_mask_region(full_frame, head_mask.astype(np.uint8), expand_x=0.4, expand_y=0.5)
torso_frame = lasr_vision_torch.extract_mask_region(full_frame, torso_mask.astype(np.uint8), expand_x=0.2, expand_y=0.0)

class_pred, colour_pred = lasr_vision_torch.predict_frame(head_frame, torso_frame, full_frame, head_mask, torso_mask, lasr_vision_torch.model, lasr_vision_torch.thresholds_mask, lasr_vision_torch.erosion_iterations, lasr_vision_torch.dilation_iterations, lasr_vision_torch.thresholds_pred)
# class_pred, colour_pred = lasr_vision_torch.predict_frame(head_frame, torso_frame, full_frame, head_mask, torso_mask, lasr_vision_torch.model, lasr_vision_torch.thresholds_mask, lasr_vision_torch.erosion_iterations, lasr_vision_torch.dilation_iterations, lasr_vision_torch.thresholds_pred)
rst_str = lasr_vision_torch.predict_frame(head_frame, torso_frame, full_frame, head_mask, torso_mask, lasr_vision_torch.model, lasr_vision_torch.thresholds_mask, lasr_vision_torch.erosion_iterations, lasr_vision_torch.dilation_iterations, lasr_vision_torch.thresholds_pred)

response = TorchFaceFeatureDetectionResponse()
response = TorchFaceFeatureDetectionDescriptionRequest()
response.description = rst_str
# response.detected_features = str(class_pred) + str(colour_pred)
response.detected_features = []
for c in ['hair', 'hat', 'glasses', 'cloth',]:
# colour_pred[c] = {k: v[0] for k, v in colour_pred[c].items()}
sorted_list = sorted(colour_pred[c].items(), key=lambda item: item[1], reverse=True)
# rospy.loginfo(str(sorted_list))
if len(sorted_list) > 3:
sorted_list = sorted_list[0:3]
sorted_list = [k for k, v in sorted_list]
# rospy.loginfo(str(colour_pred[c]))
response.detected_features.append(FeatureWithColour(c, class_pred[c], sorted_list))
# response.detected_features = []
# for c in ['hair', 'hat', 'glasses', 'cloth',]:
# # colour_pred[c] = {k: v[0] for k, v in colour_pred[c].items()}
# sorted_list = sorted(colour_pred[c].items(), key=lambda item: item[1], reverse=True)
# # rospy.loginfo(str(sorted_list))
# if len(sorted_list) > 3:
# sorted_list = sorted_list[0:3]
# sorted_list = [k for k, v in sorted_list]
# # rospy.loginfo(str(colour_pred[c]))
# response.detected_features.append(FeatureWithColour(c, class_pred[c], sorted_list))
return response


Expand Down
Loading

0 comments on commit 3e384d1

Please sign in to comment.