Skip to content

Commit

Permalink
replaced the model and predictor initialization, put "__main__"
Browse files Browse the repository at this point in the history
  • Loading branch information
Benteng Ma committed Apr 18, 2024
1 parent 44b17b2 commit 689fd0c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 21 deletions.
20 changes: 11 additions & 9 deletions common/vision/lasr_vision_feature_extraction/nodes/service
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from lasr_vision_msgs.srv import TorchFaceFeatureDetectionDescription, TorchFaceFeatureDetectionDescriptionRequest, TorchFaceFeatureDetectionDescriptionResponse
from lasr_vision_feature_extraction.categories_and_attributes import CategoriesAndAttributes, CelebAMaskHQCategoriesAndAttributes

from cv2_img import msg_to_cv2_img
from numpy2message import message2numpy

import numpy as np
import cv2
import torch
Expand All @@ -21,16 +22,17 @@ def detect(request: TorchFaceFeatureDetectionDescriptionRequest) -> TorchFaceFea
head_mask = message2numpy(head_mask_data, head_mask_shape, head_mask_dtype)
head_frame = lasr_vision_feature_extraction.extract_mask_region(full_frame, head_mask.astype(np.uint8), expand_x=0.4, expand_y=0.5)
torso_frame = lasr_vision_feature_extraction.extract_mask_region(full_frame, torso_mask.astype(np.uint8), expand_x=0.2, expand_y=0.0)

# class_pred, colour_pred = lasr_vision_feature_extraction.predict_frame(head_frame, torso_frame, full_frame, head_mask, torso_mask, lasr_vision_feature_extraction.model, lasr_vision_feature_extraction.thresholds_mask, lasr_vision_feature_extraction.erosion_iterations, lasr_vision_feature_extraction.dilation_iterations, lasr_vision_feature_extraction.thresholds_pred)
rst_str = lasr_vision_feature_extraction.predict_frame(head_frame, torso_frame, full_frame, head_mask, torso_mask,)

rst_str = lasr_vision_feature_extraction.predict_frame(head_frame, torso_frame, full_frame, head_mask, torso_mask, predictor=predictor)
response = TorchFaceFeatureDetectionDescriptionResponse()
response.description = rst_str
return response


rospy.init_node('torch_service')
rospy.Service('/torch/detect/face_features', TorchFaceFeatureDetectionDescription, detect)
rospy.loginfo('Torch service started')
rospy.spin()
if __name__ == '__main__':
# predictor will be global when inited, thus will be used within the function above.
model = lasr_vision_feature_extraction.load_face_classifier_model()
predictor = lasr_vision_feature_extraction.Predictor(model, torch.device('cpu'), CelebAMaskHQCategoriesAndAttributes)
rospy.init_node('torch_service')
rospy.Service('/torch/detect/face_features', TorchFaceFeatureDetectionDescription, detect)
rospy.loginfo('Torch service started')
rospy.spin()
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from feature_extractor.modules import UNetWithResnetEncoder, MultiLabelResNet, CombinedModel # DeepLabV3PlusMobileNetV3, MultiLabelMobileNetV3Large, CombinedModelNoRegression
from feature_extractor.modules import UNetWithResnetEncoder, MultiLabelResNet, CombinedModel
from feature_extractor.helpers import load_torch_model, binary_erosion_dilation
from lasr_vision_feature_extraction.categories_and_attributes import CategoriesAndAttributes, CelebAMaskHQCategoriesAndAttributes
from lasr_vision_feature_extraction.image_with_masks_and_attributes import ImageWithMasksAndAttributes, ImageOfPerson

import numpy as np
import cv2
import torch
import rospy
import rospkg
import lasr_vision_feature_extraction
from os import path
# import matplotlib.pyplot as plt


class Predictor:
Expand Down Expand Up @@ -72,9 +69,6 @@ def load_face_classifier_model():
return model


model = load_face_classifier_model()


def pad_image_to_even_dims(image):
# Get the current shape of the image
height, width, _ = image.shape
Expand Down Expand Up @@ -121,17 +115,14 @@ def extract_mask_region(frame, mask, expand_x=0.5, expand_y=0.5):
return None


p = Predictor(model, torch.device('cpu'), CelebAMaskHQCategoriesAndAttributes)


def predict_frame(head_frame, torso_frame, full_frame, head_mask, torso_mask,):
def predict_frame(head_frame, torso_frame, full_frame, head_mask, torso_mask, predictor):
full_frame = cv2.cvtColor(full_frame, cv2.COLOR_BGR2RGB)
head_frame = cv2.cvtColor(head_frame, cv2.COLOR_BGR2RGB)
torso_frame = cv2.cvtColor(torso_frame, cv2.COLOR_BGR2RGB)

head_frame = pad_image_to_even_dims(head_frame)
torso_frame = pad_image_to_even_dims(torso_frame)

rst = ImageOfPerson.from_parent_instance(p.predict(head_frame))
rst = ImageOfPerson.from_parent_instance(predictor.predict(head_frame))

return rst.describe()

0 comments on commit 689fd0c

Please sign in to comment.