Skip to content

MichalBortkiewicz/JaxGCRL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

JaxGCRL

Accelerating Goal-Conditioned RL Algorithms and Research

We provide blazingly fast goal-conditioned environments based on MJX and BRAX for quick experimentation with goal-conditioned self-supervised reinforcement learning.

  • Blazing Fast Training - Train 10 million environment steps in 10 minutes on a single GPU, up to 22$\times$ faster than prior implementations.
  • Comprehensive Benchmarking - Includes 10+ diverse environments and multiple pre-implemented baselines for out-of-the-box evaluation.
  • Modular Implementation - Designed for clarity and scalability, allowing for easy modification of algorithms.

Installation πŸ“‚

The entire process of installing the benchmark is just one step using the conda environment.yml file.

conda env create -f environment.yml

Quick Start πŸš€

To check whether installation worked, run a test experiment using ./scripts/train.sh file:

chmod +x ./scripts/train.sh; ./scripts/train.sh

Note

If you haven't configured yet wandb, you might be prompted to log in.

To run experiments of interest, change scripts/train.sh; descriptions of flags are in utils.py:create_parser(). Common flags you may want to change:

  • env=...: replace "ant" with any environment name. See utils.py:create_env() for names.
  • Removing --log_wandb: omits logging, if you don't want to use a wandb account.
  • --num_timesteps: shorter or longer runs.
  • --num_envs: based on how many environments your GPU memory allows.
  • --contrastive_loss_fn, --energy_fn, --h_dim, --n_hidden, etc.: algorithmic and architectural changes.

Environment Interaction

This section demonstrates how to interact with the environment using the reset and step functions. The environment returns a state object, which is a dataclass containing the following fields:

state.pipeline_state: current, internal state of the environment
state.obs: current observation
state.done: flag indicating if the agent reached the goal
state.metrics: agent performance metrics
state.info: additional info

The following code demonstrates how to interact with the environment:

import jax
from utils import create_env

key = jax.random.PRNGKey(0)

# Initialize the environment
env = create_env('ant')

# Use JIT compilation to make environment's reset and step functions execute faster
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)

NUM_STEPS = 1000

# Reset the environment and obtain the initial state
state = jit_env_reset(key)

# Simulate the environment for a fixed number of steps
for _ in range(NUM_STEPS):
    # Generate a random action
    key, key_act = jax.random.split(key, 2)
    random_action = jax.random.uniform(key_act, shape=(8,), minval=-1, maxval=1)
    
    # Perform an environment step with the generated action
    state = jit_env_step(state, random_action)

Wandb support πŸ“ˆ

We highly recommend using Wandb for tracking and visualizing your results (Wandb support). Enable Wandb logging with the --log_wandb flag. Additionally, you can organize experiments with the following flags:

  • --project_name
  • --group_name
  • --exp_name

Logging to W&B happens when the --log_wandb flag is used when it's not used, metrics are logging to CSV file.

  1. Run exemplary sweep:
wandb sweep --project exemplary_sweep ./scripts/sweep.yml
  1. Then run wandb agent with :
wandb agent <previous_command_output>

Besides logging the metrics, we also render final policy to wandb artifacts.

Environments 🌎

We currently support a number of continuous control environments:

  • Locomotion: Half-Cheetah, Ant, Humanoid
  • Locomotion + task: AntMaze, AntBall (AntSoccer), AntPush, HumanoidMaze
  • Simple arm: Reacher, Pusher, Pusher 2-object
  • Manipulation: Reach, Grasp, Push (easy/hard), Binpick (easy/hard)
Environment Env name Code
Reacher reacher link
Half Cheetah cheetah link
Pusher pusher_easy
pusher_hard
link
Ant ant link
Ant Maze ant_u_maze
ant_big_maze
ant_hardest_maze
link
Ant Soccer ant_ball link
Ant Push ant_push link
Humanoid humanoid link
Humanoid Maze humanoid_u_maze
humanoid_big_maze
humanoid_hardest_maze
link
Arm Reach arm_reach link
Arm Grasp arm_grasp link
Arm Push arm_push_easy
arm_push_hard
link
Arm Binpick arm_binpick_easy
arm_binpick_hard
link

