Skip to content

Commit

Permalink
Cleaned load model method, restart to use downloaded model.
Browse files Browse the repository at this point in the history
  • Loading branch information
Benteng Ma committed Apr 19, 2024
1 parent fbe6212 commit 40cd1b1
Showing 1 changed file with 2 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from lasr_vision_msgs.srv import BodyPixDetectionRequest, BodyPixDetectionResponse

import rospkg
from os import path

# model cache
loaded_models = {}
Expand All @@ -23,27 +22,20 @@ def load_model_cached(dataset: str) -> None:
'''
Load a model into cache
'''

model = None
if dataset in loaded_models:
model = loaded_models[dataset]
else:
if dataset == 'resnet50':
# name = download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16)
# rospy.logwarn(name) /home/bentengma/keras_model/tf-bodypix/tfjs-models-savedmodel-bodypix-resnet50-float-model-stride16
# model = load_model(name)
model = load_model(path.join(r.get_path("lasr_vision_bodypix"), "models", "keras_model", "tf-bodypix", "tfjs-models-savedmodel-bodypix-resnet50-float-model-stride16"))
# model = load_model(download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16))
name = download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16)
model = load_model(name)
elif dataset == 'mobilenet50':
name = download_model(BodyPixModelPaths.MOBILENET_FLOAT_50_STRIDE_16)
rospy.logwarn(name)
model = load_model(name)
else:
model = load_model(dataset)

rospy.loginfo(f'Loaded {dataset} model')
loaded_models[dataset] = model

return model

def detect(request: BodyPixDetectionRequest, debug_publisher: rospy.Publisher | None) -> BodyPixDetectionResponse:
Expand All @@ -54,8 +46,6 @@ def detect(request: BodyPixDetectionRequest, debug_publisher: rospy.Publisher |
# decode the image
rospy.loginfo('Decoding')
img = cv2_img.msg_to_cv2_img(request.image_raw)
# rospy.logwarn(str(type(img)))
img_height, img_width, _ = img.shape # Get image dimensions

# load model
rospy.loginfo('Loading model')
Expand Down Expand Up @@ -92,27 +82,16 @@ def detect(request: BodyPixDetectionRequest, debug_publisher: rospy.Publisher |
right_shoulder = right_shoulder_keypoint.position
neck_x = (left_shoulder.x + right_shoulder.x) / 2
neck_y = (left_shoulder.y + right_shoulder.y) / 2
neck_score = (left_shoulder_keypoint.score + right_shoulder_keypoint.score) / 2 # Optional: average score
elif left_shoulder_keypoint:
# Only left shoulder detected, use it as neck coordinate
left_shoulder = left_shoulder_keypoint.position
neck_x = left_shoulder.x
neck_y = left_shoulder.y
neck_score = left_shoulder_keypoint.score
elif right_shoulder_keypoint:
# Only right shoulder detected, use it as neck coordinate
right_shoulder = right_shoulder_keypoint.position
neck_x = right_shoulder.x
neck_y = right_shoulder.y
neck_score = right_shoulder_keypoint.score

# # Convert neck coordinates to relative positions (0-1)
# rel_neck_x = neck_x / img_width
# rel_neck_y = neck_y / img_height

# pose = BodyPixPose()
# pose.coord = np.array([rel_neck_x, rel_neck_y]).astype(np.float32)
# neck_coordinates.append(pose)

pose = BodyPixPose()
pose.coord = np.array([neck_x, neck_y]).astype(np.int32)
Expand All @@ -130,7 +109,6 @@ def detect(request: BodyPixDetectionRequest, debug_publisher: rospy.Publisher |
keypoints_color=(255, 100, 100),
skeleton_color=(100, 100, 255),
)

debug_publisher.publish(cv2_img.cv2_img_to_msg(coloured_mask))

response = BodyPixDetectionResponse()
Expand Down

0 comments on commit 40cd1b1

Please sign in to comment.