We provide blazingly fast goal-conditioned environments based on MJX and BRAX for quick experimentation with goal-conditioned self-supervised reinforcement learning.
- Blazing Fast Training - Train 10 million environment steps in 10 minutes on a single GPU, up to 22$\times$ faster than prior implementations.
- Comprehensive Benchmarking - Includes 10+ diverse environments and multiple pre-implemented baselines for out-of-the-box evaluation.
- Modular Implementation - Designed for clarity and scalability, allowing for easy modification of algorithms.
The entire process of installing the benchmark is just one step using the conda environment.yml
file.
conda env create -f environment.yml
To check whether installation worked, run a test experiment using ./scripts/train.sh
file:
chmod +x ./scripts/train.sh; ./scripts/train.sh
Note
If you haven't configured yet wandb
, you might be prompted to log in.
To run experiments of interest, change scripts/train.sh
; descriptions of flags are in utils.py:create_parser()
. Common flags you may want to change:
- env=...: replace "ant" with any environment name. See
utils.py:create_env()
for names. - Removing --log_wandb: omits logging, if you don't want to use a wandb account.
- --num_timesteps: shorter or longer runs.
- --num_envs: based on how many environments your GPU memory allows.
- --contrastive_loss_fn, --energy_fn, --h_dim, --n_hidden, etc.: algorithmic and architectural changes.
This section demonstrates how to interact with the environment using the reset
and step
functions. The environment returns a state object, which is a dataclass containing the following fields:
state.pipeline_state
: current, internal state of the environment
state.obs
: current observation
state.done
: flag indicating if the agent reached the goal
state.metrics
: agent performance metrics
state.info
: additional info
The following code demonstrates how to interact with the environment:
import jax
from utils import create_env
key = jax.random.PRNGKey(0)
# Initialize the environment
env = create_env('ant')
# Use JIT compilation to make environment's reset and step functions execute faster
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
NUM_STEPS = 1000
# Reset the environment and obtain the initial state
state = jit_env_reset(key)
# Simulate the environment for a fixed number of steps
for _ in range(NUM_STEPS):
# Generate a random action
key, key_act = jax.random.split(key, 2)
random_action = jax.random.uniform(key_act, shape=(8,), minval=-1, maxval=1)
# Perform an environment step with the generated action
state = jit_env_step(state, random_action)
We highly recommend using Wandb for tracking and visualizing your results (Wandb support). Enable Wandb logging with the --log_wandb
flag. Additionally, you can organize experiments with the following flags:
--project_name
--group_name
--exp_name
Logging to W&B happens when the --log_wandb
flag is used when it's not used, metrics are logging to CSV file.
- Run exemplary
sweep
:
wandb sweep --project exemplary_sweep ./scripts/sweep.yml
- Then run wandb agent with :
wandb agent <previous_command_output>
Besides logging the metrics, we also render final policy to wandb
artifacts.
We currently support a number of continuous control environments:
- Locomotion: Half-Cheetah, Ant, Humanoid
- Locomotion + task: AntMaze, AntBall (AntSoccer), AntPush, HumanoidMaze
- Simple arm: Reacher, Pusher, Pusher 2-object
- Manipulation: Reach, Grasp, Push (easy/hard), Binpick (easy/hard)
Environment | Env name | Code |
---|---|---|
Reacher | reacher |
link |
Half Cheetah | cheetah |
link |
Pusher | pusher_easy pusher_hard |
link |
Ant | ant |
link |
Ant Maze | ant_u_maze ant_big_maze ant_hardest_maze |
link |
Ant Soccer | ant_ball |
link |
Ant Push | ant_push |
link |
Humanoid | humanoid |
link |
Humanoid Maze | humanoid_u_maze humanoid_big_maze humanoid_hardest_maze |
link |
Arm Reach | arm_reach |
link |
Arm Grasp | arm_grasp |
link |
Arm Push | arm_push_easy arm_push_hard |
link |
Arm Binpick | arm_binpick_easy arm_binpick_hard |
link |
To add new environments: add an XML to envs/assets
, add a python environment file in envs
, and register the environment name in utils.py
.
We currently support following algorithms:
Algorithm | How to run | Code |
---|---|---|
CRL | python training.py ... |
link |
SAC | python training_sac.py ... |
link |
SAC + HER | python training_sac.py ... --use_her |
link |
TD3 | python training_td3.py ... |
link |
TD3 + HER | python training_td3.py ... --use_her |
link |
PPO | python training_ppo.py ... |
link |
We summarize the most important elements of the code structure, for users wanting to understand the implementation specifics or modify the code:
βββ src: Algorithm code (training, network, replay buffer, etc.)
β βββ train.py: Main file. Collects trajectories, trains networks, runs evaluations.
β βββ losses.py: Contains energy functions, and actor, critic, and alpha losses.
β βββ networks.py: Contains network definitions for policy, and encoders for the critic.
β βββ replay_buffer.py: Contains replay buffer, including logic for state, action, and goal sampling for training.
β βββ evaluator.py: Runs evaluation and collects metrics.
βββ envs: Environments (python files and XMLs)
β βββ ant.py, humanoid.py, ...: Most environments are here
β βββ assets: Contains XMLs for environments
β βββ manipulation: Contains all manipulation environments
βββ scripts/train.sh: Modify to choose environment and hyperparameters
βββ utils.py: Logic for script argument processing, rendering, environment names, etc.
βββ training.py: Interface file that processes script arguments, calls train.py, initializes wandb, etc.
To modify the architecture: modify networks.py
.
Help us build JaxGCRL into the best possible tool for the GCRL community. Reach out and start contributing or just add an Issue/PR!
- Add Franka robot arm environments. [Done by SimpleGeometry]
- Get around 70% success rate on Ant Big Maze task. [Done by RajGhugare19]
- Add more complex versions of Ant Sokoban.
- Integrate environments:
- Overcooked
- Hanabi
- Rubik's cube
- Sokoban
@article{bortkiewicz2024accelerating,
title = {Accelerating Goal-Conditioned RL Algorithms and Research},
author = {MichaΕ Bortkiewicz and WΕadek PaΕucki and Vivek Myers and Tadeusz Dziarmaga and Tomasz Arczewski and Εukasz KuciΕski and Benjamin Eysenbach},
year = {2024},
journal = {arXiv preprint arXiv: 2408.11052}
}
If you have any questions, comments, or suggestions, please reach out to MichaΕ Bortkiewicz (michalbortkiewicz8@gmail.com).
There are a number of other libraries which inspired this work, we encourage you to take a look!
JAX-native algorithms:
- Mava: JAX implementations of IPPO and MAPPO, two popular MARL algorithms.
- PureJaxRL: JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.
- Minimax: JAX implementations of autocurricula baselines for RL.
- JaxIRL: JAX implementation of algorithms for inverse reinforcement learning.
JAX-native environments:
- Gymnax: Implementations of classic RL tasks including classic control, bsuite and MinAtar.
- Jumanji: A diverse set of environments ranging from simple games to NP-hard combinatorial problems.
- Pgx: JAX implementations of classic board games, such as Chess, Go and Shogi.
- Brax: A fully differentiable physics engine written in JAX, features continuous control tasks.
- XLand-MiniGrid: Meta-RL gridworld environments inspired by XLand and MiniGrid.
- Craftax: (Crafter + NetHack) in JAX.
- JaxMARL: Multi-agent RL in Jax.