Skip to content

Commit

Permalink
Fixes handling of time-out signal in RSL-RL and RL-Games wrapper (#375)
Browse files Browse the repository at this point in the history
# Description

On termination of an episode, three conditions arise:

1. **bad** terminations (terminated dones): the agent gets a termination
penalty
2. **timeout** terminations (truncated dones):
* infinite-horizon: bootstrapping by the agent based on terminal state
    * finite-horizon: no penalty or bootstrapping

Currently, we have not handled the last case, which leads to issues when
training RL tasks with a finite horizon (for instance, Nikita's agile
locomotion work).

This MR adds a flag to the RLTaskEnvCfg called `is_finite_horizon` that
helps deal with this case. The flag is consumed by the env wrappers to
decide how they want to specifically handle the finite horizon problem.

## Type of change

- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)

## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./orbit.sh --test` and they pass
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
  • Loading branch information
Mayankm96 authored Feb 7, 2024
1 parent 93ec2c6 commit af4e801
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@ class RLTaskEnvCfg(BaseEnvCfg):
ui_window_class_type: type | None = RLTaskEnvWindow

# general settings
is_finite_horizon: bool = False
"""Whether the learning task is treated as a finite or infinite horizon problem for the agent.
Defaults to False, which means the task is treated as an infinite horizon problem.
This flag handles the subtleties of finite and infinite horizon tasks:
* **Finite horizon**: no penalty or bootstrapping value is required by the the agent for
running out of time. However, the environment still needs to terminate the episode after the
time limit is reached.
* **Infinite horizon**: the agent needs to bootstrap the value of the state at the end of the episode.
This is done by sending a time-limit (or truncated) done signal to the agent, which triggers this
bootstrapping calculation.
If True, then the environment is treated as a finite horizon problem and no time-out (or truncated) done signal
is sent to the agent. If False, then the environment is treated as an infinite horizon problem and a time-out
(or truncated) done signal is sent to the agent.
Note:
The base :class:`RLTaskEnv` class does not use this flag directly. It is used by the environment
wrappers to determine what type of done signal to send to the corresponding learning agent.
"""

episode_length_s: float = MISSING
"""Duration of an episode (in seconds)."""

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.5.3"
version = "0.5.4"

# Description
title = "ORBIT Environments"
Expand Down
12 changes: 12 additions & 0 deletions source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
Changelog
---------

0.5.4 (2024-02-06)
~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added a check for the flag :attr:`omni.isaac.orbit.envs.RLTaskEnvCfg.is_finite_horizon`
in the RSL-RL and RL-Games wrappers to handle the finite horizon tasks properly. Earlier,
the wrappers were always assuming the tasks to be infinite horizon tasks and returning a
time-out signals when the episode length was reached.


0.5.3 (2023-11-16)
~~~~~~~~~~~~~~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,10 @@ def step(self, actions): # noqa: D102
obs_dict, rew, terminated, truncated, extras = self.env.step(actions)

# move time out information to the extras dict
# this is only needed for infinite horizon tasks
# note: only useful when `value_bootstrap` is True in the agent configuration
extras["time_outs"] = truncated.to(device=self._rl_device)
if not self.unwrapped.cfg.is_finite_horizon:
extras["time_outs"] = truncated.to(device=self._rl_device)
# process observations and states
obs_and_states = self._process_obs(obs_dict)
# move buffers to rl-device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch
obs = obs_dict["policy"]
extras["observations"] = obs_dict
# move time out information to the extras dict
extras["time_outs"] = truncated
# this is only needed for infinite horizon tasks
if not self.unwrapped.cfg.is_finite_horizon:
extras["time_outs"] = truncated

# return the step information
return obs, rew, dones, extras
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,41 @@ def test_random_actions(self):
print(f">>> Closing environment: {task_name}")
env.close()

def test_no_time_outs(self):
"""Check that environments with finite horizon do not send time-out signals."""
for task_name in self.registered_tasks[0:5]:
print(f">>> Running test for environment: {task_name}")
# create a new stage
omni.usd.get_context().new_stage()
# parse configuration
env_cfg: RLTaskEnvCfg = parse_env_cfg(task_name, use_gpu=self.use_gpu, num_envs=self.num_envs)
# change to finite horizon
env_cfg.is_finite_horizon = True

# create environment
env = gym.make(task_name, cfg=env_cfg)
# wrap environment
env = RslRlVecEnvWrapper(env)

# reset environment
_, extras = env.reset()
# check signal
self.assertNotIn("time_outs", extras, msg="Time-out signal found in finite horizon environment.")

# simulate environment for 10 steps
with torch.inference_mode():
for _ in range(10):
# sample actions from -1 to 1
actions = 2 * torch.rand(env.action_space.shape, device=env.unwrapped.device) - 1
# apply actions
extras = env.step(actions)[-1]
# check signals
self.assertNotIn("time_outs", extras, msg="Time-out signal found in finite horizon environment.")

# close the environment
print(f">>> Closing environment: {task_name}")
env.close()

"""
Helper functions.
"""
Expand Down

0 comments on commit af4e801

Please sign in to comment.