diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm.py index 02ec115c7..d06126257 100644 --- a/lerobot/common/envs/simxarm.py +++ b/lerobot/common/envs/simxarm.py @@ -168,23 +168,36 @@ def _step(self, tensordict: TensorDict): def _make_spec(self): obs = {} if self.from_pixels: + image_shape = (3, self.image_size, self.image_size) + if self.num_prev_obs > 0: + image_shape = (self.num_prev_obs + 1, *image_shape) + obs["image"] = BoundedTensorSpec( low=0, high=255, - shape=(3, self.image_size, self.image_size), + shape=image_shape, dtype=torch.uint8, device=self.device, ) if not self.pixels_only: + state_shape = (len(self._env.robot_state),) + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs + 1, *state_shape) + obs["state"] = UnboundedContinuousTensorSpec( - shape=(len(self._env.robot_state),), + shape=state_shape, dtype=torch.float32, device=self.device, ) else: # TODO(rcadene): add observation_space achieved_goal and desired_goal? + state_shape = self._env.observation_space["observation"].shape + if self.num_prev_obs > 0: + state_shape = (self.num_prev_obs + 1, *state_shape) + obs["state"] = UnboundedContinuousTensorSpec( - shape=self._env.observation_space["observation"].shape, + # TODO: + shape=state_shape, dtype=torch.float32, device=self.device, )