This repository contains the official code implementation of the paper: Expert Proximity as Surrogate Rewards for Single Demonstration Imitation Learning
The code is based on @stanl1y's reinforcement learning framework, which is available at stanl1y/RL_framework.
Note: The neighbor model
in the codebase refers to the transition discriminator in our paper.
TDIL enables the agent to learn from a single demonstration and achieve expert-level performance. The following video shows the HalfCheetah-v3 environment. The left side is the expert demonstration, and the right side is the learned policy.
hc.mov
Clone this repo with:
git clone https://github.com/stanl1y/tdil.git
cd tdil
Install docker and nvidia-docker, and then run:
# assume the current directory is the root of this repository
docker build -t j3soon/tdil .
docker run --rm -it --gpus all --ipc=host -v $(pwd):/workspace j3soon/tdil
pip install -r requirements.txt
The expert data for performing Imitation Learning (IL) is provided for reproducibility.
The expert data is stored in the folder saved_expert_transition/
. The expert data is generated by a pretrained SAC agent and is stored in dictionary format with the following keys: "states"
, "actions"
, "rewards"
, "next_states"
, "dones"
. The value of each keys contains a numpy array.
- Hopper-v3: 4114
- Walker-v3: 6123
- Ant-v3: 6561
- HalfCheetah-v3: 15251
- Humanoid-v3: 5855
Please note that TDIL (Transition Discriminator-based Imitation Learning) is our proposed method.
-
Hopper-v3:
python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Hopper-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_4114
- For fixing alpha, add:
--no_update_alpha --log_alpha_init -4.6
- For fixing alpha, add:
-
Walker-v3:
python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Walker2d-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_6123
- For fixing alpha, add:
--no_update_alpha --log_alpha_init -1.2
- For fixing alpha, add:
-
Ant-v3:
python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Ant-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_6561 --terminate_when_unhealthy
- For fixing alpha, add:
--no_update_alpha --log_alpha_init -1.9
- For fixing alpha, add:
-
HalfCheetah-v3:
python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env HalfCheetah-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_15251
- For fixing alpha, add:
--no_update_alpha --log_alpha_init 0.4
- For fixing alpha, add:
-
Humanoid-v3:
python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Humanoid-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_5855 --terminate_when_unhealthy
- For fixing alpha, add:
--no_update_alpha --log_alpha_init -0.6
- For fixing alpha, add:
- Hopper-v3:
python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Hopper-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_4114 --no_bc --beta 0.9 --use_discriminator
- Walker-v3:
python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Walker2d-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_6123 --no_bc --beta 0.9 --use_discriminator
- Ant-v3:
python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Ant-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_6561 --terminate_when_unhealthy --no_bc --beta 0.9 --use_discriminator
- HalfCheetah-v3:
python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env HalfCheetah-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_15251 --no_bc --beta 0.9 --use_discriminator
- Humanoid-v3:
python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env Humanoid-v3 --wrapper basic --total_timesteps 3000000 --data_name sac/episode_num1_5855 --terminate_when_unhealthy --no_bc --beta 0.9 --use_discriminator
Add the following flag:
--log_name <custom_name>
Add the following flag:
--no_bc
Add the following flag:
--no_hard_negative_sampling
python main.py --main_stage neighborhood_il --main_task neighborhood_dsac --env Maze-v6 --episodes 300 --policy_threshold_ratio 0.5 --neighbor_model_alpha 0.1 --gamma 0.8
The policy_threshold_ratio
hyperparameter aims to filter out the state-action pairs that are too close to the expert proximity when training the policy. Because the toy maze environment is a commutative type of environment, which means the agent can go back to
python main.py --main_stage neighborhood_il --main_task neighborhood_sac --env AdroitHandDoor-v1 --wrapper gymnasium --total_timesteps 1000000 --data_name dapg/episode_num1_3019 --max_episode_steps 200 --no_hard_negative_sampling --policy_threshold_ratio 0.005 --ood
The ood
argument makes the agent to test on the out-of-distribution (OOD) states. More specifically, in the beginning of testing, the agent will first take few timesteps of random actions. Then, the agent will start to take actions based on the learned policy.