diff --git a/jsk_perception/node_scripts/solidity_rag_merge.py b/jsk_perception/node_scripts/solidity_rag_merge.py index ba5e2a2b38..d72ef3f07f 100755 --- a/jsk_perception/node_scripts/solidity_rag_merge.py +++ b/jsk_perception/node_scripts/solidity_rag_merge.py @@ -165,39 +165,63 @@ def __init__(self): queue_size=5) def subscribe(self): - self.sub = message_filters.Subscriber('~input', Image) - self.sub_mask = message_filters.Subscriber('~input/mask', Image) - self.use_async = rospy.get_param('~approximate_sync', False) - rospy.loginfo('~approximate_sync: {}'.format(self.use_async)) - if self.use_async: - sync = message_filters.ApproximateTimeSynchronizer( - [self.sub, self.sub_mask], queue_size=1000, slop=0.1) + self.use_mask = rospy.get_param('~use_mask', True) + if self.use_mask: + self.sub = message_filters.Subscriber('~input', Image) + self.sub_mask = message_filters.Subscriber('~input/mask', Image) + self.use_async = rospy.get_param('~approximate_sync', False) + rospy.loginfo('~approximate_sync: {}'.format(self.use_async)) + if self.use_async: + sync = message_filters.ApproximateTimeSynchronizer( + [self.sub, self.sub_mask], queue_size=1000, slop=0.1) + else: + sync = message_filters.TimeSynchronizer( + [self.sub, self.sub_mask], queue_size=1000) + sync.registerCallback(self.sub_cb) + warn_no_remap('~input', '~input/mask') else: - sync = message_filters.TimeSynchronizer( - [self.sub, self.sub_mask], queue_size=1000) - sync.registerCallback(self._apply) - warn_no_remap('~input', '~input/mask') + self.sub = rospy.Subscriber('~input', Image, self.sub_img_cb, + queue_size=1) + warn_no_remap('~input') def unsubscribe(self): self.sub.unregister() - self.sub_mask.unregister() + if self.use_mask: + self.sub_mask.unregister() + + def sub_cb(self, imgmsg, maskmsg): + self._apply(imgmsg, maskmsg) + + def sub_img_cb(self, imgmsg): + self._apply(imgmsg, None) - def _apply(self, imgmsg, maskmsg): + def _apply(self, imgmsg, maskmsg=None): bridge = cv_bridge.CvBridge() img = bridge.imgmsg_to_cv2(imgmsg) if img.ndim == 2: img = gray2rgb(img) - mask = bridge.imgmsg_to_cv2(maskmsg, desired_encoding='mono8') - mask = mask.reshape(mask.shape[:2]) - # compute label - roi = closed_mask_roi(mask) - roi_labels = masked_slic(img=img[roi], mask=mask[roi], + + if maskmsg is None: + # compute label + mask = np.ones(img.shape[:2], dtype=np.uint8) + labels = masked_slic(img=img, mask=mask, n_segments=20, compactness=30) - if roi_labels is None: - return - labels = np.zeros(mask.shape, dtype=np.int32) + if labels is None: + return + labels = labels.astype(np.int32) + else: + mask = bridge.imgmsg_to_cv2(maskmsg, desired_encoding='mono8') + mask = mask.reshape(mask.shape[:2]) + # compute label + roi = closed_mask_roi(mask) + roi_labels = masked_slic(img=img[roi], mask=mask[roi], + n_segments=20, compactness=30) + if roi_labels is None: + return + labels = np.zeros(mask.shape, dtype=np.int32) + labels[roi] = roi_labels + # labels.fill(-1) # set bg_label - labels[roi] = roi_labels if self.is_debugging: # publish debug slic label slic_labelmsg = bridge.cv2_to_imgmsg(labels)