Skip to content

Commit

Permalink
Segmentation masks returned as 1 channel. Resolves #43.
Browse files Browse the repository at this point in the history
  • Loading branch information
stepjam committed Mar 9, 2020
1 parent fb9118d commit 9f3bf88
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 28 deletions.
17 changes: 14 additions & 3 deletions rlbench/backend/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rlbench.backend.observation import Observation
from rlbench.backend.exceptions import (
WaypointError, BoundaryError, NoWaypointsError, DemoError)
from rlbench.backend.utils import rgb_handles_to_mask
from rlbench.demo import Demo
from rlbench.observation_config import ObservationConfig, CameraConfig
from rlbench.backend.task import Task
Expand Down Expand Up @@ -156,6 +157,13 @@ def get_observation(self) -> Observation:
rsc_ob = self._obs_config.right_shoulder_camera
wc_ob = self._obs_config.wrist_camera

lsc_mask_fn = (
rgb_handles_to_mask if lsc_ob.masks_as_one_channel else lambda x: x)
rsc_mask_fn = (
rgb_handles_to_mask if rsc_ob.masks_as_one_channel else lambda x: x)
wc_mask_fn = (
rgb_handles_to_mask if wc_ob.masks_as_one_channel else lambda x: x)

obs = Observation(
left_shoulder_rgb=(
lsc_ob.rgb_noise.apply(
Expand All @@ -181,13 +189,16 @@ def get_observation(self) -> Observation:
if wc_ob.depth else None),

left_shoulder_mask=(
self._cam_over_shoulder_left_mask.capture_rgb()
lsc_mask_fn(
self._cam_over_shoulder_left_mask.capture_rgb())
if lsc_ob.mask else None),
right_shoulder_mask=(
self._cam_over_shoulder_right_mask.capture_rgb()
rsc_mask_fn(
self._cam_over_shoulder_right_mask.capture_rgb())
if rsc_ob.mask else None),
wrist_mask=(
self._cam_wrist_mask.capture_rgb()
wc_mask_fn(
self._cam_wrist_mask.capture_rgb())
if wc_ob.mask else None),

joint_velocities=(
Expand Down
10 changes: 10 additions & 0 deletions rlbench/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,13 @@ def task_file_to_task_class(task_file):
mod = importlib.reload(mod)
task_class = getattr(mod, class_name)
return task_class


def rgb_handles_to_mask(rgb_coded_handles):
# rgb_coded_handles should be (w, h, c)
# Handle encoded as : handle = R + G * 256 + B * 256 * 256
rgb_coded_handles *= 255 # takes rgb range to 0 -> 255
rgb_coded_handles.astype(int)
return (rgb_coded_handles[:, :, 0] +
rgb_coded_handles[:, :, 1] * 256 +
rgb_coded_handles[:, :, 2] * 256 * 256)
4 changes: 3 additions & 1 deletion rlbench/observation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ def __init__(self,
depth_noise: NoiseModel=Identity(),
mask=True,
image_size=(128, 128),
render_mode=RenderMode.OPENGL3):
render_mode=RenderMode.OPENGL3,
masks_as_one_channel=True):
self.rgb = rgb
self.rgb_noise = rgb_noise
self.depth = depth
self.depth_noise = depth_noise
self.mask = mask
self.image_size = image_size
self.render_mode = render_mode
self.masks_as_one_channel = masks_as_one_channel

def set_all(self, value: bool):
self.rgb = value
Expand Down
54 changes: 30 additions & 24 deletions rlbench/task_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rlbench.backend.scene import Scene
from rlbench.backend.task import Task
from rlbench.backend.const import *
from rlbench.backend.utils import image_to_float_array
from rlbench.backend.utils import image_to_float_array, rgb_handles_to_mask
from rlbench.backend.robot import Robot
import logging
from typing import List
Expand Down Expand Up @@ -374,48 +374,54 @@ def _get_stored_demos(self, amount: int, image_paths: bool) -> List[Demo]:
self._resize_if_needed(
Image.open(obs[i].left_shoulder_rgb),
obs_config.left_shoulder_camera.image_size))
if obs_config.right_shoulder_camera.rgb:
obs[i].right_shoulder_rgb = np.array(
self._resize_if_needed(Image.open(
obs[i].right_shoulder_rgb),
obs_config.right_shoulder_camera.image_size))
if obs_config.wrist_camera.rgb:
obs[i].wrist_rgb = np.array(
self._resize_if_needed(
Image.open(obs[i].wrist_rgb),
obs_config.wrist_camera.image_size))

if obs_config.left_shoulder_camera.depth:
obs[i].left_shoulder_depth = image_to_float_array(
self._resize_if_needed(
Image.open(obs[i].left_shoulder_depth),
obs_config.left_shoulder_camera.image_size),
DEPTH_SCALE)
if obs_config.left_shoulder_camera.mask:
obs[i].left_shoulder_mask = np.array(
self._resize_if_needed(Image.open(
obs[i].left_shoulder_mask),
obs_config.left_shoulder_camera.image_size))
if obs_config.right_shoulder_camera.rgb:
obs[i].right_shoulder_rgb = np.array(
self._resize_if_needed(Image.open(
obs[i].right_shoulder_rgb),
obs_config.right_shoulder_camera.image_size))
if obs_config.right_shoulder_camera.depth:
obs[i].right_shoulder_depth = image_to_float_array(
self._resize_if_needed(
Image.open(obs[i].right_shoulder_depth),
obs_config.right_shoulder_camera.image_size),
DEPTH_SCALE)
if obs_config.right_shoulder_camera.mask:
obs[i].right_shoulder_mask = np.array(
self._resize_if_needed(Image.open(
obs[i].right_shoulder_mask),
obs_config.right_shoulder_camera.image_size))
if obs_config.wrist_camera.rgb:
obs[i].wrist_rgb = np.array(
self._resize_if_needed(
Image.open(obs[i].wrist_rgb),
obs_config.wrist_camera.image_size))
if obs_config.wrist_camera.depth:
obs[i].wrist_depth = image_to_float_array(
self._resize_if_needed(
Image.open(obs[i].wrist_depth),
obs_config.wrist_camera.image_size), DEPTH_SCALE)
obs_config.wrist_camera.image_size),
DEPTH_SCALE)

