diff --git a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/envs/rl_task_env_cfg.py b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/envs/rl_task_env_cfg.py index 6646dda950..eb125f242a 100644 --- a/source/extensions/omni.isaac.orbit/omni/isaac/orbit/envs/rl_task_env_cfg.py +++ b/source/extensions/omni.isaac.orbit/omni/isaac/orbit/envs/rl_task_env_cfg.py @@ -21,6 +21,28 @@ class RLTaskEnvCfg(BaseEnvCfg): ui_window_class_type: type | None = RLTaskEnvWindow # general settings + is_finite_horizon: bool = False + """Whether the learning task is treated as a finite or infinite horizon problem for the agent. + Defaults to False, which means the task is treated as an infinite horizon problem. + + This flag handles the subtleties of finite and infinite horizon tasks: + + * **Finite horizon**: no penalty or bootstrapping value is required by the the agent for + running out of time. However, the environment still needs to terminate the episode after the + time limit is reached. + * **Infinite horizon**: the agent needs to bootstrap the value of the state at the end of the episode. + This is done by sending a time-limit (or truncated) done signal to the agent, which triggers this + bootstrapping calculation. + + If True, then the environment is treated as a finite horizon problem and no time-out (or truncated) done signal + is sent to the agent. If False, then the environment is treated as an infinite horizon problem and a time-out + (or truncated) done signal is sent to the agent. + + Note: + The base :class:`RLTaskEnv` class does not use this flag directly. It is used by the environment + wrappers to determine what type of done signal to send to the corresponding learning agent. + """ + episode_length_s: float = MISSING """Duration of an episode (in seconds).""" diff --git a/source/extensions/omni.isaac.orbit_tasks/config/extension.toml b/source/extensions/omni.isaac.orbit_tasks/config/extension.toml index 0b45bac15f..fc93f61c2e 100644 --- a/source/extensions/omni.isaac.orbit_tasks/config/extension.toml +++ b/source/extensions/omni.isaac.orbit_tasks/config/extension.toml @@ -1,7 +1,7 @@ [package] # Note: Semantic Versioning is used: https://semver.org/ -version = "0.5.3" +version = "0.5.4" # Description title = "ORBIT Environments" diff --git a/source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst b/source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst index fecd7c9de2..1b923ddbf0 100644 --- a/source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst +++ b/source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst @@ -1,6 +1,18 @@ Changelog --------- +0.5.4 (2024-02-06) +~~~~~~~~~~~~~~~~~~ + +Added +^^^^^ + +* Added a check for the flag :attr:`omni.isaac.orbit.envs.RLTaskEnvCfg.is_finite_horizon` + in the RSL-RL and RL-Games wrappers to handle the finite horizon tasks properly. Earlier, + the wrappers were always assuming the tasks to be infinite horizon tasks and returning a + time-out signals when the episode length was reached. + + 0.5.3 (2023-11-16) ~~~~~~~~~~~~~~~~~~ diff --git a/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rl_games.py b/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rl_games.py index 31f1af9934..fe1d4a7a2d 100644 --- a/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rl_games.py +++ b/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rl_games.py @@ -240,8 +240,10 @@ def step(self, actions): # noqa: D102 obs_dict, rew, terminated, truncated, extras = self.env.step(actions) # move time out information to the extras dict + # this is only needed for infinite horizon tasks # note: only useful when `value_bootstrap` is True in the agent configuration - extras["time_outs"] = truncated.to(device=self._rl_device) + if not self.unwrapped.cfg.is_finite_horizon: + extras["time_outs"] = truncated.to(device=self._rl_device) # process observations and states obs_and_states = self._process_obs(obs_dict) # move buffers to rl-device diff --git a/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rsl_rl/vecenv_wrapper.py b/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rsl_rl/vecenv_wrapper.py index 7c963e28d4..a9955bb999 100644 --- a/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rsl_rl/vecenv_wrapper.py +++ b/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rsl_rl/vecenv_wrapper.py @@ -165,7 +165,9 @@ def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch obs = obs_dict["policy"] extras["observations"] = obs_dict # move time out information to the extras dict - extras["time_outs"] = truncated + # this is only needed for infinite horizon tasks + if not self.unwrapped.cfg.is_finite_horizon: + extras["time_outs"] = truncated # return the step information return obs, rew, dones, extras diff --git a/source/extensions/omni.isaac.orbit_tasks/test/wrappers/test_rsl_rl_wrapper.py b/source/extensions/omni.isaac.orbit_tasks/test/wrappers/test_rsl_rl_wrapper.py index 2957ab919f..addf40b62a 100644 --- a/source/extensions/omni.isaac.orbit_tasks/test/wrappers/test_rsl_rl_wrapper.py +++ b/source/extensions/omni.isaac.orbit_tasks/test/wrappers/test_rsl_rl_wrapper.py @@ -92,6 +92,41 @@ def test_random_actions(self): print(f">>> Closing environment: {task_name}") env.close() + def test_no_time_outs(self): + """Check that environments with finite horizon do not send time-out signals.""" + for task_name in self.registered_tasks[0:5]: + print(f">>> Running test for environment: {task_name}") + # create a new stage + omni.usd.get_context().new_stage() + # parse configuration + env_cfg: RLTaskEnvCfg = parse_env_cfg(task_name, use_gpu=self.use_gpu, num_envs=self.num_envs) + # change to finite horizon + env_cfg.is_finite_horizon = True + + # create environment + env = gym.make(task_name, cfg=env_cfg) + # wrap environment + env = RslRlVecEnvWrapper(env) + + # reset environment + _, extras = env.reset() + # check signal + self.assertNotIn("time_outs", extras, msg="Time-out signal found in finite horizon environment.") + + # simulate environment for 10 steps + with torch.inference_mode(): + for _ in range(10): + # sample actions from -1 to 1 + actions = 2 * torch.rand(env.action_space.shape, device=env.unwrapped.device) - 1 + # apply actions + extras = env.step(actions)[-1] + # check signals + self.assertNotIn("time_outs", extras, msg="Time-out signal found in finite horizon environment.") + + # close the environment + print(f">>> Closing environment: {task_name}") + env.close() + """ Helper functions. """