To add new environments: add an XML to envs/assets, add a python environment file in envs, and register the environment name in utils.py.

Baselines πŸ€–

We currently support following algorithms:

Algorithm How to run Code
CRL python training.py ... link
SAC python training_sac.py ... link
SAC + HER python training_sac.py ... --use_her link
TD3 python training_td3.py ... link
TD3 + HER python training_td3.py ... --use_her link
PPO python training_ppo.py ... link

Code Structure πŸ“

We summarize the most important elements of the code structure, for users wanting to understand the implementation specifics or modify the code:


β”œβ”€β”€ src: Algorithm code (training, network, replay buffer, etc.)
β”‚   β”œβ”€β”€ train.py: Main file. Collects trajectories, trains networks, runs evaluations.
β”‚   β”œβ”€β”€ losses.py: Contains energy functions, and actor, critic, and alpha losses.
β”‚   β”œβ”€β”€ networks.py: Contains network definitions for policy, and encoders for the critic.
β”‚   β”œβ”€β”€ replay_buffer.py: Contains replay buffer, including logic for state, action, and goal sampling for training.
β”‚   └── evaluator.py: Runs evaluation and collects metrics.
β”œβ”€β”€ envs: Environments (python files and XMLs)
β”‚   β”œβ”€β”€ ant.py, humanoid.py, ...: Most environments are here
β”‚   β”œβ”€β”€ assets: Contains XMLs for environments
β”‚   └── manipulation: Contains all manipulation environments
β”œβ”€β”€ scripts/train.sh: Modify to choose environment and hyperparameters
β”œβ”€β”€ utils.py: Logic for script argument processing, rendering, environment names, etc.
└── training.py: Interface file that processes script arguments, calls train.py, initializes wandb, etc.

To modify the architecture: modify networks.py.

Contributing πŸ—οΈ

Help us build JaxGCRL into the best possible tool for the GCRL community. Reach out and start contributing or just add an Issue/PR!

  • Add Franka robot arm environments. [Done by SimpleGeometry]
  • Get around 70% success rate on Ant Big Maze task. [Done by RajGhugare19]
  • Add more complex versions of Ant Sokoban.
  • Integrate environments:
    • Overcooked
    • Hanabi
    • Rubik's cube
    • Sokoban

Citing JaxGCRL πŸ“œ

If you use JaxGCRL in your work, please cite us as follows:
@article{bortkiewicz2024accelerating,
  title   = {Accelerating Goal-Conditioned RL Algorithms and Research},
  author  = {MichaΕ‚ Bortkiewicz and WΕ‚adek PaΕ‚ucki and Vivek Myers and Tadeusz Dziarmaga and Tomasz Arczewski and Łukasz KuciΕ„ski and Benjamin Eysenbach},
  year    = {2024},
  journal = {arXiv preprint arXiv: 2408.11052}
}

Questions ❓

If you have any questions, comments, or suggestions, please reach out to MichaΕ‚ Bortkiewicz (michalbortkiewicz8@gmail.com).

See Also πŸ™Œ

There are a number of other libraries which inspired this work, we encourage you to take a look!

JAX-native algorithms:

  • Mava: JAX implementations of IPPO and MAPPO, two popular MARL algorithms.
  • PureJaxRL: JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.
  • Minimax: JAX implementations of autocurricula baselines for RL.
  • JaxIRL: JAX implementation of algorithms for inverse reinforcement learning.

JAX-native environments:

  • Gymnax: Implementations of classic RL tasks including classic control, bsuite and MinAtar.
  • Jumanji: A diverse set of environments ranging from simple games to NP-hard combinatorial problems.
  • Pgx: JAX implementations of classic board games, such as Chess, Go and Shogi.
  • Brax: A fully differentiable physics engine written in JAX, features continuous control tasks.
  • XLand-MiniGrid: Meta-RL gridworld environments inspired by XLand and MiniGrid.
  • Craftax: (Crafter + NetHack) in JAX.
  • JaxMARL: Multi-agent RL in Jax.