# Masks are stored as coded RGB images.
# Here we transform them into 1 channel handles.
if obs_config.left_shoulder_camera.mask:
obs[i].left_shoulder_mask = rgb_handles_to_mask(
np.array(self._resize_if_needed(Image.open(
obs[i].left_shoulder_mask),
obs_config.left_shoulder_camera.image_size)))
if obs_config.right_shoulder_camera.mask:
obs[i].right_shoulder_mask = rgb_handles_to_mask(
np.array(self._resize_if_needed(Image.open(
obs[i].right_shoulder_mask),
obs_config.right_shoulder_camera.image_size)))
if obs_config.wrist_camera.mask:
obs[i].wrist_mask = np.array(
obs[i].wrist_mask = rgb_handles_to_mask(np.array(
self._resize_if_needed(Image.open(
obs[i].wrist_mask),
obs_config.wrist_camera.image_size))
obs_config.wrist_camera.image_size)))

demos.append(obs)
return demos

Expand Down
4 changes: 4 additions & 0 deletions tools/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):

obs_config = ObservationConfig()
obs_config.set_all(True)
# We want to save the masks as rgb encodings.
obs_config.left_shoulder_camera.masks_as_one_channel = False
obs_config.right_shoulder_camera.masks_as_one_channel = False
obs_config.wrist_camera.masks_as_one_channel = False

rlbench_env = Environment(
action_mode=ActionMode(),
Expand Down

0 comments on commit 9f3bf88

Please sign in to comment.