From 40cd1b13972871dacf43390d34cc6151011a1fe8 Mon Sep 17 00:00:00 2001 From: Benteng Ma Date: Fri, 19 Apr 2024 11:32:33 +0100 Subject: [PATCH] Cleaned load model method, restart to use downloaded model. --- .../src/lasr_vision_bodypix/bodypix.py | 26 ++----------------- 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/bodypix.py b/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/bodypix.py index b366b7af4..b1cab9f83 100644 --- a/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/bodypix.py +++ b/common/vision/lasr_vision_bodypix/src/lasr_vision_bodypix/bodypix.py @@ -13,7 +13,6 @@ from lasr_vision_msgs.srv import BodyPixDetectionRequest, BodyPixDetectionResponse import rospkg -from os import path # model cache loaded_models = {} @@ -23,27 +22,20 @@ def load_model_cached(dataset: str) -> None: ''' Load a model into cache ''' - model = None if dataset in loaded_models: model = loaded_models[dataset] else: if dataset == 'resnet50': -# name = download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16) -# rospy.logwarn(name) /home/bentengma/keras_model/tf-bodypix/tfjs-models-savedmodel-bodypix-resnet50-float-model-stride16 -# model = load_model(name) - model = load_model(path.join(r.get_path("lasr_vision_bodypix"), "models", "keras_model", "tf-bodypix", "tfjs-models-savedmodel-bodypix-resnet50-float-model-stride16")) - # model = load_model(download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16)) + name = download_model(BodyPixModelPaths.RESNET50_FLOAT_STRIDE_16) + model = load_model(name) elif dataset == 'mobilenet50': name = download_model(BodyPixModelPaths.MOBILENET_FLOAT_50_STRIDE_16) - rospy.logwarn(name) model = load_model(name) else: model = load_model(dataset) - rospy.loginfo(f'Loaded {dataset} model') loaded_models[dataset] = model - return model def detect(request: BodyPixDetectionRequest, debug_publisher: rospy.Publisher | None) -> BodyPixDetectionResponse: @@ -54,8 +46,6 @@ def detect(request: BodyPixDetectionRequest, debug_publisher: rospy.Publisher | # decode the image rospy.loginfo('Decoding') img = cv2_img.msg_to_cv2_img(request.image_raw) - # rospy.logwarn(str(type(img))) - img_height, img_width, _ = img.shape # Get image dimensions # load model rospy.loginfo('Loading model') @@ -92,27 +82,16 @@ def detect(request: BodyPixDetectionRequest, debug_publisher: rospy.Publisher | right_shoulder = right_shoulder_keypoint.position neck_x = (left_shoulder.x + right_shoulder.x) / 2 neck_y = (left_shoulder.y + right_shoulder.y) / 2 - neck_score = (left_shoulder_keypoint.score + right_shoulder_keypoint.score) / 2 # Optional: average score elif left_shoulder_keypoint: # Only left shoulder detected, use it as neck coordinate left_shoulder = left_shoulder_keypoint.position neck_x = left_shoulder.x neck_y = left_shoulder.y - neck_score = left_shoulder_keypoint.score elif right_shoulder_keypoint: # Only right shoulder detected, use it as neck coordinate right_shoulder = right_shoulder_keypoint.position neck_x = right_shoulder.x neck_y = right_shoulder.y - neck_score = right_shoulder_keypoint.score - - # # Convert neck coordinates to relative positions (0-1) - # rel_neck_x = neck_x / img_width - # rel_neck_y = neck_y / img_height - - # pose = BodyPixPose() - # pose.coord = np.array([rel_neck_x, rel_neck_y]).astype(np.float32) - # neck_coordinates.append(pose) pose = BodyPixPose() pose.coord = np.array([neck_x, neck_y]).astype(np.int32) @@ -130,7 +109,6 @@ def detect(request: BodyPixDetectionRequest, debug_publisher: rospy.Publisher | keypoints_color=(255, 100, 100), skeleton_color=(100, 100, 255), ) - debug_publisher.publish(cv2_img.cv2_img_to_msg(coloured_mask)) response = BodyPixDetectionResponse()