Skip to content

Commit

Permalink
Small fix on simxarm
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadene committed Mar 11, 2024
1 parent c94bef6 commit d6279da
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions lerobot/common/envs/simxarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,23 +168,36 @@ def _step(self, tensordict: TensorDict):
def _make_spec(self):
obs = {}
if self.from_pixels:
image_shape = (3, self.image_size, self.image_size)
if self.num_prev_obs > 0:
image_shape = (self.num_prev_obs + 1, *image_shape)

obs["image"] = BoundedTensorSpec(
low=0,
high=255,
shape=(3, self.image_size, self.image_size),
shape=image_shape,
dtype=torch.uint8,
device=self.device,
)
if not self.pixels_only:
state_shape = (len(self._env.robot_state),)
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)

obs["state"] = UnboundedContinuousTensorSpec(
shape=(len(self._env.robot_state),),
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
else:
# TODO(rcadene): add observation_space achieved_goal and desired_goal?
state_shape = self._env.observation_space["observation"].shape
if self.num_prev_obs > 0:
state_shape = (self.num_prev_obs + 1, *state_shape)

obs["state"] = UnboundedContinuousTensorSpec(
shape=self._env.observation_space["observation"].shape,
# TODO:
shape=state_shape,
dtype=torch.float32,
device=self.device,
)
Expand Down

0 comments on commit d6279da

Please sign in to comment.