Skip to content

Commit 9f3bf88

Browse files
stepjamStephen James
authored andcommitted
Segmentation masks returned as 1 channel. Resolves #43.
1 parent fb9118d commit 9f3bf88

File tree

5 files changed

+61
-28
lines changed

5 files changed

+61
-28
lines changed

rlbench/backend/scene.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from rlbench.backend.observation import Observation
88
from rlbench.backend.exceptions import (
99
WaypointError, BoundaryError, NoWaypointsError, DemoError)
10+
from rlbench.backend.utils import rgb_handles_to_mask
1011
from rlbench.demo import Demo
1112
from rlbench.observation_config import ObservationConfig, CameraConfig
1213
from rlbench.backend.task import Task
@@ -156,6 +157,13 @@ def get_observation(self) -> Observation:
156157
rsc_ob = self._obs_config.right_shoulder_camera
157158
wc_ob = self._obs_config.wrist_camera
158159

160+
lsc_mask_fn = (
161+
rgb_handles_to_mask if lsc_ob.masks_as_one_channel else lambda x: x)
162+
rsc_mask_fn = (
163+
rgb_handles_to_mask if rsc_ob.masks_as_one_channel else lambda x: x)
164+
wc_mask_fn = (
165+
rgb_handles_to_mask if wc_ob.masks_as_one_channel else lambda x: x)
166+
159167
obs = Observation(
160168
left_shoulder_rgb=(
161169
lsc_ob.rgb_noise.apply(
@@ -181,13 +189,16 @@ def get_observation(self) -> Observation:
181189
if wc_ob.depth else None),
182190

183191
left_shoulder_mask=(
184-
self._cam_over_shoulder_left_mask.capture_rgb()
192+
lsc_mask_fn(
193+
self._cam_over_shoulder_left_mask.capture_rgb())
185194
if lsc_ob.mask else None),
186195
right_shoulder_mask=(
187-
self._cam_over_shoulder_right_mask.capture_rgb()
196+
rsc_mask_fn(
197+
self._cam_over_shoulder_right_mask.capture_rgb())
188198
if rsc_ob.mask else None),
189199
wrist_mask=(
190-
self._cam_wrist_mask.capture_rgb()
200+
wc_mask_fn(
201+
self._cam_wrist_mask.capture_rgb())
191202
if wc_ob.mask else None),
192203

193204
joint_velocities=(

rlbench/backend/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,13 @@ def task_file_to_task_class(task_file):
216216
mod = importlib.reload(mod)
217217
task_class = getattr(mod, class_name)
218218
return task_class
219+
220+
221+
def rgb_handles_to_mask(rgb_coded_handles):
222+
# rgb_coded_handles should be (w, h, c)
223+
# Handle encoded as : handle = R + G * 256 + B * 256 * 256
224+
rgb_coded_handles *= 255 # takes rgb range to 0 -> 255
225+
rgb_coded_handles.astype(int)
226+
return (rgb_coded_handles[:, :, 0] +
227+
rgb_coded_handles[:, :, 1] * 256 +
228+
rgb_coded_handles[:, :, 2] * 256 * 256)

rlbench/observation_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@ def __init__(self,
1010
depth_noise: NoiseModel=Identity(),
1111
mask=True,
1212
image_size=(128, 128),
13-
render_mode=RenderMode.OPENGL3):
13+
render_mode=RenderMode.OPENGL3,
14+
masks_as_one_channel=True):
1415
self.rgb = rgb
1516
self.rgb_noise = rgb_noise
1617
self.depth = depth
1718
self.depth_noise = depth_noise
1819
self.mask = mask
1920
self.image_size = image_size
2021
self.render_mode = render_mode
22+
self.masks_as_one_channel = masks_as_one_channel
2123

2224
def set_all(self, value: bool):
2325
self.rgb = value

rlbench/task_environment.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from rlbench.backend.scene import Scene
1010
from rlbench.backend.task import Task
1111
from rlbench.backend.const import *
12-
from rlbench.backend.utils import image_to_float_array
12+
from rlbench.backend.utils import image_to_float_array, rgb_handles_to_mask
1313
from rlbench.backend.robot import Robot
1414
import logging
1515
from typing import List
@@ -374,48 +374,54 @@ def _get_stored_demos(self, amount: int, image_paths: bool) -> List[Demo]:
374374
self._resize_if_needed(
375375
Image.open(obs[i].left_shoulder_rgb),
376376
obs_config.left_shoulder_camera.image_size))
377+
if obs_config.right_shoulder_camera.rgb:
378+
obs[i].right_shoulder_rgb = np.array(
379+
self._resize_if_needed(Image.open(
380+
obs[i].right_shoulder_rgb),
381+
obs_config.right_shoulder_camera.image_size))
382+
if obs_config.wrist_camera.rgb:
383+
obs[i].wrist_rgb = np.array(
384+
self._resize_if_needed(
385+
Image.open(obs[i].wrist_rgb),
386+
obs_config.wrist_camera.image_size))
387+
377388
if obs_config.left_shoulder_camera.depth:
378389
obs[i].left_shoulder_depth = image_to_float_array(
379390
self._resize_if_needed(
380391
Image.open(obs[i].left_shoulder_depth),
381392
obs_config.left_shoulder_camera.image_size),
382393
DEPTH_SCALE)
383-
if obs_config.left_shoulder_camera.mask:
384-
obs[i].left_shoulder_mask = np.array(
385-
self._resize_if_needed(Image.open(
386-
obs[i].left_shoulder_mask),
387-
obs_config.left_shoulder_camera.image_size))
388-
if obs_config.right_shoulder_camera.rgb:
389-
obs[i].right_shoulder_rgb = np.array(
390-
self._resize_if_needed(Image.open(
391-
obs[i].right_shoulder_rgb),
392-
obs_config.right_shoulder_camera.image_size))
393394
if obs_config.right_shoulder_camera.depth:
394395
obs[i].right_shoulder_depth = image_to_float_array(
395396
self._resize_if_needed(
396397
Image.open(obs[i].right_shoulder_depth),
397398
obs_config.right_shoulder_camera.image_size),
398399
DEPTH_SCALE)
399-
if obs_config.right_shoulder_camera.mask:
400-
obs[i].right_shoulder_mask = np.array(
401-
self._resize_if_needed(Image.open(
402-
obs[i].right_shoulder_mask),
403-
obs_config.right_shoulder_camera.image_size))
404-
if obs_config.wrist_camera.rgb:
405-
obs[i].wrist_rgb = np.array(
406-
self._resize_if_needed(
407-
Image.open(obs[i].wrist_rgb),
408-
obs_config.wrist_camera.image_size))
409400
if obs_config.wrist_camera.depth:
410401
obs[i].wrist_depth = image_to_float_array(
411402
self._resize_if_needed(
412403
Image.open(obs[i].wrist_depth),
413-
obs_config.wrist_camera.image_size), DEPTH_SCALE)
404+
obs_config.wrist_camera.image_size),
405+
DEPTH_SCALE)
406+
407+
# Masks are stored as coded RGB images.
408+
# Here we transform them into 1 channel handles.
409+
if obs_config.left_shoulder_camera.mask:
410+
obs[i].left_shoulder_mask = rgb_handles_to_mask(
411+
np.array(self._resize_if_needed(Image.open(
412+
obs[i].left_shoulder_mask),
413+
obs_config.left_shoulder_camera.image_size)))
414+
if obs_config.right_shoulder_camera.mask:
415+
obs[i].right_shoulder_mask = rgb_handles_to_mask(
416+
np.array(self._resize_if_needed(Image.open(
417+
obs[i].right_shoulder_mask),
418+
obs_config.right_shoulder_camera.image_size)))
414419
if obs_config.wrist_camera.mask:
415-
obs[i].wrist_mask = np.array(
420+
obs[i].wrist_mask = rgb_handles_to_mask(np.array(
416421
self._resize_if_needed(Image.open(
417422
obs[i].wrist_mask),
418-
obs_config.wrist_camera.image_size))
423+
obs_config.wrist_camera.image_size)))
424+
419425
demos.append(obs)
420426
return demos
421427

tools/dataset_generator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
123123

124124
obs_config = ObservationConfig()
125125
obs_config.set_all(True)
126+
# We want to save the masks as rgb encodings.
127+
obs_config.left_shoulder_camera.masks_as_one_channel = False
128+
obs_config.right_shoulder_camera.masks_as_one_channel = False
129+
obs_config.wrist_camera.masks_as_one_channel = False
126130

127131
rlbench_env = Environment(
128132
action_mode=ActionMode(),

0 commit comments

Comments
 (0)