Skip to content

[arXiv'23] Official implementation of the paper "Diffusion Reward: Learning Rewards via Conditional Video Diffusion"

License

Notifications You must be signed in to change notification settings

TaoHuang13/diffusion_reward

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Diffusion Reward

[Project Website] [Paper] [Data] [Models]

This is the official PyTorch implementation of the paper "Diffusion Reward: Learning Rewards via Conditional Video Diffusion" by

Tao Huang*, Guangqi Jiang*, Yanjie Ze, Huazhe Xu.

🛠️ Installation Instructions

Clone this repository.

git clone https://github.com/TaoHuang13/diffusion_reward.git
cd diffusion_reward

Create a virtual environment.

conda env create -f conda_env.yml 
conda activate diffusion_reward
pip install -e .

Install extra dependencies.

  • Install PyTorch.
pip3 install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
  • Install mujoco210 and mujoco-py following instructions here.

  • Install Adroit dependencies.

cd env_dependencies
pip install -e mj_envs/.
pip install -e mjrl/.
cd ..
  • Install MetaWorld following instructions here.

💻 Reproducing Experimental Results

Download Video Demonstrations

Domain Tasks Episodes Size Collection Link
Adroit 3 150 23.8M VRL3 Download
MetaWorld 7 140 38.8M Scripts Download

You can download the datasets and place them to /video_dataset to reproduce the results in this paper.

Pretrain Reward Models

Train VQGAN encoder.

bash scripts/run/codec_model/vqgan_${domain}.sh    # [adroit, metaworld]

Train video models.

bash scripts/run/video_model/${video_model}_${domain}.sh    # [vqdiffusion, videogpt]_[adroit, metaworld]

(Optinal) Download Pre-trained Models

We also provide the pre-trained reward models (including Diffusion Reward and VIPER) used in this paper for result reproduction. You may download the models with configuration files here, and place the folders in /exp_local.

Train RL with Pre-trained Rewards

Train DrQv2 with different rewards.

bash scripts/run/rl/drqv2_${domain}_${reward}.sh ${task}    # [adroit, metaworld]_[diffusion_reward, viper, viper_std, amp, rnd, raw_sparse_reward]

Notice that you should login wandb for logging experiments online. Turn it off, if you aim to log locally, in configuration file here.

🧭 Code Navigation

diffusion_reward
  |- configs               # experiment configs 
  |    |- models           # configs of codec models and video models
  |    |- rl               # configs of rl 
  |
  |- envs                  # envrionments, wrappers, env maker
  |    |- adroit.py        # Adroit env
  |    |- metaworld.py     # MetaWorld env
  |    |- wrapper.py       # env wrapper and utils
  |
  |- models                # implements core codec models and video models
  |    |- codec_models     # image encoder, e.g., VQGAN
  |    |- video_models     # video prediction models, e.g., VQDiffusion and VideoGPT
  |    |- reward_models    # reward models, e.g., Diffusion Reward and VIPER
  |
  |- rl                    # implements core rl algorithms

✉️ Contact

For any questions, please feel free to email taou.cs13@gmail.com or luccachiang@gmail.com.

🙏 Acknowledgement

Our code is built upon VQGAN, VQ-Diffusion, VIPER, AMP, RND, and DrQv2. We thank all these authors for their nicely open sourced code and their great contributions to the community.

🏷️ License

This repository is released under the MIT license. See LICENSE for additional details.

📝 Citation

If you find our work useful, please consider citing:

@article{Huang2023DiffusionReward,
  title={Diffusion Reward: Learning Rewards via Conditional Video Diffusion},
  author={Tao Huang and Guangqi Jiang and Yanjie Ze and Huazhe Xu},
  journal={arxiv}, 
  year={2023},
}

About

[arXiv'23] Official implementation of the paper "Diffusion Reward: Learning Rewards via Conditional Video Diffusion"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.8%
  • Shell 0.2%